add weight for tunnel connector
This commit is contained in:
parent
43d37d0a5f
commit
b5b39de62c
@ -40,6 +40,10 @@ func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
c.md.tunnelID = relay.NewTunnelID(uuid[:])
|
c.md.tunnelID = relay.NewTunnelID(uuid[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if weight := mdutil.GetInt(md, "tunnel.weight"); weight > 0 {
|
||||||
|
c.md.tunnelID = c.md.tunnelID.SetWeight(uint8(weight))
|
||||||
|
}
|
||||||
|
|
||||||
c.md.muxCfg = &mux.Config{
|
c.md.muxCfg = &mux.Config{
|
||||||
Version: mdutil.GetInt(md, "mux.version"),
|
Version: mdutil.GetInt(md, "mux.version"),
|
||||||
KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"),
|
KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"),
|
||||||
|
2
go.mod
2
go.mod
@ -11,7 +11,7 @@ require (
|
|||||||
github.com/go-gost/gosocks4 v0.0.1
|
github.com/go-gost/gosocks4 v0.0.1
|
||||||
github.com/go-gost/gosocks5 v0.4.0
|
github.com/go-gost/gosocks5 v0.4.0
|
||||||
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a
|
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a
|
||||||
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7
|
github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a
|
||||||
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451
|
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451
|
||||||
github.com/go-redis/redis/v8 v8.11.5
|
github.com/go-redis/redis/v8 v8.11.5
|
||||||
github.com/gobwas/glob v0.2.3
|
github.com/gobwas/glob v0.2.3
|
||||||
|
2
go.sum
2
go.sum
@ -61,6 +61,8 @@ github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a h1:ME7P1Brcg4C640DS
|
|||||||
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw=
|
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw=
|
||||||
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI=
|
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI=
|
||||||
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8=
|
github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8=
|
||||||
|
github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a h1:Y9BSyWFOnytIA4v3rTmw1YF1OI+ZGDF3hhFEZA8HNwg=
|
||||||
|
github.com/go-gost/relay v0.4.1-0.20240127152636-06a246ca1c1a/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8=
|
||||||
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I=
|
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I=
|
||||||
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs=
|
github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs=
|
||||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||||
|
@ -30,6 +30,8 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network,
|
|||||||
if network == "udp" {
|
if network == "udp" {
|
||||||
connectorID = relay.NewUDPConnectorID(uuid[:])
|
connectorID = relay.NewUDPConnectorID(uuid[:])
|
||||||
}
|
}
|
||||||
|
// copy weight from tunnelID
|
||||||
|
connectorID = connectorID.SetWeight(tunnelID.Weight())
|
||||||
|
|
||||||
v := md5.Sum([]byte(tunnelID.String()))
|
v := md5.Sum([]byte(tunnelID.String()))
|
||||||
endpoint := hex.EncodeToString(v[:8])
|
endpoint := hex.EncodeToString(v[:8])
|
||||||
|
@ -3,16 +3,20 @@ package tunnel
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
"github.com/go-gost/core/sd"
|
"github.com/go-gost/core/sd"
|
||||||
"github.com/go-gost/relay"
|
"github.com/go-gost/relay"
|
||||||
"github.com/go-gost/x/internal/util/mux"
|
"github.com/go-gost/x/internal/util/mux"
|
||||||
|
"github.com/go-gost/x/selector"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxWeight uint8 = 0xff
|
||||||
|
)
|
||||||
|
|
||||||
type Connector struct {
|
type Connector struct {
|
||||||
id relay.ConnectorID
|
id relay.ConnectorID
|
||||||
tid relay.TunnelID
|
tid relay.TunnelID
|
||||||
@ -67,11 +71,11 @@ type Tunnel struct {
|
|||||||
id relay.TunnelID
|
id relay.TunnelID
|
||||||
connectors []*Connector
|
connectors []*Connector
|
||||||
t time.Time
|
t time.Time
|
||||||
n uint64
|
|
||||||
close chan struct{}
|
close chan struct{}
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
sd sd.SD
|
sd sd.SD
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
|
rw *selector.RandomWeighted[*Connector]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
||||||
@ -81,6 +85,7 @@ func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
|||||||
t: time.Now(),
|
t: time.Now(),
|
||||||
close: make(chan struct{}),
|
close: make(chan struct{}),
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
|
rw: selector.NewRandomWeighted[*Connector](),
|
||||||
}
|
}
|
||||||
if t.ttl <= 0 {
|
if t.ttl <= 0 {
|
||||||
t.ttl = defaultTTL
|
t.ttl = defaultTTL
|
||||||
@ -112,21 +117,39 @@ func (t *Tunnel) GetConnector(network string) *Connector {
|
|||||||
t.mu.RLock()
|
t.mu.RLock()
|
||||||
defer t.mu.RUnlock()
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
rw := t.rw
|
||||||
|
rw.Reset()
|
||||||
|
|
||||||
|
found := false
|
||||||
var connectors []*Connector
|
var connectors []*Connector
|
||||||
for _, c := range t.connectors {
|
for _, c := range t.connectors {
|
||||||
if c.Session().IsClosed() {
|
if c.Session().IsClosed() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
weight := c.ID().Weight()
|
||||||
|
if weight == 0 {
|
||||||
|
weight = 1
|
||||||
|
}
|
||||||
|
|
||||||
if network == "udp" && c.id.IsUDP() ||
|
if network == "udp" && c.id.IsUDP() ||
|
||||||
network != "udp" && !c.id.IsUDP() {
|
network != "udp" && !c.id.IsUDP() {
|
||||||
connectors = append(connectors, c)
|
if weight == MaxWeight && !found {
|
||||||
|
connectors = nil
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if weight == MaxWeight || !found {
|
||||||
|
connectors = append(connectors, c)
|
||||||
|
rw.Add(c, int(weight))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(connectors) == 0 {
|
if len(connectors) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
n := atomic.AddUint64(&t.n, 1) - 1
|
|
||||||
return connectors[n%uint64(len(connectors))]
|
return rw.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) CloseOnIdle() bool {
|
func (t *Tunnel) CloseOnIdle() bool {
|
||||||
|
@ -35,7 +35,7 @@ func (s *roundRobinStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type randomStrategy[T any] struct {
|
type randomStrategy[T any] struct {
|
||||||
rw *randomWeighted[T]
|
rw *RandomWeighted[T]
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ type randomStrategy[T any] struct {
|
|||||||
// The node will be selected randomly.
|
// The node will be selected randomly.
|
||||||
func RandomStrategy[T any]() selector.Strategy[T] {
|
func RandomStrategy[T any]() selector.Strategy[T] {
|
||||||
return &randomStrategy[T]{
|
return &randomStrategy[T]{
|
||||||
rw: newRandomWeighted[T](),
|
rw: NewRandomWeighted[T](),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,25 +10,25 @@ type randomWeightedItem[T any] struct {
|
|||||||
weight int
|
weight int
|
||||||
}
|
}
|
||||||
|
|
||||||
type randomWeighted[T any] struct {
|
type RandomWeighted[T any] struct {
|
||||||
items []*randomWeightedItem[T]
|
items []*randomWeightedItem[T]
|
||||||
sum int
|
sum int
|
||||||
r *rand.Rand
|
r *rand.Rand
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRandomWeighted[T any]() *randomWeighted[T] {
|
func NewRandomWeighted[T any]() *RandomWeighted[T] {
|
||||||
return &randomWeighted[T]{
|
return &RandomWeighted[T]{
|
||||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *randomWeighted[T]) Add(item T, weight int) {
|
func (rw *RandomWeighted[T]) Add(item T, weight int) {
|
||||||
ri := &randomWeightedItem[T]{item: item, weight: weight}
|
ri := &randomWeightedItem[T]{item: item, weight: weight}
|
||||||
rw.items = append(rw.items, ri)
|
rw.items = append(rw.items, ri)
|
||||||
rw.sum += weight
|
rw.sum += weight
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *randomWeighted[T]) Next() (v T) {
|
func (rw *RandomWeighted[T]) Next() (v T) {
|
||||||
if len(rw.items) == 0 {
|
if len(rw.items) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -46,7 +46,7 @@ func (rw *randomWeighted[T]) Next() (v T) {
|
|||||||
return rw.items[len(rw.items)-1].item
|
return rw.items[len(rw.items)-1].item
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *randomWeighted[T]) Reset() {
|
func (rw *RandomWeighted[T]) Reset() {
|
||||||
rw.items = nil
|
rw.items = nil
|
||||||
rw.sum = 0
|
rw.sum = 0
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user