diff --git a/chain/chain.go b/chain/chain.go index 8230eaf..3b63d39 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -2,21 +2,22 @@ package chain import ( "github.com/go-gost/core/metadata" + "github.com/go-gost/core/selector" ) type Chainer interface { - Route(network, address string) *Route + Route(network, address string) Route } type SelectableChainer interface { Chainer - Selectable + selector.Selectable } type Chain struct { name string groups []*NodeGroup - marker Marker + marker selector.Marker metadata metadata.Metadata } @@ -24,7 +25,7 @@ func NewChain(name string, groups ...*NodeGroup) *Chain { return &Chain{ name: name, groups: groups, - marker: NewFailMarker(), + marker: selector.NewFailMarker(), } } @@ -40,18 +41,16 @@ func (c *Chain) Metadata() metadata.Metadata { return c.metadata } -func (c *Chain) Marker() Marker { +func (c *Chain) Marker() selector.Marker { return c.marker } -func (c *Chain) Route(network, address string) (r *Route) { +func (c *Chain) Route(network, address string) Route { if c == nil || len(c.groups) == 0 { - return + return nil } - r = &Route{ - chain: c, - } + rt := newRoute().WithChain(c) for _, group := range c.groups { // hop level bypass test if group.bypass != nil && group.bypass.Contains(address) { @@ -60,27 +59,37 @@ func (c *Chain) Route(network, address string) (r *Route) { node := group.FilterAddr(address).Next() if node == nil { - return + return rt } if node.transport.Multiplex() { - tr := node.transport.Copy(). - WithRoute(r) + tr := node.transport. + Copy(). + WithRoute(rt) node = node.Copy() node.transport = tr - r = &Route{} + rt = newRoute() } - r.addNode(node) + rt.addNode(node) } - return r + return rt } type ChainGroup struct { - Chains []SelectableChainer - Selector Selector[SelectableChainer] + chains []SelectableChainer + selector selector.Selector[SelectableChainer] } -func (p *ChainGroup) Route(network, address string) *Route { +func NewChainGroup(chains ...SelectableChainer) *ChainGroup { + return &ChainGroup{chains: chains} +} + +func (p *ChainGroup) WithSelector(s selector.Selector[SelectableChainer]) *ChainGroup { + p.selector = s + return p +} + +func (p *ChainGroup) Route(network, address string) Route { if chain := p.next(); chain != nil { return chain.Route(network, address) } @@ -88,13 +97,9 @@ func (p *ChainGroup) Route(network, address string) *Route { } func (p *ChainGroup) next() Chainer { - if p == nil || len(p.Chains) == 0 { + if p == nil || len(p.chains) == 0 { return nil } - s := p.Selector - if s == nil { - s = DefaultChainSelector - } - return s.Select(p.Chains...) + return p.selector.Select(p.chains...) } diff --git a/chain/node.go b/chain/node.go index b8fbc93..6e92765 100644 --- a/chain/node.go +++ b/chain/node.go @@ -5,6 +5,7 @@ import ( "github.com/go-gost/core/hosts" "github.com/go-gost/core/metadata" "github.com/go-gost/core/resolver" + "github.com/go-gost/core/selector" ) type Node struct { @@ -14,7 +15,7 @@ type Node struct { bypass bypass.Bypass resolver resolver.Resolver hostMapper hosts.HostMapper - marker Marker + marker selector.Marker metadata metadata.Metadata } @@ -22,7 +23,7 @@ func NewNode(name, addr string) *Node { return &Node{ Name: name, Addr: addr, - marker: NewFailMarker(), + marker: selector.NewFailMarker(), } } @@ -51,7 +52,7 @@ func (node *Node) WithMetadata(md metadata.Metadata) *Node { return node } -func (node *Node) Marker() Marker { +func (node *Node) Marker() selector.Marker { return node.marker } @@ -67,7 +68,7 @@ func (node *Node) Copy() *Node { type NodeGroup struct { nodes []*Node - selector Selector[*Node] + selector selector.Selector[*Node] bypass bypass.Bypass } @@ -85,7 +86,7 @@ func (g *NodeGroup) Nodes() []*Node { return g.nodes } -func (g *NodeGroup) WithSelector(selector Selector[*Node]) *NodeGroup { +func (g *NodeGroup) WithSelector(selector selector.Selector[*Node]) *NodeGroup { g.selector = selector return g } @@ -114,10 +115,5 @@ func (g *NodeGroup) Next() *Node { return nil } - s := g.selector - if s == nil { - s = DefaultNodeSelector - } - - return s.Select(g.nodes...) + return g.selector.Select(g.nodes...) } diff --git a/chain/route.go b/chain/route.go index 4816a3a..73dbcd4 100644 --- a/chain/route.go +++ b/chain/route.go @@ -18,17 +18,32 @@ var ( ErrEmptyRoute = errors.New("empty route") ) -type Route struct { - chain *Chain - nodes []*Node - logger logger.Logger +type Route interface { + Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) + Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) + Len() int + Path() []*Node } -func (r *Route) addNode(node *Node) { +type route struct { + chain *Chain + nodes []*Node +} + +func newRoute() *route { + return &route{} +} + +func (r *route) addNode(node *Node) { r.nodes = append(r.nodes, node) } -func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) { +func (r *route) WithChain(chain *Chain) *route { + r.chain = chain + return r +} + +func (r *route) Dial(ctx context.Context, network, address string, opts ...DialOption) (net.Conn, error) { var options DialOptions for _, opt := range opts { opt(&options) @@ -43,13 +58,13 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO netd.Mark = options.SockOpts.Mark } if r != nil { - netd.Logger = r.logger + netd.Logger = options.Logger } return netd.Dial(ctx, network, address) } - conn, err := r.connect(ctx) + conn, err := r.connect(ctx, options.Logger) if err != nil { return nil, err } @@ -62,17 +77,29 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO return cc, nil } -func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { - if r.Len() == 0 { - return r.bindLocal(ctx, network, address, opts...) +func (r *route) Bind(ctx context.Context, network, address string, opts ...BindOption) (net.Listener, error) { + var options BindOptions + for _, opt := range opts { + opt(&options) } - conn, err := r.connect(ctx) + if r.Len() == 0 { + return r.bindLocal(ctx, network, address, &options) + } + + conn, err := r.connect(ctx, options.Logger) if err != nil { return nil, err } - ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx, conn, network, address, opts...) + ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx, + conn, network, address, + connector.BacklogBindOption(options.Backlog), + connector.MuxBindOption(options.Mux), + connector.UDPConnTTLBindOption(options.UDPConnTTL), + connector.UDPDataBufferSizeBindOption(options.UDPDataBufferSize), + connector.UDPDataQueueSizeBindOption(options.UDPDataQueueSize), + ) if err != nil { conn.Close() return nil, err @@ -81,7 +108,7 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne return ln, nil } -func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { +func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Conn, err error) { if r.Len() == 0 { return nil, ErrEmptyRoute } @@ -109,7 +136,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { } }() - addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger) + addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger) marker := node.Marker() if err != nil { if marker != nil { @@ -149,7 +176,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { preNode := node for _, node := range r.nodes[1:] { marker := node.Marker() - addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger) + addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, logger) if err != nil { cn.Close() if marker != nil { @@ -185,21 +212,21 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { return } -func (r *Route) Len() int { +func (r *route) Len() int { if r == nil { return 0 } return len(r.nodes) } -func (r *Route) GetNode(index int) *Node { +func (r *route) GetNode(index int) *Node { if r.Len() == 0 || index < 0 || index >= len(r.nodes) { return nil } return r.nodes[index] } -func (r *Route) Path() (path []*Node) { +func (r *route) Path() (path []*Node) { if r == nil || len(r.nodes) == 0 { return nil } @@ -213,12 +240,7 @@ func (r *Route) Path() (path []*Node) { return } -func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { - options := connector.BindOptions{} - for _, opt := range opts { - opt(&options) - } - +func (r *route) bindLocal(ctx context.Context, network, address string, opts *BindOptions) (net.Listener, error) { switch network { case "tcp", "tcp4", "tcp6": addr, err := net.ResolveTCPAddr(network, address) @@ -240,10 +262,10 @@ func (r *Route) bindLocal(ctx context.Context, network, address string, opts ... "address": address, }) ln := udp.NewListener(conn, &udp.ListenConfig{ - Backlog: options.Backlog, - ReadQueueSize: options.UDPDataQueueSize, - ReadBufferSize: options.UDPDataBufferSize, - TTL: options.UDPConnTTL, + Backlog: opts.Backlog, + ReadQueueSize: opts.UDPDataQueueSize, + ReadBufferSize: opts.UDPDataBufferSize, + TTL: opts.UDPConnTTL, KeepAlive: true, Logger: logger, }) @@ -258,6 +280,7 @@ type DialOptions struct { Timeout time.Duration Interface string SockOpts *SockOpts + Logger logger.Logger } type DialOption func(opts *DialOptions) @@ -279,3 +302,56 @@ func SockOptsDialOption(so *SockOpts) DialOption { opts.SockOpts = so } } + +func LoggerDialOption(logger logger.Logger) DialOption { + return func(opts *DialOptions) { + opts.Logger = logger + } +} + +type BindOptions struct { + Mux bool + Backlog int + UDPDataQueueSize int + UDPDataBufferSize int + UDPConnTTL time.Duration + Logger logger.Logger +} + +type BindOption func(opts *BindOptions) + +func MuxBindOption(mux bool) BindOption { + return func(opts *BindOptions) { + opts.Mux = mux + } +} + +func BacklogBindOption(backlog int) BindOption { + return func(opts *BindOptions) { + opts.Backlog = backlog + } +} + +func UDPDataQueueSizeBindOption(size int) BindOption { + return func(opts *BindOptions) { + opts.UDPDataQueueSize = size + } +} + +func UDPDataBufferSizeBindOption(size int) BindOption { + return func(opts *BindOptions) { + opts.UDPDataBufferSize = size + } +} + +func UDPConnTTLBindOption(ttl time.Duration) BindOption { + return func(opts *BindOptions) { + opts.UDPConnTTL = ttl + } +} + +func LoggerBindOption(logger logger.Logger) BindOption { + return func(opts *BindOptions) { + opts.Logger = logger + } +} diff --git a/chain/router.go b/chain/router.go index ba1c81c..34d5027 100644 --- a/chain/router.go +++ b/chain/router.go @@ -7,7 +7,6 @@ import ( "net" "time" - "github.com/go-gost/core/connector" "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" @@ -128,7 +127,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co r.logger.Debugf("dial %s/%s", address, network) for i := 0; i < count; i++ { - var route *Route + var route Route if r.chain != nil { route = r.chain.Route(network, address) } @@ -149,12 +148,12 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co } if route == nil { - route = &Route{} + route = newRoute() } - route.logger = r.logger conn, err = route.Dial(ctx, network, address, InterfaceDialOption(r.ifceName), SockOptsDialOption(r.sockOpts), + LoggerDialOption(r.logger), ) if err == nil { break @@ -165,7 +164,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co return } -func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) { +func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) { count := r.retries + 1 if count <= 0 { count = 1 @@ -173,7 +172,7 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn r.logger.Debugf("bind on %s/%s", address, network) for i := 0; i < count; i++ { - var route *Route + var route Route if r.chain != nil { route = r.chain.Route(network, address) if route.Len() == 0 { diff --git a/chain/selector.go b/chain/selector.go deleted file mode 100644 index 6021c11..0000000 --- a/chain/selector.go +++ /dev/null @@ -1,245 +0,0 @@ -package chain - -import ( - "math/rand" - "sync" - "sync/atomic" - "time" - - "github.com/go-gost/core/metadata" - mdutil "github.com/go-gost/core/metadata/util" -) - -// default options for FailFilter -const ( - DefaultFailTimeout = 30 * time.Second -) - -var ( - DefaultNodeSelector = NewSelector( - RoundRobinStrategy[*Node](), - // FailFilter[*Node](1, DefaultFailTimeout), - ) - DefaultChainSelector = NewSelector( - RoundRobinStrategy[SelectableChainer](), - // FailFilter[SelectableChainer](1, DefaultFailTimeout), - ) -) - -type Selectable interface { - Marker() Marker - Metadata() metadata.Metadata -} - -type Selector[T any] interface { - Select(...T) T -} - -type selector[T Selectable] struct { - strategy Strategy[T] - filters []Filter[T] -} - -func NewSelector[T Selectable](strategy Strategy[T], filters ...Filter[T]) Selector[T] { - return &selector[T]{ - filters: filters, - strategy: strategy, - } -} - -func (s *selector[T]) Select(vs ...T) (v T) { - for _, filter := range s.filters { - vs = filter.Filter(vs...) - } - if len(vs) == 0 { - return - } - return s.strategy.Apply(vs...) -} - -type Strategy[T Selectable] interface { - Apply(...T) T -} - -type roundRobinStrategy[T Selectable] struct { - counter uint64 -} - -// RoundRobinStrategy is a strategy for node selector. -// The node will be selected by round-robin algorithm. -func RoundRobinStrategy[T Selectable]() Strategy[T] { - return &roundRobinStrategy[T]{} -} - -func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) { - if len(vs) == 0 { - return - } - - n := atomic.AddUint64(&s.counter, 1) - 1 - return vs[int(n%uint64(len(vs)))] -} - -type randomStrategy[T Selectable] struct { - rand *rand.Rand - mux sync.Mutex -} - -// RandomStrategy is a strategy for node selector. -// The node will be selected randomly. -func RandomStrategy[T Selectable]() Strategy[T] { - return &randomStrategy[T]{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - } -} - -func (s *randomStrategy[T]) Apply(vs ...T) (v T) { - if len(vs) == 0 { - return - } - - s.mux.Lock() - defer s.mux.Unlock() - - r := s.rand.Int() - - return vs[r%len(vs)] -} - -type fifoStrategy[T Selectable] struct{} - -// FIFOStrategy is a strategy for node selector. -// The node will be selected from first to last, -// and will stick to the selected node until it is failed. -func FIFOStrategy[T Selectable]() Strategy[T] { - return &fifoStrategy[T]{} -} - -// Apply applies the fifo strategy for the nodes. -func (s *fifoStrategy[T]) Apply(vs ...T) (v T) { - if len(vs) == 0 { - return - } - return vs[0] -} - -type Filter[T Selectable] interface { - Filter(...T) []T -} - -type failFilter[T Selectable] struct { - maxFails int - failTimeout time.Duration -} - -// FailFilter filters the dead objects. -// An object is marked as dead if its failed count is greater than MaxFails. -func FailFilter[T Selectable](maxFails int, timeout time.Duration) Filter[T] { - return &failFilter[T]{ - maxFails: maxFails, - failTimeout: timeout, - } -} - -// Filter filters dead objects. -func (f *failFilter[T]) Filter(vs ...T) []T { - maxFails := f.maxFails - failTimeout := f.failTimeout - if failTimeout == 0 { - failTimeout = DefaultFailTimeout - } - - if len(vs) <= 1 || maxFails <= 0 { - return vs - } - var l []T - for _, v := range vs { - if marker := v.Marker(); marker != nil { - if marker.Count() < int64(maxFails) || - time.Since(marker.Time()) >= failTimeout { - l = append(l, v) - } - } else { - l = append(l, v) - } - } - return l -} - -type backupFilter[T Selectable] struct{} - -// BackupFilter filters the backup objects. -// An object is marked as backup if its metadata has backup flag. -func BackupFilter[T Selectable]() Filter[T] { - return &backupFilter[T]{} -} - -// Filter filters backup objects. -func (f *backupFilter[T]) Filter(vs ...T) []T { - if len(vs) <= 1 { - return vs - } - - var l, backups []T - for _, v := range vs { - if mdutil.GetBool(v.Metadata(), "backup") { - backups = append(backups, v) - } else { - l = append(l, v) - } - } - - if len(l) == 0 { - return backups - } - return l -} - -type Marker interface { - Time() time.Time - Count() int64 - Mark() - Reset() -} - -type failMarker struct { - failTime int64 - failCount int64 -} - -func NewFailMarker() Marker { - return &failMarker{} -} - -func (m *failMarker) Time() time.Time { - if m == nil { - return time.Time{} - } - - return time.Unix(atomic.LoadInt64(&m.failTime), 0) -} - -func (m *failMarker) Count() int64 { - if m == nil { - return 0 - } - - return atomic.LoadInt64(&m.failCount) -} - -func (m *failMarker) Mark() { - if m == nil { - return - } - - atomic.AddInt64(&m.failCount, 1) - atomic.StoreInt64(&m.failTime, time.Now().Unix()) -} - -func (m *failMarker) Reset() { - if m == nil { - return - } - - atomic.StoreInt64(&m.failCount, 0) -} diff --git a/chain/transport.go b/chain/transport.go index 0efe5cf..15a4c3f 100644 --- a/chain/transport.go +++ b/chain/transport.go @@ -14,7 +14,7 @@ type Transport struct { addr string ifceName string sockOpts *SockOpts - route *Route + route Route dialer dialer.Dialer connector connector.Connector } @@ -53,7 +53,7 @@ func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { if tr.sockOpts != nil { netd.Mark = tr.sockOpts.Mark } - if tr.route.Len() > 0 { + if tr.route != nil && tr.route.Len() > 0 { netd.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { return tr.route.Dial(ctx, network, addr) } @@ -98,7 +98,7 @@ func (tr *Transport) Multiplex() bool { return false } -func (tr *Transport) WithRoute(r *Route) *Transport { +func (tr *Transport) WithRoute(r Route) *Transport { tr.route = r return tr } diff --git a/selector/selector.go b/selector/selector.go new file mode 100644 index 0000000..cc1f2b3 --- /dev/null +++ b/selector/selector.go @@ -0,0 +1,74 @@ +package selector + +import ( + "sync/atomic" + "time" + + "github.com/go-gost/core/metadata" +) + +type Selectable interface { + Marker() Marker + Metadata() metadata.Metadata +} + +type Selector[T any] interface { + Select(...T) T +} + +type Strategy[T Selectable] interface { + Apply(...T) T +} + +type Filter[T Selectable] interface { + Filter(...T) []T +} + +type Marker interface { + Time() time.Time + Count() int64 + Mark() + Reset() +} + +type failMarker struct { + failTime int64 + failCount int64 +} + +func NewFailMarker() Marker { + return &failMarker{} +} + +func (m *failMarker) Time() time.Time { + if m == nil { + return time.Time{} + } + + return time.Unix(atomic.LoadInt64(&m.failTime), 0) +} + +func (m *failMarker) Count() int64 { + if m == nil { + return 0 + } + + return atomic.LoadInt64(&m.failCount) +} + +func (m *failMarker) Mark() { + if m == nil { + return + } + + atomic.AddInt64(&m.failCount, 1) + atomic.StoreInt64(&m.failTime, time.Now().Unix()) +} + +func (m *failMarker) Reset() { + if m == nil { + return + } + + atomic.StoreInt64(&m.failCount, 0) +}