diff --git a/admission/admission.go b/admission/admission.go index 0187d2a..5d29a6b 100644 --- a/admission/admission.go +++ b/admission/admission.go @@ -4,17 +4,17 @@ type Admission interface { Admit(addr string) bool } -type admissionList struct { +type admissionGroup struct { admissions []Admission } -func AdmissionList(admissions ...Admission) Admission { - return &admissionList{ +func AdmissionGroup(admissions ...Admission) Admission { + return &admissionGroup{ admissions: admissions, } } -func (p *admissionList) Admit(addr string) bool { +func (p *admissionGroup) Admit(addr string) bool { for _, admission := range p.admissions { if admission != nil && !admission.Admit(addr) { return false diff --git a/auth/auth.go b/auth/auth.go index 7b88c35..50450f6 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -5,17 +5,17 @@ type Authenticator interface { Authenticate(user, password string) bool } -type authenticatorList struct { +type authenticatorGroup struct { authers []Authenticator } -func AuthenticatorList(authers ...Authenticator) Authenticator { - return &authenticatorList{ +func AuthenticatorGroup(authers ...Authenticator) Authenticator { + return &authenticatorGroup{ authers: authers, } } -func (p *authenticatorList) Authenticate(user, password string) bool { +func (p *authenticatorGroup) Authenticate(user, password string) bool { if len(p.authers) == 0 { return true } diff --git a/bypass/bypass.go b/bypass/bypass.go index babfd16..f35fc30 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -6,17 +6,17 @@ type Bypass interface { Contains(addr string) bool } -type bypassList struct { +type bypassGroup struct { bypasses []Bypass } -func BypassList(bypasses ...Bypass) Bypass { - return &bypassList{ +func BypassGroup(bypasses ...Bypass) Bypass { + return &bypassGroup{ bypasses: bypasses, } } -func (p *bypassList) Contains(addr string) bool { +func (p *bypassGroup) Contains(addr string) bool { for _, bypass := range p.bypasses { if bypass != nil && bypass.Contains(addr) { return true diff --git a/chain/chain.go b/chain/chain.go index 7fb6859..8230eaf 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -1,18 +1,30 @@ package chain +import ( + "github.com/go-gost/core/metadata" +) + type Chainer interface { Route(network, address string) *Route } +type SelectableChainer interface { + Chainer + Selectable +} + type Chain struct { - name string - groups []*NodeGroup + name string + groups []*NodeGroup + marker Marker + metadata metadata.Metadata } func NewChain(name string, groups ...*NodeGroup) *Chain { return &Chain{ name: name, groups: groups, + marker: NewFailMarker(), } } @@ -20,6 +32,18 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) { c.groups = append(c.groups, group) } +func (c *Chain) WithMetadata(md metadata.Metadata) { + c.metadata = md +} + +func (c *Chain) Metadata() metadata.Metadata { + return c.metadata +} + +func (c *Chain) Marker() Marker { + return c.marker +} + func (c *Chain) Route(network, address string) (r *Route) { if c == nil || len(c.groups) == 0 { return @@ -38,11 +62,11 @@ func (c *Chain) Route(network, address string) (r *Route) { if node == nil { return } - if node.Transport.Multiplex() { - tr := node.Transport.Copy(). + if node.transport.Multiplex() { + tr := node.transport.Copy(). WithRoute(r) node = node.Copy() - node.Transport = tr + node.transport = tr r = &Route{} } @@ -50,3 +74,27 @@ func (c *Chain) Route(network, address string) (r *Route) { } return r } + +type ChainGroup struct { + Chains []SelectableChainer + Selector Selector[SelectableChainer] +} + +func (p *ChainGroup) Route(network, address string) *Route { + if chain := p.next(); chain != nil { + return chain.Route(network, address) + } + return nil +} + +func (p *ChainGroup) next() Chainer { + if p == nil || len(p.Chains) == 0 { + return nil + } + + s := p.Selector + if s == nil { + s = DefaultChainSelector + } + return s.Select(p.Chains...) +} diff --git a/chain/node.go b/chain/node.go index 957db27..b8fbc93 100644 --- a/chain/node.go +++ b/chain/node.go @@ -1,22 +1,62 @@ package chain import ( - "sync/atomic" - "time" - "github.com/go-gost/core/bypass" "github.com/go-gost/core/hosts" + "github.com/go-gost/core/metadata" "github.com/go-gost/core/resolver" ) type Node struct { - Name string - Addr string - Transport *Transport - Bypass bypass.Bypass - Resolver resolver.Resolver - Hosts hosts.HostMapper - Marker *FailMarker + Name string + Addr string + transport *Transport + bypass bypass.Bypass + resolver resolver.Resolver + hostMapper hosts.HostMapper + marker Marker + metadata metadata.Metadata +} + +func NewNode(name, addr string) *Node { + return &Node{ + Name: name, + Addr: addr, + marker: NewFailMarker(), + } +} + +func (node *Node) WithTransport(tr *Transport) *Node { + node.transport = tr + return node +} + +func (node *Node) WithBypass(bypass bypass.Bypass) *Node { + node.bypass = bypass + return node +} + +func (node *Node) WithResolver(reslv resolver.Resolver) *Node { + node.resolver = reslv + return node +} + +func (node *Node) WithHostMapper(m hosts.HostMapper) *Node { + node.hostMapper = m + return node +} + +func (node *Node) WithMetadata(md metadata.Metadata) *Node { + node.metadata = md + return node +} + +func (node *Node) Marker() Marker { + return node.marker +} + +func (node *Node) Metadata() metadata.Metadata { + return node.metadata } func (node *Node) Copy() *Node { @@ -27,7 +67,7 @@ func (node *Node) Copy() *Node { type NodeGroup struct { nodes []*Node - selector Selector + selector Selector[*Node] bypass bypass.Bypass } @@ -45,7 +85,7 @@ func (g *NodeGroup) Nodes() []*Node { return g.nodes } -func (g *NodeGroup) WithSelector(selector Selector) *NodeGroup { +func (g *NodeGroup) WithSelector(selector Selector[*Node]) *NodeGroup { g.selector = selector return g } @@ -58,7 +98,7 @@ func (g *NodeGroup) WithBypass(bypass bypass.Bypass) *NodeGroup { func (g *NodeGroup) FilterAddr(addr string) *NodeGroup { var nodes []*Node for _, node := range g.nodes { - if node.Bypass == nil || !node.Bypass.Contains(addr) { + if node.bypass == nil || !node.bypass.Contains(addr) { nodes = append(nodes, node) } } @@ -76,46 +116,8 @@ func (g *NodeGroup) Next() *Node { s := g.selector if s == nil { - s = DefaultSelector + s = DefaultNodeSelector } return s.Select(g.nodes...) } - -type FailMarker struct { - failTime int64 - failCount int64 -} - -func (m *FailMarker) FailTime() int64 { - if m == nil { - return 0 - } - - return atomic.LoadInt64(&m.failTime) -} - -func (m *FailMarker) FailCount() int64 { - if m == nil { - return 0 - } - - return atomic.LoadInt64(&m.failCount) -} - -func (m *FailMarker) Mark() { - if m == nil { - return - } - - atomic.AddInt64(&m.failCount, 1) - atomic.StoreInt64(&m.failTime, time.Now().Unix()) -} - -func (m *FailMarker) Reset() { - if m == nil { - return - } - - atomic.StoreInt64(&m.failCount, 0) -} diff --git a/chain/route.go b/chain/route.go index 7e6692c..4816a3a 100644 --- a/chain/route.go +++ b/chain/route.go @@ -54,7 +54,7 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO return nil, err } - cc, err := r.GetNode(r.Len()-1).Transport.Connect(ctx, conn, network, address) + cc, err := r.GetNode(r.Len()-1).transport.Connect(ctx, conn, network, address) if err != nil { conn.Close() return nil, err @@ -72,7 +72,7 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne return nil, err } - ln, err := r.GetNode(r.Len()-1).Transport.Bind(ctx, conn, network, address, opts...) + ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx, conn, network, address, opts...) if err != nil { conn.Close() return nil, err @@ -90,34 +90,54 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { node := r.nodes[0] defer func() { - if err != nil && r.chain != nil { - if v := metrics.GetCounter(metrics.MetricChainErrorsCounter, - metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil { - v.Inc() + if r.chain != nil { + marker := r.chain.Marker() + // chain error + if err != nil { + if marker != nil { + marker.Mark() + } + if v := metrics.GetCounter(metrics.MetricChainErrorsCounter, + metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil { + v.Inc() + } + } else { + if marker != nil { + marker.Reset() + } } } }() - addr, err := resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger) + addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger) + marker := node.Marker() if err != nil { - node.Marker.Mark() + if marker != nil { + marker.Mark() + } return } start := time.Now() - cc, err := node.Transport.Dial(ctx, addr) + cc, err := node.transport.Dial(ctx, addr) if err != nil { - node.Marker.Mark() + if marker != nil { + marker.Mark() + } return } - cn, err := node.Transport.Handshake(ctx, cc) + cn, err := node.transport.Handshake(ctx, cc) if err != nil { cc.Close() - node.Marker.Mark() + if marker != nil { + marker.Mark() + } return } - node.Marker.Reset() + if marker != nil { + marker.Reset() + } if r.chain != nil { if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver, @@ -128,25 +148,34 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { preNode := node for _, node := range r.nodes[1:] { - addr, err = resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger) + marker := node.Marker() + addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger) if err != nil { cn.Close() - node.Marker.Mark() + if marker != nil { + marker.Mark() + } return } - cc, err = preNode.Transport.Connect(ctx, cn, "tcp", addr) + cc, err = preNode.transport.Connect(ctx, cn, "tcp", addr) if err != nil { cn.Close() - node.Marker.Mark() + if marker != nil { + marker.Mark() + } return } - cc, err = node.Transport.Handshake(ctx, cc) + cc, err = node.transport.Handshake(ctx, cc) if err != nil { cn.Close() - node.Marker.Mark() + if marker != nil { + marker.Mark() + } return } - node.Marker.Reset() + if marker != nil { + marker.Reset() + } cn = cc preNode = node @@ -176,8 +205,8 @@ func (r *Route) Path() (path []*Node) { } for _, node := range r.nodes { - if node.Transport != nil && node.Transport.route != nil { - path = append(path, node.Transport.route.Path()...) + if node.transport != nil && node.transport.route != nil { + path = append(path, node.transport.route.Path()...) } path = append(path, node) } diff --git a/chain/selector.go b/chain/selector.go index b6e0414..1ca334a 100644 --- a/chain/selector.go +++ b/chain/selector.go @@ -5,6 +5,8 @@ import ( "sync" "sync/atomic" "time" + + "github.com/go-gost/core/metadata" ) // default options for FailFilter @@ -13,77 +15,86 @@ const ( ) var ( - DefaultSelector = NewSelector( - RoundRobinStrategy(), - FailFilter(1, DefaultFailTimeout), + DefaultNodeSelector = NewSelector( + RoundRobinStrategy[*Node](), + // FailFilter[*Node](1, DefaultFailTimeout), + ) + DefaultChainSelector = NewSelector( + RoundRobinStrategy[SelectableChainer](), + // FailFilter[SelectableChainer](1, DefaultFailTimeout), ) ) -type Selector interface { - Select(nodes ...*Node) *Node +type Selectable interface { + Marker() Marker + Metadata() metadata.Metadata } -type selector struct { - strategy Strategy - filters []Filter +type Selector[T any] interface { + Select(...T) T } -func NewSelector(strategy Strategy, filters ...Filter) Selector { - return &selector{ +type selector[T Selectable] struct { + strategy Strategy[T] + filters []Filter[T] +} + +func NewSelector[T Selectable](strategy Strategy[T], filters ...Filter[T]) Selector[T] { + return &selector[T]{ filters: filters, strategy: strategy, } } -func (s *selector) Select(nodes ...*Node) *Node { +func (s *selector[T]) Select(vs ...T) (v T) { for _, filter := range s.filters { - nodes = filter.Filter(nodes...) + vs = filter.Filter(vs...) } - if len(nodes) == 0 { - return nil + if len(vs) == 0 { + return } - return s.strategy.Apply(nodes...) + return s.strategy.Apply(vs...) } -type Strategy interface { - Apply(nodes ...*Node) *Node +type Strategy[T Selectable] interface { + Apply(...T) T } -type roundRobinStrategy struct { +type roundRobinStrategy[T Selectable] struct { counter uint64 } // RoundRobinStrategy is a strategy for node selector. // The node will be selected by round-robin algorithm. -func RoundRobinStrategy() Strategy { - return &roundRobinStrategy{} +func RoundRobinStrategy[T Selectable]() Strategy[T] { + return &roundRobinStrategy[T]{} } -func (s *roundRobinStrategy) Apply(nodes ...*Node) *Node { - if len(nodes) == 0 { - return nil +func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) { + if len(vs) == 0 { + return } n := atomic.AddUint64(&s.counter, 1) - 1 - return nodes[int(n%uint64(len(nodes)))] + return vs[int(n%uint64(len(vs)))] } -type randomStrategy struct { +type randomStrategy[T Selectable] struct { rand *rand.Rand mux sync.Mutex } // RandomStrategy is a strategy for node selector. // The node will be selected randomly. -func RandomStrategy() Strategy { - return &randomStrategy{ +func RandomStrategy[T Selectable]() Strategy[T] { + return &randomStrategy[T]{ rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } -func (s *randomStrategy) Apply(nodes ...*Node) *Node { - if len(nodes) == 0 { - return nil +func (s *randomStrategy[T]) Apply(vs ...T) (v T) { + if len(vs) == 0 { + return } s.mux.Lock() @@ -91,61 +102,114 @@ func (s *randomStrategy) Apply(nodes ...*Node) *Node { r := s.rand.Int() - return nodes[r%len(nodes)] + return vs[r%len(vs)] } -type fifoStrategy struct{} +type fifoStrategy[T Selectable] struct{} // FIFOStrategy is a strategy for node selector. // The node will be selected from first to last, // and will stick to the selected node until it is failed. -func FIFOStrategy() Strategy { - return &fifoStrategy{} +func FIFOStrategy[T Selectable]() Strategy[T] { + return &fifoStrategy[T]{} } // Apply applies the fifo strategy for the nodes. -func (s *fifoStrategy) Apply(nodes ...*Node) *Node { - if len(nodes) == 0 { - return nil +func (s *fifoStrategy[T]) Apply(vs ...T) (v T) { + if len(vs) == 0 { + return } - return nodes[0] + return vs[0] } -type Filter interface { - Filter(nodes ...*Node) []*Node +type Filter[T Selectable] interface { + Filter(...T) []T } -type failFilter struct { +type failFilter[T Selectable] struct { maxFails int failTimeout time.Duration } // FailFilter filters the dead node. // A node is marked as dead if its failed count is greater than MaxFails. -func FailFilter(maxFails int, timeout time.Duration) Filter { - return &failFilter{ +func FailFilter[T Selectable](maxFails int, timeout time.Duration) Filter[T] { + return &failFilter[T]{ maxFails: maxFails, failTimeout: timeout, } } // Filter filters dead nodes. -func (f *failFilter) Filter(nodes ...*Node) []*Node { +func (f *failFilter[T]) Filter(vs ...T) []T { maxFails := f.maxFails failTimeout := f.failTimeout if failTimeout == 0 { failTimeout = DefaultFailTimeout } - if len(nodes) <= 1 || maxFails <= 0 { - return nodes + if len(vs) <= 1 || maxFails <= 0 { + return vs } - var nl []*Node - for _, node := range nodes { - if node.Marker.FailCount() < int64(maxFails) || - time.Since(time.Unix(node.Marker.FailTime(), 0)) >= failTimeout { - nl = append(nl, node) + var l []T + for _, v := range vs { + if marker := v.Marker(); marker != nil { + if marker.Count() < int64(maxFails) || + time.Since(marker.Time()) >= failTimeout { + l = append(l, v) + } + } else { + l = append(l, v) } } - return nl + return l +} + +type Marker interface { + Time() time.Time + Count() int64 + Mark() + Reset() +} + +type failMarker struct { + failTime int64 + failCount int64 +} + +func NewFailMarker() Marker { + return &failMarker{} +} + +func (m *failMarker) Time() time.Time { + if m == nil { + return time.Time{} + } + + return time.Unix(atomic.LoadInt64(&m.failTime), 0) +} + +func (m *failMarker) Count() int64 { + if m == nil { + return 0 + } + + return atomic.LoadInt64(&m.failCount) +} + +func (m *failMarker) Mark() { + if m == nil { + return + } + + atomic.AddInt64(&m.failCount, 1) + atomic.StoreInt64(&m.failTime, time.Now().Unix()) +} + +func (m *failMarker) Reset() { + if m == nil { + return + } + + atomic.StoreInt64(&m.failCount, 0) }