add weight for selector

This commit is contained in:
ginuerzh
2022-09-02 17:23:59 +08:00
parent c643014e12
commit 00f7fa2997
16 changed files with 109 additions and 45 deletions

View File

@ -10,7 +10,6 @@ import (
tls_util "github.com/go-gost/x/internal/util/tls" tls_util "github.com/go-gost/x/internal/util/tls"
"github.com/go-gost/x/metadata" "github.com/go-gost/x/metadata"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
xs "github.com/go-gost/x/selector"
) )
func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
@ -142,7 +141,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) {
sel = s sel = s
} }
if sel == nil { if sel == nil {
sel = xs.DefaultNodeSelector sel = defaultNodeSelector()
} }
group.WithSelector(sel). group.WithSelector(sel).
WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...)) WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...))

View File

@ -298,3 +298,11 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
return return
} }
func defaultNodeSelector() selector.Selector[*chain.Node] {
return xs.NewSelector(xs.RoundRobinStrategy[*chain.Node]())
}
func defaultChainSelector() selector.Selector[chain.SelectableChainer] {
return xs.NewSelector(xs.RoundRobinStrategy[chain.SelectableChainer]())
}

View File

@ -17,7 +17,6 @@ import (
tls_util "github.com/go-gost/x/internal/util/tls" tls_util "github.com/go-gost/x/internal/util/tls"
"github.com/go-gost/x/metadata" "github.com/go-gost/x/metadata"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
xs "github.com/go-gost/x/selector"
) )
func ParseService(cfg *config.ServiceConfig) (service.Service, error) { func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
@ -188,7 +187,7 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup {
sel := parseNodeSelector(cfg.Selector) sel := parseNodeSelector(cfg.Selector)
if sel == nil { if sel == nil {
sel = xs.DefaultNodeSelector sel = defaultNodeSelector()
} }
return group.WithSelector(sel) return group.WithSelector(sel)
} }
@ -250,7 +249,7 @@ func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
} }
if sel == nil { if sel == nil {
sel = xs.DefaultChainSelector sel = defaultChainSelector()
} }
return chain.NewChainGroup(chains...). return chain.NewChainGroup(chains...).

3
go.mod
View File

@ -76,6 +76,7 @@ require (
github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect github.com/prometheus/procfs v0.7.3 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c // indirect
github.com/spf13/afero v1.8.2 // indirect github.com/spf13/afero v1.8.2 // indirect
github.com/spf13/cast v1.4.1 // indirect github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect
@ -86,7 +87,7 @@ require (
github.com/tjfoc/gmsm v1.3.2 // indirect github.com/tjfoc/gmsm v1.3.2 // indirect
github.com/ugorji/go/codec v1.2.7 // indirect github.com/ugorji/go/codec v1.2.7 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.12 // indirect golang.org/x/tools v0.1.12 // indirect

4
go.sum
View File

@ -335,6 +335,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c h1:XBpqxCr2X2HYZMOA+HTDhj8njR4PGhsK+M+geaMAQ20=
github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c/go.mod h1:xc9CoZ+ZBGwajnWto5Aqw/wWg8euy4HtOr6K9Fxp9iw=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
@ -424,6 +426,8 @@ golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EH
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA=
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=

View File

@ -221,7 +221,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
return nil, err return nil, err
} }
ex := h.selectExchanger(strings.Trim(mq.Question[0].Name, ".")) ex := h.selectExchanger(ctx, strings.Trim(mq.Question[0].Name, "."))
if ex == nil { if ex == nil {
err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name) err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
log.Error(err) log.Error(err)
@ -293,11 +293,11 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
return return
} }
func (h *dnsHandler) selectExchanger(addr string) exchanger.Exchanger { func (h *dnsHandler) selectExchanger(ctx context.Context, addr string) exchanger.Exchanger {
if h.group == nil { if h.group == nil {
return nil return nil
} }
node := h.group.FilterAddr(addr).Next() node := h.group.FilterAddr(addr).Next(ctx)
if node == nil { if node == nil {
return nil return nil
} }

View File

@ -77,7 +77,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}() }()
target := h.group.Next() target := h.group.Next(ctx)
if target == nil { if target == nil {
err := errors.New("target not available") err := errors.New("target not available")
log.Error(err) log.Error(err)

View File

@ -71,7 +71,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}() }()
target := h.group.Next() target := h.group.Next(ctx)
if target == nil { if target == nil {
err := errors.New("target not available") err := errors.New("target not available")
log.Error(err) log.Error(err)

View File

@ -17,7 +17,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
} }
target := h.group.Next() target := h.group.Next(ctx)
if target == nil { if target == nil {
resp.Status = relay.StatusServiceUnavailable resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn) resp.WriteTo(conn)

View File

@ -105,7 +105,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
var raddr net.Addr var raddr net.Addr
var err error var err error
target := h.group.Next() target := h.group.Next(ctx)
if target != nil { if target != nil {
raddr, err = net.ResolveUDPAddr(network, target.Addr) raddr, err = net.ResolveUDPAddr(network, target.Addr)
if err != nil { if err != nil {

View File

@ -87,7 +87,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
var raddr net.Addr var raddr net.Addr
var err error var err error
target := h.group.Next() target := h.group.Next(ctx)
if target != nil { if target != nil {
raddr, err = net.ResolveUDPAddr(network, target.Addr) raddr, err = net.ResolveUDPAddr(network, target.Addr)
if err != nil { if err != nil {

View File

@ -1,6 +1,8 @@
package registry package registry
import ( import (
"context"
"github.com/go-gost/core/chain" "github.com/go-gost/core/chain"
"github.com/go-gost/core/metadata" "github.com/go-gost/core/metadata"
"github.com/go-gost/core/selector" "github.com/go-gost/core/selector"
@ -49,10 +51,10 @@ func (w *chainWrapper) Metadata() metadata.Metadata {
return v.Metadata() return v.Metadata()
} }
func (w *chainWrapper) Route(network, address string) chain.Route { func (w *chainWrapper) Route(ctx context.Context, network, address string) chain.Route {
v := w.r.get(w.name) v := w.r.get(w.name)
if v == nil { if v == nil {
return nil return nil
} }
return v.Route(network, address) return v.Route(ctx, network, address)
} }

View File

@ -1,6 +1,7 @@
package selector package selector
import ( import (
"context"
"time" "time"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
@ -22,7 +23,7 @@ func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) sele
} }
// Filter filters dead objects. // Filter filters dead objects.
func (f *failFilter[T]) Filter(vs ...T) []T { func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T {
if len(vs) <= 1 { if len(vs) <= 1 {
return vs return vs
} }
@ -66,7 +67,7 @@ func BackupFilter[T selector.Selectable]() selector.Filter[T] {
} }
// Filter filters backup objects. // Filter filters backup objects.
func (f *backupFilter[T]) Filter(vs ...T) []T { func (f *backupFilter[T]) Filter(ctx context.Context, vs ...T) []T {
if len(vs) <= 1 { if len(vs) <= 1 {
return vs return vs
} }

View File

@ -1,9 +1,9 @@
package selector package selector
import ( import (
"context"
"time" "time"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/selector" "github.com/go-gost/core/selector"
) )
@ -20,17 +20,6 @@ const (
labelFailTimeout = "failTimeout" labelFailTimeout = "failTimeout"
) )
var (
DefaultNodeSelector = NewSelector(
RoundRobinStrategy[*chain.Node](),
// FailFilter[*Node](1, DefaultFailTimeout),
)
DefaultChainSelector = NewSelector(
RoundRobinStrategy[chain.SelectableChainer](),
// FailFilter[SelectableChainer](1, DefaultFailTimeout),
)
)
type defaultSelector[T selector.Selectable] struct { type defaultSelector[T selector.Selectable] struct {
strategy selector.Strategy[T] strategy selector.Strategy[T]
filters []selector.Filter[T] filters []selector.Filter[T]
@ -43,12 +32,12 @@ func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters .
} }
} }
func (s *defaultSelector[T]) Select(vs ...T) (v T) { func (s *defaultSelector[T]) Select(ctx context.Context, vs ...T) (v T) {
for _, filter := range s.filters { for _, filter := range s.filters {
vs = filter.Filter(vs...) vs = filter.Filter(ctx, vs...)
} }
if len(vs) == 0 { if len(vs) == 0 {
return return
} }
return s.strategy.Apply(vs...) return s.strategy.Apply(ctx, vs...)
} }

View File

@ -1,11 +1,11 @@
package selector package selector
import ( import (
"math/rand" "context"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/core/selector" "github.com/go-gost/core/selector"
) )
@ -19,7 +19,7 @@ func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] {
return &roundRobinStrategy[T]{} return &roundRobinStrategy[T]{}
} }
func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) { func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
if len(vs) == 0 { if len(vs) == 0 {
return return
} }
@ -29,29 +29,36 @@ func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) {
} }
type randomStrategy[T selector.Selectable] struct { type randomStrategy[T selector.Selectable] struct {
rand *rand.Rand rw *randomWeighted[T]
mux sync.Mutex mu sync.Mutex
} }
// RandomStrategy is a strategy for node selector. // RandomStrategy is a strategy for node selector.
// The node will be selected randomly. // The node will be selected randomly.
func RandomStrategy[T selector.Selectable]() selector.Strategy[T] { func RandomStrategy[T selector.Selectable]() selector.Strategy[T] {
return &randomStrategy[T]{ return &randomStrategy[T]{
rand: rand.New(rand.NewSource(time.Now().UnixNano())), rw: newRandomWeighted[T](),
} }
} }
func (s *randomStrategy[T]) Apply(vs ...T) (v T) { func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
if len(vs) == 0 { if len(vs) == 0 {
return return
} }
s.mux.Lock() s.mu.Lock()
defer s.mux.Unlock() defer s.mu.Unlock()
r := s.rand.Int() s.rw.Reset()
for i := range vs {
weight := mdutil.GetInt(vs[i].Metadata(), labelWeight)
if weight <= 0 {
weight = 1
}
s.rw.Add(vs[i], weight)
}
return vs[r%len(vs)] return s.rw.Next()
} }
type fifoStrategy[T selector.Selectable] struct{} type fifoStrategy[T selector.Selectable] struct{}
@ -64,7 +71,7 @@ func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] {
} }
// Apply applies the fifo strategy for the nodes. // Apply applies the fifo strategy for the nodes.
func (s *fifoStrategy[T]) Apply(vs ...T) (v T) { func (s *fifoStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
if len(vs) == 0 { if len(vs) == 0 {
return return
} }

54
selector/weighted.go Normal file
View File

@ -0,0 +1,54 @@
package selector
import (
"math/rand"
"time"
"github.com/go-gost/core/selector"
)
type randomWeightedItem[T selector.Selectable] struct {
item T
weight int
}
type randomWeighted[T selector.Selectable] struct {
items []*randomWeightedItem[T]
sum int
r *rand.Rand
}
func newRandomWeighted[T selector.Selectable]() *randomWeighted[T] {
return &randomWeighted[T]{
r: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
func (rw *randomWeighted[T]) Add(item T, weight int) {
ri := &randomWeightedItem[T]{item: item, weight: weight}
rw.items = append(rw.items, ri)
rw.sum += weight
}
func (rw *randomWeighted[T]) Next() (v T) {
if len(rw.items) == 0 {
return
}
if rw.sum <= 0 {
return
}
weight := rw.r.Intn(rw.sum) + 1
for _, item := range rw.items {
weight -= item.weight
if weight <= 0 {
return item.item
}
}
return rw.items[len(rw.items)-1].item
}
func (rw *randomWeighted[T]) Reset() {
rw.items = nil
rw.sum = 0
}