update selector

This commit is contained in:
ginuerzh 2022-09-04 13:24:32 +08:00
parent 05bfeb8a0f
commit a04c8b45f3
19 changed files with 97 additions and 81 deletions

View File

@ -12,7 +12,7 @@ import (
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
) )
func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
if cfg == nil { if cfg == nil {
return nil, nil return nil, nil
} }

View File

@ -89,26 +89,26 @@ func parseAuth(cfg *config.AuthConfig) *url.Userinfo {
return url.UserPassword(cfg.Username, cfg.Password) return url.UserPassword(cfg.Username, cfg.Password)
} }
func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.SelectableChainer] { func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.Chainer] {
if cfg == nil { if cfg == nil {
return nil return nil
} }
var strategy selector.Strategy[chain.SelectableChainer] var strategy selector.Strategy[chain.Chainer]
switch cfg.Strategy { switch cfg.Strategy {
case "round", "rr": case "round", "rr":
strategy = xs.RoundRobinStrategy[chain.SelectableChainer]() strategy = xs.RoundRobinStrategy[chain.Chainer]()
case "random", "rand": case "random", "rand":
strategy = xs.RandomStrategy[chain.SelectableChainer]() strategy = xs.RandomStrategy[chain.Chainer]()
case "fifo", "ha": case "fifo", "ha":
strategy = xs.FIFOStrategy[chain.SelectableChainer]() strategy = xs.FIFOStrategy[chain.Chainer]()
default: default:
strategy = xs.RoundRobinStrategy[chain.SelectableChainer]() strategy = xs.RoundRobinStrategy[chain.Chainer]()
} }
return xs.NewSelector( return xs.NewSelector(
strategy, strategy,
xs.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout), xs.FailFilter[chain.Chainer](cfg.MaxFails, cfg.FailTimeout),
xs.BackupFilter[chain.SelectableChainer](), xs.BackupFilter[chain.Chainer](),
) )
} }
@ -311,10 +311,10 @@ func defaultNodeSelector() selector.Selector[*chain.Node] {
) )
} }
func defaultChainSelector() selector.Selector[chain.SelectableChainer] { func defaultChainSelector() selector.Selector[chain.Chainer] {
return xs.NewSelector( return xs.NewSelector(
xs.RoundRobinStrategy[chain.SelectableChainer](), xs.RoundRobinStrategy[chain.Chainer](),
xs.FailFilter[chain.SelectableChainer](xs.DefaultMaxFails, xs.DefaultFailTimeout), xs.FailFilter[chain.Chainer](xs.DefaultMaxFails, xs.DefaultFailTimeout),
xs.BackupFilter[chain.SelectableChainer](), xs.BackupFilter[chain.Chainer](),
) )
} }

View File

@ -242,8 +242,8 @@ func admissionList(name string, names ...string) []admission.Admission {
} }
func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
var chains []chain.SelectableChainer var chains []chain.Chainer
var sel selector.Selector[chain.SelectableChainer] var sel selector.Selector[chain.Chainer]
if c := registry.ChainRegistry().Get(name); c != nil { if c := registry.ChainRegistry().Get(name); c != nil {
chains = append(chains, c) chains = append(chains, c)

View File

@ -90,7 +90,7 @@ func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, ad
defer conn.SetDeadline(time.Time{}) defer conn.SetDeadline(time.Time{})
} }
client := v.GetMetadata().Get("client").(*http.Client) client := v.Metadata().Get("client").(*http.Client)
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
log.Error(err) log.Error(err)

View File

@ -51,7 +51,7 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
} }
// GetMetadata implements metadata.Metadatable interface. // Metadata implements metadata.Metadatable interface.
func (c *conn) GetMetadata() mdata.Metadata { func (c *conn) Metadata() mdata.Metadata {
return c.md return c.md
} }

2
go.mod
View File

@ -6,7 +6,7 @@ require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/gin-contrib/cors v1.3.1 github.com/gin-contrib/cors v1.3.1
github.com/gin-gonic/gin v1.7.7 github.com/gin-gonic/gin v1.7.7
github.com/go-gost/core v0.0.0-20220902092328-713671067369 github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692
github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7

4
go.sum
View File

@ -96,8 +96,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gost/core v0.0.0-20220902092328-713671067369 h1:qPZgaT7p3WP06X0uVGv5bVxD2DUP7x+RiMaYxJyQuwI= github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692 h1:exs+esWEKuK/ZtmaIiUGxHmC1FG2YZSUZOLls0t2O4I=
github.com/go-gost/core v0.0.0-20220902092328-713671067369/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=

View File

@ -81,7 +81,7 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle
log.Error(err) log.Error(err)
return err return err
} }
md := v.GetMetadata() md := v.Metadata()
return h.roundTrip(ctx, return h.roundTrip(ctx,
md.Get("w").(http.ResponseWriter), md.Get("w").(http.ResponseWriter),
md.Get("r").(*http.Request), md.Get("r").(*http.Request),

View File

@ -118,7 +118,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr)
} }
config := v.GetMetadata().Get("config").(*tap_util.Config) config := v.Metadata().Get("config").(*tap_util.Config)
h.handleLoop(ctx, conn, raddr, config, log) h.handleLoop(ctx, conn, raddr, config, log)
return nil return nil
} }

View File

@ -68,7 +68,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
log.Error(err) log.Error(err)
return err return err
} }
config := v.GetMetadata().Get("config").(*tun_util.Config) config := v.Metadata().Get("config").(*tun_util.Config)
start := time.Now() start := time.Now()
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{

View File

@ -60,7 +60,7 @@ func (c *conn) Done() <-chan struct{} {
return c.closed return c.closed
} }
// GetMetadata implements metadata.Metadatable interface. // Metadata implements metadata.Metadatable interface.
func (c *conn) GetMetadata() mdata.Metadata { func (c *conn) Metadata() mdata.Metadata {
return c.md return c.md
} }

View File

@ -52,8 +52,8 @@ type metadataConn struct {
md mdata.Metadata md mdata.Metadata
} }
// GetMetadata implements metadata.Metadatable interface. // Metadata implements metadata.Metadatable interface.
func (c *metadataConn) GetMetadata() mdata.Metadata { func (c *metadataConn) Metadata() mdata.Metadata {
return c.md return c.md
} }

View File

@ -57,8 +57,8 @@ type metadataConn struct {
md mdata.Metadata md mdata.Metadata
} }
// GetMetadata implements metadata.Metadatable interface. // Metadata implements metadata.Metadatable interface.
func (c *metadataConn) GetMetadata() mdata.Metadata { func (c *metadataConn) Metadata() mdata.Metadata {
return c.md return c.md
} }

View File

@ -12,20 +12,20 @@ type chainRegistry struct {
registry registry
} }
func (r *chainRegistry) Register(name string, v chain.SelectableChainer) error { func (r *chainRegistry) Register(name string, v chain.Chainer) error {
return r.registry.Register(name, v) return r.registry.Register(name, v)
} }
func (r *chainRegistry) Get(name string) chain.SelectableChainer { func (r *chainRegistry) Get(name string) chain.Chainer {
if name != "" { if name != "" {
return &chainWrapper{name: name, r: r} return &chainWrapper{name: name, r: r}
} }
return nil return nil
} }
func (r *chainRegistry) get(name string) chain.SelectableChainer { func (r *chainRegistry) get(name string) chain.Chainer {
if v := r.registry.Get(name); v != nil { if v := r.registry.Get(name); v != nil {
return v.(chain.SelectableChainer) return v.(chain.Chainer)
} }
return nil return nil
} }
@ -40,7 +40,10 @@ func (w *chainWrapper) Marker() selector.Marker {
if v == nil { if v == nil {
return nil return nil
} }
return v.Marker() if mi, ok := v.(selector.Markable); ok {
return mi.Marker()
}
return nil
} }
func (w *chainWrapper) Metadata() metadata.Metadata { func (w *chainWrapper) Metadata() metadata.Metadata {
@ -48,7 +51,11 @@ func (w *chainWrapper) Metadata() metadata.Metadata {
if v == nil { if v == nil {
return nil return nil
} }
return v.Metadata()
if mi, ok := v.(metadata.Metadatable); ok {
return mi.Metadata()
}
return nil
} }
func (w *chainWrapper) Route(ctx context.Context, network, address string) chain.Route { func (w *chainWrapper) Route(ctx context.Context, network, address string) chain.Route {

View File

@ -26,7 +26,7 @@ var (
connectorReg Registry[NewConnector] = &connectorRegistry{} connectorReg Registry[NewConnector] = &connectorRegistry{}
serviceReg Registry[service.Service] = &serviceRegistry{} serviceReg Registry[service.Service] = &serviceRegistry{}
chainReg Registry[chain.SelectableChainer] = &chainRegistry{} chainReg Registry[chain.Chainer] = &chainRegistry{}
autherReg Registry[auth.Authenticator] = &autherRegistry{} autherReg Registry[auth.Authenticator] = &autherRegistry{}
admissionReg Registry[admission.Admission] = &admissionRegistry{} admissionReg Registry[admission.Admission] = &admissionRegistry{}
bypassReg Registry[bypass.Bypass] = &bypassRegistry{} bypassReg Registry[bypass.Bypass] = &bypassRegistry{}
@ -99,7 +99,7 @@ func ServiceRegistry() Registry[service.Service] {
return serviceReg return serviceReg
} }
func ChainRegistry() Registry[chain.SelectableChainer] { func ChainRegistry() Registry[chain.Chainer] {
return chainReg return chainReg
} }

View File

@ -4,18 +4,19 @@ import (
"context" "context"
"time" "time"
"github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/core/selector" "github.com/go-gost/core/selector"
) )
type failFilter[T selector.Selectable] struct { type failFilter[T any] struct {
maxFails int maxFails int
failTimeout time.Duration failTimeout time.Duration
} }
// FailFilter filters the dead objects. // FailFilter filters the dead objects.
// An object is marked as dead if its failed count is greater than MaxFails. // An object is marked as dead if its failed count is greater than MaxFails.
func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) selector.Filter[T] { func FailFilter[T any](maxFails int, timeout time.Duration) selector.Filter[T] {
return &failFilter[T]{ return &failFilter[T]{
maxFails: maxFails, maxFails: maxFails,
failTimeout: timeout, failTimeout: timeout,
@ -31,7 +32,8 @@ func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T {
for _, v := range vs { for _, v := range vs {
maxFails := f.maxFails maxFails := f.maxFails
failTimeout := f.failTimeout failTimeout := f.failTimeout
if md := v.Metadata(); md != nil { if mi, _ := any(v).(metadata.Metadatable); mi != nil {
if md := mi.Metadata(); md != nil {
if md.IsExists(labelMaxFails) { if md.IsExists(labelMaxFails) {
maxFails = mdutil.GetInt(md, labelMaxFails) maxFails = mdutil.GetInt(md, labelMaxFails)
} }
@ -39,6 +41,7 @@ func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T {
failTimeout = mdutil.GetDuration(md, labelFailTimeout) failTimeout = mdutil.GetDuration(md, labelFailTimeout)
} }
} }
}
if maxFails <= 0 { if maxFails <= 0 {
maxFails = 1 maxFails = 1
} }
@ -46,23 +49,25 @@ func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T {
failTimeout = DefaultFailTimeout failTimeout = DefaultFailTimeout
} }
if marker := v.Marker(); marker != nil { if mi, _ := any(v).(selector.Markable); mi != nil {
if marker := mi.Marker(); marker != nil {
if marker.Count() < int64(maxFails) || if marker.Count() < int64(maxFails) ||
time.Since(marker.Time()) >= failTimeout { time.Since(marker.Time()) >= failTimeout {
l = append(l, v) l = append(l, v)
} }
} else { continue
l = append(l, v)
} }
} }
l = append(l, v)
}
return l return l
} }
type backupFilter[T selector.Selectable] struct{} type backupFilter[T any] struct{}
// BackupFilter filters the backup objects. // BackupFilter filters the backup objects.
// An object is marked as backup if its metadata has backup flag. // An object is marked as backup if its metadata has backup flag.
func BackupFilter[T selector.Selectable]() selector.Filter[T] { func BackupFilter[T any]() selector.Filter[T] {
return &backupFilter[T]{} return &backupFilter[T]{}
} }
@ -74,12 +79,14 @@ func (f *backupFilter[T]) Filter(ctx context.Context, vs ...T) []T {
var l, backups []T var l, backups []T
for _, v := range vs { for _, v := range vs {
if mdutil.GetBool(v.Metadata(), labelBackup) { if mi, _ := any(v).(metadata.Metadatable); mi != nil {
if mdutil.GetBool(mi.Metadata(), labelBackup) {
backups = append(backups, v) backups = append(backups, v)
} else { continue
l = append(l, v)
} }
} }
l = append(l, v)
}
if len(l) == 0 { if len(l) == 0 {
return backups return backups

View File

@ -20,12 +20,12 @@ const (
labelFailTimeout = "failTimeout" labelFailTimeout = "failTimeout"
) )
type defaultSelector[T selector.Selectable] struct { type defaultSelector[T any] struct {
strategy selector.Strategy[T] strategy selector.Strategy[T]
filters []selector.Filter[T] filters []selector.Filter[T]
} }
func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] { func NewSelector[T any](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] {
return &defaultSelector[T]{ return &defaultSelector[T]{
filters: filters, filters: filters,
strategy: strategy, strategy: strategy,

View File

@ -7,18 +7,19 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/core/selector" "github.com/go-gost/core/selector"
sx "github.com/go-gost/x/internal/util/selector" sx "github.com/go-gost/x/internal/util/selector"
) )
type roundRobinStrategy[T selector.Selectable] struct { type roundRobinStrategy[T any] struct {
counter uint64 counter uint64
} }
// RoundRobinStrategy is a strategy for node selector. // RoundRobinStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm. // The node will be selected by round-robin algorithm.
func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] { func RoundRobinStrategy[T any]() selector.Strategy[T] {
return &roundRobinStrategy[T]{} return &roundRobinStrategy[T]{}
} }
@ -31,14 +32,14 @@ func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
return vs[int(n%uint64(len(vs)))] return vs[int(n%uint64(len(vs)))]
} }
type randomStrategy[T selector.Selectable] struct { type randomStrategy[T any] struct {
rw *randomWeighted[T] rw *randomWeighted[T]
mu sync.Mutex mu sync.Mutex
} }
// RandomStrategy is a strategy for node selector. // RandomStrategy is a strategy for node selector.
// The node will be selected randomly. // The node will be selected randomly.
func RandomStrategy[T selector.Selectable]() selector.Strategy[T] { func RandomStrategy[T any]() selector.Strategy[T] {
return &randomStrategy[T]{ return &randomStrategy[T]{
rw: newRandomWeighted[T](), rw: newRandomWeighted[T](),
} }
@ -54,7 +55,10 @@ func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
s.rw.Reset() s.rw.Reset()
for i := range vs { for i := range vs {
weight := mdutil.GetInt(vs[i].Metadata(), labelWeight) weight := 0
if md, _ := any(vs[i]).(metadata.Metadatable); md != nil {
weight = mdutil.GetInt(md.Metadata(), labelWeight)
}
if weight <= 0 { if weight <= 0 {
weight = 1 weight = 1
} }
@ -64,12 +68,12 @@ func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
return s.rw.Next() return s.rw.Next()
} }
type fifoStrategy[T selector.Selectable] struct{} type fifoStrategy[T any] struct{}
// FIFOStrategy is a strategy for node selector. // FIFOStrategy is a strategy for node selector.
// The node will be selected from first to last, // The node will be selected from first to last,
// and will stick to the selected node until it is failed. // and will stick to the selected node until it is failed.
func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] { func FIFOStrategy[T any]() selector.Strategy[T] {
return &fifoStrategy[T]{} return &fifoStrategy[T]{}
} }
@ -81,12 +85,12 @@ func (s *fifoStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
return vs[0] return vs[0]
} }
type hashStrategy[T selector.Selectable] struct { type hashStrategy[T any] struct {
r *rand.Rand r *rand.Rand
mu sync.Mutex mu sync.Mutex
} }
func HashStrategy[T selector.Selectable]() selector.Strategy[T] { func HashStrategy[T any]() selector.Strategy[T] {
return &hashStrategy[T]{ return &hashStrategy[T]{
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
} }

View File

@ -3,22 +3,20 @@ package selector
import ( import (
"math/rand" "math/rand"
"time" "time"
"github.com/go-gost/core/selector"
) )
type randomWeightedItem[T selector.Selectable] struct { type randomWeightedItem[T any] struct {
item T item T
weight int weight int
} }
type randomWeighted[T selector.Selectable] struct { type randomWeighted[T any] struct {
items []*randomWeightedItem[T] items []*randomWeightedItem[T]
sum int sum int
r *rand.Rand r *rand.Rand
} }
func newRandomWeighted[T selector.Selectable]() *randomWeighted[T] { func newRandomWeighted[T any]() *randomWeighted[T] {
return &randomWeighted[T]{ return &randomWeighted[T]{
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
} }