diff --git a/cmd/gost/config.go b/cmd/gost/config.go index b024d02..6f47815 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -71,12 +71,12 @@ func buildService(cfg *config.Config) (services []*service.Service) { h := registry.GetHandler(svc.Handler.Type)( handler.BypassOption(bypasses[svc.Bypass]), handler.LoggerOption(handlerLogger), + handler.RouterOption(&chain.Router{ + Chain: chains[svc.Chain], + Logger: handlerLogger, + }), ) - if chainable, ok := h.(chain.Chainable); ok { - chainable.WithChain(chains[svc.Chain]) - } - if forwarder, ok := h.(handler.Forwarder); ok { forwarder.Forward(forwarderFromConfig(svc.Forwarder)) } diff --git a/pkg/chain/router.go b/pkg/chain/router.go index d35c153..693ca6f 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -13,30 +13,10 @@ import ( ) type Router struct { - retries int - chain *Chain - resolver resolver.Resolver - logger logger.Logger -} - -func (r *Router) WithChain(chain *Chain) *Router { - r.chain = chain - return r -} - -func (r *Router) WithResolver(resolver resolver.Resolver) *Router { - r.resolver = resolver - return r -} - -func (r *Router) WithRetry(retries int) *Router { - r.retries = retries - return r -} - -func (r *Router) WithLogger(logger logger.Logger) *Router { - r.logger = logger - return r + Retries int + Chain *Chain + Resolver resolver.Resolver + Logger logger.Logger } func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { @@ -53,27 +33,27 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co } func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - count := r.retries + 1 + count := r.Retries + 1 if count <= 0 { count = 1 } - r.logger.Debugf("dial %s/%s", address, network) + r.Logger.Debugf("dial %s/%s", address, network) for i := 0; i < count; i++ { - route := r.chain.GetRouteFor(network, address) + route := r.Chain.GetRouteFor(network, address) - if r.logger.IsLevelEnabled(logger.DebugLevel) { + if r.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} for _, node := range route.Path() { fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) } fmt.Fprintf(&buf, "%s", address) - r.logger.Debugf("route(retry=%d) %s", i, buf.String()) + r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) } address, err = r.resolve(ctx, address) if err != nil { - r.logger.Error(err) + r.Logger.Error(err) break } @@ -81,15 +61,19 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co if err == nil { break } - r.logger.Errorf("route(retry=%d) %s", i, err) + r.Logger.Errorf("route(retry=%d) %s", i, err) } return } func (r *Router) resolve(ctx context.Context, addr string) (string, error) { + if addr == "" { + return addr, nil + } + host, port, err := net.SplitHostPort(addr) - if err != nil { + if err != nil || host == "" { return "", err } @@ -99,10 +83,10 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) { } */ - if r.resolver != nil { - ips, err := r.resolver.Resolve(ctx, host) + if r.Resolver != nil { + ips, err := r.Resolver.Resolve(ctx, host) if err != nil { - r.logger.Error(err) + r.Logger.Error(err) } if len(ips) == 0 { return "", errors.New("domain not exists") @@ -113,29 +97,29 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) { } func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) { - count := r.retries + 1 + count := r.Retries + 1 if count <= 0 { count = 1 } - r.logger.Debugf("bind on %s/%s", address, network) + r.Logger.Debugf("bind on %s/%s", address, network) for i := 0; i < count; i++ { - route := r.chain.GetRouteFor(network, address) + route := r.Chain.GetRouteFor(network, address) - if r.logger.IsLevelEnabled(logger.DebugLevel) { + if r.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} for _, node := range route.Path() { fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) } fmt.Fprintf(&buf, "%s", address) - r.logger.Debugf("route(retry=%d) %s", i, buf.String()) + r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) } ln, err = route.Bind(ctx, network, address, opts...) if err == nil { break } - r.logger.Errorf("route(retry=%d) %s", i, err) + r.Logger.Errorf("route(retry=%d) %s", i, err) } return diff --git a/pkg/handler/auto/handler.go b/pkg/handler/auto/handler.go index 0280ce4..5b5b6a0 100644 --- a/pkg/handler/auto/handler.go +++ b/pkg/handler/auto/handler.go @@ -8,7 +8,6 @@ import ( "github.com/go-gost/gosocks4" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -21,7 +20,6 @@ func init() { } type autoHandler struct { - chain *chain.Chain httpHandler handler.Handler socks4Handler handler.Handler socks5Handler handler.Handler @@ -68,43 +66,23 @@ func NewHandler(opts ...handler.Option) handler.Handler { return h } -func (h *autoHandler) WithChain(chain *chain.Chain) { - h.chain = chain -} - func (h *autoHandler) Init(md md.Metadata) error { if h.httpHandler != nil { - if chainable, ok := h.httpHandler.(chain.Chainable); ok { - chainable.WithChain(h.chain) - } - if err := h.httpHandler.Init(md); err != nil { return err } } if h.socks4Handler != nil { - if chainable, ok := h.socks4Handler.(chain.Chainable); ok { - chainable.WithChain(h.chain) - } - if err := h.socks4Handler.Init(md); err != nil { return err } } if h.socks5Handler != nil { - if chainable, ok := h.socks5Handler.(chain.Chainable); ok { - chainable.WithChain(h.chain) - } - if err := h.socks5Handler.Init(md); err != nil { return err } } if h.relayHandler != nil { - if chainable, ok := h.relayHandler.(chain.Chainable); ok { - chainable.WithChain(h.chain) - } - if err := h.relayHandler.Init(md); err != nil { return err } diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go index 632dd16..c823373 100644 --- a/pkg/handler/dns/handler.go +++ b/pkg/handler/dns/handler.go @@ -21,13 +21,17 @@ import ( "github.com/miekg/dns" ) +const ( + defaultNameserver = "udp://127.0.0.1:53" +) + func init() { registry.RegisterHandler("dns", NewHandler) } type dnsHandler struct { - chain *chain.Chain bypass bypass.Bypass + router *chain.Router exchangers []exchanger.Exchanger cache *resolver_util.Cache logger logger.Logger @@ -44,6 +48,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &dnsHandler{ bypass: options.Bypass, + router: options.Router, cache: cache, logger: options.Logger, } @@ -61,7 +66,7 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { } ex, err := exchanger.NewExchanger( server, - exchanger.ChainOption(h.chain), + exchanger.RouterOption(h.router), exchanger.TimeoutOption(h.md.timeout), exchanger.LoggerOption(h.logger), ) @@ -72,14 +77,13 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { h.exchangers = append(h.exchangers, ex) } if len(h.exchangers) == 0 { - addr := "udp://127.0.0.1:53" ex, err := exchanger.NewExchanger( - addr, - exchanger.ChainOption(h.chain), + defaultNameserver, + exchanger.RouterOption(h.router), exchanger.TimeoutOption(h.md.timeout), exchanger.LoggerOption(h.logger), ) - h.logger.Warnf("resolver not found, default to %s", addr) + h.logger.Warnf("resolver not found, default to %s", defaultNameserver) if err != nil { return err } @@ -88,11 +92,6 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { return } -// implements chain.Chainable interface -func (h *dnsHandler) WithChain(chain *chain.Chain) { - h.chain = chain -} - func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/dns/metadata.go b/pkg/handler/dns/metadata.go index 26ef5e2..4433edd 100644 --- a/pkg/handler/dns/metadata.go +++ b/pkg/handler/dns/metadata.go @@ -10,7 +10,6 @@ import ( type metadata struct { readTimeout time.Duration - retryCount int ttl time.Duration timeout time.Duration clientIP net.IP @@ -22,7 +21,6 @@ type metadata struct { func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" - retryCount = "retry" ttl = "ttl" timeout = "timeout" clientIP = "clientIP" @@ -31,7 +29,6 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { ) h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) h.md.ttl = mdata.GetDuration(md, ttl) h.md.timeout = mdata.GetDuration(md, timeout) if h.md.timeout <= 0 { diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index 8342061..df42666 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -22,8 +22,8 @@ func init() { type forwardHandler struct { group *chain.NodeGroup - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -36,6 +36,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &forwardHandler{ bypass: options.Bypass, + router: options.Router, logger: options.Logger, } } @@ -49,12 +50,8 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { // dummy node used by relay connector. h.group = chain.NewNodeGroup(chain.NewNode("dummy", ":0")) } - return nil -} -// WithChain implements chain.Chainable interface -func (h *forwardHandler) WithChain(chain *chain.Chain) { - h.chain = chain + return nil } // Forward implements handler.Forwarder. @@ -95,12 +92,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - - cc, err := r.Dial(ctx, network, target.Addr()) + cc, err := h.router.Dial(ctx, network, target.Addr()) if err != nil { h.logger.Error(err) // TODO: the router itself may be failed due to the failed node in the router, diff --git a/pkg/handler/forward/local/metadata.go b/pkg/handler/forward/local/metadata.go index f66b1ad..773778b 100644 --- a/pkg/handler/forward/local/metadata.go +++ b/pkg/handler/forward/local/metadata.go @@ -8,16 +8,13 @@ import ( type metadata struct { readTimeout time.Duration - retryCount int } func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" - retryCount = "retry" ) h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go index 8bf529a..cdb2ced 100644 --- a/pkg/handler/forward/remote/handler.go +++ b/pkg/handler/forward/remote/handler.go @@ -22,6 +22,7 @@ func init() { type forwardHandler struct { group *chain.NodeGroup bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -34,12 +35,20 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &forwardHandler{ bypass: options.Bypass, + router: &chain.Router{ + Retries: options.Router.Retries, + Resolver: options.Resolver, + Logger: options.Logger, + }, logger: options.Logger, } } func (h *forwardHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err = h.parseMetadata(md); err != nil { + return + } + return } // Forward implements handler.Forwarder. @@ -80,12 +89,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) - // without chain - r := (&chain.Router{}). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - - cc, err := r.Dial(ctx, network, target.Addr()) + cc, err := h.router.Dial(ctx, network, target.Addr()) if err != nil { h.logger.Error(err) // TODO: the router itself may be failed due to the failed node in the router, diff --git a/pkg/handler/forward/remote/metadata.go b/pkg/handler/forward/remote/metadata.go index 26bd723..2176deb 100644 --- a/pkg/handler/forward/remote/metadata.go +++ b/pkg/handler/forward/remote/metadata.go @@ -8,16 +8,13 @@ import ( type metadata struct { readTimeout time.Duration - retryCount int } func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" - retryCount = "retry" ) h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/forward/ssh/handler.go b/pkg/handler/forward/ssh/handler.go index 9ed9e76..c06d9bc 100644 --- a/pkg/handler/forward/ssh/handler.go +++ b/pkg/handler/forward/ssh/handler.go @@ -31,9 +31,9 @@ func init() { } type forwardHandler struct { - chain *chain.Chain bypass bypass.Bypass config *ssh.ServerConfig + router *chain.Router logger logger.Logger md metadata } @@ -46,6 +46,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &forwardHandler{ bypass: options.Bypass, + router: options.Router, logger: options.Logger, } } @@ -71,11 +72,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { return nil } -// WithChain implements chain.Chainable interface -func (h *forwardHandler) WithChain(chain *chain.Chain) { - h.chain = chain -} - func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() @@ -168,11 +164,7 @@ func (h *forwardHandler) directPortForwardChannel(ctx context.Context, channel s return } - r := (&chain.Router{}). - WithChain(h.chain). - // WithRetry(h.md.retryCount). - WithLogger(h.logger) - conn, err := r.Dial(ctx, "tcp", raddr) + conn, err := h.router.Dial(ctx, "tcp", raddr) if err != nil { return } diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index d28f55f..5ca7cae 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -43,9 +43,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &httpHandler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, } } @@ -55,16 +53,9 @@ func (h *httpHandler) Init(md md.Metadata) error { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// implements chain.Chainable interface -func (h *httpHandler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - func (h *httpHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/http/metadata.go b/pkg/handler/http/metadata.go index c7d38d0..f6aecd2 100644 --- a/pkg/handler/http/metadata.go +++ b/pkg/handler/http/metadata.go @@ -9,7 +9,6 @@ import ( ) type metadata struct { - retryCount int authenticator auth.Authenticator probeResist *probeResist sni bool @@ -23,7 +22,6 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { users = "users" probeResistKey = "probeResist" knock = "knock" - retryCount = "retry" sni = "sni" enableUDP = "udp" ) @@ -58,7 +56,6 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { } } } - h.md.retryCount = mdata.GetInt(md, retryCount) h.md.sni = mdata.GetBool(md, sni) h.md.enableUDP = mdata.GetBool(md, enableUDP) diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index ca407b6..315aefa 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -43,9 +43,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &http2Handler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, } } @@ -55,16 +53,9 @@ func (h *http2Handler) Init(md md.Metadata) error { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// implements chain.Chainable interface -func (h *http2Handler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/http2/metadata.go b/pkg/handler/http2/metadata.go index 5112784..5093a78 100644 --- a/pkg/handler/http2/metadata.go +++ b/pkg/handler/http2/metadata.go @@ -10,7 +10,6 @@ import ( type metadata struct { authenticator auth.Authenticator proxyAgent string - retryCount int probeResist *probeResist sni bool enableUDP bool @@ -22,7 +21,6 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { users = "users" probeResistKey = "probeResist" knock = "knock" - retryCount = "retry" sni = "sni" enableUDP = "udp" ) @@ -51,7 +49,6 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { } } } - h.md.retryCount = mdata.GetInt(md, retryCount) h.md.sni = mdata.GetBool(md, sni) h.md.enableUDP = mdata.GetBool(md, enableUDP) diff --git a/pkg/handler/option.go b/pkg/handler/option.go index 4d12715..2ef067e 100644 --- a/pkg/handler/option.go +++ b/pkg/handler/option.go @@ -2,11 +2,13 @@ package handler import ( "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/resolver" ) type Options struct { + Router *chain.Router Bypass bypass.Bypass Resolver resolver.Resolver Logger logger.Logger @@ -14,9 +16,9 @@ type Options struct { type Option func(opts *Options) -func LoggerOption(logger logger.Logger) Option { +func RouterOption(router *chain.Router) Option { return func(opts *Options) { - opts.Logger = logger + opts.Router = router } } @@ -26,8 +28,8 @@ func BypassOption(bypass bypass.Bypass) Option { } } -func ResolverOption(resolver resolver.Resolver) Option { +func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { - opts.Resolver = resolver + opts.Logger = logger } } diff --git a/pkg/handler/redirect/handler.go b/pkg/handler/redirect/handler.go index b319079..bc9c6bd 100644 --- a/pkg/handler/redirect/handler.go +++ b/pkg/handler/redirect/handler.go @@ -22,8 +22,8 @@ func init() { } type redirectHandler struct { - chain *chain.Chain bypass bypass.Bypass + router *chain.Router logger logger.Logger md metadata } @@ -36,6 +36,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &redirectHandler{ bypass: options.Bypass, + router: options.Router, logger: options.Logger, } } @@ -44,11 +45,6 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) { return h.parseMetadata(md) } -// WithChain implements chain.Chainable interface -func (h *redirectHandler) WithChain(chain *chain.Chain) { - h.chain = chain -} - func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() @@ -93,12 +89,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn) { return } - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - - cc, err := r.Dial(ctx, network, dstAddr.String()) + cc, err := h.router.Dial(ctx, network, dstAddr.String()) if err != nil { h.logger.Error(err) return diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 751711b..475bc98 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -35,9 +35,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &relayHandler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, } } @@ -47,16 +45,9 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// WithChain implements chain.Chainable interface -func (h *relayHandler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - // Forward implements handler.Forwarder. func (h *relayHandler) Forward(group *chain.NodeGroup) { h.group = group diff --git a/pkg/handler/relay/metadata.go b/pkg/handler/relay/metadata.go index c4f3a0f..3564a0f 100644 --- a/pkg/handler/relay/metadata.go +++ b/pkg/handler/relay/metadata.go @@ -12,7 +12,6 @@ import ( type metadata struct { authenticator auth.Authenticator readTimeout time.Duration - retryCount int enableBind bool udpBufferSize int noDelay bool @@ -22,7 +21,6 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" readTimeout = "readTimeout" - retryCount = "retry" enableBind = "bind" udpBufferSize = "udpBufferSize" noDelay = "nodelay" @@ -42,7 +40,6 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { } h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) h.md.enableBind = mdata.GetBool(md, enableBind) h.md.noDelay = mdata.GetBool(md, noDelay) diff --git a/pkg/handler/sni/handler.go b/pkg/handler/sni/handler.go index 8c1e64d..22c8845 100644 --- a/pkg/handler/sni/handler.go +++ b/pkg/handler/sni/handler.go @@ -46,9 +46,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { h := &sniHandler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: log, } @@ -74,16 +72,9 @@ func (h *sniHandler) Init(md md.Metadata) (err error) { } } - h.router.WithRetry(h.md.retryCount) - return nil } -// WithChain implements chain.Chainable interface -func (h *sniHandler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/sni/metadata.go b/pkg/handler/sni/metadata.go index 4fdbdd7..7b44ff8 100644 --- a/pkg/handler/sni/metadata.go +++ b/pkg/handler/sni/metadata.go @@ -8,16 +8,13 @@ import ( type metadata struct { readTimeout time.Duration - retryCount int } func (h *sniHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" - retryCount = "retry" ) h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index 72c3b01..db46b2c 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -34,9 +34,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &socks4Handler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, } } @@ -46,16 +44,9 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// implements chain.Chainable interface -func (h *socks4Handler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/socks/v4/metadata.go b/pkg/handler/socks/v4/metadata.go index 2842c01..d6ab966 100644 --- a/pkg/handler/socks/v4/metadata.go +++ b/pkg/handler/socks/v4/metadata.go @@ -10,14 +10,12 @@ import ( type metadata struct { authenticator auth.Authenticator readTimeout time.Duration - retryCount int } func (h *socks4Handler) parseMetadata(md mdata.Metadata) (err error) { const ( users = "users" readTimeout = "readTimeout" - retryCount = "retry" ) if auths := mdata.GetStrings(md, users); len(auths) > 0 { @@ -31,6 +29,5 @@ func (h *socks4Handler) parseMetadata(md mdata.Metadata) (err error) { } h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index fb4a35e..84980c3 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -36,9 +36,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &socks5Handler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, } } @@ -55,16 +53,9 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { noTLS: h.md.noTLS, } - h.router.WithRetry(h.md.retryCount) - return } -// implements chain.Chainable interface -func (h *socks5Handler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index 47a13d0..a49c939 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -16,7 +16,6 @@ type metadata struct { authenticator auth.Authenticator timeout time.Duration readTimeout time.Duration - retryCount int noTLS bool enableBind bool enableUDP bool @@ -32,7 +31,6 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { users = "users" readTimeout = "readTimeout" timeout = "timeout" - retryCount = "retry" noTLS = "notls" enableBind = "bind" enableUDP = "udp" @@ -64,7 +62,6 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { h.md.readTimeout = mdata.GetDuration(md, readTimeout) h.md.timeout = mdata.GetDuration(md, timeout) - h.md.retryCount = mdata.GetInt(md, retryCount) h.md.noTLS = mdata.GetBool(md, noTLS) h.md.enableBind = mdata.GetBool(md, enableBind) h.md.enableUDP = mdata.GetBool(md, enableUDP) diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 9fde568..0921eec 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -36,9 +36,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &ssHandler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, } } @@ -48,16 +46,9 @@ func (h *ssHandler) Init(md md.Metadata) (err error) { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// implements chain.Chainable interface -func (h *ssHandler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/ss/metadata.go b/pkg/handler/ss/metadata.go index f841cb3..6833bbf 100644 --- a/pkg/handler/ss/metadata.go +++ b/pkg/handler/ss/metadata.go @@ -12,7 +12,6 @@ import ( type metadata struct { cipher core.Cipher readTimeout time.Duration - retryCount int } func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -20,7 +19,6 @@ func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) { users = "users" key = "key" readTimeout = "readTimeout" - retryCount = "retry" ) var method, password string @@ -39,7 +37,6 @@ func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) { } h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) return } diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index 79ac23d..b12f50d 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -35,9 +35,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &ssuHandler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, } } @@ -47,16 +45,9 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// WithChain implements chain.Chainable interface -func (h *ssuHandler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() diff --git a/pkg/handler/ss/udp/metadata.go b/pkg/handler/ss/udp/metadata.go index 24d1908..42a6ec1 100644 --- a/pkg/handler/ss/udp/metadata.go +++ b/pkg/handler/ss/udp/metadata.go @@ -13,7 +13,6 @@ import ( type metadata struct { cipher core.Cipher readTimeout time.Duration - retryCount int bufferSize int } @@ -22,7 +21,6 @@ func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) { users = "users" key = "key" readTimeout = "readTimeout" - retryCount = "retry" bufferSize = "bufferSize" ) @@ -42,7 +40,6 @@ func (h *ssuHandler) parseMetadata(md mdata.Metadata) (err error) { } h.md.readTimeout = mdata.GetDuration(md, readTimeout) - h.md.retryCount = mdata.GetInt(md, retryCount) if bs := mdata.GetInt(md, bufferSize); bs > 0 { h.md.bufferSize = int(math.Min(math.Max(float64(bs), 512), 64*1024)) diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index ebc24f6..a3b613a 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -44,9 +44,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &tapHandler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, exit: make(chan struct{}, 1), } @@ -57,16 +55,9 @@ func (h *tapHandler) Init(md md.Metadata) (err error) { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// implements chain.Chainable interface -func (h *tapHandler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - // Forward implements handler.Forwarder. func (h *tapHandler) Forward(group *chain.NodeGroup) { h.group = group diff --git a/pkg/handler/tap/metadata.go b/pkg/handler/tap/metadata.go index a7ebd70..f15a0f5 100644 --- a/pkg/handler/tap/metadata.go +++ b/pkg/handler/tap/metadata.go @@ -19,7 +19,6 @@ func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) { users = "users" key = "key" readTimeout = "readTimeout" - retryCount = "retry" bufferSize = "bufferSize" ) @@ -37,7 +36,6 @@ func (h *tapHandler) parseMetadata(md mdata.Metadata) (err error) { if err != nil { return } - h.md.retryCount = mdata.GetInt(md, retryCount) h.md.bufferSize = mdata.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index 60bbf24..b732104 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -46,9 +46,7 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &tunHandler{ bypass: options.Bypass, - router: (&chain.Router{}). - WithLogger(options.Logger). - WithResolver(options.Resolver), + router: options.Router, logger: options.Logger, exit: make(chan struct{}, 1), } @@ -59,16 +57,9 @@ func (h *tunHandler) Init(md md.Metadata) (err error) { return err } - h.router.WithRetry(h.md.retryCount) - return nil } -// implements chain.Chainable interface -func (h *tunHandler) WithChain(chain *chain.Chain) { - h.router.WithChain(chain) -} - // Forward implements handler.Forwarder. func (h *tunHandler) Forward(group *chain.NodeGroup) { h.group = group diff --git a/pkg/handler/tun/metadata.go b/pkg/handler/tun/metadata.go index e6ac54a..2aa19d8 100644 --- a/pkg/handler/tun/metadata.go +++ b/pkg/handler/tun/metadata.go @@ -10,7 +10,6 @@ import ( type metadata struct { cipher core.Cipher - retryCount int bufferSize int } @@ -19,7 +18,6 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { users = "users" key = "key" readTimeout = "readTimeout" - retryCount = "retry" bufferSize = "bufferSize" ) @@ -37,7 +35,6 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { if err != nil { return } - h.md.retryCount = mdata.GetInt(md, retryCount) h.md.bufferSize = mdata.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { diff --git a/pkg/listener/rtcp/listener.go b/pkg/listener/rtcp/listener.go index 36681a0..55ffd54 100644 --- a/pkg/listener/rtcp/listener.go +++ b/pkg/listener/rtcp/listener.go @@ -22,6 +22,7 @@ type rtcpListener struct { chain *chain.Chain ln net.Listener md metadata + router *chain.Router logger logger.Logger closed chan struct{} } @@ -34,13 +35,16 @@ func NewListener(opts ...listener.Option) listener.Listener { return &rtcpListener{ addr: options.Addr, closed: make(chan struct{}), + router: &chain.Router{ + Logger: options.Logger, + }, logger: options.Logger, } } // implements chain.Chainable interface func (l *rtcpListener) WithChain(chain *chain.Chain) { - l.chain = chain + l.router.Chain = chain } func (l *rtcpListener) Init(md md.Metadata) (err error) { @@ -66,11 +70,7 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) { } if l.ln == nil { - r := (&chain.Router{}). - WithChain(l.chain). - WithRetry(l.md.retryCount). - WithLogger(l.logger) - l.ln, err = r.Bind(context.Background(), "tcp", l.laddr.String(), + l.ln, err = l.router.Bind(context.Background(), "tcp", l.laddr.String(), connector.MuxBindOption(true), ) if err != nil { diff --git a/pkg/listener/rudp/listener.go b/pkg/listener/rudp/listener.go index 4377b93..7d03e65 100644 --- a/pkg/listener/rudp/listener.go +++ b/pkg/listener/rudp/listener.go @@ -22,6 +22,7 @@ type rudpListener struct { chain *chain.Chain ln net.Listener md metadata + router *chain.Router logger logger.Logger closed chan struct{} } @@ -34,13 +35,16 @@ func NewListener(opts ...listener.Option) listener.Listener { return &rudpListener{ addr: options.Addr, closed: make(chan struct{}), + router: &chain.Router{ + Logger: options.Logger, + }, logger: options.Logger, } } // implements chain.Chainable interface func (l *rudpListener) WithChain(chain *chain.Chain) { - l.chain = chain + l.router.Chain = chain } func (l *rudpListener) Init(md md.Metadata) (err error) { @@ -66,11 +70,7 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) { } if l.ln == nil { - r := (&chain.Router{}). - WithChain(l.chain). - WithRetry(l.md.retryCount). - WithLogger(l.logger) - l.ln, err = r.Bind(context.Background(), "udp", l.laddr.String(), + l.ln, err = l.router.Bind(context.Background(), "udp", l.laddr.String(), connector.BacklogBindOption(l.md.backlog), connector.UDPConnTTLBindOption(l.md.ttl), connector.UDPDataBufferSizeBindOption(l.md.readBufferSize), diff --git a/pkg/resolver/exchanger/exchanger.go b/pkg/resolver/exchanger/exchanger.go index c0c6d94..b592b77 100644 --- a/pkg/resolver/exchanger/exchanger.go +++ b/pkg/resolver/exchanger/exchanger.go @@ -18,7 +18,7 @@ import ( ) type Options struct { - chain *chain.Chain + router *chain.Router tlsConfig *tls.Config timeout time.Duration logger logger.Logger @@ -27,10 +27,10 @@ type Options struct { // Option allows a common way to set Exchanger options. type Option func(opts *Options) -// ChainOption sets the chain for Exchanger. -func ChainOption(chain *chain.Chain) Option { +// RouterOption sets the router for Exchanger. +func RouterOption(router *chain.Router) Option { return func(opts *Options) { - opts.chain = chain + opts.router = router } } @@ -89,14 +89,17 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) { network: u.Scheme, addr: u.Host, rawAddr: addr, + router: options.router, options: options, } - ex.router = (&chain.Router{}). - WithChain(options.chain). - WithLogger(options.logger) if _, port, _ := net.SplitHostPort(ex.addr); port == "" { ex.addr = net.JoinHostPort(ex.addr, "53") } + if ex.router == nil { + ex.router = &chain.Router{ + Logger: options.logger, + } + } switch ex.network { case "tcp":