From 7136710673690b0223f7cd76edff64598303d2c6 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 2 Sep 2022 17:23:28 +0800 Subject: [PATCH] add context for selector --- chain/chain.go | 18 ++++++++++-------- chain/node.go | 6 ++++-- chain/router.go | 10 +++++++--- selector/selector.go | 7 ++++--- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/chain/chain.go b/chain/chain.go index 3b63d39..786408d 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -1,12 +1,14 @@ package chain import ( + "context" + "github.com/go-gost/core/metadata" "github.com/go-gost/core/selector" ) type Chainer interface { - Route(network, address string) Route + Route(ctx context.Context, network, address string) Route } type SelectableChainer interface { @@ -45,7 +47,7 @@ func (c *Chain) Marker() selector.Marker { return c.marker } -func (c *Chain) Route(network, address string) Route { +func (c *Chain) Route(ctx context.Context, network, address string) Route { if c == nil || len(c.groups) == 0 { return nil } @@ -57,7 +59,7 @@ func (c *Chain) Route(network, address string) Route { break } - node := group.FilterAddr(address).Next() + node := group.FilterAddr(address).Next(ctx) if node == nil { return rt } @@ -89,17 +91,17 @@ func (p *ChainGroup) WithSelector(s selector.Selector[SelectableChainer]) *Chain return p } -func (p *ChainGroup) Route(network, address string) Route { - if chain := p.next(); chain != nil { - return chain.Route(network, address) +func (p *ChainGroup) Route(ctx context.Context, network, address string) Route { + if chain := p.next(ctx); chain != nil { + return chain.Route(ctx, network, address) } return nil } -func (p *ChainGroup) next() Chainer { +func (p *ChainGroup) next(ctx context.Context) Chainer { if p == nil || len(p.chains) == 0 { return nil } - return p.selector.Select(p.chains...) + return p.selector.Select(ctx, p.chains...) } diff --git a/chain/node.go b/chain/node.go index 6e92765..336b7dd 100644 --- a/chain/node.go +++ b/chain/node.go @@ -1,6 +1,8 @@ package chain import ( + "context" + "github.com/go-gost/core/bypass" "github.com/go-gost/core/hosts" "github.com/go-gost/core/metadata" @@ -110,10 +112,10 @@ func (g *NodeGroup) FilterAddr(addr string) *NodeGroup { } } -func (g *NodeGroup) Next() *Node { +func (g *NodeGroup) Next(ctx context.Context) *Node { if g == nil || len(g.nodes) == 0 { return nil } - return g.selector.Select(g.nodes...) + return g.selector.Select(ctx, g.nodes...) } diff --git a/chain/router.go b/chain/router.go index 34d5027..93a80c2 100644 --- a/chain/router.go +++ b/chain/router.go @@ -129,12 +129,16 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co for i := 0; i < count; i++ { var route Route if r.chain != nil { - route = r.chain.Route(network, address) + route = r.chain.Route(ctx, network, address) } if r.logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} - for _, node := range route.Path() { + var path []*Node + if route != nil { + path = route.Path() + } + for _, node := range path { fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) } fmt.Fprintf(&buf, "%s", address) @@ -174,7 +178,7 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...Bind for i := 0; i < count; i++ { var route Route if r.chain != nil { - route = r.chain.Route(network, address) + route = r.chain.Route(ctx, network, address) if route.Len() == 0 { err = ErrEmptyRoute return diff --git a/selector/selector.go b/selector/selector.go index cc1f2b3..0aa38ae 100644 --- a/selector/selector.go +++ b/selector/selector.go @@ -1,6 +1,7 @@ package selector import ( + "context" "sync/atomic" "time" @@ -13,15 +14,15 @@ type Selectable interface { } type Selector[T any] interface { - Select(...T) T + Select(context.Context, ...T) T } type Strategy[T Selectable] interface { - Apply(...T) T + Apply(context.Context, ...T) T } type Filter[T Selectable] interface { - Filter(...T) []T + Filter(context.Context, ...T) []T } type Marker interface {