From 1a1c038fd71f1c8bb852acb6e3d68321733d6f9d Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 20 Sep 2022 11:48:51 +0800 Subject: [PATCH] update chain route --- chain/chain.go | 107 ++++++++++++++ chain/hop.go | 87 +++++++++++ chain/route.go | 231 ++++++++++++++++++++++++++++++ config/config.go | 13 +- config/parsing/chain.go | 51 +++---- config/parsing/service.go | 39 ++--- handler/dns/handler.go | 39 +++-- handler/forward/local/handler.go | 17 ++- handler/forward/remote/handler.go | 13 +- handler/http/handler.go | 2 +- handler/http2/handler.go | 2 +- handler/redirect/tcp/handler.go | 2 +- handler/redirect/udp/handler.go | 2 +- handler/relay/forward.go | 2 +- handler/relay/handler.go | 10 +- handler/sni/handler.go | 2 +- handler/socks/v4/handler.go | 2 +- handler/socks/v5/handler.go | 2 +- handler/ss/handler.go | 2 +- handler/ss/udp/handler.go | 2 +- handler/sshd/handler.go | 2 +- handler/tap/handler.go | 13 +- handler/tun/handler.go | 13 +- listener/rtcp/listener.go | 7 +- listener/rudp/listener.go | 7 +- resolver/exchanger/exchanger.go | 2 +- resolver/resolver.go | 7 +- 27 files changed, 565 insertions(+), 113 deletions(-) create mode 100644 chain/chain.go create mode 100644 chain/hop.go create mode 100644 chain/route.go diff --git a/chain/chain.go b/chain/chain.go new file mode 100644 index 0000000..b572197 --- /dev/null +++ b/chain/chain.go @@ -0,0 +1,107 @@ +package chain + +import ( + "context" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/metadata" + "github.com/go-gost/core/selector" +) + +var ( + _ chain.Chainer = (*chainGroup)(nil) +) + +type chainNamer interface { + Name() string +} + +type Chain struct { + name string + hops []chain.Hop + marker selector.Marker + metadata metadata.Metadata +} + +func NewChain(name string, hops ...chain.Hop) *Chain { + return &Chain{ + name: name, + hops: hops, + marker: selector.NewFailMarker(), + } +} + +func (c *Chain) AddHop(hop chain.Hop) { + c.hops = append(c.hops, hop) +} + +func (c *Chain) WithMetadata(md metadata.Metadata) { + c.metadata = md +} + +// Metadata implements metadata.Metadatable interface. +func (c *Chain) Metadata() metadata.Metadata { + return c.metadata +} + +// Marker implements selector.Markable interface. +func (c *Chain) Marker() selector.Marker { + return c.marker +} + +func (c *Chain) Name() string { + return c.name +} + +func (c *Chain) Route(ctx context.Context, network, address string) chain.Route { + if c == nil || len(c.hops) == 0 { + return nil + } + + rt := NewRoute(ChainRouteOption(c)) + for _, hop := range c.hops { + node := hop.Select(ctx, chain.AddrSelectOption(address)) + if node == nil { + return rt + } + if node.Options().Transport.Multiplex() { + tr := node.Options().Transport.Copy() + tr.Options().Route = rt + node = node.Copy() + node.Options().Transport = tr + rt = NewRoute() + } + + rt.addNode(node) + } + return rt +} + +type chainGroup struct { + chains []chain.Chainer + selector selector.Selector[chain.Chainer] +} + +func NewChainGroup(chains ...chain.Chainer) *chainGroup { + return &chainGroup{chains: chains} +} + +func (p *chainGroup) WithSelector(s selector.Selector[chain.Chainer]) *chainGroup { + p.selector = s + return p +} + +func (p *chainGroup) Route(ctx context.Context, network, address string) chain.Route { + if chain := p.next(ctx); chain != nil { + return chain.Route(ctx, network, address) + } + return nil +} + +func (p *chainGroup) next(ctx context.Context) chain.Chainer { + if p == nil || len(p.chains) == 0 { + return nil + } + + return p.selector.Select(ctx, p.chains...) +} diff --git a/chain/hop.go b/chain/hop.go new file mode 100644 index 0000000..72816c0 --- /dev/null +++ b/chain/hop.go @@ -0,0 +1,87 @@ +package chain + +import ( + "context" + + "github.com/go-gost/core/bypass" + "github.com/go-gost/core/chain" + "github.com/go-gost/core/selector" +) + +type HopOptions struct { + bypass bypass.Bypass + selector selector.Selector[*chain.Node] +} + +type HopOption func(*HopOptions) + +func BypassHopOption(bp bypass.Bypass) HopOption { + return func(o *HopOptions) { + o.bypass = bp + } +} + +func SelectorHopOption(s selector.Selector[*chain.Node]) HopOption { + return func(o *HopOptions) { + o.selector = s + } +} + +type chainHop struct { + nodes []*chain.Node + options HopOptions +} + +func NewChainHop(nodes []*chain.Node, opts ...HopOption) chain.Hop { + var options HopOptions + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + + return &chainHop{ + nodes: nodes, + options: options, + } +} + +func (p *chainHop) Nodes() []*chain.Node { + return p.nodes +} + +func (p *chainHop) Select(ctx context.Context, opts ...chain.SelectOption) *chain.Node { + var options chain.SelectOptions + for _, opt := range opts { + opt(&options) + } + + if p == nil || len(p.nodes) == 0 { + return nil + } + + // hop level bypass + if p.options.bypass != nil && p.options.bypass.Contains(options.Addr) { + return nil + } + + var nodes []*chain.Node + for _, node := range p.nodes { + if node == nil { + continue + } + // node level bypass + if node.Options().Bypass != nil && node.Options().Bypass.Contains(options.Addr) { + continue + } + nodes = append(nodes, node) + } + if len(nodes) == 0 { + return nil + } + + if s := p.options.selector; s != nil { + return s.Select(ctx, nodes...) + } + return nodes[0] +} diff --git a/chain/route.go b/chain/route.go new file mode 100644 index 0000000..9d1f117 --- /dev/null +++ b/chain/route.go @@ -0,0 +1,231 @@ +package chain + +import ( + "context" + "net" + "time" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/connector" + "github.com/go-gost/core/logger" + "github.com/go-gost/core/metrics" + "github.com/go-gost/core/selector" +) + +type RouteOptions struct { + Chain chain.Chainer +} + +type RouteOption func(*RouteOptions) + +func ChainRouteOption(c chain.Chainer) RouteOption { + return func(o *RouteOptions) { + o.Chain = c + } +} + +type route struct { + nodes []*chain.Node + options RouteOptions +} + +func NewRoute(opts ...RouteOption) *route { + var options RouteOptions + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + + return &route{ + options: options, + } +} + +func (r *route) addNode(nodes ...*chain.Node) { + r.nodes = append(r.nodes, nodes...) +} + +func (r *route) Dial(ctx context.Context, network, address string, opts ...chain.DialOption) (net.Conn, error) { + if len(r.Nodes()) == 0 { + return chain.DefaultRoute.Dial(ctx, network, address, opts...) + } + + var options chain.DialOptions + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + conn, err := r.connect(ctx, options.Logger) + if err != nil { + return nil, err + } + + cc, err := r.getNode(len(r.Nodes())-1).Options().Transport.Connect(ctx, conn, network, address) + if err != nil { + if conn != nil { + conn.Close() + } + return nil, err + } + return cc, nil +} + +func (r *route) Bind(ctx context.Context, network, address string, opts ...chain.BindOption) (net.Listener, error) { + if len(r.Nodes()) == 0 { + return chain.DefaultRoute.Bind(ctx, network, address, opts...) + } + + var options chain.BindOptions + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + + conn, err := r.connect(ctx, options.Logger) + if err != nil { + return nil, err + } + + ln, err := r.getNode(len(r.Nodes())-1).Options().Transport.Bind(ctx, + conn, network, address, + connector.BacklogBindOption(options.Backlog), + connector.MuxBindOption(options.Mux), + connector.UDPConnTTLBindOption(options.UDPConnTTL), + connector.UDPDataBufferSizeBindOption(options.UDPDataBufferSize), + connector.UDPDataQueueSizeBindOption(options.UDPDataQueueSize), + ) + if err != nil { + conn.Close() + return nil, err + } + + return ln, nil +} + +func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Conn, err error) { + network := "ip" + node := r.nodes[0] + + defer func() { + if r.options.Chain != nil { + var marker selector.Marker + if m, ok := r.options.Chain.(selector.Markable); ok && m != nil { + marker = m.Marker() + } + var name string + if cn, _ := r.options.Chain.(chainNamer); cn != nil { + name = cn.Name() + } + // chain error + if err != nil { + if marker != nil { + marker.Mark() + } + if v := metrics.GetCounter(metrics.MetricChainErrorsCounter, + metrics.Labels{"chain": name, "node": node.Name}); v != nil { + v.Inc() + } + } else { + if marker != nil { + marker.Reset() + } + } + } + }() + + addr, err := chain.Resolve(ctx, network, node.Addr, node.Options().Resolver, node.Options().HostMapper, logger) + marker := node.Marker() + if err != nil { + if marker != nil { + marker.Mark() + } + return + } + + start := time.Now() + cc, err := node.Options().Transport.Dial(ctx, addr) + if err != nil { + if marker != nil { + marker.Mark() + } + return + } + + cn, err := node.Options().Transport.Handshake(ctx, cc) + if err != nil { + cc.Close() + if marker != nil { + marker.Mark() + } + return + } + if marker != nil { + marker.Reset() + } + + if r.options.Chain != nil { + var name string + if cn, _ := r.options.Chain.(chainNamer); cn != nil { + name = cn.Name() + } + if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver, + metrics.Labels{"chain": name, "node": node.Name}); v != nil { + v.Observe(time.Since(start).Seconds()) + } + } + + preNode := node + for _, node := range r.nodes[1:] { + marker := node.Marker() + addr, err = chain.Resolve(ctx, network, node.Addr, node.Options().Resolver, node.Options().HostMapper, logger) + if err != nil { + cn.Close() + if marker != nil { + marker.Mark() + } + return + } + cc, err = preNode.Options().Transport.Connect(ctx, cn, "tcp", addr) + if err != nil { + cn.Close() + if marker != nil { + marker.Mark() + } + return + } + cc, err = node.Options().Transport.Handshake(ctx, cc) + if err != nil { + cn.Close() + if marker != nil { + marker.Mark() + } + return + } + if marker != nil { + marker.Reset() + } + + cn = cc + preNode = node + } + + conn = cn + return +} + +func (r *route) getNode(index int) *chain.Node { + if r == nil || len(r.Nodes()) == 0 || index < 0 || index >= len(r.Nodes()) { + return nil + } + return r.nodes[index] +} + +func (r *route) Nodes() []*chain.Node { + if r != nil { + return r.nodes + } + return nil +} diff --git a/config/config.go b/config/config.go index 451d2b2..a63c353 100644 --- a/config/config.go +++ b/config/config.go @@ -217,9 +217,16 @@ type HandlerConfig struct { type ForwarderConfig struct { // DEPRECATED by nodes since beta.4 - Targets []string `yaml:",omitempty" json:"targets,omitempty"` - Nodes []*NodeConfig `json:"nodes"` - Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` + Targets []string `yaml:",omitempty" json:"targets,omitempty"` + Nodes []*ForwardNodeConfig `json:"nodes"` + Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` +} + +type ForwardNodeConfig struct { + Name string `yaml:",omitempty" json:"name,omitempty"` + Addr string `yaml:",omitempty" json:"addr,omitempty"` + Bypass string `yaml:",omitempty" json:"bypass,omitempty"` + Bypasses []string `yaml:",omitempty" json:"bypasses,omitempty"` } type DialerConfig struct { diff --git a/config/parsing/chain.go b/config/parsing/chain.go index 240735f..52e34c0 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -11,6 +11,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + xchain "github.com/go-gost/x/chain" "github.com/go-gost/x/config" tls_util "github.com/go-gost/x/internal/util/tls" mdx "github.com/go-gost/x/metadata" @@ -27,14 +28,14 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { "chain": cfg.Name, }) - c := chain.NewChain(cfg.Name) + c := xchain.NewChain(cfg.Name) if cfg.Metadata != nil { c.WithMetadata(mdx.NewMetadata(cfg.Metadata)) } - selector := parseNodeSelector(cfg.Selector) + sel := parseNodeSelector(cfg.Selector) for _, hop := range cfg.Hops { - group := &chain.NodeGroup{} + var nodes []*chain.Node for _, v := range hop.Nodes { nodeLogger := chainLogger.WithFields(map[string]any{ "kind": "node", @@ -144,35 +145,35 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { } } - tr := (&chain.Transport{}). - WithConnector(cr). - WithDialer(d). - WithAddr(v.Addr). - WithInterface(v.Interface). - WithSockOpts(sockOpts). - WithTimeout(10 * time.Second) + tr := chain.NewTransport(d, cr, + chain.AddrTransportOption(v.Addr), + chain.InterfaceTransportOption(v.Interface), + chain.SockOptsTransportOption(sockOpts), + chain.TimeoutTransportOption(10*time.Second), + ) - node := chain.NewNode(v.Name, v.Addr). - WithTransport(tr). - WithBypass(bypass.BypassGroup(bypassList(v.Bypass, v.Bypasses...)...)). - WithResolver(registry.ResolverRegistry().Get(v.Resolver)). - WithHostMapper(registry.HostsRegistry().Get(v.Hosts)). - WithMetadata(nm) - - group.AddNode(node) + node := chain.NewNode(v.Name, v.Addr, + chain.TransportNodeOption(tr), + chain.BypassNodeOption(bypass.BypassGroup(bypassList(v.Bypass, v.Bypasses...)...)), + chain.ResoloverNodeOption(registry.ResolverRegistry().Get(v.Resolver)), + chain.HostMapperNodeOption(registry.HostsRegistry().Get(v.Hosts)), + chain.MetadataNodeOption(nm), + ) + nodes = append(nodes, node) } - sel := selector + sl := sel if s := parseNodeSelector(hop.Selector); s != nil { - sel = s + sl = s } - if sel == nil { - sel = defaultNodeSelector() + if sl == nil { + sl = defaultNodeSelector() } - group.WithSelector(sel). - WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...)) - c.AddNodeGroup(group) + c.AddHop(xchain.NewChainHop(nodes, + xchain.SelectorHopOption(sl), + xchain.BypassHopOption(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...))), + ) } return c, nil diff --git a/config/parsing/service.go b/config/parsing/service.go index 3eb7b4f..2c0e46d 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -15,6 +15,7 @@ import ( "github.com/go-gost/core/recorder" "github.com/go-gost/core/selector" "github.com/go-gost/core/service" + xchain "github.com/go-gost/x/chain" "github.com/go-gost/x/config" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" @@ -156,16 +157,17 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { Record: r.Record, }) } - router := (&chain.Router{}). - WithRetries(cfg.Handler.Retries). - // WithTimeout(timeout time.Duration). - WithInterface(ifce). - WithSockOpts(sockOpts). - WithChain(chainGroup(cfg.Handler.Chain, cfg.Handler.ChainGroup)). - WithResolver(registry.ResolverRegistry().Get(cfg.Resolver)). - WithHosts(registry.HostsRegistry().Get(cfg.Hosts)). - WithRecorder(recorders...). - WithLogger(handlerLogger) + router := chain.NewRouter( + chain.RetriesRouterOption(cfg.Handler.Retries), + // chain.TimeoutRouterOption(10*time.Second), + chain.InterfaceRouterOption(ifce), + chain.SockOptsRouterOption(sockOpts), + chain.ChainRouterOption(chainGroup(cfg.Handler.Chain, cfg.Handler.ChainGroup)), + chain.ResolverRouterOption(registry.ResolverRegistry().Get(cfg.Resolver)), + chain.HostMapperRouterOption(registry.HostsRegistry().Get(cfg.Hosts)), + chain.RecordersRouterOption(recorders...), + chain.LoggerRouterOption(handlerLogger), + ) var h handler.Handler if rf := registry.HandlerRegistry().Get(cfg.Handler.Type); rf != nil { @@ -203,24 +205,27 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { return s, nil } -func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { +func parseForwarder(cfg *config.ForwarderConfig) chain.Hop { if cfg == nil || (len(cfg.Targets) == 0 && len(cfg.Nodes) == 0) { return nil } - group := &chain.NodeGroup{} + var nodes []*chain.Node if len(cfg.Nodes) > 0 { for _, node := range cfg.Nodes { if node != nil { - group.AddNode(chain.NewNode(node.Name, node.Addr). - WithBypass(bypass.BypassGroup(bypassList(node.Bypass, node.Bypasses...)...))) + nodes = append(nodes, + chain.NewNode(node.Name, node.Addr, + chain.BypassNodeOption(bypass.BypassGroup(bypassList(node.Bypass, node.Bypasses...)...)), + ), + ) } } } else { for _, target := range cfg.Targets { if v := strings.TrimSpace(target); v != "" { - group.AddNode(chain.NewNode(target, target)) + nodes = append(nodes, chain.NewNode(target, target)) } } } @@ -229,7 +234,7 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { if sel == nil { sel = defaultNodeSelector() } - return group.WithSelector(sel) + return xchain.NewChainHop(nodes, xchain.SelectorHopOption(sel)) } func bypassList(name string, names ...string) []bypass.Bypass { @@ -292,6 +297,6 @@ func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { sel = defaultChainSelector() } - return chain.NewChainGroup(chains...). + return xchain.NewChainGroup(chains...). WithSelector(sel) } diff --git a/handler/dns/handler.go b/handler/dns/handler.go index d227e4c..d223355 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -14,6 +14,7 @@ import ( "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + xchain "github.com/go-gost/x/chain" resolver_util "github.com/go-gost/x/internal/util/resolver" "github.com/go-gost/x/registry" "github.com/go-gost/x/resolver/exchanger" @@ -29,11 +30,11 @@ func init() { } type dnsHandler struct { - group *chain.NodeGroup + hop chain.Hop exchangers map[string]exchanger.Exchanger cache *resolver_util.Cache router *chain.Router - hosts hosts.HostMapper + hostMapper hosts.HostMapper md metadata options handler.Options } @@ -60,21 +61,19 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(log) + h.router = chain.NewRouter(chain.LoggerRouterOption(log)) } - h.hosts = h.router.Hosts() + h.hostMapper = h.router.Options().HostMapper - if h.group == nil { - h.group = &chain.NodeGroup{} + if h.hop == nil { + var nodes []*chain.Node for i, addr := range h.md.dns { - addr = strings.TrimSpace(addr) - if addr == "" { - continue - } - h.group.AddNode(chain.NewNode(fmt.Sprintf("target-%d", i), addr)) + nodes = append(nodes, chain.NewNode(fmt.Sprintf("target-%d", i), addr)) } + h.hop = xchain.NewChainHop(nodes) } - for _, node := range h.group.Nodes() { + + for _, node := range h.hop.Nodes() { addr := strings.TrimSpace(node.Addr) if addr == "" { continue @@ -99,7 +98,7 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { exchanger.TimeoutOption(h.md.timeout), exchanger.LoggerOption(log), ) - log.Warnf("resolver not found, default to %s", defaultNameserver) + log.Warnf("resolver not found, use default %s", defaultNameserver) if err != nil { return err } @@ -110,8 +109,8 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *dnsHandler) Forward(group *chain.NodeGroup) { - h.group = group +func (h *dnsHandler) Forward(hop chain.Hop) { + h.hop = hop } func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { @@ -261,7 +260,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger // lookup host mapper func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { - if h.hosts == nil || + if h.hostMapper == nil || r.Question[0].Qclass != dns.ClassINET || (r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) { return nil @@ -274,7 +273,7 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { switch r.Question[0].Qtype { case dns.TypeA: - ips, _ := h.hosts.Lookup("ip4", host) + ips, _ := h.hostMapper.Lookup("ip4", host) if len(ips) == 0 { return nil } @@ -290,7 +289,7 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { } case dns.TypeAAAA: - ips, _ := h.hosts.Lookup("ip6", host) + ips, _ := h.hostMapper.Lookup("ip6", host) if len(ips) == 0 { return nil } @@ -310,10 +309,10 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { } func (h *dnsHandler) selectExchanger(ctx context.Context, addr string) exchanger.Exchanger { - if h.group == nil { + if h.hop == nil { return nil } - node := h.group.FilterAddr(addr).Next(ctx) + node := h.hop.Select(ctx, chain.AddrSelectOption(addr)) if node == nil { return nil } diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 1c90b55..42f72dc 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" + xchain "github.com/go-gost/x/chain" netpkg "github.com/go-gost/x/internal/net" "github.com/go-gost/x/registry" ) @@ -21,7 +22,7 @@ func init() { } type forwardHandler struct { - group *chain.NodeGroup + hop chain.Hop router *chain.Router md metadata options handler.Options @@ -43,22 +44,24 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { return } - if h.group == nil { + if h.hop == nil { // dummy node used by relay connector. - h.group = chain.NewNodeGroup(&chain.Node{Name: "dummy", Addr: ":0"}) + h.hop = xchain.NewChainHop([]*chain.Node{ + {Name: "dummy", Addr: ":0"}, + }) } h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return } // Forward implements handler.Forwarder. -func (h *forwardHandler) Forward(group *chain.NodeGroup) { - h.group = group +func (h *forwardHandler) Forward(hop chain.Hop) { + h.hop = hop } func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { @@ -81,7 +84,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } - target := h.group.Next(ctx) + target := h.hop.Select(ctx) if target == nil { err := errors.New("target not available") log.Error(err) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 449b73a..ec082eb 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -20,7 +20,7 @@ func init() { } type forwardHandler struct { - group *chain.NodeGroup + hop chain.Hop router *chain.Router md metadata options handler.Options @@ -44,15 +44,15 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return } // Forward implements handler.Forwarder. -func (h *forwardHandler) Forward(group *chain.NodeGroup) { - h.group = group +func (h *forwardHandler) Forward(hop chain.Hop) { + h.hop = hop } func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { @@ -75,7 +75,10 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } - target := h.group.Next(ctx) + var target *chain.Node + if h.hop != nil { + target = h.hop.Select(ctx) + } if target == nil { err := errors.New("target not available") log.Error(err) diff --git a/handler/http/handler.go b/handler/http/handler.go index 9241c82..64fc9a3 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -52,7 +52,7 @@ func (h *httpHandler) Init(md md.Metadata) error { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return nil diff --git a/handler/http2/handler.go b/handler/http2/handler.go index e6f356c..6babdd4 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -54,7 +54,7 @@ func (h *http2Handler) Init(md md.Metadata) error { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return nil diff --git a/handler/redirect/tcp/handler.go b/handler/redirect/tcp/handler.go index 55caa62..bad7049 100644 --- a/handler/redirect/tcp/handler.go +++ b/handler/redirect/tcp/handler.go @@ -52,7 +52,7 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return diff --git a/handler/redirect/udp/handler.go b/handler/redirect/udp/handler.go index 0d13f6a..28e9911 100644 --- a/handler/redirect/udp/handler.go +++ b/handler/redirect/udp/handler.go @@ -41,7 +41,7 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return diff --git a/handler/relay/forward.go b/handler/relay/forward.go index 4221a0d..62b1a24 100644 --- a/handler/relay/forward.go +++ b/handler/relay/forward.go @@ -17,7 +17,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network Version: relay.Version1, Status: relay.StatusOK, } - target := h.group.Next(ctx) + target := h.hop.Select(ctx) if target == nil { resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) diff --git a/handler/relay/handler.go b/handler/relay/handler.go index e0c7298..5ef92af 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -24,7 +24,7 @@ func init() { } type relayHandler struct { - group *chain.NodeGroup + hop chain.Hop router *chain.Router md metadata options handler.Options @@ -48,15 +48,15 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return nil } // Forward implements handler.Forwarder. -func (h *relayHandler) Forward(group *chain.NodeGroup) { - h.group = group +func (h *relayHandler) Forward(hop chain.Hop) { + h.hop = hop } func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { @@ -130,7 +130,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle network = "udp" } - if h.group != nil { + if h.hop != nil { if address != "" { resp.Status = relay.StatusForbidden log.Error("forward mode, connect is forbidden") diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 01bd14d..74b91ed 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -54,7 +54,7 @@ func (h *sniHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return nil diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index e7e60b3..b7b6050 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -49,7 +49,7 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return nil diff --git a/handler/socks/v5/handler.go b/handler/socks/v5/handler.go index 65266d3..a6597c8 100644 --- a/handler/socks/v5/handler.go +++ b/handler/socks/v5/handler.go @@ -48,7 +48,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } h.selector = &serverSelector{ diff --git a/handler/ss/handler.go b/handler/ss/handler.go index 2825c04..fd206bc 100644 --- a/handler/ss/handler.go +++ b/handler/ss/handler.go @@ -54,7 +54,7 @@ func (h *ssHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return diff --git a/handler/ss/udp/handler.go b/handler/ss/udp/handler.go index 24d451b..f043dea 100644 --- a/handler/ss/udp/handler.go +++ b/handler/ss/udp/handler.go @@ -55,7 +55,7 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return diff --git a/handler/sshd/handler.go b/handler/sshd/handler.go index 813699c..49ecf31 100644 --- a/handler/sshd/handler.go +++ b/handler/sshd/handler.go @@ -52,7 +52,7 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return nil diff --git a/handler/tap/handler.go b/handler/tap/handler.go index ef3234c..ea1dac3 100644 --- a/handler/tap/handler.go +++ b/handler/tap/handler.go @@ -28,7 +28,7 @@ func init() { } type tapHandler struct { - group *chain.NodeGroup + hop chain.Hop routes sync.Map exit chan struct{} cipher core.Cipher @@ -65,15 +65,15 @@ func (h *tapHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return } // Forward implements handler.Forwarder. -func (h *tapHandler) Forward(group *chain.NodeGroup) { - h.group = group +func (h *tapHandler) Forward(hop chain.Hop) { + h.hop = hop } func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { @@ -105,7 +105,10 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. var raddr net.Addr var err error - target := h.group.Next(ctx) + var target *chain.Node + if h.hop != nil { + target = h.hop.Select(ctx) + } if target != nil { raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { diff --git a/handler/tun/handler.go b/handler/tun/handler.go index 794cf82..77e971f 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -21,7 +21,7 @@ func init() { } type tunHandler struct { - group *chain.NodeGroup + hop chain.Hop routes sync.Map router *chain.Router md metadata @@ -46,15 +46,15 @@ func (h *tunHandler) Init(md md.Metadata) (err error) { h.router = h.options.Router if h.router == nil { - h.router = (&chain.Router{}).WithLogger(h.options.Logger) + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } return } // Forward implements handler.Forwarder. -func (h *tunHandler) Forward(group *chain.NodeGroup) { - h.group = group +func (h *tunHandler) Forward(hop chain.Hop) { + h.hop = hop } func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { @@ -87,7 +87,10 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. var raddr net.Addr var err error - target := h.group.Next(ctx) + var target *chain.Node + if h.hop != nil { + target = h.hop.Select(ctx) + } if target != nil { raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { diff --git a/listener/rtcp/listener.go b/listener/rtcp/listener.go index e614f7b..ef1d187 100644 --- a/listener/rtcp/listener.go +++ b/listener/rtcp/listener.go @@ -56,9 +56,10 @@ func (l *rtcpListener) Init(md md.Metadata) (err error) { } l.laddr = laddr - l.router = (&chain.Router{}). - WithChain(l.options.Chain). - WithLogger(l.logger) + l.router = chain.NewRouter( + chain.ChainRouterOption(l.options.Chain), + chain.LoggerRouterOption(l.logger), + ) return } diff --git a/listener/rudp/listener.go b/listener/rudp/listener.go index 143b90a..e1943d8 100644 --- a/listener/rudp/listener.go +++ b/listener/rudp/listener.go @@ -56,9 +56,10 @@ func (l *rudpListener) Init(md md.Metadata) (err error) { } l.laddr = laddr - l.router = (&chain.Router{}). - WithChain(l.options.Chain). - WithLogger(l.logger) + l.router = chain.NewRouter( + chain.ChainRouterOption(l.options.Chain), + chain.LoggerRouterOption(l.logger), + ) return } diff --git a/resolver/exchanger/exchanger.go b/resolver/exchanger/exchanger.go index b109f95..89ad775 100644 --- a/resolver/exchanger/exchanger.go +++ b/resolver/exchanger/exchanger.go @@ -102,7 +102,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) { ex.addr = net.JoinHostPort(ex.addr, "53") } if ex.router == nil { - ex.router = (&chain.Router{}).WithLogger(options.logger) + ex.router = chain.NewRouter(chain.LoggerRouterOption(options.logger)) } switch ex.network { diff --git a/resolver/resolver.go b/resolver/resolver.go index d317026..536fa58 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -65,9 +65,10 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg. ex, err := exchanger.NewExchanger( addr, exchanger.RouterOption( - (&chain.Router{}). - WithChain(server.Chain). - WithLogger(options.logger), + chain.NewRouter( + chain.ChainRouterOption(server.Chain), + chain.LoggerRouterOption(options.logger), + ), ), exchanger.TimeoutOption(server.Timeout), exchanger.LoggerOption(options.logger),