add weight for selector
This commit is contained in:
@ -10,7 +10,6 @@ import (
|
||||
tls_util "github.com/go-gost/x/internal/util/tls"
|
||||
"github.com/go-gost/x/metadata"
|
||||
"github.com/go-gost/x/registry"
|
||||
xs "github.com/go-gost/x/selector"
|
||||
)
|
||||
|
||||
func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
|
||||
@ -142,7 +141,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
|
||||
sel = s
|
||||
}
|
||||
if sel == nil {
|
||||
sel = xs.DefaultNodeSelector
|
||||
sel = defaultNodeSelector()
|
||||
}
|
||||
group.WithSelector(sel).
|
||||
WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...))
|
||||
|
@ -298,3 +298,11 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func defaultNodeSelector() selector.Selector[*chain.Node] {
|
||||
return xs.NewSelector(xs.RoundRobinStrategy[*chain.Node]())
|
||||
}
|
||||
|
||||
func defaultChainSelector() selector.Selector[chain.SelectableChainer] {
|
||||
return xs.NewSelector(xs.RoundRobinStrategy[chain.SelectableChainer]())
|
||||
}
|
||||
|
@ -17,7 +17,6 @@ import (
|
||||
tls_util "github.com/go-gost/x/internal/util/tls"
|
||||
"github.com/go-gost/x/metadata"
|
||||
"github.com/go-gost/x/registry"
|
||||
xs "github.com/go-gost/x/selector"
|
||||
)
|
||||
|
||||
func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
|
||||
@ -188,7 +187,7 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup {
|
||||
|
||||
sel := parseNodeSelector(cfg.Selector)
|
||||
if sel == nil {
|
||||
sel = xs.DefaultNodeSelector
|
||||
sel = defaultNodeSelector()
|
||||
}
|
||||
return group.WithSelector(sel)
|
||||
}
|
||||
@ -250,7 +249,7 @@ func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
|
||||
}
|
||||
|
||||
if sel == nil {
|
||||
sel = xs.DefaultChainSelector
|
||||
sel = defaultChainSelector()
|
||||
}
|
||||
|
||||
return chain.NewChainGroup(chains...).
|
||||
|
3
go.mod
3
go.mod
@ -76,6 +76,7 @@ require (
|
||||
github.com/prometheus/common v0.32.1 // indirect
|
||||
github.com/prometheus/procfs v0.7.3 // indirect
|
||||
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
|
||||
github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c // indirect
|
||||
github.com/spf13/afero v1.8.2 // indirect
|
||||
github.com/spf13/cast v1.4.1 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
@ -86,7 +87,7 @@ require (
|
||||
github.com/tjfoc/gmsm v1.3.2 // indirect
|
||||
github.com/ugorji/go/codec v1.2.7 // indirect
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
||||
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect
|
||||
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
golang.org/x/tools v0.1.12 // indirect
|
||||
|
4
go.sum
4
go.sum
@ -335,6 +335,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
|
||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
|
||||
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||
github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c h1:XBpqxCr2X2HYZMOA+HTDhj8njR4PGhsK+M+geaMAQ20=
|
||||
github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c/go.mod h1:xc9CoZ+ZBGwajnWto5Aqw/wWg8euy4HtOr6K9Fxp9iw=
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
@ -424,6 +426,8 @@ golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EH
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA=
|
||||
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA=
|
||||
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw=
|
||||
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
|
@ -221,7 +221,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ex := h.selectExchanger(strings.Trim(mq.Question[0].Name, "."))
|
||||
ex := h.selectExchanger(ctx, strings.Trim(mq.Question[0].Name, "."))
|
||||
if ex == nil {
|
||||
err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
|
||||
log.Error(err)
|
||||
@ -293,11 +293,11 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
func (h *dnsHandler) selectExchanger(addr string) exchanger.Exchanger {
|
||||
func (h *dnsHandler) selectExchanger(ctx context.Context, addr string) exchanger.Exchanger {
|
||||
if h.group == nil {
|
||||
return nil
|
||||
}
|
||||
node := h.group.FilterAddr(addr).Next()
|
||||
node := h.group.FilterAddr(addr).Next(ctx)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
target := h.group.Next()
|
||||
target := h.group.Next(ctx)
|
||||
if target == nil {
|
||||
err := errors.New("target not available")
|
||||
log.Error(err)
|
||||
|
@ -71,7 +71,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
|
||||
target := h.group.Next()
|
||||
target := h.group.Next(ctx)
|
||||
if target == nil {
|
||||
err := errors.New("target not available")
|
||||
log.Error(err)
|
||||
|
@ -17,7 +17,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
|
||||
Version: relay.Version1,
|
||||
Status: relay.StatusOK,
|
||||
}
|
||||
target := h.group.Next()
|
||||
target := h.group.Next(ctx)
|
||||
if target == nil {
|
||||
resp.Status = relay.StatusServiceUnavailable
|
||||
resp.WriteTo(conn)
|
||||
|
@ -105,7 +105,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
|
||||
var raddr net.Addr
|
||||
var err error
|
||||
|
||||
target := h.group.Next()
|
||||
target := h.group.Next(ctx)
|
||||
if target != nil {
|
||||
raddr, err = net.ResolveUDPAddr(network, target.Addr)
|
||||
if err != nil {
|
||||
|
@ -87,7 +87,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
|
||||
var raddr net.Addr
|
||||
var err error
|
||||
|
||||
target := h.group.Next()
|
||||
target := h.group.Next(ctx)
|
||||
if target != nil {
|
||||
raddr, err = net.ResolveUDPAddr(network, target.Addr)
|
||||
if err != nil {
|
||||
|
@ -1,6 +1,8 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/metadata"
|
||||
"github.com/go-gost/core/selector"
|
||||
@ -49,10 +51,10 @@ func (w *chainWrapper) Metadata() metadata.Metadata {
|
||||
return v.Metadata()
|
||||
}
|
||||
|
||||
func (w *chainWrapper) Route(network, address string) chain.Route {
|
||||
func (w *chainWrapper) Route(ctx context.Context, network, address string) chain.Route {
|
||||
v := w.r.get(w.name)
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return v.Route(network, address)
|
||||
return v.Route(ctx, network, address)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
mdutil "github.com/go-gost/core/metadata/util"
|
||||
@ -22,7 +23,7 @@ func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) sele
|
||||
}
|
||||
|
||||
// Filter filters dead objects.
|
||||
func (f *failFilter[T]) Filter(vs ...T) []T {
|
||||
func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T {
|
||||
if len(vs) <= 1 {
|
||||
return vs
|
||||
}
|
||||
@ -66,7 +67,7 @@ func BackupFilter[T selector.Selectable]() selector.Filter[T] {
|
||||
}
|
||||
|
||||
// Filter filters backup objects.
|
||||
func (f *backupFilter[T]) Filter(vs ...T) []T {
|
||||
func (f *backupFilter[T]) Filter(ctx context.Context, vs ...T) []T {
|
||||
if len(vs) <= 1 {
|
||||
return vs
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/selector"
|
||||
)
|
||||
|
||||
@ -20,17 +20,6 @@ const (
|
||||
labelFailTimeout = "failTimeout"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultNodeSelector = NewSelector(
|
||||
RoundRobinStrategy[*chain.Node](),
|
||||
// FailFilter[*Node](1, DefaultFailTimeout),
|
||||
)
|
||||
DefaultChainSelector = NewSelector(
|
||||
RoundRobinStrategy[chain.SelectableChainer](),
|
||||
// FailFilter[SelectableChainer](1, DefaultFailTimeout),
|
||||
)
|
||||
)
|
||||
|
||||
type defaultSelector[T selector.Selectable] struct {
|
||||
strategy selector.Strategy[T]
|
||||
filters []selector.Filter[T]
|
||||
@ -43,12 +32,12 @@ func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters .
|
||||
}
|
||||
}
|
||||
|
||||
func (s *defaultSelector[T]) Select(vs ...T) (v T) {
|
||||
func (s *defaultSelector[T]) Select(ctx context.Context, vs ...T) (v T) {
|
||||
for _, filter := range s.filters {
|
||||
vs = filter.Filter(vs...)
|
||||
vs = filter.Filter(ctx, vs...)
|
||||
}
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
return s.strategy.Apply(vs...)
|
||||
return s.strategy.Apply(ctx, vs...)
|
||||
}
|
||||
|
@ -1,11 +1,11 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
mdutil "github.com/go-gost/core/metadata/util"
|
||||
"github.com/go-gost/core/selector"
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] {
|
||||
return &roundRobinStrategy[T]{}
|
||||
}
|
||||
|
||||
func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
|
||||
func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
@ -29,29 +29,36 @@ func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
|
||||
}
|
||||
|
||||
type randomStrategy[T selector.Selectable] struct {
|
||||
rand *rand.Rand
|
||||
mux sync.Mutex
|
||||
rw *randomWeighted[T]
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// RandomStrategy is a strategy for node selector.
|
||||
// The node will be selected randomly.
|
||||
func RandomStrategy[T selector.Selectable]() selector.Strategy[T] {
|
||||
return &randomStrategy[T]{
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
rw: newRandomWeighted[T](),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *randomStrategy[T]) Apply(vs ...T) (v T) {
|
||||
func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
r := s.rand.Int()
|
||||
s.rw.Reset()
|
||||
for i := range vs {
|
||||
weight := mdutil.GetInt(vs[i].Metadata(), labelWeight)
|
||||
if weight <= 0 {
|
||||
weight = 1
|
||||
}
|
||||
s.rw.Add(vs[i], weight)
|
||||
}
|
||||
|
||||
return vs[r%len(vs)]
|
||||
return s.rw.Next()
|
||||
}
|
||||
|
||||
type fifoStrategy[T selector.Selectable] struct{}
|
||||
@ -64,7 +71,7 @@ func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] {
|
||||
}
|
||||
|
||||
// Apply applies the fifo strategy for the nodes.
|
||||
func (s *fifoStrategy[T]) Apply(vs ...T) (v T) {
|
||||
func (s *fifoStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
|
||||
if len(vs) == 0 {
|
||||
return
|
||||
}
|
||||
|
54
selector/weighted.go
Normal file
54
selector/weighted.go
Normal file
@ -0,0 +1,54 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/selector"
|
||||
)
|
||||
|
||||
type randomWeightedItem[T selector.Selectable] struct {
|
||||
item T
|
||||
weight int
|
||||
}
|
||||
|
||||
type randomWeighted[T selector.Selectable] struct {
|
||||
items []*randomWeightedItem[T]
|
||||
sum int
|
||||
r *rand.Rand
|
||||
}
|
||||
|
||||
func newRandomWeighted[T selector.Selectable]() *randomWeighted[T] {
|
||||
return &randomWeighted[T]{
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *randomWeighted[T]) Add(item T, weight int) {
|
||||
ri := &randomWeightedItem[T]{item: item, weight: weight}
|
||||
rw.items = append(rw.items, ri)
|
||||
rw.sum += weight
|
||||
}
|
||||
|
||||
func (rw *randomWeighted[T]) Next() (v T) {
|
||||
if len(rw.items) == 0 {
|
||||
return
|
||||
}
|
||||
if rw.sum <= 0 {
|
||||
return
|
||||
}
|
||||
weight := rw.r.Intn(rw.sum) + 1
|
||||
for _, item := range rw.items {
|
||||
weight -= item.weight
|
||||
if weight <= 0 {
|
||||
return item.item
|
||||
}
|
||||
}
|
||||
|
||||
return rw.items[len(rw.items)-1].item
|
||||
}
|
||||
|
||||
func (rw *randomWeighted[T]) Reset() {
|
||||
rw.items = nil
|
||||
rw.sum = 0
|
||||
}
|
Reference in New Issue
Block a user