diff --git a/connector/tunnel/metadata.go b/connector/tunnel/metadata.go index d678b62..5163293 100644 --- a/connector/tunnel/metadata.go +++ b/connector/tunnel/metadata.go @@ -40,6 +40,10 @@ func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) { c.md.tunnelID = relay.NewTunnelID(uuid[:]) } + if weight := mdutil.GetInt(md, "tunnel.weight"); weight > 0 { + c.md.tunnelID = c.md.tunnelID.SetWeight(uint8(weight)) + } + c.md.muxCfg = &mux.Config{ Version: mdutil.GetInt(md, "mux.version"), KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), diff --git a/go.mod b/go.mod index a6930f7..fb6b158 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a - github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 + github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 github.com/gobwas/glob v0.2.3 diff --git a/go.sum b/go.sum index 38f6648..5305a2d 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a h1:ME7P1Brcg4C640DS github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= +github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a h1:Y9BSyWFOnytIA4v3rTmw1YF1OI+ZGDF3hhFEZA8HNwg= +github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index 188691b..c0ffc40 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -30,6 +30,8 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, if network == "udp" { connectorID = relay.NewUDPConnectorID(uuid[:]) } + // copy weight from tunnelID + connectorID = connectorID.SetWeight(tunnelID.Weight()) v := md5.Sum([]byte(tunnelID.String())) endpoint := hex.EncodeToString(v[:8]) diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go index b71f6eb..467ade8 100644 --- a/handler/tunnel/tunnel.go +++ b/handler/tunnel/tunnel.go @@ -3,16 +3,20 @@ package tunnel import ( "context" "sync" - "sync/atomic" "time" "github.com/go-gost/core/logger" "github.com/go-gost/core/sd" "github.com/go-gost/relay" "github.com/go-gost/x/internal/util/mux" + "github.com/go-gost/x/selector" "github.com/google/uuid" ) +const ( + MaxWeight uint8 = 0xff +) + type Connector struct { id relay.ConnectorID tid relay.TunnelID @@ -67,11 +71,11 @@ type Tunnel struct { id relay.TunnelID connectors []*Connector t time.Time - n uint64 close chan struct{} mu sync.RWMutex sd sd.SD ttl time.Duration + rw *selector.RandomWeighted[*Connector] } func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel { @@ -81,6 +85,7 @@ func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel { t: time.Now(), close: make(chan struct{}), ttl: ttl, + rw: selector.NewRandomWeighted[*Connector](), } if t.ttl <= 0 { t.ttl = defaultTTL @@ -112,21 +117,39 @@ func (t *Tunnel) GetConnector(network string) *Connector { t.mu.RLock() defer t.mu.RUnlock() + rw := t.rw + rw.Reset() + + found := false var connectors []*Connector for _, c := range t.connectors { if c.Session().IsClosed() { continue } + + weight := c.ID().Weight() + if weight == 0 { + weight = 1 + } + if network == "udp" && c.id.IsUDP() || network != "udp" && !c.id.IsUDP() { - connectors = append(connectors, c) + if weight == MaxWeight && !found { + connectors = nil + found = true + } + + if weight == MaxWeight || !found { + connectors = append(connectors, c) + rw.Add(c, int(weight)) + } } } if len(connectors) == 0 { return nil } - n := atomic.AddUint64(&t.n, 1) - 1 - return connectors[n%uint64(len(connectors))] + + return rw.Next() } func (t *Tunnel) CloseOnIdle() bool { diff --git a/selector/strategy.go b/selector/strategy.go index 4a93448..fe4d886 100644 --- a/selector/strategy.go +++ b/selector/strategy.go @@ -35,7 +35,7 @@ func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { } type randomStrategy[T any] struct { - rw *randomWeighted[T] + rw *RandomWeighted[T] mu sync.Mutex } @@ -43,7 +43,7 @@ type randomStrategy[T any] struct { // The node will be selected randomly. func RandomStrategy[T any]() selector.Strategy[T] { return &randomStrategy[T]{ - rw: newRandomWeighted[T](), + rw: NewRandomWeighted[T](), } } diff --git a/selector/weighted.go b/selector/weighted.go index 956342e..bb50f2a 100644 --- a/selector/weighted.go +++ b/selector/weighted.go @@ -10,25 +10,25 @@ type randomWeightedItem[T any] struct { weight int } -type randomWeighted[T any] struct { +type RandomWeighted[T any] struct { items []*randomWeightedItem[T] sum int r *rand.Rand } -func newRandomWeighted[T any]() *randomWeighted[T] { - return &randomWeighted[T]{ +func NewRandomWeighted[T any]() *RandomWeighted[T] { + return &RandomWeighted[T]{ r: rand.New(rand.NewSource(time.Now().UnixNano())), } } -func (rw *randomWeighted[T]) Add(item T, weight int) { +func (rw *RandomWeighted[T]) Add(item T, weight int) { ri := &randomWeightedItem[T]{item: item, weight: weight} rw.items = append(rw.items, ri) rw.sum += weight } -func (rw *randomWeighted[T]) Next() (v T) { +func (rw *RandomWeighted[T]) Next() (v T) { if len(rw.items) == 0 { return } @@ -46,7 +46,7 @@ func (rw *randomWeighted[T]) Next() (v T) { return rw.items[len(rw.items)-1].item } -func (rw *randomWeighted[T]) Reset() { +func (rw *RandomWeighted[T]) Reset() { rw.items = nil rw.sum = 0 }