From a430384bba32dbd4c6cde3aaba63c5bb574a7eb7 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 30 Dec 2021 23:07:05 +0800 Subject: [PATCH] add resolver for router --- pkg/chain/router.go | 45 +++++++- pkg/handler/http/handler.go | 21 ++-- pkg/handler/http/udp.go | 7 +- pkg/handler/http2/handler.go | 27 ++--- pkg/handler/option.go | 12 ++- pkg/handler/relay/connect.go | 7 +- pkg/handler/relay/forward.go | 8 +- pkg/handler/relay/handler.go | 15 ++- pkg/handler/sni/handler.go | 15 +-- pkg/handler/socks/v4/handler.go | 21 ++-- pkg/handler/socks/v5/connect.go | 7 +- pkg/handler/socks/v5/handler.go | 9 +- pkg/handler/socks/v5/udp.go | 7 +- pkg/handler/socks/v5/udp_tun.go | 7 +- pkg/handler/ss/handler.go | 21 ++-- pkg/handler/ss/udp/handler.go | 21 ++-- pkg/handler/tap/handler.go | 40 ++++---- pkg/handler/tap/metadata.go | 3 - pkg/handler/tun/handler.go | 39 +++---- pkg/handler/tun/metadata.go | 3 - pkg/resolver/impl/resolver.go | 176 ++++++++++++++++++++++++++++++++ pkg/resolver/resolver.go | 169 ------------------------------ 22 files changed, 362 insertions(+), 318 deletions(-) create mode 100644 pkg/resolver/impl/resolver.go diff --git a/pkg/chain/router.go b/pkg/chain/router.go index af0a4f8..d35c153 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -3,17 +3,20 @@ package chain import ( "bytes" "context" + "errors" "fmt" "net" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/resolver" ) type Router struct { - chain *Chain - retries int - logger logger.Logger + retries int + chain *Chain + resolver resolver.Resolver + logger logger.Logger } func (r *Router) WithChain(chain *Chain) *Router { @@ -21,6 +24,11 @@ func (r *Router) WithChain(chain *Chain) *Router { return r } +func (r *Router) WithResolver(resolver resolver.Resolver) *Router { + r.resolver = resolver + return r +} + func (r *Router) WithRetry(retries int) *Router { r.retries = retries return r @@ -63,6 +71,12 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co r.logger.Debugf("route(retry=%d) %s", i, buf.String()) } + address, err = r.resolve(ctx, address) + if err != nil { + r.logger.Error(err) + break + } + conn, err = route.Dial(ctx, network, address) if err == nil { break @@ -73,6 +87,31 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co return } +func (r *Router) resolve(ctx context.Context, addr string) (string, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", err + } + + /* + if ip := hosts.Lookup(host); ip != nil { + return net.JoinHostPort(ip.String(), port) + } + */ + + if r.resolver != nil { + ips, err := r.resolver.Resolve(ctx, host) + if err != nil { + r.logger.Error(err) + } + if len(ips) == 0 { + return "", errors.New("domain not exists") + } + return net.JoinHostPort(ips[0].String(), port), nil + } + return addr, nil +} + func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) { count := r.retries + 1 if count <= 0 { diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index ab3dfe6..d28f55f 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -29,8 +29,8 @@ func init() { } type httpHandler struct { - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -43,17 +43,26 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &httpHandler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, } } func (h *httpHandler) Init(md md.Metadata) error { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // implements chain.Chainable interface func (h *httpHandler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) { @@ -192,11 +201,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt req.Header.Del("Proxy-Authorization") - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, network, addr) + cc, err := h.router.Dial(ctx, network, addr) if err != nil { resp.StatusCode = http.StatusServiceUnavailable resp.Write(conn) diff --git a/pkg/handler/http/udp.go b/pkg/handler/http/udp.go index 12a7d8b..582c750 100644 --- a/pkg/handler/http/udp.go +++ b/pkg/handler/http/udp.go @@ -7,7 +7,6 @@ import ( "net/http/httputil" "time" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" @@ -51,11 +50,7 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add } // obtain a udp connection - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - c, err := r.Dial(ctx, "udp", "") // UDP association + c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { h.logger.Error(err) return diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index 1a4deac..ca407b6 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -29,8 +29,8 @@ func init() { } type http2Handler struct { - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -43,17 +43,26 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &http2Handler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, } } func (h *http2Handler) Init(md md.Metadata) error { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // implements chain.Chainable interface func (h *http2Handler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) { @@ -154,11 +163,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req req.Header.Del("Proxy-Authorization") req.Header.Del("Proxy-Connection") - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, "tcp", addr) + cc, err := h.router.Dial(ctx, "tcp", addr) if err != nil { h.logger.Error(err) w.WriteHeader(http.StatusServiceUnavailable) @@ -312,11 +317,7 @@ func (h *http2Handler) handleRequest(ctx context.Context, conn net.Conn, req *ht req.Header.Del("Proxy-Authorization") - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, network, addr) + cc, err := h.router.Dial(ctx, network, addr) if err != nil { resp.StatusCode = http.StatusServiceUnavailable resp.Write(conn) diff --git a/pkg/handler/option.go b/pkg/handler/option.go index b65443b..4d12715 100644 --- a/pkg/handler/option.go +++ b/pkg/handler/option.go @@ -3,11 +3,13 @@ package handler import ( "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/resolver" ) type Options struct { - Bypass bypass.Bypass - Logger logger.Logger + Bypass bypass.Bypass + Resolver resolver.Resolver + Logger logger.Logger } type Option func(opts *Options) @@ -23,3 +25,9 @@ func BypassOption(bypass bypass.Bypass) Option { opts.Bypass = bypass } } + +func ResolverOption(resolver resolver.Resolver) Option { + return func(opts *Options) { + opts.Resolver = resolver + } +} diff --git a/pkg/handler/relay/connect.go b/pkg/handler/relay/connect.go index 648cadb..cf97351 100644 --- a/pkg/handler/relay/connect.go +++ b/pkg/handler/relay/connect.go @@ -6,7 +6,6 @@ import ( "net" "time" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/relay" ) @@ -38,11 +37,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network return } - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, network, address) + cc, err := h.router.Dial(ctx, network, address) if err != nil { resp.Status = relay.StatusNetworkUnreachable resp.WriteTo(conn) diff --git a/pkg/handler/relay/forward.go b/pkg/handler/relay/forward.go index 96389a5..aa44e84 100644 --- a/pkg/handler/relay/forward.go +++ b/pkg/handler/relay/forward.go @@ -6,7 +6,6 @@ import ( "net" "time" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/relay" ) @@ -30,12 +29,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - - cc, err := r.Dial(ctx, network, target.Addr()) + cc, err := h.router.Dial(ctx, network, target.Addr()) if err != nil { // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 03dac36..751711b 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -21,8 +21,8 @@ func init() { type relayHandler struct { group *chain.NodeGroup - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -35,17 +35,26 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &relayHandler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, } } func (h *relayHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // WithChain implements chain.Chainable interface func (h *relayHandler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } // Forward implements handler.Forwarder. diff --git a/pkg/handler/sni/handler.go b/pkg/handler/sni/handler.go index 2421b12..8c1e64d 100644 --- a/pkg/handler/sni/handler.go +++ b/pkg/handler/sni/handler.go @@ -27,8 +27,8 @@ func init() { type sniHandler struct { httpHandler handler.Handler - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -46,6 +46,9 @@ func NewHandler(opts ...handler.Option) handler.Handler { h := &sniHandler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: log, } @@ -71,12 +74,14 @@ func (h *sniHandler) Init(md md.Metadata) (err error) { } } + h.router.WithRetry(h.md.retryCount) + return nil } // WithChain implements chain.Chainable interface func (h *sniHandler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { @@ -141,11 +146,7 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { return } - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, "tcp", target) + cc, err := h.router.Dial(ctx, "tcp", target) if err != nil { return } diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index 53eca12..72c3b01 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -20,8 +20,8 @@ func init() { } type socks4Handler struct { - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -34,17 +34,26 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &socks4Handler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, } } func (h *socks4Handler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // implements chain.Chainable interface func (h *socks4Handler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { @@ -111,11 +120,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g return } - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, "tcp", addr) + cc, err := h.router.Dial(ctx, "tcp", addr) if err != nil { resp := gosocks4.NewReply(gosocks4.Failed, nil) resp.Write(conn) diff --git a/pkg/handler/socks/v5/connect.go b/pkg/handler/socks/v5/connect.go index fce9572..6792358 100644 --- a/pkg/handler/socks/v5/connect.go +++ b/pkg/handler/socks/v5/connect.go @@ -7,7 +7,6 @@ import ( "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" ) @@ -26,11 +25,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ return } - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, network, address) + cc, err := h.router.Dial(ctx, network, address) if err != nil { resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp.Write(conn) diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index d1d4dd8..fb4a35e 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -22,8 +22,8 @@ func init() { type socks5Handler struct { selector gosocks5.Selector - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -36,6 +36,9 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &socks5Handler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, } } @@ -52,12 +55,14 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { noTLS: h.md.noTLS, } + h.router.WithRetry(h.md.retryCount) + return } // implements chain.Chainable interface func (h *socks5Handler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index 3a155c6..59564ae 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -9,7 +9,6 @@ import ( "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" ) @@ -54,11 +53,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { h.logger.Debugf("bind on %s OK", cc.LocalAddr()) // obtain a udp connection - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - c, err := r.Dial(ctx, "udp", "") // UDP association + c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { h.logger.Error(err) return diff --git a/pkg/handler/socks/v5/udp_tun.go b/pkg/handler/socks/v5/udp_tun.go index 01ba438..cfaca1f 100644 --- a/pkg/handler/socks/v5/udp_tun.go +++ b/pkg/handler/socks/v5/udp_tun.go @@ -6,7 +6,6 @@ import ( "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" ) @@ -33,11 +32,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network h.logger.Debug(reply) // obtain a udp connection - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - c, err := r.Dial(ctx, "udp", "") // UDP association + c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { h.logger.Error(err) return diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 85d688b..9fde568 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -22,8 +22,8 @@ func init() { } type ssHandler struct { - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -36,17 +36,26 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &ssHandler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, } } func (h *ssHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // implements chain.Chainable interface func (h *ssHandler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { @@ -91,11 +100,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { return } - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, "tcp", addr.String()) + cc, err := h.router.Dial(ctx, "tcp", addr.String()) if err != nil { return } diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index bd736b5..79ac23d 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -21,8 +21,8 @@ func init() { } type ssuHandler struct { - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -35,17 +35,26 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &ssuHandler{ bypass: options.Bypass, + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, } } func (h *ssuHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // WithChain implements chain.Chainable interface func (h *ssuHandler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { @@ -80,11 +89,7 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { } // obtain a udp connection - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - c, err := r.Dial(ctx, "udp", "") // UDP association + c, err := h.router.Dial(ctx, "udp", "") // UDP association if err != nil { h.logger.Error(err) return diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index 58667da..ebc24f6 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -20,7 +20,6 @@ import ( "github.com/go-gost/gost/pkg/registry" "github.com/shadowsocks/go-shadowsocks2/shadowaead" "github.com/songgao/water/waterutil" - "github.com/xtaci/tcpraw" ) func init() { @@ -29,10 +28,10 @@ func init() { type tapHandler struct { group *chain.NodeGroup - chain *chain.Chain bypass bypass.Bypass routes sync.Map exit chan struct{} + router *chain.Router logger logger.Logger md metadata } @@ -45,18 +44,27 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &tapHandler{ bypass: options.Bypass, - exit: make(chan struct{}, 1), + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, + exit: make(chan struct{}, 1), } } func (h *tapHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // implements chain.Chainable interface func (h *tapHandler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } // Forward implements handler.Forwarder. @@ -113,13 +121,9 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add err := func() error { var err error var pc net.PacketConn - // fake tcp mode will be ignored when the client specifies a chain. - if addr != nil && !h.chain.IsEmpty() { - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, addr.Network(), addr.String()) + + if addr != nil { + cc, err := h.router.Dial(ctx, addr.Network(), addr.String()) if err != nil { return err } @@ -130,16 +134,8 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add return errors.New("invalid connection") } } else { - if h.md.tcpMode { - if addr != nil { - pc, err = tcpraw.Dial("tcp", addr.String()) - } else { - pc, err = tcpraw.Listen("tcp", conn.LocalAddr().String()) - } - } else { - laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) - pc, err = net.ListenUDP("udp", laddr) - } + laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) + pc, err = net.ListenUDP("udp", laddr) } if err != nil { return err diff --git a/pkg/handler/tap/metadata.go b/pkg/handler/tap/metadata.go index 5a97380..a7ebd70 100644 --- a/pkg/handler/tap/metadata.go +++ b/pkg/handler/tap/metadata.go @@ -11,7 +11,6 @@ import ( type metadata struct { cipher core.Cipher retryCount int - tcpMode bool bufferSize int } @@ -21,7 +20,6 @@ func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) { key = "key" readTimeout = "readTimeout" retryCount = "retry" - tcpMode = "tcp" bufferSize = "bufferSize" ) @@ -40,7 +38,6 @@ func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) { return } h.md.retryCount = mdata.GetInt(md, retryCount) - h.md.tcpMode = mdata.GetBool(md, tcpMode) h.md.bufferSize = mdata.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index ba7d04e..60bbf24 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -20,7 +20,6 @@ import ( "github.com/go-gost/gost/pkg/registry" "github.com/shadowsocks/go-shadowsocks2/shadowaead" "github.com/songgao/water/waterutil" - "github.com/xtaci/tcpraw" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -31,10 +30,10 @@ func init() { type tunHandler struct { group *chain.NodeGroup - chain *chain.Chain bypass bypass.Bypass routes sync.Map exit chan struct{} + router *chain.Router logger logger.Logger md metadata } @@ -47,18 +46,27 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &tunHandler{ bypass: options.Bypass, - exit: make(chan struct{}, 1), + router: (&chain.Router{}). + WithLogger(options.Logger). + WithResolver(options.Resolver), logger: options.Logger, + exit: make(chan struct{}, 1), } } func (h *tunHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router.WithRetry(h.md.retryCount) + + return nil } // implements chain.Chainable interface func (h *tunHandler) WithChain(chain *chain.Chain) { - h.chain = chain + h.router.WithChain(chain) } // Forward implements handler.Forwarder. @@ -115,13 +123,8 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add err := func() error { var err error var pc net.PacketConn - // fake tcp mode will be ignored when the client specifies a chain. - if addr != nil && !h.chain.IsEmpty() { - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Dial(ctx, addr.Network(), addr.String()) + if addr != nil { + cc, err := h.router.Dial(ctx, addr.Network(), addr.String()) if err != nil { return err } @@ -132,16 +135,8 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add return errors.New("invalid connnection") } } else { - if h.md.tcpMode { - if addr != nil { - pc, err = tcpraw.Dial("tcp", addr.String()) - } else { - pc, err = tcpraw.Listen("tcp", conn.LocalAddr().String()) - } - } else { - laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) - pc, err = net.ListenUDP("udp", laddr) - } + laddr, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) + pc, err = net.ListenUDP("udp", laddr) } if err != nil { return err diff --git a/pkg/handler/tun/metadata.go b/pkg/handler/tun/metadata.go index 2d8437c..e6ac54a 100644 --- a/pkg/handler/tun/metadata.go +++ b/pkg/handler/tun/metadata.go @@ -11,7 +11,6 @@ import ( type metadata struct { cipher core.Cipher retryCount int - tcpMode bool bufferSize int } @@ -21,7 +20,6 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { key = "key" readTimeout = "readTimeout" retryCount = "retry" - tcpMode = "tcp" bufferSize = "bufferSize" ) @@ -40,7 +38,6 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { return } h.md.retryCount = mdata.GetInt(md, retryCount) - h.md.tcpMode = mdata.GetBool(md, tcpMode) h.md.bufferSize = mdata.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { diff --git a/pkg/resolver/impl/resolver.go b/pkg/resolver/impl/resolver.go new file mode 100644 index 0000000..5458cf3 --- /dev/null +++ b/pkg/resolver/impl/resolver.go @@ -0,0 +1,176 @@ +package impl + +import ( + "context" + "net" + "strings" + "time" + + "github.com/go-gost/gost/pkg/chain" + resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver" + "github.com/go-gost/gost/pkg/logger" + resolverpkg "github.com/go-gost/gost/pkg/resolver" + "github.com/go-gost/gost/pkg/resolver/exchanger" + "github.com/miekg/dns" +) + +type NameServer struct { + Addr string + Chain *chain.Chain + TTL time.Duration + Timeout time.Duration + ClientIP net.IP + Prefer string + Hostname string // for TLS handshake verification + exchanger exchanger.Exchanger +} + +type resolverOptions struct { + domain string + logger logger.Logger +} + +type ResolverOption func(opts *resolverOptions) + +func DomainResolverOption(domain string) ResolverOption { + return func(opts *resolverOptions) { + opts.domain = domain + } +} + +func LoggerResolverOption(logger logger.Logger) ResolverOption { + return func(opts *resolverOptions) { + opts.logger = logger + } +} + +type resolver struct { + servers []NameServer + cache *resolver_util.Cache + options resolverOptions + logger logger.Logger +} + +func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg.Resolver, error) { + options := resolverOptions{} + for _, opt := range opts { + opt(&options) + } + + var servers []NameServer + for _, server := range nameservers { + addr := strings.TrimSpace(server.Addr) + if addr == "" { + continue + } + ex, err := exchanger.NewExchanger( + addr, + exchanger.ChainOption(server.Chain), + exchanger.TimeoutOption(server.Timeout), + exchanger.LoggerOption(options.logger), + ) + if err != nil { + options.logger.Warnf("parse %s: %v", server, err) + continue + } + + server.exchanger = ex + servers = append(servers, server) + } + cache := resolver_util.NewCache(). + WithLogger(options.logger) + + return &resolver{ + servers: servers, + cache: cache, + options: options, + logger: options.logger, + }, nil +} + +func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { + if ip := net.ParseIP(host); ip != nil { + return []net.IP{ip}, nil + } + + if r.options.domain != "" && + !strings.Contains(host, ".") { + host = host + "." + r.options.domain + } + + for _, server := range r.servers { + ips, err = r.resolve(ctx, &server, host) + if err != nil { + r.logger.Error(err) + continue + } + + r.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips) + + if len(ips) > 0 { + break + } + } + + return +} + +func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { + if server == nil { + return + } + + if server.Prefer == "ipv6" { // prefer ipv6 + mq := dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) + ips, err = r.resolveIPs(ctx, server, &mq) + if err != nil || len(ips) > 0 { + return + } + } + + // fallback to ipv4 + mq := dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeA) + return r.resolveIPs(ctx, server, &mq) +} + +func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) { + key := resolver_util.NewCacheKey(&mq.Question[0]) + mr := r.cache.Load(key) + if mr == nil { + resolver_util.AddSubnetOpt(mq, server.ClientIP) + mr, err = r.exchange(ctx, server.exchanger, mq) + if err != nil { + return + } + r.cache.Store(key, mr, server.TTL) + } + + for _, ans := range mr.Answer { + if ar, _ := ans.(*dns.AAAA); ar != nil { + ips = append(ips, ar.AAAA) + } + if ar, _ := ans.(*dns.A); ar != nil { + ips = append(ips, ar.A) + } + } + + return +} + +func (r *resolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { + query, err := mq.Pack() + if err != nil { + return + } + reply, err := ex.Exchange(ctx, query) + if err != nil { + return + } + + mr = &dns.Msg{} + err = mr.Unpack(reply) + + return +} diff --git a/pkg/resolver/resolver.go b/pkg/resolver/resolver.go index 548868a..6b6108f 100644 --- a/pkg/resolver/resolver.go +++ b/pkg/resolver/resolver.go @@ -3,178 +3,9 @@ package resolver import ( "context" "net" - "strings" - "time" - - "github.com/go-gost/gost/pkg/chain" - resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver" - "github.com/go-gost/gost/pkg/logger" - "github.com/go-gost/gost/pkg/resolver/exchanger" - "github.com/miekg/dns" ) type Resolver interface { // Resolve returns a slice of the host's IPv4 and IPv6 addresses. Resolve(ctx context.Context, host string) ([]net.IP, error) } - -type NameServer struct { - Addr string - Chain *chain.Chain - TTL time.Duration - Timeout time.Duration - ClientIP net.IP - Prefer string - Hostname string // for TLS handshake verification - exchanger exchanger.Exchanger -} - -type resolverOptions struct { - domain string - logger logger.Logger -} - -type ResolverOption func(opts *resolverOptions) - -func DomainResolverOption(domain string) ResolverOption { - return func(opts *resolverOptions) { - opts.domain = domain - } -} - -func LoggerResolverOption(logger logger.Logger) ResolverOption { - return func(opts *resolverOptions) { - opts.logger = logger - } -} - -type resolver struct { - servers []NameServer - cache *resolver_util.Cache - options resolverOptions - logger logger.Logger -} - -func NewResolver(nameservers []NameServer, opts ...ResolverOption) (Resolver, error) { - options := resolverOptions{} - for _, opt := range opts { - opt(&options) - } - - var servers []NameServer - for _, server := range nameservers { - addr := strings.TrimSpace(server.Addr) - if addr == "" { - continue - } - ex, err := exchanger.NewExchanger( - addr, - exchanger.ChainOption(server.Chain), - exchanger.TimeoutOption(server.Timeout), - exchanger.LoggerOption(options.logger), - ) - if err != nil { - options.logger.Warnf("parse %s: %v", server, err) - continue - } - - server.exchanger = ex - servers = append(servers, server) - } - cache := resolver_util.NewCache(). - WithLogger(options.logger) - - return &resolver{ - servers: servers, - cache: cache, - options: options, - logger: options.logger, - }, nil -} - -func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { - if ip := net.ParseIP(host); ip != nil { - return []net.IP{ip}, nil - } - - if r.options.domain != "" && - !strings.Contains(host, ".") { - host = host + "." + r.options.domain - } - - for _, server := range r.servers { - ips, err = r.resolve(ctx, &server, host) - if err != nil { - r.logger.Error(err) - continue - } - - r.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips) - - if len(ips) > 0 { - break - } - } - - return -} - -func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { - if server == nil { - return - } - - if server.Prefer == "ipv6" { // prefer ipv6 - mq := dns.Msg{} - mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) - ips, err = r.resolveIPs(ctx, server, &mq) - if err != nil || len(ips) > 0 { - return - } - } - - // fallback to ipv4 - mq := dns.Msg{} - mq.SetQuestion(dns.Fqdn(host), dns.TypeA) - return r.resolveIPs(ctx, server, &mq) -} - -func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) { - key := resolver_util.NewCacheKey(&mq.Question[0]) - mr := r.cache.Load(key) - if mr == nil { - resolver_util.AddSubnetOpt(mq, server.ClientIP) - mr, err = r.exchange(ctx, server.exchanger, mq) - if err != nil { - return - } - r.cache.Store(key, mr, server.TTL) - } - - for _, ans := range mr.Answer { - if ar, _ := ans.(*dns.AAAA); ar != nil { - ips = append(ips, ar.AAAA) - } - if ar, _ := ans.(*dns.A); ar != nil { - ips = append(ips, ar.A) - } - } - - return -} - -func (r *resolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { - query, err := mq.Pack() - if err != nil { - return - } - reply, err := ex.Exchange(ctx, query) - if err != nil { - return - } - - mr = &dns.Msg{} - err = mr.Unpack(reply) - - return -}