From 41ff9835a66d71b49efcd2bc1d9a9e5ac39a07d9 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 20 Sep 2022 11:48:30 +0800 Subject: [PATCH] add chain hop --- chain/chain.go | 95 ------------------ chain/hop.go | 20 ++++ chain/node.go | 145 +++++++++++----------------- chain/resovle.go | 2 +- chain/route.go | 235 +++++---------------------------------------- chain/router.go | 201 ++++++++++++++++++++++---------------- chain/transport.go | 120 +++++++++++++---------- handler/handler.go | 2 +- 8 files changed, 293 insertions(+), 527 deletions(-) create mode 100644 chain/hop.go diff --git a/chain/chain.go b/chain/chain.go index 1013a84..68b88d2 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -2,103 +2,8 @@ package chain import ( "context" - - "github.com/go-gost/core/metadata" - "github.com/go-gost/core/selector" ) type Chainer interface { Route(ctx context.Context, network, address string) Route } - -type Chain struct { - name string - groups []*NodeGroup - marker selector.Marker - metadata metadata.Metadata -} - -func NewChain(name string, groups ...*NodeGroup) *Chain { - return &Chain{ - name: name, - groups: groups, - marker: selector.NewFailMarker(), - } -} - -func (c *Chain) AddNodeGroup(group *NodeGroup) { - c.groups = append(c.groups, group) -} - -func (c *Chain) WithMetadata(md metadata.Metadata) { - c.metadata = md -} - -// Metadata implements metadata.Metadatable interface. -func (c *Chain) Metadata() metadata.Metadata { - return c.metadata -} - -// Marker implements selector.Markable interface. -func (c *Chain) Marker() selector.Marker { - return c.marker -} - -func (c *Chain) Route(ctx context.Context, network, address string) Route { - if c == nil || len(c.groups) == 0 { - return nil - } - - rt := newRoute().WithChain(c) - for _, group := range c.groups { - // hop level bypass test - if group.bypass != nil && group.bypass.Contains(address) { - break - } - - node := group.FilterAddr(address).Next(ctx) - if node == nil { - return rt - } - if node.transport.Multiplex() { - tr := node.transport. - Copy(). - WithRoute(rt) - node = node.Copy() - node.transport = tr - rt = newRoute() - } - - rt.addNode(node) - } - return rt -} - -type ChainGroup struct { - chains []Chainer - selector selector.Selector[Chainer] -} - -func NewChainGroup(chains ...Chainer) *ChainGroup { - return &ChainGroup{chains: chains} -} - -func (p *ChainGroup) WithSelector(s selector.Selector[Chainer]) *ChainGroup { - p.selector = s - return p -} - -func (p *ChainGroup) Route(ctx context.Context, network, address string) Route { - if chain := p.next(ctx); chain != nil { - return chain.Route(ctx, network, address) - } - return nil -} - -func (p *ChainGroup) next(ctx context.Context) Chainer { - if p == nil || len(p.chains) == 0 { - return nil - } - - return p.selector.Select(ctx, p.chains...) -} diff --git a/chain/hop.go b/chain/hop.go new file mode 100644 index 0000000..dfd58f9 --- /dev/null +++ b/chain/hop.go @@ -0,0 +1,20 @@ +package chain + +import "context" + +type SelectOptions struct { + Addr string +} + +type SelectOption func(*SelectOptions) + +func AddrSelectOption(addr string) SelectOption { + return func(o *SelectOptions) { + o.Addr = addr + } +} + +type Hop interface { + Nodes() []*Node + Select(ctx context.Context, opts ...SelectOption) *Node +} diff --git a/chain/node.go b/chain/node.go index 9e136cc..a5ebcb8 100644 --- a/chain/node.go +++ b/chain/node.go @@ -1,8 +1,6 @@ package chain import ( - "context" - "github.com/go-gost/core/bypass" "github.com/go-gost/core/hosts" "github.com/go-gost/core/metadata" @@ -10,48 +8,76 @@ import ( "github.com/go-gost/core/selector" ) -type Node struct { - Name string - Addr string - transport *Transport - bypass bypass.Bypass - resolver resolver.Resolver - hostMapper hosts.HostMapper - marker selector.Marker - metadata metadata.Metadata +type NodeOptions struct { + Transport *Transport + Bypass bypass.Bypass + Resolver resolver.Resolver + HostMapper hosts.HostMapper + Metadata metadata.Metadata } -func NewNode(name, addr string) *Node { - return &Node{ - Name: name, - Addr: addr, - marker: selector.NewFailMarker(), +type NodeOption func(*NodeOptions) + +func TransportNodeOption(tr *Transport) NodeOption { + return func(o *NodeOptions) { + o.Transport = tr } } -func (node *Node) WithTransport(tr *Transport) *Node { - node.transport = tr - return node +func BypassNodeOption(bp bypass.Bypass) NodeOption { + return func(o *NodeOptions) { + o.Bypass = bp + } } -func (node *Node) WithBypass(bypass bypass.Bypass) *Node { - node.bypass = bypass - return node +func ResoloverNodeOption(resolver resolver.Resolver) NodeOption { + return func(o *NodeOptions) { + o.Resolver = resolver + } } -func (node *Node) WithResolver(reslv resolver.Resolver) *Node { - node.resolver = reslv - return node +func HostMapperNodeOption(m hosts.HostMapper) NodeOption { + return func(o *NodeOptions) { + o.HostMapper = m + } } -func (node *Node) WithHostMapper(m hosts.HostMapper) *Node { - node.hostMapper = m - return node +func MetadataNodeOption(md metadata.Metadata) NodeOption { + return func(o *NodeOptions) { + o.Metadata = md + } } -func (node *Node) WithMetadata(md metadata.Metadata) *Node { - node.metadata = md - return node +type Node struct { + Name string + Addr string + marker selector.Marker + options NodeOptions +} + +func NewNode(name string, addr string, opts ...NodeOption) *Node { + var options NodeOptions + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + + return &Node{ + Name: name, + Addr: addr, + marker: selector.NewFailMarker(), + options: options, + } +} + +func (node *Node) Options() *NodeOptions { + return &node.options +} + +// Metadata implements metadadta.Metadatable interface. +func (node *Node) Metadata() metadata.Metadata { + return node.options.Metadata } // Marker implements selector.Markable interface. @@ -59,65 +85,8 @@ func (node *Node) Marker() selector.Marker { return node.marker } -// Metadata implements metadadta.Metadatable interface. -func (node *Node) Metadata() metadata.Metadata { - return node.metadata -} - func (node *Node) Copy() *Node { n := &Node{} *n = *node return n } - -type NodeGroup struct { - nodes []*Node - selector selector.Selector[*Node] - bypass bypass.Bypass -} - -func NewNodeGroup(nodes ...*Node) *NodeGroup { - return &NodeGroup{ - nodes: nodes, - } -} - -func (g *NodeGroup) AddNode(node *Node) { - g.nodes = append(g.nodes, node) -} - -func (g *NodeGroup) Nodes() []*Node { - return g.nodes -} - -func (g *NodeGroup) WithSelector(selector selector.Selector[*Node]) *NodeGroup { - g.selector = selector - return g -} - -func (g *NodeGroup) WithBypass(bypass bypass.Bypass) *NodeGroup { - g.bypass = bypass - return g -} - -func (g *NodeGroup) FilterAddr(addr string) *NodeGroup { - var nodes []*Node - for _, node := range g.nodes { - if node.bypass == nil || !node.bypass.Contains(addr) { - nodes = append(nodes, node) - } - } - return &NodeGroup{ - nodes: nodes, - selector: g.selector, - bypass: g.bypass, - } -} - -func (g *NodeGroup) Next(ctx context.Context) *Node { - if g == nil || len(g.nodes) == 0 { - return nil - } - - return g.selector.Select(ctx, g.nodes...) -} diff --git a/chain/resovle.go b/chain/resovle.go index 60b8fd8..2d23e2f 100644 --- a/chain/resovle.go +++ b/chain/resovle.go @@ -10,7 +10,7 @@ import ( "github.com/go-gost/core/resolver" ) -func resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) { +func Resolve(ctx context.Context, network, addr string, r resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) { if addr == "" { return addr, nil } diff --git a/chain/route.go b/chain/route.go index 7052ec0..c4a4425 100644 --- a/chain/route.go +++ b/chain/route.go @@ -9,240 +9,49 @@ import ( "github.com/go-gost/core/common/net/dialer" "github.com/go-gost/core/common/net/udp" - "github.com/go-gost/core/connector" "github.com/go-gost/core/logger" - "github.com/go-gost/core/metrics" ) var ( ErrEmptyRoute = errors.New("empty route") ) +var ( + DefaultRoute Route = &route{} +) + type Route interface { Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) - Len() int - Path() []*Node + Nodes() []*Node } -type route struct { - chain *Chain - nodes []*Node -} +type route struct{} -func newRoute() *route { - return &route{} -} - -func (r *route) addNode(node *Node) { - r.nodes = append(r.nodes, node) -} - -func (r *route) WithChain(chain *Chain) *route { - r.chain = chain - return r -} - -func (r *route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) { +func (*route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) { var options DialOptions for _, opt := range opts { opt(&options) } - if r.Len() == 0 { - netd := dialer.NetDialer{ - Timeout: options.Timeout, - Interface: options.Interface, - } - if options.SockOpts != nil { - netd.Mark = options.SockOpts.Mark - } - if r != nil { - netd.Logger = options.Logger - } - - return netd.Dial(ctx, network, address) + netd := dialer.NetDialer{ + Timeout: options.Timeout, + Interface: options.Interface, + Logger: options.Logger, + } + if options.SockOpts != nil { + netd.Mark = options.SockOpts.Mark } - conn, err := r.connect(ctx, options.Logger) - if err != nil { - return nil, err - } - - cc, err := r.GetNode(r.Len()-1).transport.Connect(ctx, conn, network, address) - if err != nil { - if conn != nil { - conn.Close() - } - return nil, err - } - return cc, nil + return netd.Dial(ctx, network, address) } -func (r *route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) { +func (*route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) { var options BindOptions for _, opt := range opts { opt(&options) } - if r.Len() == 0 { - return r.bindLocal(ctx, network, address, &options) - } - - conn, err := r.connect(ctx, options.Logger) - if err != nil { - return nil, err - } - - ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx, - conn, network, address, - connector.BacklogBindOption(options.Backlog), - connector.MuxBindOption(options.Mux), - connector.UDPConnTTLBindOption(options.UDPConnTTL), - connector.UDPDataBufferSizeBindOption(options.UDPDataBufferSize), - connector.UDPDataQueueSizeBindOption(options.UDPDataQueueSize), - ) - if err != nil { - conn.Close() - return nil, err - } - - return ln, nil -} - -func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Conn, err error) { - if r.Len() == 0 { - return nil, ErrEmptyRoute - } - - network := "ip" - node := r.nodes[0] - - defer func() { - if r.chain != nil { - marker := r.chain.Marker() - // chain error - if err != nil { - if marker != nil { - marker.Mark() - } - if v := metrics.GetCounter(metrics.MetricChainErrorsCounter, - metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil { - v.Inc() - } - } else { - if marker != nil { - marker.Reset() - } - } - } - }() - - addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger) - marker := node.Marker() - if err != nil { - if marker != nil { - marker.Mark() - } - return - } - - start := time.Now() - cc, err := node.transport.Dial(ctx, addr) - if err != nil { - if marker != nil { - marker.Mark() - } - return - } - - cn, err := node.transport.Handshake(ctx, cc) - if err != nil { - cc.Close() - if marker != nil { - marker.Mark() - } - return - } - if marker != nil { - marker.Reset() - } - - if r.chain != nil { - if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver, - metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil { - v.Observe(time.Since(start).Seconds()) - } - } - - preNode := node - for _, node := range r.nodes[1:] { - marker := node.Marker() - addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger) - if err != nil { - cn.Close() - if marker != nil { - marker.Mark() - } - return - } - cc, err = preNode.transport.Connect(ctx, cn, "tcp", addr) - if err != nil { - cn.Close() - if marker != nil { - marker.Mark() - } - return - } - cc, err = node.transport.Handshake(ctx, cc) - if err != nil { - cn.Close() - if marker != nil { - marker.Mark() - } - return - } - if marker != nil { - marker.Reset() - } - - cn = cc - preNode = node - } - - conn = cn - return -} - -func (r *route) Len() int { - if r == nil { - return 0 - } - return len(r.nodes) -} - -func (r *route) GetNode(index int) *Node { - if r.Len() == 0 || index < 0 || index >= len(r.nodes) { - return nil - } - return r.nodes[index] -} - -func (r *route) Path() (path []*Node) { - if r == nil || len(r.nodes) == 0 { - return nil - } - - for _, node := range r.nodes { - if node.transport != nil && node.transport.route != nil { - path = append(path, node.transport.route.Path()...) - } - path = append(path, node) - } - return -} - -func (r *route) bindLocal(ctx context.Context, network, address string, opts *BindOptions) (net.Listener, error) { switch network { case "tcp", "tcp4", "tcp6": addr, err := net.ResolveTCPAddr(network, address) @@ -264,10 +73,10 @@ func (r *route) bindLocal(ctx context.Context, network, address string, opts *Bi "address": address, }) ln := udp.NewListener(conn, &udp.ListenConfig{ - Backlog: opts.Backlog, - ReadQueueSize: opts.UDPDataQueueSize, - ReadBufferSize: opts.UDPDataBufferSize, - TTL: opts.UDPConnTTL, + Backlog: options.Backlog, + ReadQueueSize: options.UDPDataQueueSize, + ReadBufferSize: options.UDPDataBufferSize, + TTL: options.UDPConnTTL, KeepAlive: true, Logger: logger, }) @@ -278,6 +87,10 @@ func (r *route) bindLocal(ctx context.Context, network, address string, opts *Bi } } +func (r *route) Nodes() []*Node { + return nil +} + type DialOptions struct { Timeout time.Duration Interface string diff --git a/chain/router.go b/chain/router.go index 93a80c2..b334060 100644 --- a/chain/router.go +++ b/chain/router.go @@ -17,68 +17,96 @@ type SockOpts struct { Mark int } -type Router struct { - ifceName string - sockOpts *SockOpts - timeout time.Duration - retries int - chain Chainer - resolver resolver.Resolver - hosts hosts.HostMapper - recorders []recorder.RecorderObject - logger logger.Logger +type RouterOptions struct { + IfceName string + SockOpts *SockOpts + Timeout time.Duration + Retries int + Chain Chainer + Resolver resolver.Resolver + HostMapper hosts.HostMapper + Recorders []recorder.RecorderObject + Logger logger.Logger } -func (r *Router) WithTimeout(timeout time.Duration) *Router { - r.timeout = timeout - return r -} +type RouterOption func(*RouterOptions) -func (r *Router) WithRetries(retries int) *Router { - r.retries = retries - return r -} - -func (r *Router) WithInterface(ifceName string) *Router { - r.ifceName = ifceName - return r -} - -func (r *Router) WithSockOpts(so *SockOpts) *Router { - r.sockOpts = so - return r -} - -func (r *Router) WithChain(chain Chainer) *Router { - r.chain = chain - return r -} - -func (r *Router) WithResolver(resolver resolver.Resolver) *Router { - r.resolver = resolver - return r -} - -func (r *Router) WithHosts(hosts hosts.HostMapper) *Router { - r.hosts = hosts - return r -} - -func (r *Router) Hosts() hosts.HostMapper { - if r != nil { - return r.hosts +func InterfaceRouterOption(ifceName string) RouterOption { + return func(o *RouterOptions) { + o.IfceName = ifceName } - return nil } -func (r *Router) WithRecorder(recorders ...recorder.RecorderObject) *Router { - r.recorders = recorders +func SockOptsRouterOption(so *SockOpts) RouterOption { + return func(o *RouterOptions) { + o.SockOpts = so + } +} + +func TimeoutRouterOption(timeout time.Duration) RouterOption { + return func(o *RouterOptions) { + o.Timeout = timeout + } +} + +func RetriesRouterOption(retries int) RouterOption { + return func(o *RouterOptions) { + o.Retries = retries + } +} + +func ChainRouterOption(chain Chainer) RouterOption { + return func(o *RouterOptions) { + o.Chain = chain + } +} + +func ResolverRouterOption(resolver resolver.Resolver) RouterOption { + return func(o *RouterOptions) { + o.Resolver = resolver + } +} + +func HostMapperRouterOption(m hosts.HostMapper) RouterOption { + return func(o *RouterOptions) { + o.HostMapper = m + } +} + +func RecordersRouterOption(recorders ...recorder.RecorderObject) RouterOption { + return func(o *RouterOptions) { + o.Recorders = recorders + } +} + +func LoggerRouterOption(logger logger.Logger) RouterOption { + return func(o *RouterOptions) { + o.Logger = logger + } +} + +type Router struct { + options RouterOptions +} + +func NewRouter(opts ...RouterOption) *Router { + r := &Router{} + for _, opt := range opts { + if opt != nil { + opt(&r.options) + } + } + if r.options.Logger == nil { + r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"}) + } return r } -func (r *Router) WithLogger(logger logger.Logger) *Router { - r.logger = logger - return r +func (r *Router) Options() *RouterOptions { + if r == nil { + return nil + } + return &r.options } func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { @@ -107,11 +135,11 @@ func (r *Router) record(ctx context.Context, name string, data []byte) error { return nil } - for _, rec := range r.recorders { + for _, rec := range r.options.Recorders { if rec.Record == name { err := rec.Recorder.Record(ctx, data) if err != nil { - r.logger.Errorf("record %s: %v", name, err) + r.options.Logger.Errorf("record %s: %v", name, err) } return err } @@ -120,90 +148,99 @@ func (r *Router) record(ctx context.Context, name string, data []byte) error { } func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - count := r.retries + 1 + count := r.options.Retries + 1 if count <= 0 { count = 1 } - r.logger.Debugf("dial %s/%s", address, network) + r.options.Logger.Debugf("dial %s/%s", address, network) for i := 0; i < count; i++ { var route Route - if r.chain != nil { - route = r.chain.Route(ctx, network, address) + if r.options.Chain != nil { + route = r.options.Chain.Route(ctx, network, address) } - if r.logger.IsLevelEnabled(logger.DebugLevel) { + if r.options.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} - var path []*Node - if route != nil { - path = route.Path() - } - for _, node := range path { + for _, node := range routePath(route) { fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) } fmt.Fprintf(&buf, "%s", address) - r.logger.Debugf("route(retry=%d) %s", i, buf.String()) + r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String()) } - address, err = resolve(ctx, "ip", address, r.resolver, r.hosts, r.logger) + address, err = Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger) if err != nil { - r.logger.Error(err) + r.options.Logger.Error(err) break } if route == nil { - route = newRoute() + route = DefaultRoute } conn, err = route.Dial(ctx, network, address, - InterfaceDialOption(r.ifceName), - SockOptsDialOption(r.sockOpts), - LoggerDialOption(r.logger), + InterfaceDialOption(r.options.IfceName), + SockOptsDialOption(r.options.SockOpts), + LoggerDialOption(r.options.Logger), ) if err == nil { break } - r.logger.Errorf("route(retry=%d) %s", i, err) + r.options.Logger.Errorf("route(retry=%d) %s", i, err) } return } func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) { - count := r.retries + 1 + count := r.options.Retries + 1 if count <= 0 { count = 1 } - r.logger.Debugf("bind on %s/%s", address, network) + r.options.Logger.Debugf("bind on %s/%s", address, network) for i := 0; i < count; i++ { var route Route - if r.chain != nil { - route = r.chain.Route(ctx, network, address) - if route.Len() == 0 { + if r.options.Chain != nil { + route = r.options.Chain.Route(ctx, network, address) + if len(route.Nodes()) == 0 { err = ErrEmptyRoute return } } - if r.logger.IsLevelEnabled(logger.DebugLevel) { + if r.options.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} - for _, node := range route.Path() { + for _, node := range routePath(route) { fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) } fmt.Fprintf(&buf, "%s", address) - r.logger.Debugf("route(retry=%d) %s", i, buf.String()) + r.options.Logger.Debugf("route(retry=%d) %s", i, buf.String()) } ln, err = route.Bind(ctx, network, address, opts...) if err == nil { break } - r.logger.Errorf("route(retry=%d) %s", i, err) + r.options.Logger.Errorf("route(retry=%d) %s", i, err) } return } +func routePath(route Route) (path []*Node) { + if route == nil { + return + } + for _, node := range route.Nodes() { + if tr := node.Options().Transport; tr != nil { + path = append(path, routePath(tr.Options().Route)...) + } + path = append(path, node) + } + return +} + type packetConn struct { net.Conn } diff --git a/chain/transport.go b/chain/transport.go index a469510..0b3c232 100644 --- a/chain/transport.go +++ b/chain/transport.go @@ -10,62 +10,81 @@ import ( "github.com/go-gost/core/dialer" ) +type TransportOptions struct { + Addr string + IfceName string + SockOpts *SockOpts + Route Route + Timeout time.Duration +} + +type TransportOption func(*TransportOptions) + +func AddrTransportOption(addr string) TransportOption { + return func(o *TransportOptions) { + o.Addr = addr + } +} + +func InterfaceTransportOption(ifceName string) TransportOption { + return func(o *TransportOptions) { + o.IfceName = ifceName + } +} + +func SockOptsTransportOption(so *SockOpts) TransportOption { + return func(o *TransportOptions) { + o.SockOpts = so + } +} + +func RouteTransportOption(route Route) TransportOption { + return func(o *TransportOptions) { + o.Route = route + } +} + +func TimeoutTransportOption(timeout time.Duration) TransportOption { + return func(o *TransportOptions) { + o.Timeout = timeout + } +} + type Transport struct { - addr string - ifceName string - sockOpts *SockOpts - route Route dialer dialer.Dialer connector connector.Connector - timeout time.Duration + options TransportOptions } -func (tr *Transport) Copy() *Transport { - tr2 := &Transport{} - *tr2 = *tr - return tr -} +func NewTransport(d dialer.Dialer, c connector.Connector, opts ...TransportOption) *Transport { + tr := &Transport{ + dialer: d, + connector: c, + } + for _, opt := range opts { + if opt != nil { + opt(&tr.options) + } + } -func (tr *Transport) WithInterface(ifceName string) *Transport { - tr.ifceName = ifceName - return tr -} - -func (tr *Transport) WithSockOpts(so *SockOpts) *Transport { - tr.sockOpts = so - return tr -} - -func (tr *Transport) WithDialer(dialer dialer.Dialer) *Transport { - tr.dialer = dialer - return tr -} - -func (tr *Transport) WithConnector(connector connector.Connector) *Transport { - tr.connector = connector - return tr -} - -func (tr *Transport) WithTimeout(d time.Duration) *Transport { - tr.timeout = d return tr } func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { netd := &net_dialer.NetDialer{ - Interface: tr.ifceName, - Timeout: tr.timeout, + Interface: tr.options.IfceName, + Timeout: tr.options.Timeout, } - if tr.sockOpts != nil { - netd.Mark = tr.sockOpts.Mark + if tr.options.SockOpts != nil { + netd.Mark = tr.options.SockOpts.Mark } - if tr.route != nil && tr.route.Len() > 0 { + if tr.options.Route != nil && len(tr.options.Route.Nodes()) > 0 { netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { - return tr.route.Dial(ctx, network, addr) + return tr.options.Route.Dial(ctx, network, addr) } } opts := []dialer.DialOption{ - dialer.HostDialOption(tr.addr), + dialer.HostDialOption(tr.options.Addr), dialer.NetDialerDialOption(netd), } return tr.dialer.Dial(ctx, addr, opts...) @@ -75,7 +94,7 @@ func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, er var err error if hs, ok := tr.dialer.(dialer.Handshaker); ok { conn, err = hs.Handshake(ctx, conn, - dialer.AddrHandshakeOption(tr.addr)) + dialer.AddrHandshakeOption(tr.options.Addr)) if err != nil { return nil, err } @@ -88,11 +107,11 @@ func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, er func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { netd := &net_dialer.NetDialer{ - Interface: tr.ifceName, - Timeout: tr.timeout, + Interface: tr.options.IfceName, + Timeout: tr.options.Timeout, } - if tr.sockOpts != nil { - netd.Mark = tr.sockOpts.Mark + if tr.options.SockOpts != nil { + netd.Mark = tr.options.SockOpts.Mark } return tr.connector.Connect(ctx, conn, network, address, connector.NetDialerConnectOption(netd), @@ -113,12 +132,15 @@ func (tr *Transport) Multiplex() bool { return false } -func (tr *Transport) WithRoute(r Route) *Transport { - tr.route = r - return tr +func (tr *Transport) Options() *TransportOptions { + if tr != nil { + return &tr.options + } + return nil } -func (tr *Transport) WithAddr(addr string) *Transport { - tr.addr = addr +func (tr *Transport) Copy() *Transport { + tr2 := &Transport{} + *tr2 = *tr return tr } diff --git a/handler/handler.go b/handler/handler.go index c1ce5f4..84d1a69 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -14,5 +14,5 @@ type Handler interface { } type Forwarder interface { - Forward(*chain.NodeGroup) + Forward(chain.Hop) }