add context for selector

This commit is contained in:
ginuerzh 2022-09-02 17:23:28 +08:00
parent 2835a5d44a
commit 7136710673
4 changed files with 25 additions and 16 deletions

View File

@ -1,12 +1,14 @@
package chain package chain
import ( import (
"context"
"github.com/go-gost/core/metadata" "github.com/go-gost/core/metadata"
"github.com/go-gost/core/selector" "github.com/go-gost/core/selector"
) )
type Chainer interface { type Chainer interface {
Route(network, address string) Route Route(ctx context.Context, network, address string) Route
} }
type SelectableChainer interface { type SelectableChainer interface {
@ -45,7 +47,7 @@ func (c *Chain) Marker() selector.Marker {
return c.marker return c.marker
} }
func (c *Chain) Route(network, address string) Route { func (c *Chain) Route(ctx context.Context, network, address string) Route {
if c == nil || len(c.groups) == 0 { if c == nil || len(c.groups) == 0 {
return nil return nil
} }
@ -57,7 +59,7 @@ func (c *Chain) Route(network, address string) Route {
break break
} }
node := group.FilterAddr(address).Next() node := group.FilterAddr(address).Next(ctx)
if node == nil { if node == nil {
return rt return rt
} }
@ -89,17 +91,17 @@ func (p *ChainGroup) WithSelector(s selector.Selector[SelectableChainer]) *Chain
return p return p
} }
func (p *ChainGroup) Route(network, address string) Route { func (p *ChainGroup) Route(ctx context.Context, network, address string) Route {
if chain := p.next(); chain != nil { if chain := p.next(ctx); chain != nil {
return chain.Route(network, address) return chain.Route(ctx, network, address)
} }
return nil return nil
} }
func (p *ChainGroup) next() Chainer { func (p *ChainGroup) next(ctx context.Context) Chainer {
if p == nil || len(p.chains) == 0 { if p == nil || len(p.chains) == 0 {
return nil return nil
} }
return p.selector.Select(p.chains...) return p.selector.Select(ctx, p.chains...)
} }

View File

@ -1,6 +1,8 @@
package chain package chain
import ( import (
"context"
"github.com/go-gost/core/bypass" "github.com/go-gost/core/bypass"
"github.com/go-gost/core/hosts" "github.com/go-gost/core/hosts"
"github.com/go-gost/core/metadata" "github.com/go-gost/core/metadata"
@ -110,10 +112,10 @@ func (g *NodeGroup) FilterAddr(addr string) *NodeGroup {
} }
} }
func (g *NodeGroup) Next() *Node { func (g *NodeGroup) Next(ctx context.Context) *Node {
if g == nil || len(g.nodes) == 0 { if g == nil || len(g.nodes) == 0 {
return nil return nil
} }
return g.selector.Select(g.nodes...) return g.selector.Select(ctx, g.nodes...)
} }

View File

@ -129,12 +129,16 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
var route Route var route Route
if r.chain != nil { if r.chain != nil {
route = r.chain.Route(network, address) route = r.chain.Route(ctx, network, address)
} }
if r.logger.IsLevelEnabled(logger.DebugLevel) { if r.logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{} buf := bytes.Buffer{}
for _, node := range route.Path() { var path []*Node
if route != nil {
path = route.Path()
}
for _, node := range path {
fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr)
} }
fmt.Fprintf(&buf, "%s", address) fmt.Fprintf(&buf, "%s", address)
@ -174,7 +178,7 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...Bind
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
var route Route var route Route
if r.chain != nil { if r.chain != nil {
route = r.chain.Route(network, address) route = r.chain.Route(ctx, network, address)
if route.Len() == 0 { if route.Len() == 0 {
err = ErrEmptyRoute err = ErrEmptyRoute
return return

View File

@ -1,6 +1,7 @@
package selector package selector
import ( import (
"context"
"sync/atomic" "sync/atomic"
"time" "time"
@ -13,15 +14,15 @@ type Selectable interface {
} }
type Selector[T any] interface { type Selector[T any] interface {
Select(...T) T Select(context.Context, ...T) T
} }
type Strategy[T Selectable] interface { type Strategy[T Selectable] interface {
Apply(...T) T Apply(context.Context, ...T) T
} }
type Filter[T Selectable] interface { type Filter[T Selectable] interface {
Filter(...T) []T Filter(context.Context, ...T) []T
} }
type Marker interface { type Marker interface {