update selector

This commit is contained in:
ginuerzh 2022-09-04 13:24:32 +08:00
parent 05bfeb8a0f
commit a04c8b45f3
19 changed files with 97 additions and 81 deletions

View File

@ -12,7 +12,7 @@ import (
"github.com/go-gost/x/registry"
)
func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
if cfg == nil {
return nil, nil
}

View File

@ -89,26 +89,26 @@ func parseAuth(cfg *config.AuthConfig) *url.Userinfo {
return url.UserPassword(cfg.Username, cfg.Password)
}
func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.SelectableChainer] {
func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.Chainer] {
if cfg == nil {
return nil
}
var strategy selector.Strategy[chain.SelectableChainer]
var strategy selector.Strategy[chain.Chainer]
switch cfg.Strategy {
case "round", "rr":
strategy = xs.RoundRobinStrategy[chain.SelectableChainer]()
strategy = xs.RoundRobinStrategy[chain.Chainer]()
case "random", "rand":
strategy = xs.RandomStrategy[chain.SelectableChainer]()
strategy = xs.RandomStrategy[chain.Chainer]()
case "fifo", "ha":
strategy = xs.FIFOStrategy[chain.SelectableChainer]()
strategy = xs.FIFOStrategy[chain.Chainer]()
default:
strategy = xs.RoundRobinStrategy[chain.SelectableChainer]()
strategy = xs.RoundRobinStrategy[chain.Chainer]()
}
return xs.NewSelector(
strategy,
xs.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout),
xs.BackupFilter[chain.SelectableChainer](),
xs.FailFilter[chain.Chainer](cfg.MaxFails, cfg.FailTimeout),
xs.BackupFilter[chain.Chainer](),
)
}
@ -311,10 +311,10 @@ func defaultNodeSelector() selector.Selector[*chain.Node] {
)
}
func defaultChainSelector() selector.Selector[chain.SelectableChainer] {
func defaultChainSelector() selector.Selector[chain.Chainer] {
return xs.NewSelector(
xs.RoundRobinStrategy[chain.SelectableChainer](),
xs.FailFilter[chain.SelectableChainer](xs.DefaultMaxFails, xs.DefaultFailTimeout),
xs.BackupFilter[chain.SelectableChainer](),
xs.RoundRobinStrategy[chain.Chainer](),
xs.FailFilter[chain.Chainer](xs.DefaultMaxFails, xs.DefaultFailTimeout),
xs.BackupFilter[chain.Chainer](),
)
}

View File

@ -242,8 +242,8 @@ func admissionList(name string, names ...string) []admission.Admission {
}
func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
var chains []chain.SelectableChainer
var sel selector.Selector[chain.SelectableChainer]
var chains []chain.Chainer
var sel selector.Selector[chain.Chainer]
if c := registry.ChainRegistry().Get(name); c != nil {
chains = append(chains, c)

View File

@ -90,7 +90,7 @@ func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, ad
defer conn.SetDeadline(time.Time{})
}
client := v.GetMetadata().Get("client").(*http.Client)
client := v.Metadata().Get("client").(*http.Client)
resp, err := client.Do(req)
if err != nil {
log.Error(err)

View File

@ -51,7 +51,7 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
// GetMetadata implements metadata.Metadatable interface.
func (c *conn) GetMetadata() mdata.Metadata {
// Metadata implements metadata.Metadatable interface.
func (c *conn) Metadata() mdata.Metadata {
return c.md
}

2
go.mod
View File

@ -6,7 +6,7 @@ require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/gin-contrib/cors v1.3.1
github.com/gin-gonic/gin v1.7.7
github.com/go-gost/core v0.0.0-20220902092328-713671067369
github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692
github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7

4
go.sum
View File

@ -96,8 +96,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gost/core v0.0.0-20220902092328-713671067369 h1:qPZgaT7p3WP06X0uVGv5bVxD2DUP7x+RiMaYxJyQuwI=
github.com/go-gost/core v0.0.0-20220902092328-713671067369/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692 h1:exs+esWEKuK/ZtmaIiUGxHmC1FG2YZSUZOLls0t2O4I=
github.com/go-gost/core v0.0.0-20220904052234-99adf4bb0692/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=

View File

@ -81,7 +81,7 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle
log.Error(err)
return err
}
md := v.GetMetadata()
md := v.Metadata()
return h.roundTrip(ctx,
md.Get("w").(http.ResponseWriter),
md.Get("r").(*http.Request),

View File

@ -118,7 +118,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr)
}
config := v.GetMetadata().Get("config").(*tap_util.Config)
config := v.Metadata().Get("config").(*tap_util.Config)
h.handleLoop(ctx, conn, raddr, config, log)
return nil
}

View File

@ -68,7 +68,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
log.Error(err)
return err
}
config := v.GetMetadata().Get("config").(*tun_util.Config)
config := v.Metadata().Get("config").(*tun_util.Config)
start := time.Now()
log = log.WithFields(map[string]any{

View File

@ -60,7 +60,7 @@ func (c *conn) Done() <-chan struct{} {
return c.closed
}
// GetMetadata implements metadata.Metadatable interface.
func (c *conn) GetMetadata() mdata.Metadata {
// Metadata implements metadata.Metadatable interface.
func (c *conn) Metadata() mdata.Metadata {
return c.md
}

View File

@ -52,8 +52,8 @@ type metadataConn struct {
md mdata.Metadata
}
// GetMetadata implements metadata.Metadatable interface.
func (c *metadataConn) GetMetadata() mdata.Metadata {
// Metadata implements metadata.Metadatable interface.
func (c *metadataConn) Metadata() mdata.Metadata {
return c.md
}

View File

@ -57,8 +57,8 @@ type metadataConn struct {
md mdata.Metadata
}
// GetMetadata implements metadata.Metadatable interface.
func (c *metadataConn) GetMetadata() mdata.Metadata {
// Metadata implements metadata.Metadatable interface.
func (c *metadataConn) Metadata() mdata.Metadata {
return c.md
}

View File

@ -12,20 +12,20 @@ type chainRegistry struct {
registry
}
func (r *chainRegistry) Register(name string, v chain.SelectableChainer) error {
func (r *chainRegistry) Register(name string, v chain.Chainer) error {
return r.registry.Register(name, v)
}
func (r *chainRegistry) Get(name string) chain.SelectableChainer {
func (r *chainRegistry) Get(name string) chain.Chainer {
if name != "" {
return &chainWrapper{name: name, r: r}
}
return nil
}
func (r *chainRegistry) get(name string) chain.SelectableChainer {
func (r *chainRegistry) get(name string) chain.Chainer {
if v := r.registry.Get(name); v != nil {
return v.(chain.SelectableChainer)
return v.(chain.Chainer)
}
return nil
}
@ -40,7 +40,10 @@ func (w *chainWrapper) Marker() selector.Marker {
if v == nil {
return nil
}
return v.Marker()
if mi, ok := v.(selector.Markable); ok {
return mi.Marker()
}
return nil
}
func (w *chainWrapper) Metadata() metadata.Metadata {
@ -48,7 +51,11 @@ func (w *chainWrapper) Metadata() metadata.Metadata {
if v == nil {
return nil
}
return v.Metadata()
if mi, ok := v.(metadata.Metadatable); ok {
return mi.Metadata()
}
return nil
}
func (w *chainWrapper) Route(ctx context.Context, network, address string) chain.Route {

View File

@ -25,14 +25,14 @@ var (
dialerReg Registry[NewDialer] = &dialerRegistry{}
connectorReg Registry[NewConnector] = &connectorRegistry{}
serviceReg Registry[service.Service] = &serviceRegistry{}
chainReg Registry[chain.SelectableChainer] = &chainRegistry{}
autherReg Registry[auth.Authenticator] = &autherRegistry{}
admissionReg Registry[admission.Admission] = &admissionRegistry{}
bypassReg Registry[bypass.Bypass] = &bypassRegistry{}
resolverReg Registry[resolver.Resolver] = &resolverRegistry{}
hostsReg Registry[hosts.HostMapper] = &hostsRegistry{}
recorderReg Registry[recorder.Recorder] = &recorderRegistry{}
serviceReg Registry[service.Service] = &serviceRegistry{}
chainReg Registry[chain.Chainer] = &chainRegistry{}
autherReg Registry[auth.Authenticator] = &autherRegistry{}
admissionReg Registry[admission.Admission] = &admissionRegistry{}
bypassReg Registry[bypass.Bypass] = &bypassRegistry{}
resolverReg Registry[resolver.Resolver] = &resolverRegistry{}
hostsReg Registry[hosts.HostMapper] = &hostsRegistry{}
recorderReg Registry[recorder.Recorder] = &recorderRegistry{}
)
type Registry[T any] interface {
@ -99,7 +99,7 @@ func ServiceRegistry() Registry[service.Service] {
return serviceReg
}
func ChainRegistry() Registry[chain.SelectableChainer] {
func ChainRegistry() Registry[chain.Chainer] {
return chainReg
}

View File

@ -4,18 +4,19 @@ import (
"context"
"time"
"github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/core/selector"
)
type failFilter[T selector.Selectable] struct {
type failFilter[T any] struct {
maxFails int
failTimeout time.Duration
}
// FailFilter filters the dead objects.
// An object is marked as dead if its failed count is greater than MaxFails.
func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) selector.Filter[T] {
func FailFilter[T any](maxFails int, timeout time.Duration) selector.Filter[T] {
return &failFilter[T]{
maxFails: maxFails,
failTimeout: timeout,
@ -31,12 +32,14 @@ func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T {
for _, v := range vs {
maxFails := f.maxFails
failTimeout := f.failTimeout
if md := v.Metadata(); md != nil {
if md.IsExists(labelMaxFails) {
maxFails = mdutil.GetInt(md, labelMaxFails)
}
if md.IsExists(labelFailTimeout) {
failTimeout = mdutil.GetDuration(md, labelFailTimeout)
if mi, _ := any(v).(metadata.Metadatable); mi != nil {
if md := mi.Metadata(); md != nil {
if md.IsExists(labelMaxFails) {
maxFails = mdutil.GetInt(md, labelMaxFails)
}
if md.IsExists(labelFailTimeout) {
failTimeout = mdutil.GetDuration(md, labelFailTimeout)
}
}
}
if maxFails <= 0 {
@ -46,23 +49,25 @@ func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T {
failTimeout = DefaultFailTimeout
}
if marker := v.Marker(); marker != nil {
if marker.Count() < int64(maxFails) ||
time.Since(marker.Time()) >= failTimeout {
l = append(l, v)
if mi, _ := any(v).(selector.Markable); mi != nil {
if marker := mi.Marker(); marker != nil {
if marker.Count() < int64(maxFails) ||
time.Since(marker.Time()) >= failTimeout {
l = append(l, v)
}
continue
}
} else {
l = append(l, v)
}
l = append(l, v)
}
return l
}
type backupFilter[T selector.Selectable] struct{}
type backupFilter[T any] struct{}
// BackupFilter filters the backup objects.
// An object is marked as backup if its metadata has backup flag.
func BackupFilter[T selector.Selectable]() selector.Filter[T] {
func BackupFilter[T any]() selector.Filter[T] {
return &backupFilter[T]{}
}
@ -74,11 +79,13 @@ func (f *backupFilter[T]) Filter(ctx context.Context, vs ...T) []T {
var l, backups []T
for _, v := range vs {
if mdutil.GetBool(v.Metadata(), labelBackup) {
backups = append(backups, v)
} else {
l = append(l, v)
if mi, _ := any(v).(metadata.Metadatable); mi != nil {
if mdutil.GetBool(mi.Metadata(), labelBackup) {
backups = append(backups, v)
continue
}
}
l = append(l, v)
}
if len(l) == 0 {

View File

@ -20,12 +20,12 @@ const (
labelFailTimeout = "failTimeout"
)
type defaultSelector[T selector.Selectable] struct {
type defaultSelector[T any] struct {
strategy selector.Strategy[T]
filters []selector.Filter[T]
}
func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] {
func NewSelector[T any](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] {
return &defaultSelector[T]{
filters: filters,
strategy: strategy,

View File

@ -7,18 +7,19 @@ import (
"sync/atomic"
"time"
"github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/core/selector"
sx "github.com/go-gost/x/internal/util/selector"
)
type roundRobinStrategy[T selector.Selectable] struct {
type roundRobinStrategy[T any] struct {
counter uint64
}
// RoundRobinStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm.
func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] {
func RoundRobinStrategy[T any]() selector.Strategy[T] {
return &roundRobinStrategy[T]{}
}
@ -31,14 +32,14 @@ func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
return vs[int(n%uint64(len(vs)))]
}
type randomStrategy[T selector.Selectable] struct {
type randomStrategy[T any] struct {
rw *randomWeighted[T]
mu sync.Mutex
}
// RandomStrategy is a strategy for node selector.
// The node will be selected randomly.
func RandomStrategy[T selector.Selectable]() selector.Strategy[T] {
func RandomStrategy[T any]() selector.Strategy[T] {
return &randomStrategy[T]{
rw: newRandomWeighted[T](),
}
@ -54,7 +55,10 @@ func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
s.rw.Reset()
for i := range vs {
weight := mdutil.GetInt(vs[i].Metadata(), labelWeight)
weight := 0
if md, _ := any(vs[i]).(metadata.Metadatable); md != nil {
weight = mdutil.GetInt(md.Metadata(), labelWeight)
}
if weight <= 0 {
weight = 1
}
@ -64,12 +68,12 @@ func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
return s.rw.Next()
}
type fifoStrategy[T selector.Selectable] struct{}
type fifoStrategy[T any] struct{}
// FIFOStrategy is a strategy for node selector.
// The node will be selected from first to last,
// and will stick to the selected node until it is failed.
func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] {
func FIFOStrategy[T any]() selector.Strategy[T] {
return &fifoStrategy[T]{}
}
@ -81,12 +85,12 @@ func (s *fifoStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
return vs[0]
}
type hashStrategy[T selector.Selectable] struct {
type hashStrategy[T any] struct {
r *rand.Rand
mu sync.Mutex
}
func HashStrategy[T selector.Selectable]() selector.Strategy[T] {
func HashStrategy[T any]() selector.Strategy[T] {
return &hashStrategy[T]{
r: rand.New(rand.NewSource(time.Now().UnixNano())),
}

View File

@ -3,22 +3,20 @@ package selector
import (
"math/rand"
"time"
"github.com/go-gost/core/selector"
)
type randomWeightedItem[T selector.Selectable] struct {
type randomWeightedItem[T any] struct {
item T
weight int
}
type randomWeighted[T selector.Selectable] struct {
type randomWeighted[T any] struct {
items []*randomWeightedItem[T]
sum int
r *rand.Rand
}
func newRandomWeighted[T selector.Selectable]() *randomWeighted[T] {
func newRandomWeighted[T any]() *randomWeighted[T] {
return &randomWeighted[T]{
r: rand.New(rand.NewSource(time.Now().UnixNano())),
}