add selector
This commit is contained in:
@ -10,6 +10,7 @@ import (
|
||||
tls_util "github.com/go-gost/x/internal/util/tls"
|
||||
"github.com/go-gost/x/metadata"
|
||||
"github.com/go-gost/x/registry"
|
||||
xs "github.com/go-gost/x/selector"
|
||||
)
|
||||
|
||||
func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
|
||||
@ -140,6 +141,9 @@ func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
|
||||
if s := parseNodeSelector(hop.Selector); s != nil {
|
||||
sel = s
|
||||
}
|
||||
if sel == nil {
|
||||
sel = xs.DefaultNodeSelector
|
||||
}
|
||||
group.WithSelector(sel).
|
||||
WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...))
|
||||
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/core/recorder"
|
||||
"github.com/go-gost/core/resolver"
|
||||
"github.com/go-gost/core/selector"
|
||||
admission_impl "github.com/go-gost/x/admission"
|
||||
auth_impl "github.com/go-gost/x/auth"
|
||||
bypass_impl "github.com/go-gost/x/bypass"
|
||||
@ -21,6 +22,7 @@ import (
|
||||
recorder_impl "github.com/go-gost/x/recorder"
|
||||
"github.com/go-gost/x/registry"
|
||||
resolver_impl "github.com/go-gost/x/resolver"
|
||||
xs "github.com/go-gost/x/selector"
|
||||
)
|
||||
|
||||
func ParseAuther(cfg *config.AutherConfig) auth.Authenticator {
|
||||
@ -83,50 +85,50 @@ func parseAuth(cfg *config.AuthConfig) *url.Userinfo {
|
||||
return url.UserPassword(cfg.Username, cfg.Password)
|
||||
}
|
||||
|
||||
func parseChainSelector(cfg *config.SelectorConfig) chain.Selector[chain.SelectableChainer] {
|
||||
func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.SelectableChainer] {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var strategy chain.Strategy[chain.SelectableChainer]
|
||||
var strategy selector.Strategy[chain.SelectableChainer]
|
||||
switch cfg.Strategy {
|
||||
case "round", "rr":
|
||||
strategy = chain.RoundRobinStrategy[chain.SelectableChainer]()
|
||||
strategy = xs.RoundRobinStrategy[chain.SelectableChainer]()
|
||||
case "random", "rand":
|
||||
strategy = chain.RandomStrategy[chain.SelectableChainer]()
|
||||
strategy = xs.RandomStrategy[chain.SelectableChainer]()
|
||||
case "fifo", "ha":
|
||||
strategy = chain.FIFOStrategy[chain.SelectableChainer]()
|
||||
strategy = xs.FIFOStrategy[chain.SelectableChainer]()
|
||||
default:
|
||||
strategy = chain.RoundRobinStrategy[chain.SelectableChainer]()
|
||||
strategy = xs.RoundRobinStrategy[chain.SelectableChainer]()
|
||||
}
|
||||
return chain.NewSelector(
|
||||
return xs.NewSelector(
|
||||
strategy,
|
||||
chain.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout),
|
||||
chain.BackupFilter[chain.SelectableChainer](),
|
||||
xs.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout),
|
||||
xs.BackupFilter[chain.SelectableChainer](),
|
||||
)
|
||||
}
|
||||
|
||||
func parseNodeSelector(cfg *config.SelectorConfig) chain.Selector[*chain.Node] {
|
||||
func parseNodeSelector(cfg *config.SelectorConfig) selector.Selector[*chain.Node] {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var strategy chain.Strategy[*chain.Node]
|
||||
var strategy selector.Strategy[*chain.Node]
|
||||
switch cfg.Strategy {
|
||||
case "round", "rr":
|
||||
strategy = chain.RoundRobinStrategy[*chain.Node]()
|
||||
strategy = xs.RoundRobinStrategy[*chain.Node]()
|
||||
case "random", "rand":
|
||||
strategy = chain.RandomStrategy[*chain.Node]()
|
||||
strategy = xs.RandomStrategy[*chain.Node]()
|
||||
case "fifo", "ha":
|
||||
strategy = chain.FIFOStrategy[*chain.Node]()
|
||||
strategy = xs.FIFOStrategy[*chain.Node]()
|
||||
default:
|
||||
strategy = chain.RoundRobinStrategy[*chain.Node]()
|
||||
strategy = xs.RoundRobinStrategy[*chain.Node]()
|
||||
}
|
||||
|
||||
return chain.NewSelector(
|
||||
return xs.NewSelector(
|
||||
strategy,
|
||||
chain.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout),
|
||||
chain.BackupFilter[*chain.Node](),
|
||||
xs.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout),
|
||||
xs.BackupFilter[*chain.Node](),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -11,11 +11,13 @@ import (
|
||||
"github.com/go-gost/core/listener"
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/core/recorder"
|
||||
"github.com/go-gost/core/selector"
|
||||
"github.com/go-gost/core/service"
|
||||
"github.com/go-gost/x/config"
|
||||
tls_util "github.com/go-gost/x/internal/util/tls"
|
||||
"github.com/go-gost/x/metadata"
|
||||
"github.com/go-gost/x/registry"
|
||||
xs "github.com/go-gost/x/selector"
|
||||
)
|
||||
|
||||
func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
|
||||
@ -184,7 +186,11 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup {
|
||||
}
|
||||
}
|
||||
|
||||
return group.WithSelector(parseNodeSelector(cfg.Selector))
|
||||
sel := parseNodeSelector(cfg.Selector)
|
||||
if sel == nil {
|
||||
sel = xs.DefaultNodeSelector
|
||||
}
|
||||
return group.WithSelector(sel)
|
||||
}
|
||||
|
||||
func bypassList(name string, names ...string) []bypass.Bypass {
|
||||
@ -228,18 +234,25 @@ func admissionList(name string, names ...string) []admission.Admission {
|
||||
}
|
||||
|
||||
func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
|
||||
cg := &chain.ChainGroup{}
|
||||
var chains []chain.SelectableChainer
|
||||
var sel selector.Selector[chain.SelectableChainer]
|
||||
|
||||
if c := registry.ChainRegistry().Get(name); c != nil {
|
||||
cg.Chains = append(cg.Chains, c)
|
||||
chains = append(chains, c)
|
||||
}
|
||||
if group != nil {
|
||||
for _, s := range group.Chains {
|
||||
if c := registry.ChainRegistry().Get(s); c != nil {
|
||||
cg.Chains = append(cg.Chains, c)
|
||||
chains = append(chains, c)
|
||||
}
|
||||
}
|
||||
cg.Selector = parseChainSelector(group.Selector)
|
||||
sel = parseChainSelector(group.Selector)
|
||||
}
|
||||
|
||||
return cg
|
||||
if sel == nil {
|
||||
sel = xs.DefaultChainSelector
|
||||
}
|
||||
|
||||
return chain.NewChainGroup(chains...).
|
||||
WithSelector(sel)
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/connector"
|
||||
"github.com/go-gost/core/listener"
|
||||
"github.com/go-gost/core/logger"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
@ -71,7 +70,7 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) {
|
||||
if l.ln == nil {
|
||||
l.ln, err = l.router.Bind(
|
||||
context.Background(), "tcp", l.laddr.String(),
|
||||
connector.MuxBindOption(true),
|
||||
chain.MuxBindOption(true),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, listener.NewAcceptError(err)
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/connector"
|
||||
"github.com/go-gost/core/listener"
|
||||
"github.com/go-gost/core/logger"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
@ -72,10 +71,10 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) {
|
||||
if l.ln == nil {
|
||||
l.ln, err = l.router.Bind(
|
||||
context.Background(), "udp", l.laddr.String(),
|
||||
connector.BacklogBindOption(l.md.backlog),
|
||||
connector.UDPConnTTLBindOption(l.md.ttl),
|
||||
connector.UDPDataBufferSizeBindOption(l.md.readBufferSize),
|
||||
connector.UDPDataQueueSizeBindOption(l.md.readQueueSize),
|
||||
chain.BacklogBindOption(l.md.backlog),
|
||||
chain.UDPConnTTLBindOption(l.md.ttl),
|
||||
chain.UDPDataBufferSizeBindOption(l.md.readBufferSize),
|
||||
chain.UDPDataQueueSizeBindOption(l.md.readQueueSize),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, listener.NewAcceptError(err)
|
||||
|
@ -3,6 +3,7 @@ package registry
|
||||
import (
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/metadata"
|
||||
"github.com/go-gost/core/selector"
|
||||
)
|
||||
|
||||
type chainRegistry struct {
|
||||
@ -32,7 +33,7 @@ type chainWrapper struct {
|
||||
r *chainRegistry
|
||||
}
|
||||
|
||||
func (w *chainWrapper) Marker() chain.Marker {
|
||||
func (w *chainWrapper) Marker() selector.Marker {
|
||||
v := w.r.get(w.name)
|
||||
if v == nil {
|
||||
return nil
|
||||
@ -48,7 +49,7 @@ func (w *chainWrapper) Metadata() metadata.Metadata {
|
||||
return v.Metadata()
|
||||
}
|
||||
|
||||
func (w *chainWrapper) Route(network, address string) *chain.Route {
|
||||
func (w *chainWrapper) Route(network, address string) chain.Route {
|
||||
v := w.r.get(w.name)
|
||||
if v == nil {
|
||||
return nil
|
||||
|
87
selector/filter.go
Normal file
87
selector/filter.go
Normal file
@ -0,0 +1,87 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
mdutil "github.com/go-gost/core/metadata/util"
|
||||
"github.com/go-gost/core/selector"
|
||||
)
|
||||
|
||||
type failFilter[T selector.Selectable] struct {
|
||||
maxFails int
|
||||
failTimeout time.Duration
|
||||
}
|
||||
|
||||
// FailFilter filters the dead objects.
|
||||
// An object is marked as dead if its failed count is greater than MaxFails.
|
||||
func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) selector.Filter[T] {
|
||||
return &failFilter[T]{
|
||||
maxFails: maxFails,
|
||||
failTimeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Filter filters dead objects.
|
||||
func (f *failFilter[T]) Filter(vs ...T) []T {
|
||||
if len(vs) <= 1 {
|
||||
return vs
|
||||
}
|
||||
var l []T
|
||||
for _, v := range vs {
|
||||
maxFails := f.maxFails
|
||||
failTimeout := f.failTimeout
|
||||
if md := v.Metadata(); md != nil {
|
||||
if md.IsExists(labelMaxFails) {
|
||||
maxFails = mdutil.GetInt(md, labelMaxFails)
|
||||
}
|
||||
if md.IsExists(labelFailTimeout) {
|
||||
failTimeout = mdutil.GetDuration(md, labelFailTimeout)
|
||||
}
|
||||
}
|
||||
if maxFails <= 0 {
|
||||
maxFails = 1
|
||||
}
|
||||
if failTimeout <= 0 {
|
||||
failTimeout = DefaultFailTimeout
|
||||
}
|
||||
|
||||
if marker := v.Marker(); marker != nil {
|
||||
if marker.Count() < int64(maxFails) ||
|
||||
time.Since(marker.Time()) >= failTimeout {
|
||||
l = append(l, v)
|
||||
}
|
||||
} else {
|
||||
l = append(l, v)
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
type backupFilter[T selector.Selectable] struct{}
|
||||
|
||||
// BackupFilter filters the backup objects.
|
||||
// An object is marked as backup if its metadata has backup flag.
|
||||
func BackupFilter[T selector.Selectable]() selector.Filter[T] {
|
||||
return &backupFilter[T]{}
|
||||
}
|
||||
|
||||
// Filter filters backup objects.
|
||||
func (f *backupFilter[T]) Filter(vs ...T) []T {
|
||||
if len(vs) <= 1 {
|
||||
return vs
|
||||
}
|
||||
|
||||
var l, backups []T
|
||||
for _, v := range vs {
|
||||
if mdutil.GetBool(v.Metadata(), labelBackup) {
|
||||
backups = append(backups, v)
|
||||
} else {
|
||||
l = append(l, v)
|
||||
}
|
||||
}
|
||||
|
||||
if len(l) == 0 {
|
||||
return backups
|
||||
}
|
||||
return l
|
||||
}
|
54
selector/selector.go
Normal file
54
selector/selector.go
Normal file
@ -0,0 +1,54 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/selector"
|
||||
)
|
||||
|
||||
// default options for FailFilter
|
||||
const (
|
||||
DefaultMaxFails = 1
|
||||
DefaultFailTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
labelWeight = "weight"
|
||||
labelBackup = "backup"
|
||||
labelMaxFails = "maxFails"
|
||||
labelFailTimeout = "failTimeout"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultNodeSelector = NewSelector(
|
||||
RoundRobinStrategy[*chain.Node](),
|
||||
// FailFilter[*Node](1, DefaultFailTimeout),
|
||||
)
|
||||
DefaultChainSelector = NewSelector(
|
||||
RoundRobinStrategy[chain.SelectableChainer](),
|
||||
// FailFilter[SelectableChainer](1, DefaultFailTimeout),
|
||||
)
|
||||
)
|
||||
|
||||
type defaultSelector[T selector.Selectable] struct {
|
||||
strategy selector.Strategy[T]
|
||||
filters []selector.Filter[T]
|
||||
}
|
||||
|
||||
func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters ...selector.Filter[T]) selector.Selector[T] {
|
||||
return &defaultSelector[T]{
|
||||
filters: filters,
|
||||
strategy: strategy,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *defaultSelector[T]) Select(vs ...T) (v T) {
|
||||
for _, filter := range s.filters {
|
||||
vs = filter.Filter(vs...)
|
||||
}
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
return s.strategy.Apply(vs...)
|
||||
}
|
72
selector/strategy.go
Normal file
72
selector/strategy.go
Normal file
@ -0,0 +1,72 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/selector"
|
||||
)
|
||||
|
||||
type roundRobinStrategy[T selector.Selectable] struct {
|
||||
counter uint64
|
||||
}
|
||||
|
||||
// RoundRobinStrategy is a strategy for node selector.
|
||||
// The node will be selected by round-robin algorithm.
|
||||
func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] {
|
||||
return &roundRobinStrategy[T]{}
|
||||
}
|
||||
|
||||
func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
n := atomic.AddUint64(&s.counter, 1) - 1
|
||||
return vs[int(n%uint64(len(vs)))]
|
||||
}
|
||||
|
||||
type randomStrategy[T selector.Selectable] struct {
|
||||
rand *rand.Rand
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// RandomStrategy is a strategy for node selector.
|
||||
// The node will be selected randomly.
|
||||
func RandomStrategy[T selector.Selectable]() selector.Strategy[T] {
|
||||
return &randomStrategy[T]{
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *randomStrategy[T]) Apply(vs ...T) (v T) {
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
r := s.rand.Int()
|
||||
|
||||
return vs[r%len(vs)]
|
||||
}
|
||||
|
||||
type fifoStrategy[T selector.Selectable] struct{}
|
||||
|
||||
// FIFOStrategy is a strategy for node selector.
|
||||
// The node will be selected from first to last,
|
||||
// and will stick to the selected node until it is failed.
|
||||
func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] {
|
||||
return &fifoStrategy[T]{}
|
||||
}
|
||||
|
||||
// Apply applies the fifo strategy for the nodes.
|
||||
func (s *fifoStrategy[T]) Apply(vs ...T) (v T) {
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
return vs[0]
|
||||
}
|
Reference in New Issue
Block a user