add weight for tunnel connector

This commit is contained in:
ginuerzh 2024-01-27 23:31:23 +08:00
parent 43d37d0a5f
commit b5b39de62c
7 changed files with 45 additions and 14 deletions

View File

@ -40,6 +40,10 @@ func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) {
c.md.tunnelID = relay.NewTunnelID(uuid[:]) c.md.tunnelID = relay.NewTunnelID(uuid[:])
} }
if weight := mdutil.GetInt(md, "tunnel.weight"); weight > 0 {
c.md.tunnelID = c.md.tunnelID.SetWeight(uint8(weight))
}
c.md.muxCfg = &mux.Config{ c.md.muxCfg = &mux.Config{
Version: mdutil.GetInt(md, "mux.version"), Version: mdutil.GetInt(md, "mux.version"),
KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"),

2
go.mod
View File

@ -11,7 +11,7 @@ require (
github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.4.0 github.com/go-gost/gosocks5 v0.4.0
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/gobwas/glob v0.2.3 github.com/gobwas/glob v0.2.3

2
go.sum
View File

@ -61,6 +61,8 @@ github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a h1:ME7P1Brcg4C640DS
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw=
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI=
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8=
github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a h1:Y9BSyWFOnytIA4v3rTmw1YF1OI+ZGDF3hhFEZA8HNwg=
github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8=
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I=
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=

View File

@ -30,6 +30,8 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network,
if network == "udp" { if network == "udp" {
connectorID = relay.NewUDPConnectorID(uuid[:]) connectorID = relay.NewUDPConnectorID(uuid[:])
} }
// copy weight from tunnelID
connectorID = connectorID.SetWeight(tunnelID.Weight())
v := md5.Sum([]byte(tunnelID.String())) v := md5.Sum([]byte(tunnelID.String()))
endpoint := hex.EncodeToString(v[:8]) endpoint := hex.EncodeToString(v[:8])

View File

@ -3,16 +3,20 @@ package tunnel
import ( import (
"context" "context"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/core/sd" "github.com/go-gost/core/sd"
"github.com/go-gost/relay" "github.com/go-gost/relay"
"github.com/go-gost/x/internal/util/mux" "github.com/go-gost/x/internal/util/mux"
"github.com/go-gost/x/selector"
"github.com/google/uuid" "github.com/google/uuid"
) )
const (
MaxWeight uint8 = 0xff
)
type Connector struct { type Connector struct {
id relay.ConnectorID id relay.ConnectorID
tid relay.TunnelID tid relay.TunnelID
@ -67,11 +71,11 @@ type Tunnel struct {
id relay.TunnelID id relay.TunnelID
connectors []*Connector connectors []*Connector
t time.Time t time.Time
n uint64
close chan struct{} close chan struct{}
mu sync.RWMutex mu sync.RWMutex
sd sd.SD sd sd.SD
ttl time.Duration ttl time.Duration
rw *selector.RandomWeighted[*Connector]
} }
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel { func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
@ -81,6 +85,7 @@ func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
t: time.Now(), t: time.Now(),
close: make(chan struct{}), close: make(chan struct{}),
ttl: ttl, ttl: ttl,
rw: selector.NewRandomWeighted[*Connector](),
} }
if t.ttl <= 0 { if t.ttl <= 0 {
t.ttl = defaultTTL t.ttl = defaultTTL
@ -112,21 +117,39 @@ func (t *Tunnel) GetConnector(network string) *Connector {
t.mu.RLock() t.mu.RLock()
defer t.mu.RUnlock() defer t.mu.RUnlock()
rw := t.rw
rw.Reset()
found := false
var connectors []*Connector var connectors []*Connector
for _, c := range t.connectors { for _, c := range t.connectors {
if c.Session().IsClosed() { if c.Session().IsClosed() {
continue continue
} }
weight := c.ID().Weight()
if weight == 0 {
weight = 1
}
if network == "udp" && c.id.IsUDP() || if network == "udp" && c.id.IsUDP() ||
network != "udp" && !c.id.IsUDP() { network != "udp" && !c.id.IsUDP() {
connectors = append(connectors, c) if weight == MaxWeight && !found {
connectors = nil
found = true
}
if weight == MaxWeight || !found {
connectors = append(connectors, c)
rw.Add(c, int(weight))
}
} }
} }
if len(connectors) == 0 { if len(connectors) == 0 {
return nil return nil
} }
n := atomic.AddUint64(&t.n, 1) - 1
return connectors[n%uint64(len(connectors))] return rw.Next()
} }
func (t *Tunnel) CloseOnIdle() bool { func (t *Tunnel) CloseOnIdle() bool {

View File

@ -35,7 +35,7 @@ func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
} }
type randomStrategy[T any] struct { type randomStrategy[T any] struct {
rw *randomWeighted[T] rw *RandomWeighted[T]
mu sync.Mutex mu sync.Mutex
} }
@ -43,7 +43,7 @@ type randomStrategy[T any] struct {
// The node will be selected randomly. // The node will be selected randomly.
func RandomStrategy[T any]() selector.Strategy[T] { func RandomStrategy[T any]() selector.Strategy[T] {
return &randomStrategy[T]{ return &randomStrategy[T]{
rw: newRandomWeighted[T](), rw: NewRandomWeighted[T](),
} }
} }

View File

@ -10,25 +10,25 @@ type randomWeightedItem[T any] struct {
weight int weight int
} }
type randomWeighted[T any] struct { type RandomWeighted[T any] struct {
items []*randomWeightedItem[T] items []*randomWeightedItem[T]
sum int sum int
r *rand.Rand r *rand.Rand
} }
func newRandomWeighted[T any]() *randomWeighted[T] { func NewRandomWeighted[T any]() *RandomWeighted[T] {
return &randomWeighted[T]{ return &RandomWeighted[T]{
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
} }
} }
func (rw *randomWeighted[T]) Add(item T, weight int) { func (rw *RandomWeighted[T]) Add(item T, weight int) {
ri := &randomWeightedItem[T]{item: item, weight: weight} ri := &randomWeightedItem[T]{item: item, weight: weight}
rw.items = append(rw.items, ri) rw.items = append(rw.items, ri)
rw.sum += weight rw.sum += weight
} }
func (rw *randomWeighted[T]) Next() (v T) { func (rw *RandomWeighted[T]) Next() (v T) {
if len(rw.items) == 0 { if len(rw.items) == 0 {
return return
} }
@ -46,7 +46,7 @@ func (rw *randomWeighted[T]) Next() (v T) {
return rw.items[len(rw.items)-1].item return rw.items[len(rw.items)-1].item
} }
func (rw *randomWeighted[T]) Reset() { func (rw *RandomWeighted[T]) Reset() {
rw.items = nil rw.items = nil
rw.sum = 0 rw.sum = 0
} }