diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index a70badf..b29788a 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -145,6 +145,11 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { md.Del("hosts") } + if v := metadata.GetString(md, "interface"); v != "" { + hopConfig.Interface = v + md.Del("interface") + } + chain.Hops = append(chain.Hops, hopConfig) } @@ -336,6 +341,10 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { if v := metadata.GetString(md, "dns"); v != "" { md.Set("dns", strings.Split(v, ",")) } + if v := metadata.GetString(md, "interface"); v != "" { + svc.Interface = v + md.Del("interface") + } if svc.Forwarder != nil { svc.Forwarder.Selector = parseSelector(md) diff --git a/pkg/chain/route.go b/pkg/chain/route.go index a3a5223..5df063b 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" "net" + "time" "github.com/go-gost/gost/pkg/common/util/udp" "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/logger" ) @@ -16,8 +18,9 @@ var ( ) type Route struct { - nodes []*Node - logger logger.Logger + nodes []*Node + ifceName string + logger logger.Logger } func (r *Route) addNode(node *Node) { @@ -26,7 +29,13 @@ func (r *Route) addNode(node *Node) { func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) { if r.Len() == 0 { - return r.dialDirect(ctx, network, address) + netd := dialer.NetDialer{ + Timeout: 30 * time.Second, + } + if r != nil { + netd.Interface = r.ifceName + } + return netd.Dial(ctx, network, address) } conn, err := r.connect(ctx) @@ -42,19 +51,6 @@ func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, er return cc, nil } -func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) { - switch network { - case "udp", "udp4", "udp6": - if address == "" { - return net.ListenUDP(network, nil) - } - default: - } - - d := net.Dialer{} - return d.DialContext(ctx, network, address) -} - func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { if r.Len() == 0 { return r.bindLocal(ctx, network, address, opts...) diff --git a/pkg/chain/router.go b/pkg/chain/router.go index da6f716..bb7dac3 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "time" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/hosts" @@ -13,11 +14,55 @@ import ( ) type Router struct { - Retries int - Chain Chainer - Hosts hosts.HostMapper - Resolver resolver.Resolver - Logger logger.Logger + timeout time.Duration + retries int + ifceName string + chain Chainer + resolver resolver.Resolver + hosts hosts.HostMapper + logger logger.Logger +} + +func (r *Router) WithTimeout(timeout time.Duration) *Router { + r.timeout = timeout + return r +} + +func (r *Router) WithRetries(retries int) *Router { + r.retries = retries + return r +} + +func (r *Router) WithInterface(ifceName string) *Router { + r.ifceName = ifceName + return r +} + +func (r *Router) WithChain(chain Chainer) *Router { + r.chain = chain + return r +} + +func (r *Router) WithResolver(resolver resolver.Resolver) *Router { + r.resolver = resolver + return r +} + +func (r *Router) WithHosts(hosts hosts.HostMapper) *Router { + r.hosts = hosts + return r +} + +func (r *Router) Hosts() hosts.HostMapper { + if r != nil { + return r.hosts + } + return nil +} + +func (r *Router) WithLogger(logger logger.Logger) *Router { + r.logger = logger + return r } func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { @@ -34,74 +79,76 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co } func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - count := r.Retries + 1 + count := r.retries + 1 if count <= 0 { count = 1 } - r.Logger.Debugf("dial %s/%s", address, network) + r.logger.Debugf("dial %s/%s", address, network) for i := 0; i < count; i++ { var route *Route - if r.Chain != nil { - route = r.Chain.Route(network, address) + if r.chain != nil { + route = r.chain.Route(network, address) } - if r.Logger.IsLevelEnabled(logger.DebugLevel) { + if r.logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} for _, node := range route.Path() { fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) } fmt.Fprintf(&buf, "%s", address) - r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) + r.logger.Debugf("route(retry=%d) %s", i, buf.String()) } - address, err = resolve(ctx, "ip", address, r.Resolver, r.Hosts, r.Logger) + address, err = resolve(ctx, "ip", address, r.resolver, r.hosts, r.logger) if err != nil { - r.Logger.Error(err) + r.logger.Error(err) break } - if route != nil { - route.logger = r.Logger + if route == nil { + route = &Route{} } + route.ifceName = r.ifceName + route.logger = r.logger conn, err = route.Dial(ctx, network, address) if err == nil { break } - r.Logger.Errorf("route(retry=%d) %s", i, err) + r.logger.Errorf("route(retry=%d) %s", i, err) } return } func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) { - count := r.Retries + 1 + count := r.retries + 1 if count <= 0 { count = 1 } - r.Logger.Debugf("bind on %s/%s", address, network) + r.logger.Debugf("bind on %s/%s", address, network) for i := 0; i < count; i++ { var route *Route - if r.Chain != nil { - route = r.Chain.Route(network, address) + if r.chain != nil { + route = r.chain.Route(network, address) } - if r.Logger.IsLevelEnabled(logger.DebugLevel) { + if r.logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} for _, node := range route.Path() { fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) } fmt.Fprintf(&buf, "%s", address) - r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) + r.logger.Debugf("route(retry=%d) %s", i, buf.String()) } ln, err = route.Bind(ctx, network, address, opts...) if err == nil { break } - r.Logger.Errorf("route(retry=%d) %s", i, err) + r.logger.Errorf("route(retry=%d) %s", i, err) } return diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index bf32351..2734b7f 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -3,6 +3,7 @@ package chain import ( "context" "net" + "time" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/dialer" @@ -10,6 +11,7 @@ import ( type Transport struct { addr string + ifceName string route *Route dialer dialer.Dialer connector connector.Connector @@ -21,6 +23,11 @@ func (tr *Transport) Copy() *Transport { return tr } +func (tr *Transport) WithInterface(ifceName string) *Transport { + tr.ifceName = ifceName + return tr +} + func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport { tr.dialer = dialer return tr @@ -32,23 +39,20 @@ func (tr *Transport) WithConnector(connector connector.Connector) *Transport { } func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { - return tr.dialer.Dial(ctx, addr, tr.dialOptions()...) -} - -func (tr *Transport) dialOptions() []dialer.DialOption { - opts := []dialer.DialOption{ - dialer.HostDialOption(tr.addr), + netd := &dialer.NetDialer{ + Interface: tr.ifceName, + Timeout: 30 * time.Second, } if tr.route.Len() > 0 { - opts = append(opts, - dialer.DialFuncDialOption( - func(ctx context.Context, addr string) (net.Conn, error) { - return tr.route.Dial(ctx, "tcp", addr) - }, - ), - ) + netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + return tr.route.Dial(ctx, network, addr) + } } - return opts + opts := []dialer.DialOption{ + dialer.HostDialOption(tr.addr), + dialer.NetDialerDialOption(netd), + } + return tr.dialer.Dial(ctx, addr, opts...) } func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) { diff --git a/pkg/config/config.go b/pkg/config/config.go index 6d77b2b..7f86960 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -180,6 +180,7 @@ type ConnectorConfig struct { type ServiceConfig struct { Name string `json:"name"` Addr string `yaml:",omitempty" json:"addr,omitempty"` + Interface string `yaml:",omitempty" json:"interface,omitempty"` Admission string `yaml:",omitempty" json:"admission,omitempty"` Bypass string `yaml:",omitempty" json:"bypass,omitempty"` Resolver string `yaml:",omitempty" json:"resolver,omitempty"` @@ -196,17 +197,19 @@ type ChainConfig struct { } type HopConfig struct { - Name string `json:"name"` - Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` - Bypass string `yaml:",omitempty" json:"bypass,omitempty"` - Resolver string `yaml:",omitempty" json:"resolver,omitempty"` - Hosts string `yaml:",omitempty" json:"hosts,omitempty"` - Nodes []*NodeConfig `json:"nodes"` + Name string `json:"name"` + Interface string `yaml:",omitempty" json:"interface,omitempty"` + Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` + Bypass string `yaml:",omitempty" json:"bypass,omitempty"` + Resolver string `yaml:",omitempty" json:"resolver,omitempty"` + Hosts string `yaml:",omitempty" json:"hosts,omitempty"` + Nodes []*NodeConfig `json:"nodes"` } type NodeConfig struct { Name string `json:"name"` Addr string `yaml:",omitempty" json:"addr,omitempty"` + Interface string `yaml:",omitempty" json:"interface,omitempty"` Bypass string `yaml:",omitempty" json:"bypass,omitempty"` Resolver string `yaml:",omitempty" json:"resolver,omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"` diff --git a/pkg/config/parsing/chain.go b/pkg/config/parsing/chain.go index f99fd58..0a1c646 100644 --- a/pkg/config/parsing/chain.go +++ b/pkg/config/parsing/chain.go @@ -93,11 +93,6 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { return nil, err } - tr := (&chain.Transport{}). - WithConnector(cr). - WithDialer(d). - WithAddr(v.Addr) - if v.Bypass == "" { v.Bypass = hop.Bypass } @@ -107,15 +102,24 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { if v.Hosts == "" { v.Hosts = hop.Hosts } + if v.Interface == "" { + v.Interface = hop.Interface + } + + tr := (&chain.Transport{}). + WithConnector(cr). + WithDialer(d). + WithAddr(v.Addr). + WithInterface(v.Interface) node := &chain.Node{ Name: v.Name, Addr: v.Addr, - Transport: tr, Bypass: registry.BypassRegistry().Get(v.Bypass), Resolver: registry.ResolverRegistry().Get(v.Resolver), Hosts: registry.HostsRegistry().Get(v.Hosts), Marker: &chain.FailMarker{}, + Transport: tr, } group.AddNode(node) } diff --git a/pkg/config/parsing/service.go b/pkg/config/parsing/service.go index feb1f43..b46d94d 100644 --- a/pkg/config/parsing/service.go +++ b/pkg/config/parsing/service.go @@ -88,14 +88,21 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { if cfg.Handler.Auther != "" { auther = registry.AutherRegistry().Get(cfg.Handler.Auther) } + + router := (&chain.Router{}). + WithRetries(cfg.Handler.Retries). + // WithTimeout(timeout time.Duration). + WithInterface(cfg.Interface). + WithChain(registry.ChainRegistry().Get(cfg.Handler.Chain)). + WithResolver(registry.ResolverRegistry().Get(cfg.Resolver)). + WithHosts(registry.HostsRegistry().Get(cfg.Hosts)). + WithLogger(handlerLogger) + h := registry.HandlerRegistry().Get(cfg.Handler.Type)( + handler.RouterOption(router), handler.AutherOption(auther), handler.AuthOption(parseAuth(cfg.Handler.Auth)), - handler.RetriesOption(cfg.Handler.Retries), - handler.ChainOption(registry.ChainRegistry().Get(cfg.Handler.Chain)), handler.BypassOption(registry.BypassRegistry().Get(cfg.Bypass)), - handler.ResolverOption(registry.ResolverRegistry().Get(cfg.Resolver)), - handler.HostsOption(registry.HostsRegistry().Get(cfg.Hosts)), handler.TLSConfigOption(tlsConfig), handler.LoggerOption(handlerLogger), ) diff --git a/pkg/dialer/grpc/dialer.go b/pkg/dialer/grpc/dialer.go index 6620154..f1871cd 100644 --- a/pkg/dialer/grpc/dialer.go +++ b/pkg/dialer/grpc/dialer.go @@ -75,7 +75,11 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO grpcOpts := []grpc.DialOption{ // grpc.WithBlock(), grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) { - return d.dial(ctx, "tcp", s, &options) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + return netd.Dial(c, "tcp", s) }), grpc.WithAuthority(host), grpc.WithConnectParams(grpc.ConnectParams{ @@ -111,31 +115,3 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO closed: make(chan struct{}), }, nil } - -func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { - dial := opts.DialFunc - if dial != nil { - conn, err := dial(ctx, addr) - if err != nil { - d.options.Logger.Error(err) - } else { - d.options.Logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } - return conn, err - } - - var netd net.Dialer - conn, err := netd.DialContext(ctx, network, addr) - if err != nil { - d.options.Logger.Error(err) - } else { - d.options.Logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debugf("dial direct %s/%s", addr, network) - } - return conn, err -} diff --git a/pkg/dialer/http2/dialer.go b/pkg/dialer/http2/dialer.go index 623feba..735012a 100644 --- a/pkg/dialer/http2/dialer.go +++ b/pkg/dialer/http2/dialer.go @@ -73,7 +73,11 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D Transport: &http.Transport{ TLSClientConfig: d.options.TLSConfig, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return d.dial(ctx, network, addr, &options) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + return netd.Dial(ctx, network, addr) }, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -94,31 +98,3 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D delete(d.clients, address) }), nil } - -func (d *http2Dialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { - dial := opts.DialFunc - if dial != nil { - conn, err := dial(ctx, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } - return conn, err - } - - var netd net.Dialer - conn, err := netd.DialContext(ctx, network, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debugf("dial direct %s/%s", addr, network) - } - return conn, err -} diff --git a/pkg/dialer/http2/h2/dialer.go b/pkg/dialer/http2/h2/dialer.go index 34b1b6b..c6ba8b8 100644 --- a/pkg/dialer/http2/h2/dialer.go +++ b/pkg/dialer/http2/h2/dialer.go @@ -93,14 +93,22 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial if d.h2c { client.Transport = &http2.Transport{ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - return d.dial(ctx, network, addr, options) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + return netd.Dial(ctx, network, addr) }, } } else { client.Transport = &http.Transport{ TLSClientConfig: d.options.TLSConfig, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return d.dial(ctx, network, addr, options) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + return netd.Dial(ctx, network, addr) }, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -163,31 +171,3 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial } return conn, nil } - -func (d *h2Dialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { - dial := opts.DialFunc - if dial != nil { - conn, err := dial(ctx, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } - return conn, err - } - - var netd net.Dialer - conn, err := netd.DialContext(ctx, network, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debugf("dial direct %s/%s", addr, network) - } - return conn, err -} diff --git a/pkg/dialer/kcp/dialer.go b/pkg/dialer/kcp/dialer.go index e44ccc2..1667ab5 100644 --- a/pkg/dialer/kcp/dialer.go +++ b/pkg/dialer/kcp/dialer.go @@ -56,11 +56,6 @@ func (d *kcpDialer) Multiplex() bool { } func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { - var options dialer.DialOptions - for _, opt := range opts { - opt(&options) - } - d.sessionMutex.Lock() defer d.sessionMutex.Unlock() @@ -70,12 +65,17 @@ func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp ok = false } if !ok { - raddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) } if d.md.config.TCP { + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + pc, err := tcpraw.Dial("tcp", addr) if err != nil { return nil, err @@ -85,7 +85,11 @@ func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp PacketConn: pc, } } else { - conn, err = net.ListenUDP("udp", nil) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err = netd.Dial(ctx, "udp", addr) if err != nil { return nil, err } diff --git a/pkg/dialer/net.go b/pkg/dialer/net.go new file mode 100644 index 0000000..482a8ea --- /dev/null +++ b/pkg/dialer/net.go @@ -0,0 +1,84 @@ +package dialer + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/go-gost/gost/pkg/logger" +) + +var ( + DefaultNetDialer = &NetDialer{ + Timeout: 30 * time.Second, + } +) + +type NetDialer struct { + Interface string + Timeout time.Duration + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +} + +func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + ifAddr, err := parseInterfaceAddr(d.Interface, network) + if err != nil { + return nil, err + } + if d.DialFunc != nil { + return d.DialFunc(ctx, network, addr) + } + logger.Default().Infof("interface: %s %s %v", d.Interface, network, ifAddr) + + switch network { + case "udp", "udp4", "udp6": + if addr == "" { + var laddr *net.UDPAddr + if ifAddr != nil { + laddr, _ = ifAddr.(*net.UDPAddr) + } + + return net.ListenUDP(network, laddr) + } + case "tcp", "tcp4", "tcp6": + default: + return nil, fmt.Errorf("dial: unsupported network %s", network) + } + netd := net.Dialer{ + Timeout: d.Timeout, + LocalAddr: ifAddr, + } + return netd.DialContext(ctx, network, addr) +} + +func parseInterfaceAddr(ifceName, network string) (net.Addr, error) { + if ifceName == "" { + return nil, nil + } + + ip := net.ParseIP(ifceName) + if ip == nil { + ifce, err := net.InterfaceByName(ifceName) + if err != nil { + return nil, err + } + addrs, err := ifce.Addrs() + if err != nil { + return nil, err + } + if len(addrs) == 0 { + return nil, fmt.Errorf("addr not found for interface %s", ifceName) + } + ip = addrs[0].(*net.IPNet).IP + } + + switch network { + case "tcp", "tcp4", "tcp6": + return &net.TCPAddr{IP: ip}, nil + case "udp", "udp4", "udp6": + return &net.UDPAddr{IP: ip}, nil + default: + return &net.IPAddr{IP: ip}, nil + } +} diff --git a/pkg/dialer/obfs/http/dialer.go b/pkg/dialer/obfs/http/dialer.go index e8ad601..d2f4550 100644 --- a/pkg/dialer/obfs/http/dialer.go +++ b/pkg/dialer/obfs/http/dialer.go @@ -35,8 +35,16 @@ func (d *obfsHTTPDialer) Init(md md.Metadata) (err error) { } func (d *obfsHTTPDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { - var netd net.Dialer - conn, err := netd.DialContext(ctx, "tcp", addr) + options := &dialer.DialOptions{} + for _, opt := range opts { + opt(options) + } + + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err := netd.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/obfs/tls/dialer.go b/pkg/dialer/obfs/tls/dialer.go index dd8b202..47c7207 100644 --- a/pkg/dialer/obfs/tls/dialer.go +++ b/pkg/dialer/obfs/tls/dialer.go @@ -35,8 +35,16 @@ func (d *obfsTLSDialer) Init(md md.Metadata) (err error) { } func (d *obfsTLSDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { - var netd net.Dialer - conn, err := netd.DialContext(ctx, "tcp", addr) + options := &dialer.DialOptions{} + for _, opt := range opts { + opt(options) + } + + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err := netd.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/option.go b/pkg/dialer/option.go index a0131d3..3cdb880 100644 --- a/pkg/dialer/option.go +++ b/pkg/dialer/option.go @@ -1,9 +1,7 @@ package dialer import ( - "context" "crypto/tls" - "net" "net/url" "github.com/go-gost/gost/pkg/logger" @@ -36,8 +34,8 @@ func LoggerOption(logger logger.Logger) Option { } type DialOptions struct { - Host string - DialFunc func(ctx context.Context, addr string) (net.Conn, error) + Host string + NetDialer *NetDialer } type DialOption func(opts *DialOptions) @@ -48,9 +46,9 @@ func HostDialOption(host string) DialOption { } } -func DialFuncDialOption(dialf func(ctx context.Context, addr string) (net.Conn, error)) DialOption { +func NetDialerDialOption(netd *NetDialer) DialOption { return func(opts *DialOptions) { - opts.DialFunc = dialf + opts.NetDialer = netd } } diff --git a/pkg/dialer/quic/dialer.go b/pkg/dialer/quic/dialer.go index 3d5005f..d3cce09 100644 --- a/pkg/dialer/quic/dialer.go +++ b/pkg/dialer/quic/dialer.go @@ -54,25 +54,27 @@ func (d *quicDialer) Multiplex() bool { } func (d *quicDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { - var options dialer.DialOptions - for _, opt := range opts { - opt(&options) - } - d.sessionMutex.Lock() defer d.sessionMutex.Unlock() session, ok := d.sessions[addr] if !ok { - var cc *net.UDPConn - cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return + options := &dialer.DialOptions{} + for _, opt := range opts { + opt(options) + } + + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err = netd.Dial(ctx, "udp", "") + if err != nil { + return nil, err } - conn = cc if d.md.cipherKey != nil { - conn = quic_util.CipherConn(cc, d.md.cipherKey) + conn = quic_util.CipherConn(conn.(*net.UDPConn), d.md.cipherKey) } session = &quicSession{conn: conn} diff --git a/pkg/dialer/ssh/dialer.go b/pkg/dialer/ssh/dialer.go index 4646310..30c2a3a 100644 --- a/pkg/dialer/ssh/dialer.go +++ b/pkg/dialer/ssh/dialer.go @@ -52,11 +52,6 @@ func (d *sshDialer) Multiplex() bool { } func (d *sshDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { - var options dialer.DialOptions - for _, opt := range opts { - opt(&options) - } - d.sessionMutex.Lock() defer d.sessionMutex.Unlock() @@ -66,7 +61,16 @@ func (d *sshDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp ok = false } if !ok { - conn, err = d.dial(ctx, "tcp", addr, &options) + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err = netd.Dial(ctx, "tcp", addr) if err != nil { return } @@ -134,34 +138,6 @@ func (d *sshDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia return ssh_util.NewConn(conn, channel), nil } -func (d *sshDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { - dial := opts.DialFunc - if dial != nil { - conn, err := dial(ctx, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } - return conn, err - } - - var netd net.Dialer - conn, err := netd.DialContext(ctx, network, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debugf("dial direct %s/%s", addr, network) - } - return conn, err -} - func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) { config := ssh.ClientConfig{ Timeout: 30 * time.Second, diff --git a/pkg/dialer/sshd/dialer.go b/pkg/dialer/sshd/dialer.go index fdef35b..c45d438 100644 --- a/pkg/dialer/sshd/dialer.go +++ b/pkg/dialer/sshd/dialer.go @@ -51,11 +51,6 @@ func (d *sshdDialer) Multiplex() bool { } func (d *sshdDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { - var options dialer.DialOptions - for _, opt := range opts { - opt(&options) - } - d.sessionMutex.Lock() defer d.sessionMutex.Unlock() @@ -65,7 +60,16 @@ func (d *sshdDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO ok = false } if !ok { - conn, err = d.dial(ctx, "tcp", addr, &options) + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err = netd.Dial(ctx, "tcp", addr) if err != nil { return } @@ -129,36 +133,6 @@ func (d *sshdDialer) Handshake(ctx context.Context, conn net.Conn, options ...di return ssh_util.NewClientConn(session.conn, session.client), nil } -func (d *sshdDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { - log := d.options.Logger - - dial := opts.DialFunc - if dial != nil { - conn, err := dial(ctx, addr) - if err != nil { - log.Error(err) - } else { - log.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } - return conn, err - } - - var netd net.Dialer - conn, err := netd.DialContext(ctx, network, addr) - if err != nil { - log.Error(err) - } else { - log.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debugf("dial direct %s/%s", addr, network) - } - return conn, err -} - func (d *sshdDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) { config := ssh.ClientConfig{ // Timeout: timeout, diff --git a/pkg/dialer/tcp/dialer.go b/pkg/dialer/tcp/dialer.go index 27cdf2d..28be9e9 100644 --- a/pkg/dialer/tcp/dialer.go +++ b/pkg/dialer/tcp/dialer.go @@ -40,8 +40,11 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } - var netd net.Dialer - conn, err := netd.DialContext(ctx, "tcp", addr) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err := netd.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/tls/dialer.go b/pkg/dialer/tls/dialer.go index 2a1d48b..0e63a2c 100644 --- a/pkg/dialer/tls/dialer.go +++ b/pkg/dialer/tls/dialer.go @@ -44,8 +44,11 @@ func (d *tlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } - var netd net.Dialer - conn, err := netd.DialContext(ctx, "tcp", addr) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err := netd.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/tls/mux/dialer.go b/pkg/dialer/tls/mux/dialer.go index 0bdb410..9879740 100644 --- a/pkg/dialer/tls/mux/dialer.go +++ b/pkg/dialer/tls/mux/dialer.go @@ -54,11 +54,6 @@ func (d *mtlsDialer) Multiplex() bool { } func (d *mtlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { - var options dialer.DialOptions - for _, opt := range opts { - opt(&options) - } - d.sessionMutex.Lock() defer d.sessionMutex.Unlock() @@ -68,7 +63,16 @@ func (d *mtlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO ok = false } if !ok { - conn, err = d.dial(ctx, "tcp", addr, &options) + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err = netd.Dial(ctx, "tcp", addr) if err != nil { return } @@ -122,34 +126,6 @@ func (d *mtlsDialer) Handshake(ctx context.Context, conn net.Conn, options ...di return cc, nil } -func (d *mtlsDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { - dial := opts.DialFunc - if dial != nil { - conn, err := dial(ctx, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } - return conn, err - } - - var netd net.Dialer - conn, err := netd.DialContext(ctx, network, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debugf("dial direct %s/%s", addr, network) - } - return conn, err -} - func (d *mtlsDialer) initSession(ctx context.Context, conn net.Conn) (*muxSession, error) { tlsConn := tls.Client(conn, d.options.TLSConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { diff --git a/pkg/dialer/udp/dialer.go b/pkg/dialer/udp/dialer.go index 0f9e264..4c84896 100644 --- a/pkg/dialer/udp/dialer.go +++ b/pkg/dialer/udp/dialer.go @@ -35,16 +35,20 @@ func (d *udpDialer) Init(md md.Metadata) (err error) { } func (d *udpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { - taddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) } - c, err := net.DialUDP("udp", nil, taddr) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + c, err := netd.Dial(ctx, "udp", addr) if err != nil { return nil, err } return &conn{ - UDPConn: c, + UDPConn: c.(*net.UDPConn), }, nil } diff --git a/pkg/dialer/ws/dialer.go b/pkg/dialer/ws/dialer.go index 7f77b74..ab6cce1 100644 --- a/pkg/dialer/ws/dialer.go +++ b/pkg/dialer/ws/dialer.go @@ -61,8 +61,11 @@ func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOpt opt(&options) } - var netd net.Dialer - conn, err := netd.DialContext(ctx, "tcp", addr) + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err := netd.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/ws/mux/dialer.go b/pkg/dialer/ws/mux/dialer.go index cc29cbb..e965e17 100644 --- a/pkg/dialer/ws/mux/dialer.go +++ b/pkg/dialer/ws/mux/dialer.go @@ -71,11 +71,6 @@ func (d *mwsDialer) Multiplex() bool { } func (d *mwsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { - var options dialer.DialOptions - for _, opt := range opts { - opt(&options) - } - d.sessionMutex.Lock() defer d.sessionMutex.Unlock() @@ -85,7 +80,16 @@ func (d *mwsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp ok = false } if !ok { - conn, err = d.dial(ctx, "tcp", addr, &options) + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + netd := options.NetDialer + if netd == nil { + netd = dialer.DefaultNetDialer + } + conn, err = netd.Dial(ctx, "tcp", addr) if err != nil { return } @@ -143,34 +147,6 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia return cc, nil } -func (d *mwsDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { - dial := opts.DialFunc - if dial != nil { - conn, err := dial(ctx, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debug("dial with dial func") - } - return conn, err - } - - var netd net.Dialer - conn, err := netd.DialContext(ctx, network, addr) - if err != nil { - d.logger.Error(err) - } else { - d.logger.WithFields(map[string]any{ - "src": conn.LocalAddr().String(), - "dst": addr, - }).Debugf("dial direct %s/%s", addr, network) - } - return conn, err -} - func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) { dialer := websocket.Dialer{ HandshakeTimeout: d.md.handshakeTimeout, diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go index 5856296..c474f9a 100644 --- a/pkg/handler/dns/handler.go +++ b/pkg/handler/dns/handler.go @@ -13,6 +13,7 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/hosts" resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -33,6 +34,7 @@ type dnsHandler struct { exchangers []exchanger.Exchanger cache *resolver_util.Cache router *chain.Router + hosts hosts.HostMapper md metadata options handler.Options } @@ -55,13 +57,12 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { log := h.options.Logger h.cache = resolver_util.NewCache().WithLogger(log) - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - // Hosts: h.options.Hosts, - Logger: log, + + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(log) } + h.hosts = h.router.Hosts() for _, server := range h.md.dns { server = strings.TrimSpace(server) @@ -218,7 +219,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger // lookup host mapper func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { - if h.options.Hosts == nil || + if h.hosts == nil || r.Question[0].Qclass != dns.ClassINET || (r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) { return nil @@ -231,7 +232,7 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { switch r.Question[0].Qtype { case dns.TypeA: - ips, _ := h.options.Hosts.Lookup("ip4", host) + ips, _ := h.hosts.Lookup("ip4", host) if len(ips) == 0 { return nil } @@ -247,7 +248,7 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { } case dns.TypeAAAA: - ips, _ := h.options.Hosts.Lookup("ip6", host) + ips, _ := h.hosts.Lookup("ip6", host) if len(ips) == 0 { return nil } diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index eded2aa..edacab9 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -46,12 +46,9 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { h.group = chain.NewNodeGroup(&chain.Node{Name: "dummy", Addr: ":0"}) } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go index cce2b95..fead5c8 100644 --- a/pkg/handler/forward/remote/handler.go +++ b/pkg/handler/forward/remote/handler.go @@ -40,12 +40,9 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { return } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index 9f37ddf..daf3528 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -49,12 +49,9 @@ func (h *httpHandler) Init(md md.Metadata) error { return err } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return nil diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index 7b8be15..c6bc572 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -52,13 +52,11 @@ func (h *http2Handler) Init(md md.Metadata) error { return err } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } + return nil } diff --git a/pkg/handler/option.go b/pkg/handler/option.go index cedce4c..6581c80 100644 --- a/pkg/handler/option.go +++ b/pkg/handler/option.go @@ -7,17 +7,12 @@ import ( "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" - "github.com/go-gost/gost/pkg/hosts" "github.com/go-gost/gost/pkg/logger" - "github.com/go-gost/gost/pkg/resolver" ) type Options struct { - Retries int - Chain chain.Chainer - Resolver resolver.Resolver - Hosts hosts.HostMapper Bypass bypass.Bypass + Router *chain.Router Auth *url.Userinfo Auther auth.Authenticator TLSConfig *tls.Config @@ -26,41 +21,24 @@ type Options struct { type Option func(opts *Options) -func RetriesOption(retries int) Option { - return func(opts *Options) { - opts.Retries = retries - } -} - -func ChainOption(chain chain.Chainer) Option { - return func(opts *Options) { - opts.Chain = chain - } -} - -func ResolverOption(resolver resolver.Resolver) Option { - return func(opts *Options) { - opts.Resolver = resolver - } -} - -func HostsOption(hosts hosts.HostMapper) Option { - return func(opts *Options) { - opts.Hosts = hosts - } -} - func BypassOption(bypass bypass.Bypass) Option { return func(opts *Options) { opts.Bypass = bypass } } +func RouterOption(router *chain.Router) Option { + return func(opts *Options) { + opts.Router = router + } +} + func AuthOption(auth *url.Userinfo) Option { return func(opts *Options) { opts.Auth = auth } } + func AutherOption(auther auth.Authenticator) Option { return func(opts *Options) { opts.Auther = auther diff --git a/pkg/handler/redirect/handler.go b/pkg/handler/redirect/handler.go index 1646abe..f4fd7b1 100644 --- a/pkg/handler/redirect/handler.go +++ b/pkg/handler/redirect/handler.go @@ -41,12 +41,9 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) { return } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 4013a17..96a75f9 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -40,13 +40,11 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { return err } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } + return nil } diff --git a/pkg/handler/sni/handler.go b/pkg/handler/sni/handler.go index ac67986..d40be0c 100644 --- a/pkg/handler/sni/handler.go +++ b/pkg/handler/sni/handler.go @@ -62,12 +62,9 @@ func (h *sniHandler) Init(md md.Metadata) (err error) { } } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return nil diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index 5305368..b370c26 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -40,12 +40,9 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) { return err } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return nil diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index d948bdc..05918c8 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -41,12 +41,9 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { return } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } h.selector = &serverSelector{ diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index a2c592e..5880f1e 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -51,12 +51,9 @@ func (h *ssHandler) Init(md md.Metadata) (err error) { } } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index eefe3b7..dcedacc 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -52,12 +52,9 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { } } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return diff --git a/pkg/handler/sshd/handler.go b/pkg/handler/sshd/handler.go index d8f0e06..2f72db7 100644 --- a/pkg/handler/sshd/handler.go +++ b/pkg/handler/sshd/handler.go @@ -48,12 +48,9 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { return } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return nil diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index e61a6a7..26ea1fd 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -63,12 +63,9 @@ func (h *tapHandler) Init(md md.Metadata) (err error) { } } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index ad0064b..3dac462 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -65,12 +65,9 @@ func (h *tunHandler) Init(md md.Metadata) (err error) { } } - h.router = &chain.Router{ - Retries: h.options.Retries, - Chain: h.options.Chain, - Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + h.router = h.options.Router + if h.router == nil { + h.router = (&chain.Router{}).WithLogger(h.options.Logger) } return diff --git a/pkg/listener/rtcp/listener.go b/pkg/listener/rtcp/listener.go index 81e6762..3def291 100644 --- a/pkg/listener/rtcp/listener.go +++ b/pkg/listener/rtcp/listener.go @@ -49,10 +49,9 @@ func (l *rtcpListener) Init(md md.Metadata) (err error) { } l.laddr = laddr - l.router = &chain.Router{ - Chain: l.options.Chain, - Logger: l.logger, - } + l.router = (&chain.Router{}). + WithChain(l.options.Chain). + WithLogger(l.logger) return } diff --git a/pkg/listener/rudp/listener.go b/pkg/listener/rudp/listener.go index 21cb24a..9f0f2a5 100644 --- a/pkg/listener/rudp/listener.go +++ b/pkg/listener/rudp/listener.go @@ -49,10 +49,9 @@ func (l *rudpListener) Init(md md.Metadata) (err error) { } l.laddr = laddr - l.router = &chain.Router{ - Chain: l.options.Chain, - Logger: l.logger, - } + l.router = (&chain.Router{}). + WithChain(l.options.Chain). + WithLogger(l.logger) return } diff --git a/pkg/resolver/exchanger/exchanger.go b/pkg/resolver/exchanger/exchanger.go index 0a5933b..f3121d4 100644 --- a/pkg/resolver/exchanger/exchanger.go +++ b/pkg/resolver/exchanger/exchanger.go @@ -102,9 +102,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) { ex.addr = net.JoinHostPort(ex.addr, "53") } if ex.router == nil { - ex.router = &chain.Router{ - Logger: options.logger, - } + ex.router = (&chain.Router{}).WithLogger(options.logger) } switch ex.network { diff --git a/pkg/resolver/impl/resolver.go b/pkg/resolver/impl/resolver.go index 7e8096c..23f4633 100644 --- a/pkg/resolver/impl/resolver.go +++ b/pkg/resolver/impl/resolver.go @@ -64,10 +64,11 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg. } ex, err := exchanger.NewExchanger( addr, - exchanger.RouterOption(&chain.Router{ - Chain: server.Chain, - Logger: options.logger, - }), + exchanger.RouterOption( + (&chain.Router{}). + WithChain(server.Chain). + WithLogger(options.logger), + ), exchanger.TimeoutOption(server.Timeout), exchanger.LoggerOption(options.logger), )