From c643014e120592e517f13f2b7c0412e5229ef48d Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 2 Sep 2022 15:00:07 +0800 Subject: [PATCH] add selector --- config/parsing/chain.go | 4 ++ config/parsing/parse.go | 38 +++++++++-------- config/parsing/service.go | 25 ++++++++--- listener/rtcp/listener.go | 3 +- listener/rudp/listener.go | 9 ++-- registry/chain.go | 5 ++- selector/filter.go | 87 +++++++++++++++++++++++++++++++++++++++ selector/selector.go | 54 ++++++++++++++++++++++++ selector/strategy.go | 72 ++++++++++++++++++++++++++++++++ 9 files changed, 264 insertions(+), 33 deletions(-) create mode 100644 selector/filter.go create mode 100644 selector/selector.go create mode 100644 selector/strategy.go diff --git a/config/parsing/chain.go b/config/parsing/chain.go index 5a53806..e8d26ce 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -10,6 +10,7 @@ import ( tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" "github.com/go-gost/x/registry" + xs "github.com/go-gost/x/selector" ) func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { @@ -140,6 +141,9 @@ func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { if s := parseNodeSelector(hop.Selector); s != nil { sel = s } + if sel == nil { + sel = xs.DefaultNodeSelector + } group.WithSelector(sel). WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...)) diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 0a8e2e8..a7ad3db 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -12,6 +12,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" "github.com/go-gost/core/resolver" + "github.com/go-gost/core/selector" admission_impl "github.com/go-gost/x/admission" auth_impl "github.com/go-gost/x/auth" bypass_impl "github.com/go-gost/x/bypass" @@ -21,6 +22,7 @@ import ( recorder_impl "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" resolver_impl "github.com/go-gost/x/resolver" + xs "github.com/go-gost/x/selector" ) func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { @@ -83,50 +85,50 @@ func parseAuth(cfg *config.AuthConfig) *url.Userinfo { return url.UserPassword(cfg.Username, cfg.Password) } -func parseChainSelector(cfg *config.SelectorConfig) chain.Selector[chain.SelectableChainer] { +func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.SelectableChainer] { if cfg == nil { return nil } - var strategy chain.Strategy[chain.SelectableChainer] + var strategy selector.Strategy[chain.SelectableChainer] switch cfg.Strategy { case "round", "rr": - strategy = chain.RoundRobinStrategy[chain.SelectableChainer]() + strategy = xs.RoundRobinStrategy[chain.SelectableChainer]() case "random", "rand": - strategy = chain.RandomStrategy[chain.SelectableChainer]() + strategy = xs.RandomStrategy[chain.SelectableChainer]() case "fifo", "ha": - strategy = chain.FIFOStrategy[chain.SelectableChainer]() + strategy = xs.FIFOStrategy[chain.SelectableChainer]() default: - strategy = chain.RoundRobinStrategy[chain.SelectableChainer]() + strategy = xs.RoundRobinStrategy[chain.SelectableChainer]() } - return chain.NewSelector( + return xs.NewSelector( strategy, - chain.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout), - chain.BackupFilter[chain.SelectableChainer](), + xs.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout), + xs.BackupFilter[chain.SelectableChainer](), ) } -func parseNodeSelector(cfg *config.SelectorConfig) chain.Selector[*chain.Node] { +func parseNodeSelector(cfg *config.SelectorConfig) selector.Selector[*chain.Node] { if cfg == nil { return nil } - var strategy chain.Strategy[*chain.Node] + var strategy selector.Strategy[*chain.Node] switch cfg.Strategy { case "round", "rr": - strategy = chain.RoundRobinStrategy[*chain.Node]() + strategy = xs.RoundRobinStrategy[*chain.Node]() case "random", "rand": - strategy = chain.RandomStrategy[*chain.Node]() + strategy = xs.RandomStrategy[*chain.Node]() case "fifo", "ha": - strategy = chain.FIFOStrategy[*chain.Node]() + strategy = xs.FIFOStrategy[*chain.Node]() default: - strategy = chain.RoundRobinStrategy[*chain.Node]() + strategy = xs.RoundRobinStrategy[*chain.Node]() } - return chain.NewSelector( + return xs.NewSelector( strategy, - chain.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout), - chain.BackupFilter[*chain.Node](), + xs.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout), + xs.BackupFilter[*chain.Node](), ) } diff --git a/config/parsing/service.go b/config/parsing/service.go index 593436d..11510c1 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -11,11 +11,13 @@ import ( "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" + "github.com/go-gost/core/selector" "github.com/go-gost/core/service" "github.com/go-gost/x/config" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" "github.com/go-gost/x/registry" + xs "github.com/go-gost/x/selector" ) func ParseService(cfg *config.ServiceConfig) (service.Service, error) { @@ -184,7 +186,11 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { } } - return group.WithSelector(parseNodeSelector(cfg.Selector)) + sel := parseNodeSelector(cfg.Selector) + if sel == nil { + sel = xs.DefaultNodeSelector + } + return group.WithSelector(sel) } func bypassList(name string, names ...string) []bypass.Bypass { @@ -228,18 +234,25 @@ func admissionList(name string, names ...string) []admission.Admission { } func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { - cg := &chain.ChainGroup{} + var chains []chain.SelectableChainer + var sel selector.Selector[chain.SelectableChainer] + if c := registry.ChainRegistry().Get(name); c != nil { - cg.Chains = append(cg.Chains, c) + chains = append(chains, c) } if group != nil { for _, s := range group.Chains { if c := registry.ChainRegistry().Get(s); c != nil { - cg.Chains = append(cg.Chains, c) + chains = append(chains, c) } } - cg.Selector = parseChainSelector(group.Selector) + sel = parseChainSelector(group.Selector) } - return cg + if sel == nil { + sel = xs.DefaultChainSelector + } + + return chain.NewChainGroup(chains...). + WithSelector(sel) } diff --git a/listener/rtcp/listener.go b/listener/rtcp/listener.go index cd7eb29..b327dd8 100644 --- a/listener/rtcp/listener.go +++ b/listener/rtcp/listener.go @@ -5,7 +5,6 @@ import ( "net" "github.com/go-gost/core/chain" - "github.com/go-gost/core/connector" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" @@ -71,7 +70,7 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) { if l.ln == nil { l.ln, err = l.router.Bind( context.Background(), "tcp", l.laddr.String(), - connector.MuxBindOption(true), + chain.MuxBindOption(true), ) if err != nil { return nil, listener.NewAcceptError(err) diff --git a/listener/rudp/listener.go b/listener/rudp/listener.go index 75898a9..65c8cb2 100644 --- a/listener/rudp/listener.go +++ b/listener/rudp/listener.go @@ -5,7 +5,6 @@ import ( "net" "github.com/go-gost/core/chain" - "github.com/go-gost/core/connector" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" @@ -72,10 +71,10 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) { if l.ln == nil { l.ln, err = l.router.Bind( context.Background(), "udp", l.laddr.String(), - connector.BacklogBindOption(l.md.backlog), - connector.UDPConnTTLBindOption(l.md.ttl), - connector.UDPDataBufferSizeBindOption(l.md.readBufferSize), - connector.UDPDataQueueSizeBindOption(l.md.readQueueSize), + chain.BacklogBindOption(l.md.backlog), + chain.UDPConnTTLBindOption(l.md.ttl), + chain.UDPDataBufferSizeBindOption(l.md.readBufferSize), + chain.UDPDataQueueSizeBindOption(l.md.readQueueSize), ) if err != nil { return nil, listener.NewAcceptError(err) diff --git a/registry/chain.go b/registry/chain.go index c71c7eb..cf440f0 100644 --- a/registry/chain.go +++ b/registry/chain.go @@ -3,6 +3,7 @@ package registry import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/metadata" + "github.com/go-gost/core/selector" ) type chainRegistry struct { @@ -32,7 +33,7 @@ type chainWrapper struct { r *chainRegistry } -func (w *chainWrapper) Marker() chain.Marker { +func (w *chainWrapper) Marker() selector.Marker { v := w.r.get(w.name) if v == nil { return nil @@ -48,7 +49,7 @@ func (w *chainWrapper) Metadata() metadata.Metadata { return v.Metadata() } -func (w *chainWrapper) Route(network, address string) *chain.Route { +func (w *chainWrapper) Route(network, address string) chain.Route { v := w.r.get(w.name) if v == nil { return nil diff --git a/selector/filter.go b/selector/filter.go new file mode 100644 index 0000000..bd0bb15 --- /dev/null +++ b/selector/filter.go @@ -0,0 +1,87 @@ +package selector + +import ( + "time" + + mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/core/selector" +) + +type failFilter[T selector.Selectable] struct { + maxFails int + failTimeout time.Duration +} + +// FailFilter filters the dead objects. +// An object is marked as dead if its failed count is greater than MaxFails. +func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) selector.Filter[T] { + return &failFilter[T]{ + maxFails: maxFails, + failTimeout: timeout, + } +} + +// Filter filters dead objects. +func (f *failFilter[T]) Filter(vs ...T) []T { + if len(vs) <= 1 { + return vs + } + var l []T + for _, v := range vs { + maxFails := f.maxFails + failTimeout := f.failTimeout + if md := v.Metadata(); md != nil { + if md.IsExists(labelMaxFails) { + maxFails = mdutil.GetInt(md, labelMaxFails) + } + if md.IsExists(labelFailTimeout) { + failTimeout = mdutil.GetDuration(md, labelFailTimeout) + } + } + if maxFails <= 0 { + maxFails = 1 + } + if failTimeout <= 0 { + failTimeout = DefaultFailTimeout + } + + if marker := v.Marker(); marker != nil { + if marker.Count() < int64(maxFails) || + time.Since(marker.Time()) >= failTimeout { + l = append(l, v) + } + } else { + l = append(l, v) + } + } + return l +} + +type backupFilter[T selector.Selectable] struct{} + +// BackupFilter filters the backup objects. +// An object is marked as backup if its metadata has backup flag. +func BackupFilter[T selector.Selectable]() selector.Filter[T] { + return &backupFilter[T]{} +} + +// Filter filters backup objects. +func (f *backupFilter[T]) Filter(vs ...T) []T { + if len(vs) <= 1 { + return vs + } + + var l, backups []T + for _, v := range vs { + if mdutil.GetBool(v.Metadata(), labelBackup) { + backups = append(backups, v) + } else { + l = append(l, v) + } + } + + if len(l) == 0 { + return backups + } + return l +} diff --git a/selector/selector.go b/selector/selector.go new file mode 100644 index 0000000..415b6a1 --- /dev/null +++ b/selector/selector.go @@ -0,0 +1,54 @@ +package selector + +import ( + "time" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/selector" +) + +// default options for FailFilter +const ( + DefaultMaxFails = 1 + DefaultFailTimeout = 10 * time.Second +) + +const ( + labelWeight = "weight" + labelBackup = "backup" + labelMaxFails = "maxFails" + labelFailTimeout = "failTimeout" +) + +var ( + DefaultNodeSelector = NewSelector( + RoundRobinStrategy[*chain.Node](), + // FailFilter[*Node](1, DefaultFailTimeout), + ) + DefaultChainSelector = NewSelector( + RoundRobinStrategy[chain.SelectableChainer](), + // FailFilter[SelectableChainer](1, DefaultFailTimeout), + ) +) + +type defaultSelector[T selector.Selectable] struct { + strategy selector.Strategy[T] + filters []selector.Filter[T] +} + +func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] { + return &defaultSelector[T]{ + filters: filters, + strategy: strategy, + } +} + +func (s *defaultSelector[T]) Select(vs ...T) (v T) { + for _, filter := range s.filters { + vs = filter.Filter(vs...) + } + if len(vs) == 0 { + return + } + return s.strategy.Apply(vs...) +} diff --git a/selector/strategy.go b/selector/strategy.go new file mode 100644 index 0000000..a271717 --- /dev/null +++ b/selector/strategy.go @@ -0,0 +1,72 @@ +package selector + +import ( + "math/rand" + "sync" + "sync/atomic" + "time" + + "github.com/go-gost/core/selector" +) + +type roundRobinStrategy[T selector.Selectable] struct { + counter uint64 +} + +// RoundRobinStrategy is a strategy for node selector. +// The node will be selected by round-robin algorithm. +func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] { + return &roundRobinStrategy[T]{} +} + +func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) { + if len(vs) == 0 { + return + } + + n := atomic.AddUint64(&s.counter, 1) - 1 + return vs[int(n%uint64(len(vs)))] +} + +type randomStrategy[T selector.Selectable] struct { + rand *rand.Rand + mux sync.Mutex +} + +// RandomStrategy is a strategy for node selector. +// The node will be selected randomly. +func RandomStrategy[T selector.Selectable]() selector.Strategy[T] { + return &randomStrategy[T]{ + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (s *randomStrategy[T]) Apply(vs ...T) (v T) { + if len(vs) == 0 { + return + } + + s.mux.Lock() + defer s.mux.Unlock() + + r := s.rand.Int() + + return vs[r%len(vs)] +} + +type fifoStrategy[T selector.Selectable] struct{} + +// FIFOStrategy is a strategy for node selector. +// The node will be selected from first to last, +// and will stick to the selected node until it is failed. +func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] { + return &fifoStrategy[T]{} +} + +// Apply applies the fifo strategy for the nodes. +func (s *fifoStrategy[T]) Apply(vs ...T) (v T) { + if len(vs) == 0 { + return + } + return vs[0] +}