From a04c8b45f30f9ddb70bd89fb7ad3ad73d625b626 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 4 Sep 2022 13:24:32 +0800 Subject: [PATCH] update selector --- config/parsing/chain.go | 2 +- config/parsing/parse.go | 24 +++++++++--------- config/parsing/service.go | 4 +-- connector/http2/connector.go | 2 +- dialer/http2/conn.go | 4 +-- go.mod | 2 +- go.sum | 4 +-- handler/http2/handler.go | 2 +- handler/tap/handler.go | 2 +- handler/tun/handler.go | 2 +- listener/http2/conn.go | 4 +-- listener/tap/conn.go | 4 +-- listener/tun/conn.go | 4 +-- registry/chain.go | 19 ++++++++++----- registry/registry.go | 18 +++++++------- selector/filter.go | 47 +++++++++++++++++++++--------------- selector/selector.go | 4 +-- selector/strategy.go | 22 ++++++++++------- selector/weighted.go | 8 +++--- 19 files changed, 97 insertions(+), 81 deletions(-) diff --git a/config/parsing/chain.go b/config/parsing/chain.go index 6d8601f..2b5c5a0 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -12,7 +12,7 @@ import ( "github.com/go-gost/x/registry" ) -func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { +func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { if cfg == nil { return nil, nil } diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 7153de1..e255b0e 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -89,26 +89,26 @@ func parseAuth(cfg *config.AuthConfig) *url.Userinfo { return url.UserPassword(cfg.Username, cfg.Password) } -func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.SelectableChainer] { +func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.Chainer] { if cfg == nil { return nil } - var strategy selector.Strategy[chain.SelectableChainer] + var strategy selector.Strategy[chain.Chainer] switch cfg.Strategy { case "round", "rr": - strategy = xs.RoundRobinStrategy[chain.SelectableChainer]() + strategy = xs.RoundRobinStrategy[chain.Chainer]() case "random", "rand": - strategy = xs.RandomStrategy[chain.SelectableChainer]() + strategy = xs.RandomStrategy[chain.Chainer]() case "fifo", "ha": - strategy = xs.FIFOStrategy[chain.SelectableChainer]() + strategy = xs.FIFOStrategy[chain.Chainer]() default: - strategy = xs.RoundRobinStrategy[chain.SelectableChainer]() + strategy = xs.RoundRobinStrategy[chain.Chainer]() } return xs.NewSelector( strategy, - xs.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout), - xs.BackupFilter[chain.SelectableChainer](), + xs.FailFilter[chain.Chainer](cfg.MaxFails, cfg.FailTimeout), + xs.BackupFilter[chain.Chainer](), ) } @@ -311,10 +311,10 @@ func defaultNodeSelector() selector.Selector[*chain.Node] { ) } -func defaultChainSelector() selector.Selector[chain.SelectableChainer] { +func defaultChainSelector() selector.Selector[chain.Chainer] { return xs.NewSelector( - xs.RoundRobinStrategy[chain.SelectableChainer](), - xs.FailFilter[chain.SelectableChainer](xs.DefaultMaxFails, xs.DefaultFailTimeout), - xs.BackupFilter[chain.SelectableChainer](), + xs.RoundRobinStrategy[chain.Chainer](), + xs.FailFilter[chain.Chainer](xs.DefaultMaxFails, xs.DefaultFailTimeout), + xs.BackupFilter[chain.Chainer](), ) } diff --git a/config/parsing/service.go b/config/parsing/service.go index 9eacab3..7c63d44 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -242,8 +242,8 @@ func admissionList(name string, names ...string) []admission.Admission { } func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { - var chains []chain.SelectableChainer - var sel selector.Selector[chain.SelectableChainer] + var chains []chain.Chainer + var sel selector.Selector[chain.Chainer] if c := registry.ChainRegistry().Get(name); c != nil { chains = append(chains, c) diff --git a/connector/http2/connector.go b/connector/http2/connector.go index 76dbff6..4a31dfb 100644 --- a/connector/http2/connector.go +++ b/connector/http2/connector.go @@ -90,7 +90,7 @@ func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, ad defer conn.SetDeadline(time.Time{}) } - client := v.GetMetadata().Get("client").(*http.Client) + client := v.Metadata().Get("client").(*http.Client) resp, err := client.Do(req) if err != nil { log.Error(err) diff --git a/dialer/http2/conn.go b/dialer/http2/conn.go index 76a95cc..2d58e10 100644 --- a/dialer/http2/conn.go +++ b/dialer/http2/conn.go @@ -51,7 +51,7 @@ func (c *conn) SetWriteDeadline(t time.Time) error { return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -// GetMetadata implements metadata.Metadatable interface. -func (c *conn) GetMetadata() mdata.Metadata { +// Metadata implements metadata.Metadatable interface. +func (c *conn) Metadata() mdata.Metadata { return c.md } diff --git a/go.mod b/go.mod index 939674f..b648b35 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.7.7 - github.com/go-gost/core v0.0.0-20220902092328-713671067369 + github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 diff --git a/go.sum b/go.sum index df6283b..1eb7ba3 100644 --- a/go.sum +++ b/go.sum @@ -96,8 +96,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20220902092328-713671067369 h1:qPZgaT7p3WP06X0uVGv5bVxD2DUP7x+RiMaYxJyQuwI= -github.com/go-gost/core v0.0.0-20220902092328-713671067369/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= +github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692 h1:exs+esWEKuK/ZtmaIiUGxHmC1FG2YZSUZOLls0t2O4I= +github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 42c3f57..5dc974a 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -81,7 +81,7 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle log.Error(err) return err } - md := v.GetMetadata() + md := v.Metadata() return h.roundTrip(ctx, md.Get("w").(http.ResponseWriter), md.Get("r").(*http.Request), diff --git a/handler/tap/handler.go b/handler/tap/handler.go index 962e033..ef3234c 100644 --- a/handler/tap/handler.go +++ b/handler/tap/handler.go @@ -118,7 +118,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) } - config := v.GetMetadata().Get("config").(*tap_util.Config) + config := v.Metadata().Get("config").(*tap_util.Config) h.handleLoop(ctx, conn, raddr, config, log) return nil } diff --git a/handler/tun/handler.go b/handler/tun/handler.go index a7b052e..794cf82 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -68,7 +68,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. log.Error(err) return err } - config := v.GetMetadata().Get("config").(*tun_util.Config) + config := v.Metadata().Get("config").(*tun_util.Config) start := time.Now() log = log.WithFields(map[string]any{ diff --git a/listener/http2/conn.go b/listener/http2/conn.go index 775ec9b..15f56f2 100644 --- a/listener/http2/conn.go +++ b/listener/http2/conn.go @@ -60,7 +60,7 @@ func (c *conn) Done() <-chan struct{} { return c.closed } -// GetMetadata implements metadata.Metadatable interface. -func (c *conn) GetMetadata() mdata.Metadata { +// Metadata implements metadata.Metadatable interface. +func (c *conn) Metadata() mdata.Metadata { return c.md } diff --git a/listener/tap/conn.go b/listener/tap/conn.go index 97057ed..7faeb31 100644 --- a/listener/tap/conn.go +++ b/listener/tap/conn.go @@ -52,8 +52,8 @@ type metadataConn struct { md mdata.Metadata } -// GetMetadata implements metadata.Metadatable interface. -func (c *metadataConn) GetMetadata() mdata.Metadata { +// Metadata implements metadata.Metadatable interface. +func (c *metadataConn) Metadata() mdata.Metadata { return c.md } diff --git a/listener/tun/conn.go b/listener/tun/conn.go index 7bed81b..20c0185 100644 --- a/listener/tun/conn.go +++ b/listener/tun/conn.go @@ -57,8 +57,8 @@ type metadataConn struct { md mdata.Metadata } -// GetMetadata implements metadata.Metadatable interface. -func (c *metadataConn) GetMetadata() mdata.Metadata { +// Metadata implements metadata.Metadatable interface. +func (c *metadataConn) Metadata() mdata.Metadata { return c.md } diff --git a/registry/chain.go b/registry/chain.go index 1062927..ac699b7 100644 --- a/registry/chain.go +++ b/registry/chain.go @@ -12,20 +12,20 @@ type chainRegistry struct { registry } -func (r *chainRegistry) Register(name string, v chain.SelectableChainer) error { +func (r *chainRegistry) Register(name string, v chain.Chainer) error { return r.registry.Register(name, v) } -func (r *chainRegistry) Get(name string) chain.SelectableChainer { +func (r *chainRegistry) Get(name string) chain.Chainer { if name != "" { return &chainWrapper{name: name, r: r} } return nil } -func (r *chainRegistry) get(name string) chain.SelectableChainer { +func (r *chainRegistry) get(name string) chain.Chainer { if v := r.registry.Get(name); v != nil { - return v.(chain.SelectableChainer) + return v.(chain.Chainer) } return nil } @@ -40,7 +40,10 @@ func (w *chainWrapper) Marker() selector.Marker { if v == nil { return nil } - return v.Marker() + if mi, ok := v.(selector.Markable); ok { + return mi.Marker() + } + return nil } func (w *chainWrapper) Metadata() metadata.Metadata { @@ -48,7 +51,11 @@ func (w *chainWrapper) Metadata() metadata.Metadata { if v == nil { return nil } - return v.Metadata() + + if mi, ok := v.(metadata.Metadatable); ok { + return mi.Metadata() + } + return nil } func (w *chainWrapper) Route(ctx context.Context, network, address string) chain.Route { diff --git a/registry/registry.go b/registry/registry.go index 1ca93e6..f9aa5b4 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -25,14 +25,14 @@ var ( dialerReg Registry[NewDialer] = &dialerRegistry{} connectorReg Registry[NewConnector] = &connectorRegistry{} - serviceReg Registry[service.Service] = &serviceRegistry{} - chainReg Registry[chain.SelectableChainer] = &chainRegistry{} - autherReg Registry[auth.Authenticator] = &autherRegistry{} - admissionReg Registry[admission.Admission] = &admissionRegistry{} - bypassReg Registry[bypass.Bypass] = &bypassRegistry{} - resolverReg Registry[resolver.Resolver] = &resolverRegistry{} - hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} - recorderReg Registry[recorder.Recorder] = &recorderRegistry{} + serviceReg Registry[service.Service] = &serviceRegistry{} + chainReg Registry[chain.Chainer] = &chainRegistry{} + autherReg Registry[auth.Authenticator] = &autherRegistry{} + admissionReg Registry[admission.Admission] = &admissionRegistry{} + bypassReg Registry[bypass.Bypass] = &bypassRegistry{} + resolverReg Registry[resolver.Resolver] = &resolverRegistry{} + hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} + recorderReg Registry[recorder.Recorder] = &recorderRegistry{} ) type Registry[T any] interface { @@ -99,7 +99,7 @@ func ServiceRegistry() Registry[service.Service] { return serviceReg } -func ChainRegistry() Registry[chain.SelectableChainer] { +func ChainRegistry() Registry[chain.Chainer] { return chainReg } diff --git a/selector/filter.go b/selector/filter.go index 077ae72..9d3785b 100644 --- a/selector/filter.go +++ b/selector/filter.go @@ -4,18 +4,19 @@ import ( "context" "time" + "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/core/selector" ) -type failFilter[T selector.Selectable] struct { +type failFilter[T any] struct { maxFails int failTimeout time.Duration } // FailFilter filters the dead objects. // An object is marked as dead if its failed count is greater than MaxFails. -func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) selector.Filter[T] { +func FailFilter[T any](maxFails int, timeout time.Duration) selector.Filter[T] { return &failFilter[T]{ maxFails: maxFails, failTimeout: timeout, @@ -31,12 +32,14 @@ func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T { for _, v := range vs { maxFails := f.maxFails failTimeout := f.failTimeout - if md := v.Metadata(); md != nil { - if md.IsExists(labelMaxFails) { - maxFails = mdutil.GetInt(md, labelMaxFails) - } - if md.IsExists(labelFailTimeout) { - failTimeout = mdutil.GetDuration(md, labelFailTimeout) + if mi, _ := any(v).(metadata.Metadatable); mi != nil { + if md := mi.Metadata(); md != nil { + if md.IsExists(labelMaxFails) { + maxFails = mdutil.GetInt(md, labelMaxFails) + } + if md.IsExists(labelFailTimeout) { + failTimeout = mdutil.GetDuration(md, labelFailTimeout) + } } } if maxFails <= 0 { @@ -46,23 +49,25 @@ func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T { failTimeout = DefaultFailTimeout } - if marker := v.Marker(); marker != nil { - if marker.Count() < int64(maxFails) || - time.Since(marker.Time()) >= failTimeout { - l = append(l, v) + if mi, _ := any(v).(selector.Markable); mi != nil { + if marker := mi.Marker(); marker != nil { + if marker.Count() < int64(maxFails) || + time.Since(marker.Time()) >= failTimeout { + l = append(l, v) + } + continue } - } else { - l = append(l, v) } + l = append(l, v) } return l } -type backupFilter[T selector.Selectable] struct{} +type backupFilter[T any] struct{} // BackupFilter filters the backup objects. // An object is marked as backup if its metadata has backup flag. -func BackupFilter[T selector.Selectable]() selector.Filter[T] { +func BackupFilter[T any]() selector.Filter[T] { return &backupFilter[T]{} } @@ -74,11 +79,13 @@ func (f *backupFilter[T]) Filter(ctx context.Context, vs ...T) []T { var l, backups []T for _, v := range vs { - if mdutil.GetBool(v.Metadata(), labelBackup) { - backups = append(backups, v) - } else { - l = append(l, v) + if mi, _ := any(v).(metadata.Metadatable); mi != nil { + if mdutil.GetBool(mi.Metadata(), labelBackup) { + backups = append(backups, v) + continue + } } + l = append(l, v) } if len(l) == 0 { diff --git a/selector/selector.go b/selector/selector.go index 53108c5..470b90a 100644 --- a/selector/selector.go +++ b/selector/selector.go @@ -20,12 +20,12 @@ const ( labelFailTimeout = "failTimeout" ) -type defaultSelector[T selector.Selectable] struct { +type defaultSelector[T any] struct { strategy selector.Strategy[T] filters []selector.Filter[T] } -func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] { +func NewSelector[T any](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] { return &defaultSelector[T]{ filters: filters, strategy: strategy, diff --git a/selector/strategy.go b/selector/strategy.go index 579df8a..bba51f1 100644 --- a/selector/strategy.go +++ b/selector/strategy.go @@ -7,18 +7,19 @@ import ( "sync/atomic" "time" + "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/core/selector" sx "github.com/go-gost/x/internal/util/selector" ) -type roundRobinStrategy[T selector.Selectable] struct { +type roundRobinStrategy[T any] struct { counter uint64 } // RoundRobinStrategy is a strategy for node selector. // The node will be selected by round-robin algorithm. -func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] { +func RoundRobinStrategy[T any]() selector.Strategy[T] { return &roundRobinStrategy[T]{} } @@ -31,14 +32,14 @@ func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { return vs[int(n%uint64(len(vs)))] } -type randomStrategy[T selector.Selectable] struct { +type randomStrategy[T any] struct { rw *randomWeighted[T] mu sync.Mutex } // RandomStrategy is a strategy for node selector. // The node will be selected randomly. -func RandomStrategy[T selector.Selectable]() selector.Strategy[T] { +func RandomStrategy[T any]() selector.Strategy[T] { return &randomStrategy[T]{ rw: newRandomWeighted[T](), } @@ -54,7 +55,10 @@ func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { s.rw.Reset() for i := range vs { - weight := mdutil.GetInt(vs[i].Metadata(), labelWeight) + weight := 0 + if md, _ := any(vs[i]).(metadata.Metadatable); md != nil { + weight = mdutil.GetInt(md.Metadata(), labelWeight) + } if weight <= 0 { weight = 1 } @@ -64,12 +68,12 @@ func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { return s.rw.Next() } -type fifoStrategy[T selector.Selectable] struct{} +type fifoStrategy[T any] struct{} // FIFOStrategy is a strategy for node selector. // The node will be selected from first to last, // and will stick to the selected node until it is failed. -func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] { +func FIFOStrategy[T any]() selector.Strategy[T] { return &fifoStrategy[T]{} } @@ -81,12 +85,12 @@ func (s *fifoStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { return vs[0] } -type hashStrategy[T selector.Selectable] struct { +type hashStrategy[T any] struct { r *rand.Rand mu sync.Mutex } -func HashStrategy[T selector.Selectable]() selector.Strategy[T] { +func HashStrategy[T any]() selector.Strategy[T] { return &hashStrategy[T]{ r: rand.New(rand.NewSource(time.Now().UnixNano())), } diff --git a/selector/weighted.go b/selector/weighted.go index 478dda1..956342e 100644 --- a/selector/weighted.go +++ b/selector/weighted.go @@ -3,22 +3,20 @@ package selector import ( "math/rand" "time" - - "github.com/go-gost/core/selector" ) -type randomWeightedItem[T selector.Selectable] struct { +type randomWeightedItem[T any] struct { item T weight int } -type randomWeighted[T selector.Selectable] struct { +type randomWeighted[T any] struct { items []*randomWeightedItem[T] sum int r *rand.Rand } -func newRandomWeighted[T selector.Selectable]() *randomWeighted[T] { +func newRandomWeighted[T any]() *randomWeighted[T] { return &randomWeighted[T]{ r: rand.New(rand.NewSource(time.Now().UnixNano())), }