diff --git a/config/parsing/chain.go b/config/parsing/chain.go index e8d26ce..6d8601f 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -10,7 +10,6 @@ import ( tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" "github.com/go-gost/x/registry" - xs "github.com/go-gost/x/selector" ) func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { @@ -142,7 +141,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { sel = s } if sel == nil { - sel = xs.DefaultNodeSelector + sel = defaultNodeSelector() } group.WithSelector(sel). WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...)) diff --git a/config/parsing/parse.go b/config/parsing/parse.go index a7ad3db..a12ad3b 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -298,3 +298,11 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { return } + +func defaultNodeSelector() selector.Selector[*chain.Node] { + return xs.NewSelector(xs.RoundRobinStrategy[*chain.Node]()) +} + +func defaultChainSelector() selector.Selector[chain.SelectableChainer] { + return xs.NewSelector(xs.RoundRobinStrategy[chain.SelectableChainer]()) +} diff --git a/config/parsing/service.go b/config/parsing/service.go index 11510c1..59db6d8 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -17,7 +17,6 @@ import ( tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" "github.com/go-gost/x/registry" - xs "github.com/go-gost/x/selector" ) func ParseService(cfg *config.ServiceConfig) (service.Service, error) { @@ -188,7 +187,7 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { sel := parseNodeSelector(cfg.Selector) if sel == nil { - sel = xs.DefaultNodeSelector + sel = defaultNodeSelector() } return group.WithSelector(sel) } @@ -250,7 +249,7 @@ func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { } if sel == nil { - sel = xs.DefaultChainSelector + sel = defaultChainSelector() } return chain.NewChainGroup(chains...). diff --git a/go.mod b/go.mod index 1c7fb32..f02e8ae 100644 --- a/go.mod +++ b/go.mod @@ -76,6 +76,7 @@ require ( github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect + github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect @@ -86,7 +87,7 @@ require ( github.com/tjfoc/gmsm v1.3.2 // indirect github.com/ugorji/go/codec v1.2.7 // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect - golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect + golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/tools v0.1.12 // indirect diff --git a/go.sum b/go.sum index 34dff07..2569d79 100644 --- a/go.sum +++ b/go.sum @@ -335,6 +335,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c h1:XBpqxCr2X2HYZMOA+HTDhj8njR4PGhsK+M+geaMAQ20= +github.com/smallnest/weighted v0.0.0-20201102054551-85ac5c79528c/go.mod h1:xc9CoZ+ZBGwajnWto5Aqw/wWg8euy4HtOr6K9Fxp9iw= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= @@ -424,6 +426,8 @@ golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EH golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= +golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw= +golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/handler/dns/handler.go b/handler/dns/handler.go index c993c16..96c76c4 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -221,7 +221,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger return nil, err } - ex := h.selectExchanger(strings.Trim(mq.Question[0].Name, ".")) + ex := h.selectExchanger(ctx, strings.Trim(mq.Question[0].Name, ".")) if ex == nil { err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name) log.Error(err) @@ -293,11 +293,11 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { return } -func (h *dnsHandler) selectExchanger(addr string) exchanger.Exchanger { +func (h *dnsHandler) selectExchanger(ctx context.Context, addr string) exchanger.Exchanger { if h.group == nil { return nil } - node := h.group.FilterAddr(addr).Next() + node := h.group.FilterAddr(addr).Next(ctx) if node == nil { return nil } diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index cebad15..a93638f 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -77,7 +77,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - target := h.group.Next() + target := h.group.Next(ctx) if target == nil { err := errors.New("target not available") log.Error(err) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 3cc76cb..faecbb6 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -71,7 +71,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - target := h.group.Next() + target := h.group.Next(ctx) if target == nil { err := errors.New("target not available") log.Error(err) diff --git a/handler/relay/forward.go b/handler/relay/forward.go index aa60da7..4221a0d 100644 --- a/handler/relay/forward.go +++ b/handler/relay/forward.go @@ -17,7 +17,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network Version: relay.Version1, Status: relay.StatusOK, } - target := h.group.Next() + target := h.group.Next(ctx) if target == nil { resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) diff --git a/handler/tap/handler.go b/handler/tap/handler.go index 2a321eb..962e033 100644 --- a/handler/tap/handler.go +++ b/handler/tap/handler.go @@ -105,7 +105,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. var raddr net.Addr var err error - target := h.group.Next() + target := h.group.Next(ctx) if target != nil { raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { diff --git a/handler/tun/handler.go b/handler/tun/handler.go index f2668b1..a7b052e 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -87,7 +87,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. var raddr net.Addr var err error - target := h.group.Next() + target := h.group.Next(ctx) if target != nil { raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { diff --git a/registry/chain.go b/registry/chain.go index cf440f0..1062927 100644 --- a/registry/chain.go +++ b/registry/chain.go @@ -1,6 +1,8 @@ package registry import ( + "context" + "github.com/go-gost/core/chain" "github.com/go-gost/core/metadata" "github.com/go-gost/core/selector" @@ -49,10 +51,10 @@ func (w *chainWrapper) Metadata() metadata.Metadata { return v.Metadata() } -func (w *chainWrapper) Route(network, address string) chain.Route { +func (w *chainWrapper) Route(ctx context.Context, network, address string) chain.Route { v := w.r.get(w.name) if v == nil { return nil } - return v.Route(network, address) + return v.Route(ctx, network, address) } diff --git a/selector/filter.go b/selector/filter.go index bd0bb15..077ae72 100644 --- a/selector/filter.go +++ b/selector/filter.go @@ -1,6 +1,7 @@ package selector import ( + "context" "time" mdutil "github.com/go-gost/core/metadata/util" @@ -22,7 +23,7 @@ func FailFilter[T selector.Selectable](maxFails int, timeout time.Duration) sele } // Filter filters dead objects. -func (f *failFilter[T]) Filter(vs ...T) []T { +func (f *failFilter[T]) Filter(ctx context.Context, vs ...T) []T { if len(vs) <= 1 { return vs } @@ -66,7 +67,7 @@ func BackupFilter[T selector.Selectable]() selector.Filter[T] { } // Filter filters backup objects. -func (f *backupFilter[T]) Filter(vs ...T) []T { +func (f *backupFilter[T]) Filter(ctx context.Context, vs ...T) []T { if len(vs) <= 1 { return vs } diff --git a/selector/selector.go b/selector/selector.go index 415b6a1..53108c5 100644 --- a/selector/selector.go +++ b/selector/selector.go @@ -1,9 +1,9 @@ package selector import ( + "context" "time" - "github.com/go-gost/core/chain" "github.com/go-gost/core/selector" ) @@ -20,17 +20,6 @@ const ( labelFailTimeout = "failTimeout" ) -var ( - DefaultNodeSelector = NewSelector( - RoundRobinStrategy[*chain.Node](), - // FailFilter[*Node](1, DefaultFailTimeout), - ) - DefaultChainSelector = NewSelector( - RoundRobinStrategy[chain.SelectableChainer](), - // FailFilter[SelectableChainer](1, DefaultFailTimeout), - ) -) - type defaultSelector[T selector.Selectable] struct { strategy selector.Strategy[T] filters []selector.Filter[T] @@ -43,12 +32,12 @@ func NewSelector[T selector.Selectable](strategy selector.Strategy[T], filters . } } -func (s *defaultSelector[T]) Select(vs ...T) (v T) { +func (s *defaultSelector[T]) Select(ctx context.Context, vs ...T) (v T) { for _, filter := range s.filters { - vs = filter.Filter(vs...) + vs = filter.Filter(ctx, vs...) } if len(vs) == 0 { return } - return s.strategy.Apply(vs...) + return s.strategy.Apply(ctx, vs...) } diff --git a/selector/strategy.go b/selector/strategy.go index a271717..fc57b58 100644 --- a/selector/strategy.go +++ b/selector/strategy.go @@ -1,11 +1,11 @@ package selector import ( - "math/rand" + "context" "sync" "sync/atomic" - "time" + mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/core/selector" ) @@ -19,7 +19,7 @@ func RoundRobinStrategy[T selector.Selectable]() selector.Strategy[T] { return &roundRobinStrategy[T]{} } -func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) { +func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { if len(vs) == 0 { return } @@ -29,29 +29,36 @@ func (s *roundRobinStrategy[T]) Apply(vs ...T) (v T) { } type randomStrategy[T selector.Selectable] struct { - rand *rand.Rand - mux sync.Mutex + rw *randomWeighted[T] + mu sync.Mutex } // RandomStrategy is a strategy for node selector. // The node will be selected randomly. func RandomStrategy[T selector.Selectable]() selector.Strategy[T] { return &randomStrategy[T]{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + rw: newRandomWeighted[T](), } } -func (s *randomStrategy[T]) Apply(vs ...T) (v T) { +func (s *randomStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { if len(vs) == 0 { return } - s.mux.Lock() - defer s.mux.Unlock() + s.mu.Lock() + defer s.mu.Unlock() - r := s.rand.Int() + s.rw.Reset() + for i := range vs { + weight := mdutil.GetInt(vs[i].Metadata(), labelWeight) + if weight <= 0 { + weight = 1 + } + s.rw.Add(vs[i], weight) + } - return vs[r%len(vs)] + return s.rw.Next() } type fifoStrategy[T selector.Selectable] struct{} @@ -64,7 +71,7 @@ func FIFOStrategy[T selector.Selectable]() selector.Strategy[T] { } // Apply applies the fifo strategy for the nodes. -func (s *fifoStrategy[T]) Apply(vs ...T) (v T) { +func (s *fifoStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { if len(vs) == 0 { return } diff --git a/selector/weighted.go b/selector/weighted.go new file mode 100644 index 0000000..478dda1 --- /dev/null +++ b/selector/weighted.go @@ -0,0 +1,54 @@ +package selector + +import ( + "math/rand" + "time" + + "github.com/go-gost/core/selector" +) + +type randomWeightedItem[T selector.Selectable] struct { + item T + weight int +} + +type randomWeighted[T selector.Selectable] struct { + items []*randomWeightedItem[T] + sum int + r *rand.Rand +} + +func newRandomWeighted[T selector.Selectable]() *randomWeighted[T] { + return &randomWeighted[T]{ + r: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (rw *randomWeighted[T]) Add(item T, weight int) { + ri := &randomWeightedItem[T]{item: item, weight: weight} + rw.items = append(rw.items, ri) + rw.sum += weight +} + +func (rw *randomWeighted[T]) Next() (v T) { + if len(rw.items) == 0 { + return + } + if rw.sum <= 0 { + return + } + weight := rw.r.Intn(rw.sum) + 1 + for _, item := range rw.items { + weight -= item.weight + if weight <= 0 { + return item.item + } + } + + return rw.items[len(rw.items)-1].item +} + +func (rw *randomWeighted[T]) Reset() { + rw.items = nil + rw.sum = 0 +}