add chain group

This commit is contained in:
ginuerzh 2022-09-02 10:54:42 +08:00
parent e77908a89e
commit b88ab3acdc
7 changed files with 286 additions and 143 deletions

View File

@ -4,17 +4,17 @@ type Admission interface {
Admit(addr string) bool Admit(addr string) bool
} }
type admissionList struct { type admissionGroup struct {
admissions []Admission admissions []Admission
} }
func AdmissionList(admissions ...Admission) Admission { func AdmissionGroup(admissions ...Admission) Admission {
return &admissionList{ return &admissionGroup{
admissions: admissions, admissions: admissions,
} }
} }
func (p *admissionList) Admit(addr string) bool { func (p *admissionGroup) Admit(addr string) bool {
for _, admission := range p.admissions { for _, admission := range p.admissions {
if admission != nil && !admission.Admit(addr) { if admission != nil && !admission.Admit(addr) {
return false return false

View File

@ -5,17 +5,17 @@ type Authenticator interface {
Authenticate(user, password string) bool Authenticate(user, password string) bool
} }
type authenticatorList struct { type authenticatorGroup struct {
authers []Authenticator authers []Authenticator
} }
func AuthenticatorList(authers ...Authenticator) Authenticator { func AuthenticatorGroup(authers ...Authenticator) Authenticator {
return &authenticatorList{ return &authenticatorGroup{
authers: authers, authers: authers,
} }
} }
func (p *authenticatorList) Authenticate(user, password string) bool { func (p *authenticatorGroup) Authenticate(user, password string) bool {
if len(p.authers) == 0 { if len(p.authers) == 0 {
return true return true
} }

View File

@ -6,17 +6,17 @@ type Bypass interface {
Contains(addr string) bool Contains(addr string) bool
} }
type bypassList struct { type bypassGroup struct {
bypasses []Bypass bypasses []Bypass
} }
func BypassList(bypasses ...Bypass) Bypass { func BypassGroup(bypasses ...Bypass) Bypass {
return &bypassList{ return &bypassGroup{
bypasses: bypasses, bypasses: bypasses,
} }
} }
func (p *bypassList) Contains(addr string) bool { func (p *bypassGroup) Contains(addr string) bool {
for _, bypass := range p.bypasses { for _, bypass := range p.bypasses {
if bypass != nil && bypass.Contains(addr) { if bypass != nil && bypass.Contains(addr) {
return true return true

View File

@ -1,18 +1,30 @@
package chain package chain
import (
"github.com/go-gost/core/metadata"
)
type Chainer interface { type Chainer interface {
Route(network, address string) *Route Route(network, address string) *Route
} }
type SelectableChainer interface {
Chainer
Selectable
}
type Chain struct { type Chain struct {
name string name string
groups []*NodeGroup groups []*NodeGroup
marker Marker
metadata metadata.Metadata
} }
func NewChain(name string, groups ...*NodeGroup) *Chain { func NewChain(name string, groups ...*NodeGroup) *Chain {
return &Chain{ return &Chain{
name: name, name: name,
groups: groups, groups: groups,
marker: NewFailMarker(),
} }
} }
@ -20,6 +32,18 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) {
c.groups = append(c.groups, group) c.groups = append(c.groups, group)
} }
func (c *Chain) WithMetadata(md metadata.Metadata) {
c.metadata = md
}
func (c *Chain) Metadata() metadata.Metadata {
return c.metadata
}
func (c *Chain) Marker() Marker {
return c.marker
}
func (c *Chain) Route(network, address string) (r *Route) { func (c *Chain) Route(network, address string) (r *Route) {
if c == nil || len(c.groups) == 0 { if c == nil || len(c.groups) == 0 {
return return
@ -38,11 +62,11 @@ func (c *Chain) Route(network, address string) (r *Route) {
if node == nil { if node == nil {
return return
} }
if node.Transport.Multiplex() { if node.transport.Multiplex() {
tr := node.Transport.Copy(). tr := node.transport.Copy().
WithRoute(r) WithRoute(r)
node = node.Copy() node = node.Copy()
node.Transport = tr node.transport = tr
r = &Route{} r = &Route{}
} }
@ -50,3 +74,27 @@ func (c *Chain) Route(network, address string) (r *Route) {
} }
return r return r
} }
type ChainGroup struct {
Chains []SelectableChainer
Selector Selector[SelectableChainer]
}
func (p *ChainGroup) Route(network, address string) *Route {
if chain := p.next(); chain != nil {
return chain.Route(network, address)
}
return nil
}
func (p *ChainGroup) next() Chainer {
if p == nil || len(p.Chains) == 0 {
return nil
}
s := p.Selector
if s == nil {
s = DefaultChainSelector
}
return s.Select(p.Chains...)
}

View File

@ -1,22 +1,62 @@
package chain package chain
import ( import (
"sync/atomic"
"time"
"github.com/go-gost/core/bypass" "github.com/go-gost/core/bypass"
"github.com/go-gost/core/hosts" "github.com/go-gost/core/hosts"
"github.com/go-gost/core/metadata"
"github.com/go-gost/core/resolver" "github.com/go-gost/core/resolver"
) )
type Node struct { type Node struct {
Name string Name string
Addr string Addr string
Transport *Transport transport *Transport
Bypass bypass.Bypass bypass bypass.Bypass
Resolver resolver.Resolver resolver resolver.Resolver
Hosts hosts.HostMapper hostMapper hosts.HostMapper
Marker *FailMarker marker Marker
metadata metadata.Metadata
}
func NewNode(name, addr string) *Node {
return &Node{
Name: name,
Addr: addr,
marker: NewFailMarker(),
}
}
func (node *Node) WithTransport(tr *Transport) *Node {
node.transport = tr
return node
}
func (node *Node) WithBypass(bypass bypass.Bypass) *Node {
node.bypass = bypass
return node
}
func (node *Node) WithResolver(reslv resolver.Resolver) *Node {
node.resolver = reslv
return node
}
func (node *Node) WithHostMapper(m hosts.HostMapper) *Node {
node.hostMapper = m
return node
}
func (node *Node) WithMetadata(md metadata.Metadata) *Node {
node.metadata = md
return node
}
func (node *Node) Marker() Marker {
return node.marker
}
func (node *Node) Metadata() metadata.Metadata {
return node.metadata
} }
func (node *Node) Copy() *Node { func (node *Node) Copy() *Node {
@ -27,7 +67,7 @@ func (node *Node) Copy() *Node {
type NodeGroup struct { type NodeGroup struct {
nodes []*Node nodes []*Node
selector Selector selector Selector[*Node]
bypass bypass.Bypass bypass bypass.Bypass
} }
@ -45,7 +85,7 @@ func (g *NodeGroup) Nodes() []*Node {
return g.nodes return g.nodes
} }
func (g *NodeGroup) WithSelector(selector Selector) *NodeGroup { func (g *NodeGroup) WithSelector(selector Selector[*Node]) *NodeGroup {
g.selector = selector g.selector = selector
return g return g
} }
@ -58,7 +98,7 @@ func (g *NodeGroup) WithBypass(bypass bypass.Bypass) *NodeGroup {
func (g *NodeGroup) FilterAddr(addr string) *NodeGroup { func (g *NodeGroup) FilterAddr(addr string) *NodeGroup {
var nodes []*Node var nodes []*Node
for _, node := range g.nodes { for _, node := range g.nodes {
if node.Bypass == nil || !node.Bypass.Contains(addr) { if node.bypass == nil || !node.bypass.Contains(addr) {
nodes = append(nodes, node) nodes = append(nodes, node)
} }
} }
@ -76,46 +116,8 @@ func (g *NodeGroup) Next() *Node {
s := g.selector s := g.selector
if s == nil { if s == nil {
s = DefaultSelector s = DefaultNodeSelector
} }
return s.Select(g.nodes...) return s.Select(g.nodes...)
} }
type FailMarker struct {
failTime int64
failCount int64
}
func (m *FailMarker) FailTime() int64 {
if m == nil {
return 0
}
return atomic.LoadInt64(&m.failTime)
}
func (m *FailMarker) FailCount() int64 {
if m == nil {
return 0
}
return atomic.LoadInt64(&m.failCount)
}
func (m *FailMarker) Mark() {
if m == nil {
return
}
atomic.AddInt64(&m.failCount, 1)
atomic.StoreInt64(&m.failTime, time.Now().Unix())
}
func (m *FailMarker) Reset() {
if m == nil {
return
}
atomic.StoreInt64(&m.failCount, 0)
}

View File

@ -54,7 +54,7 @@ func (r *Route) Dial(ctx context.Context, network, address string, opts ...DialO
return nil, err return nil, err
} }
cc, err := r.GetNode(r.Len()-1).Transport.Connect(ctx, conn, network, address) cc, err := r.GetNode(r.Len()-1).transport.Connect(ctx, conn, network, address)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
@ -72,7 +72,7 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne
return nil, err return nil, err
} }
ln, err := r.GetNode(r.Len()-1).Transport.Bind(ctx, conn, network, address, opts...) ln, err := r.GetNode(r.Len()-1).transport.Bind(ctx, conn, network, address, opts...)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
@ -90,34 +90,54 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
node := r.nodes[0] node := r.nodes[0]
defer func() { defer func() {
if err != nil && r.chain != nil { if r.chain != nil {
if v := metrics.GetCounter(metrics.MetricChainErrorsCounter, marker := r.chain.Marker()
metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil { // chain error
v.Inc() if err != nil {
if marker != nil {
marker.Mark()
}
if v := metrics.GetCounter(metrics.MetricChainErrorsCounter,
metrics.Labels{"chain": r.chain.name, "node": node.Name}); v != nil {
v.Inc()
}
} else {
if marker != nil {
marker.Reset()
}
} }
} }
}() }()
addr, err := resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger) addr, err := resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger)
marker := node.Marker()
if err != nil { if err != nil {
node.Marker.Mark() if marker != nil {
marker.Mark()
}
return return
} }
start := time.Now() start := time.Now()
cc, err := node.Transport.Dial(ctx, addr) cc, err := node.transport.Dial(ctx, addr)
if err != nil { if err != nil {
node.Marker.Mark() if marker != nil {
marker.Mark()
}
return return
} }
cn, err := node.Transport.Handshake(ctx, cc) cn, err := node.transport.Handshake(ctx, cc)
if err != nil { if err != nil {
cc.Close() cc.Close()
node.Marker.Mark() if marker != nil {
marker.Mark()
}
return return
} }
node.Marker.Reset() if marker != nil {
marker.Reset()
}
if r.chain != nil { if r.chain != nil {
if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver, if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver,
@ -128,25 +148,34 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
preNode := node preNode := node
for _, node := range r.nodes[1:] { for _, node := range r.nodes[1:] {
addr, err = resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger) marker := node.Marker()
addr, err = resolve(ctx, network, node.Addr, node.resolver, node.hostMapper, r.logger)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.Marker.Mark() if marker != nil {
marker.Mark()
}
return return
} }
cc, err = preNode.Transport.Connect(ctx, cn, "tcp", addr) cc, err = preNode.transport.Connect(ctx, cn, "tcp", addr)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.Marker.Mark() if marker != nil {
marker.Mark()
}
return return
} }
cc, err = node.Transport.Handshake(ctx, cc) cc, err = node.transport.Handshake(ctx, cc)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.Marker.Mark() if marker != nil {
marker.Mark()
}
return return
} }
node.Marker.Reset() if marker != nil {
marker.Reset()
}
cn = cc cn = cc
preNode = node preNode = node
@ -176,8 +205,8 @@ func (r *Route) Path() (path []*Node) {
} }
for _, node := range r.nodes { for _, node := range r.nodes {
if node.Transport != nil && node.Transport.route != nil { if node.transport != nil && node.transport.route != nil {
path = append(path, node.Transport.route.Path()...) path = append(path, node.transport.route.Path()...)
} }
path = append(path, node) path = append(path, node)
} }

View File

@ -5,6 +5,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-gost/core/metadata"
) )
// default options for FailFilter // default options for FailFilter
@ -13,77 +15,86 @@ const (
) )
var ( var (
DefaultSelector = NewSelector( DefaultNodeSelector = NewSelector(
RoundRobinStrategy(), RoundRobinStrategy[*Node](),
FailFilter(1, DefaultFailTimeout), // FailFilter[*Node](1, DefaultFailTimeout),
)
DefaultChainSelector = NewSelector(
RoundRobinStrategy[SelectableChainer](),
// FailFilter[SelectableChainer](1, DefaultFailTimeout),
) )
) )
type Selector interface { type Selectable interface {
Select(nodes ...*Node) *Node Marker() Marker
Metadata() metadata.Metadata
} }
type selector struct { type Selector[T any] interface {
strategy Strategy Select(...T) T
filters []Filter
} }
func NewSelector(strategy Strategy, filters ...Filter) Selector { type selector[T Selectable] struct {
return &selector{ strategy Strategy[T]
filters []Filter[T]
}
func NewSelector[T Selectable](strategy Strategy[T], filters ...Filter[T]) Selector[T] {
return &selector[T]{
filters: filters, filters: filters,
strategy: strategy, strategy: strategy,
} }
} }
func (s *selector) Select(nodes ...*Node) *Node { func (s *selector[T]) Select(vs ...T) (v T) {
for _, filter := range s.filters { for _, filter := range s.filters {
nodes = filter.Filter(nodes...) vs = filter.Filter(vs...)
} }
if len(nodes) == 0 { if len(vs) == 0 {
return nil return
} }
return s.strategy.Apply(nodes...) return s.strategy.Apply(vs...)
} }
type Strategy interface { type Strategy[T Selectable] interface {
Apply(nodes ...*Node) *Node Apply(...T) T
} }
type roundRobinStrategy struct { type roundRobinStrategy[T Selectable] struct {
counter uint64 counter uint64
} }
// RoundRobinStrategy is a strategy for node selector. // RoundRobinStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm. // The node will be selected by round-robin algorithm.
func RoundRobinStrategy() Strategy { func RoundRobinStrategy[T Selectable]() Strategy[T] {
return &roundRobinStrategy{} return &roundRobinStrategy[T]{}
} }
func (s *roundRobinStrategy) Apply(nodes ...*Node) *Node { func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
if len(nodes) == 0 { if len(vs) == 0 {
return nil return
} }
n := atomic.AddUint64(&s.counter, 1) - 1 n := atomic.AddUint64(&s.counter, 1) - 1
return nodes[int(n%uint64(len(nodes)))] return vs[int(n%uint64(len(vs)))]
} }
type randomStrategy struct { type randomStrategy[T Selectable] struct {
rand *rand.Rand rand *rand.Rand
mux sync.Mutex mux sync.Mutex
} }
// RandomStrategy is a strategy for node selector. // RandomStrategy is a strategy for node selector.
// The node will be selected randomly. // The node will be selected randomly.
func RandomStrategy() Strategy { func RandomStrategy[T Selectable]() Strategy[T] {
return &randomStrategy{ return &randomStrategy[T]{
rand: rand.New(rand.NewSource(time.Now().UnixNano())), rand: rand.New(rand.NewSource(time.Now().UnixNano())),
} }
} }
func (s *randomStrategy) Apply(nodes ...*Node) *Node { func (s *randomStrategy[T]) Apply(vs ...T) (v T) {
if len(nodes) == 0 { if len(vs) == 0 {
return nil return
} }
s.mux.Lock() s.mux.Lock()
@ -91,61 +102,114 @@ func (s *randomStrategy) Apply(nodes ...*Node) *Node {
r := s.rand.Int() r := s.rand.Int()
return nodes[r%len(nodes)] return vs[r%len(vs)]
} }
type fifoStrategy struct{} type fifoStrategy[T Selectable] struct{}
// FIFOStrategy is a strategy for node selector. // FIFOStrategy is a strategy for node selector.
// The node will be selected from first to last, // The node will be selected from first to last,
// and will stick to the selected node until it is failed. // and will stick to the selected node until it is failed.
func FIFOStrategy() Strategy { func FIFOStrategy[T Selectable]() Strategy[T] {
return &fifoStrategy{} return &fifoStrategy[T]{}
} }
// Apply applies the fifo strategy for the nodes. // Apply applies the fifo strategy for the nodes.
func (s *fifoStrategy) Apply(nodes ...*Node) *Node { func (s *fifoStrategy[T]) Apply(vs ...T) (v T) {
if len(nodes) == 0 { if len(vs) == 0 {
return nil return
} }
return nodes[0] return vs[0]
} }
type Filter interface { type Filter[T Selectable] interface {
Filter(nodes ...*Node) []*Node Filter(...T) []T
} }
type failFilter struct { type failFilter[T Selectable] struct {
maxFails int maxFails int
failTimeout time.Duration failTimeout time.Duration
} }
// FailFilter filters the dead node. // FailFilter filters the dead node.
// A node is marked as dead if its failed count is greater than MaxFails. // A node is marked as dead if its failed count is greater than MaxFails.
func FailFilter(maxFails int, timeout time.Duration) Filter { func FailFilter[T Selectable](maxFails int, timeout time.Duration) Filter[T] {
return &failFilter{ return &failFilter[T]{
maxFails: maxFails, maxFails: maxFails,
failTimeout: timeout, failTimeout: timeout,
} }
} }
// Filter filters dead nodes. // Filter filters dead nodes.
func (f *failFilter) Filter(nodes ...*Node) []*Node { func (f *failFilter[T]) Filter(vs ...T) []T {
maxFails := f.maxFails maxFails := f.maxFails
failTimeout := f.failTimeout failTimeout := f.failTimeout
if failTimeout == 0 { if failTimeout == 0 {
failTimeout = DefaultFailTimeout failTimeout = DefaultFailTimeout
} }
if len(nodes) <= 1 || maxFails <= 0 { if len(vs) <= 1 || maxFails <= 0 {
return nodes return vs
} }
var nl []*Node var l []T
for _, node := range nodes { for _, v := range vs {
if node.Marker.FailCount() < int64(maxFails) || if marker := v.Marker(); marker != nil {
time.Since(time.Unix(node.Marker.FailTime(), 0)) >= failTimeout { if marker.Count() < int64(maxFails) ||
nl = append(nl, node) time.Since(marker.Time()) >= failTimeout {
l = append(l, v)
}
} else {
l = append(l, v)
} }
} }
return nl return l
}
type Marker interface {
Time() time.Time
Count() int64
Mark()
Reset()
}
type failMarker struct {
failTime int64
failCount int64
}
func NewFailMarker() Marker {
return &failMarker{}
}
func (m *failMarker) Time() time.Time {
if m == nil {
return time.Time{}
}
return time.Unix(atomic.LoadInt64(&m.failTime), 0)
}
func (m *failMarker) Count() int64 {
if m == nil {
return 0
}
return atomic.LoadInt64(&m.failCount)
}
func (m *failMarker) Mark() {
if m == nil {
return
}
atomic.AddInt64(&m.failCount, 1)
atomic.StoreInt64(&m.failTime, time.Now().Unix())
}
func (m *failMarker) Reset() {
if m == nil {
return
}
atomic.StoreInt64(&m.failCount, 0)
} }