From 9b3d7e1110752d03ffe385a29743cb677c3f2251 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 29 Dec 2021 23:45:58 +0800 Subject: [PATCH] add dns handler --- cmd/gost/register.go | 1 + go.mod | 2 +- go.sum | 2 + gost.yml | 13 ++ pkg/chain/router.go | 27 ++++ pkg/handler/dns/handler.go | 199 ++++++++++++++++++++++++++ pkg/handler/dns/metadata.go | 43 ++++++ pkg/handler/tap/conn.go | 17 --- pkg/handler/tap/handler.go | 8 +- pkg/handler/tun/conn.go | 17 --- pkg/handler/tun/handler.go | 8 +- pkg/listener/dns/server.go | 6 +- pkg/resolver/exchanger/exchanger.go | 213 ++++++++++++++++++++++++++++ pkg/resolver/ns.go | 13 ++ pkg/resolver/resolver.go | 11 ++ 15 files changed, 541 insertions(+), 39 deletions(-) create mode 100644 pkg/handler/dns/handler.go create mode 100644 pkg/handler/dns/metadata.go delete mode 100644 pkg/handler/tap/conn.go delete mode 100644 pkg/handler/tun/conn.go create mode 100644 pkg/resolver/exchanger/exchanger.go create mode 100644 pkg/resolver/ns.go create mode 100644 pkg/resolver/resolver.go diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 764e55b..c752e21 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -32,6 +32,7 @@ import ( // Register handlers _ "github.com/go-gost/gost/pkg/handler/auto" + _ "github.com/go-gost/gost/pkg/handler/dns" _ "github.com/go-gost/gost/pkg/handler/forward/local" _ "github.com/go-gost/gost/pkg/handler/forward/remote" _ "github.com/go-gost/gost/pkg/handler/forward/ssh" diff --git a/go.mod b/go.mod index 634b54e..d16cafc 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/magiconair/properties v1.8.5 // indirect github.com/marten-seemann/qtls-go1-16 v0.1.4 // indirect github.com/marten-seemann/qtls-go1-17 v0.1.0 // indirect - github.com/miekg/dns v1.1.44 + github.com/miekg/dns v1.1.45 github.com/milosgajdos/tenus v0.0.3 github.com/mitchellh/mapstructure v1.4.2 // indirect github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 // indirect diff --git a/go.sum b/go.sum index e09a15e..4852ccc 100644 --- a/go.sum +++ b/go.sum @@ -285,6 +285,8 @@ github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.44 h1:4rpqcegYPVkvIeVhITrKP1sRR3KjfRc1nrOPMUZmLyc= github.com/miekg/dns v1.1.44/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= +github.com/miekg/dns v1.1.45 h1:g5fRIhm9nx7g8osrAvgb16QJfmyMsyOCb+J7LSv+Qzk= +github.com/miekg/dns v1.1.45/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/milosgajdos/tenus v0.0.3 h1:jmaJzwaY1DUyYVD0lM4U+uvP2kkEg1VahDqRFxIkVBE= github.com/milosgajdos/tenus v0.0.3/go.mod h1:eIjx29vNeDOYWJuCnaHY2r4fq5egetV26ry3on7p8qY= github.com/mitchellh/cli v1.1.0/go.mod h1:xcISNoH86gajksDmfB23e/pu+B+GeFRMYmoHXxx3xhI= diff --git a/gost.yml b/gost.yml index bf725ba..87cc83c 100644 --- a/gost.yml +++ b/gost.yml @@ -12,6 +12,19 @@ profiling: # key: "key.pem" # ca: "root.ca" +resolvers: +- name: resolver-0 + ttl: 60s + prefer: ipv4 + clientIP: 1.2.3.4 + nameServers: + - addr: udp://8.8.8.8:53 + timeout: 5s + - addr: tcp://1.1.1.1:53 + - addr: tls://1.1.1.1:853 + - addr: https://1.0.0.1/dns-query + domain: cloudflare-dns.com + services: - name: http+tcp url: "http://gost:gost@:8000" diff --git a/pkg/chain/router.go b/pkg/chain/router.go index f6158d3..af0a4f8 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -32,6 +32,19 @@ func (r *Router) WithLogger(logger logger.Logger) *Router { } func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + conn, err = r.dial(ctx, network, address) + if err != nil { + return + } + if network == "udp" || network == "udp4" || network == "udp6" { + if _, ok := conn.(net.PacketConn); !ok { + return &packetConn{conn}, nil + } + } + return +} + +func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { count := r.retries + 1 if count <= 0 { count = 1 @@ -88,3 +101,17 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn return } + +type packetConn struct { + net.Conn +} + +func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.Conn.RemoteAddr() + return +} + +func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.Write(b) +} diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go new file mode 100644 index 0000000..1ac65b1 --- /dev/null +++ b/pkg/handler/dns/handler.go @@ -0,0 +1,199 @@ +package dns + +import ( + "bytes" + "context" + "errors" + "net" + "strconv" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/go-gost/gost/pkg/resolver/exchanger" + "github.com/miekg/dns" +) + +func init() { + registry.RegisterHandler("dns", NewHandler) +} + +type dnsHandler struct { + chain *chain.Chain + bypass bypass.Bypass + exchangers []exchanger.Exchanger + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &dnsHandler{ + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *dnsHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + for _, server := range h.md.servers { + ex, err := exchanger.NewExchanger( + server, + exchanger.ChainOption(h.chain), + exchanger.LoggerOption(h.logger), + ) + if err != nil { + h.logger.Warnf("parse %s: %v", server, err) + continue + } + h.exchangers = append(h.exchangers, ex) + } + if len(h.exchangers) == 0 { + ex, _ := exchanger.NewExchanger( + "udp://127.0.0.53:53", + exchanger.ChainOption(h.chain), + exchanger.LoggerOption(h.logger), + ) + if ex != nil { + h.exchangers = append(h.exchangers, ex) + } + } + return +} + +// implements chain.Chainable interface +func (h *dnsHandler) WithChain(chain *chain.Chain) { + h.chain = chain +} + +func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + b := bufpool.Get(4096) + defer bufpool.Put(b) + + n, err := conn.Read(b) + if err != nil { + h.logger.Error(err) + return + } + h.logger.Info("read data: ", n) + + reply, err := h.exchange(ctx, b[:n]) + if err != nil { + h.logger.Error(err) + return + } + + if _, err = conn.Write(reply); err != nil { + h.logger.Error(err) + } +} + +func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { + mq := dns.Msg{} + if err := mq.Unpack(msg); err != nil { + return nil, err + } + + if len(mq.Question) == 0 { + return nil, errors.New("msg: empty question") + } + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(mq.String()) + } else { + h.logger.Info(h.dumpMsgHeader(&mq)) + } + + var mr *dns.Msg + // Only cache for single question. + /* + if len(mq.Question) == 1 { + key := newResolverCacheKey(&mq.Question[0]) + mr = r.cache.loadCache(key) + if mr != nil { + log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) + mr.Id = mq.Id + return mr.Pack() + } + + defer func() { + if mr != nil { + r.cache.storeCache(key, mr, r.TTL()) + } + }() + } + */ + + // r.addSubnetOpt(mq) + + query, err := mq.Pack() + if err != nil { + h.logger.Error(err) + return nil, err + } + + var reply []byte + for _, ex := range h.exchangers { + h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String()) + reply, err = ex.Exchange(ctx, query) + if err == nil { + break + } + h.logger.Error(err) + } + if err != nil { + h.logger.Error(err) + return nil, err + } + + mr = &dns.Msg{} + if err = mr.Unpack(reply); err != nil { + h.logger.Error(err) + return nil, err + } + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(mr.String()) + } else { + h.logger.Info(h.dumpMsgHeader(mr)) + } + + return reply, nil +} + +func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { + buf := new(bytes.Buffer) + buf.WriteString(m.MsgHdr.String() + " ") + buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ") + buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ") + buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ") + buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra))) + return buf.String() +} diff --git a/pkg/handler/dns/metadata.go b/pkg/handler/dns/metadata.go new file mode 100644 index 0000000..c077d7d --- /dev/null +++ b/pkg/handler/dns/metadata.go @@ -0,0 +1,43 @@ +package dns + +import ( + "time" + + mdata "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + readTimeout time.Duration + retryCount int + ttl time.Duration + timeout time.Duration + prefer string + clientIP string + // nameservers + servers []string +} + +func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + retryCount = "retry" + ttl = "ttl" + timeout = "timeout" + prefer = "prefer" + clientIP = "clientIP" + servers = "servers" + ) + + h.md.readTimeout = mdata.GetDuration(md, readTimeout) + h.md.retryCount = mdata.GetInt(md, retryCount) + h.md.ttl = mdata.GetDuration(md, ttl) + h.md.timeout = mdata.GetDuration(md, timeout) + if h.md.timeout <= 0 { + h.md.timeout = 5 * time.Second + } + h.md.prefer = mdata.GetString(md, prefer) + h.md.clientIP = mdata.GetString(md, clientIP) + h.md.servers = mdata.GetStrings(md, servers) + + return +} diff --git a/pkg/handler/tap/conn.go b/pkg/handler/tap/conn.go deleted file mode 100644 index 19ace7d..0000000 --- a/pkg/handler/tap/conn.go +++ /dev/null @@ -1,17 +0,0 @@ -package tap - -import "net" - -type packetConn struct { - net.Conn -} - -func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(b) - addr = c.Conn.RemoteAddr() - return -} - -func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - return c.Write(b) -} diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index d59474c..58667da 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -2,6 +2,7 @@ package tap import ( "context" + "errors" "fmt" "io" "net" @@ -122,7 +123,12 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add if err != nil { return err } - pc = &packetConn{cc} + + var ok bool + pc, ok = cc.(net.PacketConn) + if !ok { + return errors.New("invalid connection") + } } else { if h.md.tcpMode { if addr != nil { diff --git a/pkg/handler/tun/conn.go b/pkg/handler/tun/conn.go deleted file mode 100644 index c16bd7f..0000000 --- a/pkg/handler/tun/conn.go +++ /dev/null @@ -1,17 +0,0 @@ -package tun - -import "net" - -type packetConn struct { - net.Conn -} - -func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(b) - addr = c.Conn.RemoteAddr() - return -} - -func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - return c.Write(b) -} diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index f84debb..ba7d04e 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -2,6 +2,7 @@ package tun import ( "context" + "errors" "fmt" "io" "net" @@ -124,7 +125,12 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add if err != nil { return err } - pc = &packetConn{cc} + + var ok bool + pc, ok = cc.(net.PacketConn) + if !ok { + return errors.New("invalid connnection") + } } else { if h.md.tcpMode { if addr != nil { diff --git a/pkg/listener/dns/server.go b/pkg/listener/dns/server.go index 7b62891..5717438 100644 --- a/pkg/listener/dns/server.go +++ b/pkg/listener/dns/server.go @@ -60,8 +60,9 @@ func (c *serverConn) Read(b []byte) (n int, err error) { case <-c.closed: err = io.ErrClosedPipe return + default: + return c.r.Read(b) } - return c.r.Read(b) } func (c *serverConn) Write(b []byte) (n int, err error) { @@ -69,8 +70,9 @@ func (c *serverConn) Write(b []byte) (n int, err error) { case <-c.closed: err = io.ErrClosedPipe return + default: + return c.w.Write(b) } - return c.w.Write(b) } func (c *serverConn) Close() error { diff --git a/pkg/resolver/exchanger/exchanger.go b/pkg/resolver/exchanger/exchanger.go new file mode 100644 index 0000000..feba5e0 --- /dev/null +++ b/pkg/resolver/exchanger/exchanger.go @@ -0,0 +1,213 @@ +package exchanger + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/logger" + "github.com/miekg/dns" +) + +type Options struct { + chain *chain.Chain + tlsConfig *tls.Config + timeout time.Duration + logger logger.Logger +} + +// Option allows a common way to set Exchanger options. +type Option func(opts *Options) + +// ChainOption sets the chain for Exchanger. +func ChainOption(chain *chain.Chain) Option { + return func(opts *Options) { + opts.chain = chain + } +} + +// TLSConfigOption sets the TLS config for Exchanger. +func TLSConfigOption(cfg *tls.Config) Option { + return func(opts *Options) { + opts.tlsConfig = cfg + } +} + +// LoggerOption sets the logger for Exchanger. +func LoggerOption(logger logger.Logger) Option { + return func(opts *Options) { + opts.logger = logger + } +} + +// TimeoutOption sets the timeout for Exchanger. +func TimeoutOption(timeout time.Duration) Option { + return func(opts *Options) { + opts.timeout = timeout + } +} + +// Exchanger is an interface for DNS synchronous query. +type Exchanger interface { + Exchange(ctx context.Context, msg []byte) ([]byte, error) + String() string +} + +type exchanger struct { + network string + addr string + rawAddr string + router *chain.Router + client *http.Client + options Options +} + +// NewExchanger create an Exchanger. +func NewExchanger(addr string, opts ...Option) (Exchanger, error) { + var options Options + for _, opt := range opts { + opt(&options) + } + + if !strings.Contains(addr, "://") { + addr = "udp://" + addr + } + u, err := url.Parse(addr) + if err != nil { + return nil, err + } + + ex := &exchanger{ + network: u.Scheme, + addr: u.Host, + rawAddr: addr, + options: options, + } + ex.router = (&chain.Router{}). + WithChain(options.chain). + WithLogger(options.logger) + if _, port, _ := net.SplitHostPort(ex.addr); port == "" { + ex.addr = net.JoinHostPort(ex.addr, "53") + } + + switch ex.network { + case "tcp": + case "dot", "tls": + if ex.options.tlsConfig == nil { + ex.options.tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + ex.network = "tcp" + case "doh": + ex.addr = addr + if ex.options.tlsConfig == nil { + ex.options.tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + ex.client = &http.Client{ + Timeout: options.timeout, + Transport: &http.Transport{ + TLSClientConfig: options.tlsConfig, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: options.timeout, + ExpectContinueTimeout: 1 * time.Second, + DialContext: ex.dial, + }, + } + default: + ex.network = "udp" + } + + return ex, nil +} + +func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) { + if ex.network == "doh" { + return ex.dohExchange(ctx, msg) + } + return ex.exchange(ctx, msg) +} + +func (ex *exchanger) dohExchange(ctx context.Context, msg []byte) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "POST", ex.addr, bytes.NewBuffer(msg)) + if err != nil { + return nil, fmt.Errorf("failed to create an HTTPS request: %w", err) + } + + // req.Header.Add("Content-Type", "application/dns-udpwireformat") + req.Header.Add("Content-Type", "application/dns-message") + + client := ex.client + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform an HTTPS request: %w", err) + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + // Read wireformat response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read the response body: %w", err) + } + + return buf, nil +} + +func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) { + if ex.options.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, ex.options.timeout) + defer cancel() + } + + c, err := ex.dial(ctx, ex.network, ex.addr) + if err != nil { + return nil, err + } + defer c.Close() + + if ex.options.tlsConfig != nil { + c = tls.Client(c, ex.options.tlsConfig) + } + + conn := &dns.Conn{Conn: c} + + if _, err = conn.Write(msg); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + return nil, err + } + + return mr.Pack() +} + +func (ex *exchanger) dial(ctx context.Context, network, address string) (net.Conn, error) { + return ex.router.Dial(ctx, network, address) +} + +func (ex *exchanger) String() string { + return ex.rawAddr +} diff --git a/pkg/resolver/ns.go b/pkg/resolver/ns.go new file mode 100644 index 0000000..4e1c8a8 --- /dev/null +++ b/pkg/resolver/ns.go @@ -0,0 +1,13 @@ +package resolver + +import ( + "time" +) + +type NameServer struct { + Addr string + Protocol string + Hostname string // for TLS handshake verification + Exchanger Exchanger + Timeout time.Duration +} diff --git a/pkg/resolver/resolver.go b/pkg/resolver/resolver.go new file mode 100644 index 0000000..6b6108f --- /dev/null +++ b/pkg/resolver/resolver.go @@ -0,0 +1,11 @@ +package resolver + +import ( + "context" + "net" +) + +type Resolver interface { + // Resolve returns a slice of the host's IPv4 and IPv6 addresses. + Resolve(ctx context.Context, host string) ([]net.IP, error) +}