add selector

This commit is contained in:
ginuerzh
2022-09-02 15:00:07 +08:00
parent 09dbdbb03c
commit c643014e12
9 changed files with 264 additions and 33 deletions

View File

@ -10,6 +10,7 @@ import (
tls_util "github.com/go-gost/x/internal/util/tls" tls_util "github.com/go-gost/x/internal/util/tls"
"github.com/go-gost/x/metadata" "github.com/go-gost/x/metadata"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
xs "github.com/go-gost/x/selector"
) )
func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
@ -140,6 +141,9 @@ func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
if s := parseNodeSelector(hop.Selector); s != nil { if s := parseNodeSelector(hop.Selector); s != nil {
sel = s sel = s
} }
if sel == nil {
sel = xs.DefaultNodeSelector
}
group.WithSelector(sel). group.WithSelector(sel).
WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...)) WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...))

View File

@ -12,6 +12,7 @@ import (
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/core/recorder" "github.com/go-gost/core/recorder"
"github.com/go-gost/core/resolver" "github.com/go-gost/core/resolver"
"github.com/go-gost/core/selector"
admission_impl "github.com/go-gost/x/admission" admission_impl "github.com/go-gost/x/admission"
auth_impl "github.com/go-gost/x/auth" auth_impl "github.com/go-gost/x/auth"
bypass_impl "github.com/go-gost/x/bypass" bypass_impl "github.com/go-gost/x/bypass"
@ -21,6 +22,7 @@ import (
recorder_impl "github.com/go-gost/x/recorder" recorder_impl "github.com/go-gost/x/recorder"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
resolver_impl "github.com/go-gost/x/resolver" resolver_impl "github.com/go-gost/x/resolver"
xs "github.com/go-gost/x/selector"
) )
func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { func ParseAuther(cfg *config.AutherConfig) auth.Authenticator {
@ -83,50 +85,50 @@ func parseAuth(cfg *config.AuthConfig) *url.Userinfo {
return url.UserPassword(cfg.Username, cfg.Password) return url.UserPassword(cfg.Username, cfg.Password)
} }
func parseChainSelector(cfg *config.SelectorConfig) chain.Selector[chain.SelectableChainer] { func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.SelectableChainer] {
if cfg == nil { if cfg == nil {
return nil return nil
} }
var strategy chain.Strategy[chain.SelectableChainer] var strategy selector.Strategy[chain.SelectableChainer]
switch cfg.Strategy { switch cfg.Strategy {
case "round", "rr": case "round", "rr":
strategy = chain.RoundRobinStrategy[chain.SelectableChainer]() strategy = xs.RoundRobinStrategy[chain.SelectableChainer]()
case "random", "rand": case "random", "rand":
strategy = chain.RandomStrategy[chain.SelectableChainer]() strategy = xs.RandomStrategy[chain.SelectableChainer]()
case "fifo", "ha": case "fifo", "ha":
strategy = chain.FIFOStrategy[chain.SelectableChainer]() strategy = xs.FIFOStrategy[chain.SelectableChainer]()
default: default:
strategy = chain.RoundRobinStrategy[chain.SelectableChainer]() strategy = xs.RoundRobinStrategy[chain.SelectableChainer]()
} }
return chain.NewSelector( return xs.NewSelector(
strategy, strategy,
chain.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout), xs.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout),
chain.BackupFilter[chain.SelectableChainer](), xs.BackupFilter[chain.SelectableChainer](),
) )
} }
func parseNodeSelector(cfg *config.SelectorConfig) chain.Selector[*chain.Node] { func parseNodeSelector(cfg *config.SelectorConfig) selector.Selector[*chain.Node] {
if cfg == nil { if cfg == nil {
return nil return nil
} }
var strategy chain.Strategy[*chain.Node] var strategy selector.Strategy[*chain.Node]
switch cfg.Strategy { switch cfg.Strategy {
case "round", "rr": case "round", "rr":
strategy = chain.RoundRobinStrategy[*chain.Node]() strategy = xs.RoundRobinStrategy[*chain.Node]()
case "random", "rand": case "random", "rand":
strategy = chain.RandomStrategy[*chain.Node]() strategy = xs.RandomStrategy[*chain.Node]()
case "fifo", "ha": case "fifo", "ha":
strategy = chain.FIFOStrategy[*chain.Node]() strategy = xs.FIFOStrategy[*chain.Node]()
default: default:
strategy = chain.RoundRobinStrategy[*chain.Node]() strategy = xs.RoundRobinStrategy[*chain.Node]()
} }
return chain.NewSelector( return xs.NewSelector(
strategy, strategy,
chain.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout), xs.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout),
chain.BackupFilter[*chain.Node](), xs.BackupFilter[*chain.Node](),
) )
} }

View File

@ -11,11 +11,13 @@ import (
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/core/recorder" "github.com/go-gost/core/recorder"
"github.com/go-gost/core/selector"
"github.com/go-gost/core/service" "github.com/go-gost/core/service"
"github.com/go-gost/x/config" "github.com/go-gost/x/config"
tls_util "github.com/go-gost/x/internal/util/tls" tls_util "github.com/go-gost/x/internal/util/tls"
"github.com/go-gost/x/metadata" "github.com/go-gost/x/metadata"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
xs "github.com/go-gost/x/selector"
) )
func ParseService(cfg *config.ServiceConfig) (service.Service, error) { func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
@ -184,7 +186,11 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup {
} }
} }
return group.WithSelector(parseNodeSelector(cfg.Selector)) sel := parseNodeSelector(cfg.Selector)
if sel == nil {
sel = xs.DefaultNodeSelector
}
return group.WithSelector(sel)
} }
func bypassList(name string, names ...string) []bypass.Bypass { func bypassList(name string, names ...string) []bypass.Bypass {
@ -228,18 +234,25 @@ func admissionList(name string, names ...string) []admission.Admission {
} }
func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
cg := &chain.ChainGroup{} var chains []chain.SelectableChainer
var sel selector.Selector[chain.SelectableChainer]
if c := registry.ChainRegistry().Get(name); c != nil { if c := registry.ChainRegistry().Get(name); c != nil {
cg.Chains = append(cg.Chains, c) chains = append(chains, c)
} }
if group != nil { if group != nil {
for _, s := range group.Chains { for _, s := range group.Chains {
if c := registry.ChainRegistry().Get(s); c != nil { if c := registry.ChainRegistry().Get(s); c != nil {
cg.Chains = append(cg.Chains, c) chains = append(chains, c)
} }
} }
cg.Selector = parseChainSelector(group.Selector) sel = parseChainSelector(group.Selector)
} }
return cg if sel == nil {
sel = xs.DefaultChainSelector
}
return chain.NewChainGroup(chains...).
WithSelector(sel)
} }

View File

@ -5,7 +5,6 @@ import (
"net" "net"
"github.com/go-gost/core/chain" "github.com/go-gost/core/chain"
"github.com/go-gost/core/connector"
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata" md "github.com/go-gost/core/metadata"
@ -71,7 +70,7 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) {
if l.ln == nil { if l.ln == nil {
l.ln, err = l.router.Bind( l.ln, err = l.router.Bind(
context.Background(), "tcp", l.laddr.String(), context.Background(), "tcp", l.laddr.String(),
connector.MuxBindOption(true), chain.MuxBindOption(true),
) )
if err != nil { if err != nil {
return nil, listener.NewAcceptError(err) return nil, listener.NewAcceptError(err)

View File

@ -5,7 +5,6 @@ import (
"net" "net"
"github.com/go-gost/core/chain" "github.com/go-gost/core/chain"
"github.com/go-gost/core/connector"
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata" md "github.com/go-gost/core/metadata"
@ -72,10 +71,10 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) {
if l.ln == nil { if l.ln == nil {
l.ln, err = l.router.Bind( l.ln, err = l.router.Bind(
context.Background(), "udp", l.laddr.String(), context.Background(), "udp", l.laddr.String(),
connector.BacklogBindOption(l.md.backlog), chain.BacklogBindOption(l.md.backlog),
connector.UDPConnTTLBindOption(l.md.ttl), chain.UDPConnTTLBindOption(l.md.ttl),
connector.UDPDataBufferSizeBindOption(l.md.readBufferSize), chain.UDPDataBufferSizeBindOption(l.md.readBufferSize),
connector.UDPDataQueueSizeBindOption(l.md.readQueueSize), chain.UDPDataQueueSizeBindOption(l.md.readQueueSize),
) )
if err != nil { if err != nil {
return nil, listener.NewAcceptError(err) return nil, listener.NewAcceptError(err)

View File

@ -3,6 +3,7 @@ package registry
import ( import (
"github.com/go-gost/core/chain" "github.com/go-gost/core/chain"
"github.com/go-gost/core/metadata" "github.com/go-gost/core/metadata"
"github.com/go-gost/core/selector"
) )
type chainRegistry struct { type chainRegistry struct {
@ -32,7 +33,7 @@ type chainWrapper struct {
r *chainRegistry r *chainRegistry
} }
func (w *chainWrapper) Marker() chain.Marker { func (w *chainWrapper) Marker() selector.Marker {
v := w.r.get(w.name) v := w.r.get(w.name)
if v == nil { if v == nil {
return nil return nil
@ -48,7 +49,7 @@ func (w *chainWrapper) Metadata() metadata.Metadata {
return v.Metadata() return v.Metadata()
} }
func (w *chainWrapper) Route(network, address string) *chain.Route { func (w *chainWrapper) Route(network, address string) chain.Route {
v := w.r.get(w.name) v := w.r.get(w.name)
if v == nil { if v == nil {
return nil return nil

87
selector/filter.go Normal file
View File

@ -0,0 +1,87 @@
package selector
import (
"time"
mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/core/selector"
)
type failFilter[T selector.Selectable] struct {
maxFails int
failTimeout time.Duration
}
// FailFilter filters the dead objects.
// An object is marked as dead if its failed count is greater than MaxFails.
func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) selector.Filter[T] {
return &failFilter[T]{
maxFails: maxFails,
failTimeout: timeout,
}
}
// Filter filters dead objects.
func (f *failFilter[T]) Filter(vs ...T) []T {
if len(vs) <= 1 {
return vs
}
var l []T
for _, v := range vs {
maxFails := f.maxFails
failTimeout := f.failTimeout
if md := v.Metadata(); md != nil {
if md.IsExists(labelMaxFails) {
maxFails = mdutil.GetInt(md, labelMaxFails)
}
if md.IsExists(labelFailTimeout) {
failTimeout = mdutil.GetDuration(md, labelFailTimeout)
}
}
if maxFails <= 0 {
maxFails = 1
}
if failTimeout <= 0 {
failTimeout = DefaultFailTimeout
}
if marker := v.Marker(); marker != nil {
if marker.Count() < int64(maxFails) ||
time.Since(marker.Time()) >= failTimeout {
l = append(l, v)
}
} else {
l = append(l, v)
}
}
return l
}
type backupFilter[T selector.Selectable] struct{}
// BackupFilter filters the backup objects.
// An object is marked as backup if its metadata has backup flag.
func BackupFilter[T selector.Selectable]() selector.Filter[T] {
return &backupFilter[T]{}
}
// Filter filters backup objects.
func (f *backupFilter[T]) Filter(vs ...T) []T {
if len(vs) <= 1 {
return vs
}
var l, backups []T
for _, v := range vs {
if mdutil.GetBool(v.Metadata(), labelBackup) {
backups = append(backups, v)
} else {
l = append(l, v)
}
}
if len(l) == 0 {
return backups
}
return l
}

54
selector/selector.go Normal file
View File

@ -0,0 +1,54 @@
package selector
import (
"time"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/selector"
)
// default options for FailFilter
const (
DefaultMaxFails = 1
DefaultFailTimeout = 10 * time.Second
)
const (
labelWeight = "weight"
labelBackup = "backup"
labelMaxFails = "maxFails"
labelFailTimeout = "failTimeout"
)
var (
DefaultNodeSelector = NewSelector(
RoundRobinStrategy[*chain.Node](),
// FailFilter[*Node](1, DefaultFailTimeout),
)
DefaultChainSelector = NewSelector(
RoundRobinStrategy[chain.SelectableChainer](),
// FailFilter[SelectableChainer](1, DefaultFailTimeout),
)
)
type defaultSelector[T selector.Selectable] struct {
strategy selector.Strategy[T]
filters []selector.Filter[T]
}
func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] {
return &defaultSelector[T]{
filters: filters,
strategy: strategy,
}
}
func (s *defaultSelector[T]) Select(vs ...T) (v T) {
for _, filter := range s.filters {
vs = filter.Filter(vs...)
}
if len(vs) == 0 {
return
}
return s.strategy.Apply(vs...)
}

72
selector/strategy.go Normal file
View File

@ -0,0 +1,72 @@
package selector
import (
"math/rand"
"sync"
"sync/atomic"
"time"
"github.com/go-gost/core/selector"
)
type roundRobinStrategy[T selector.Selectable] struct {
counter uint64
}
// RoundRobinStrategy is a strategy for node selector.
// The node will be selected by round-robin algorithm.
func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] {
return &roundRobinStrategy[T]{}
}
func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
if len(vs) == 0 {
return
}
n := atomic.AddUint64(&s.counter, 1) - 1
return vs[int(n%uint64(len(vs)))]
}
type randomStrategy[T selector.Selectable] struct {
rand *rand.Rand
mux sync.Mutex
}
// RandomStrategy is a strategy for node selector.
// The node will be selected randomly.
func RandomStrategy[T selector.Selectable]() selector.Strategy[T] {
return &randomStrategy[T]{
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (s *randomStrategy[T]) Apply(vs ...T) (v T) {
if len(vs) == 0 {
return
}
s.mux.Lock()
defer s.mux.Unlock()
r := s.rand.Int()
return vs[r%len(vs)]
}
type fifoStrategy[T selector.Selectable] struct{}
// FIFOStrategy is a strategy for node selector.
// The node will be selected from first to last,
// and will stick to the selected node until it is failed.
func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] {
return &fifoStrategy[T]{}
}
// Apply applies the fifo strategy for the nodes.
func (s *fifoStrategy[T]) Apply(vs ...T) (v T) {
if len(vs) == 0 {
return
}
return vs[0]
}