x/handler/tunnel/tunnel.go
2024-01-27 23:31:23 +08:00

296 lines
5.0 KiB
Go

package tunnel
import (
"context"
"sync"
"time"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/sd"
"github.com/go-gost/relay"
"github.com/go-gost/x/internal/util/mux"
"github.com/go-gost/x/selector"
"github.com/google/uuid"
)
const (
MaxWeight uint8 = 0xff
)
type Connector struct {
id relay.ConnectorID
tid relay.TunnelID
node string
sd sd.SD
t time.Time
s *mux.Session
}
func NewConnector(id relay.ConnectorID, tid relay.TunnelID, node string, s *mux.Session, sd sd.SD) *Connector {
c := &Connector{
id: id,
tid: tid,
node: node,
sd: sd,
t: time.Now(),
s: s,
}
go c.accept()
return c
}
func (c *Connector) accept() {
for {
conn, err := c.s.Accept()
if err != nil {
logger.Default().Errorf("connector %s: %v", c.id, err)
c.s.Close()
if c.sd != nil {
c.sd.Deregister(context.Background(), &sd.Service{
ID: c.id.String(),
Name: c.tid.String(),
Node: c.node,
})
}
return
}
conn.Close()
}
}
func (c *Connector) ID() relay.ConnectorID {
return c.id
}
func (c *Connector) Session() *mux.Session {
return c.s
}
type Tunnel struct {
node string
id relay.TunnelID
connectors []*Connector
t time.Time
close chan struct{}
mu sync.RWMutex
sd sd.SD
ttl time.Duration
rw *selector.RandomWeighted[*Connector]
}
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
t := &Tunnel{
node: node,
id: tid,
t: time.Now(),
close: make(chan struct{}),
ttl: ttl,
rw: selector.NewRandomWeighted[*Connector](),
}
if t.ttl <= 0 {
t.ttl = defaultTTL
}
go t.clean()
return t
}
func (t *Tunnel) WithSD(sd sd.SD) {
t.sd = sd
}
func (t *Tunnel) ID() relay.TunnelID {
return t.id
}
func (t *Tunnel) AddConnector(c *Connector) {
if c == nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.connectors = append(t.connectors, c)
}
func (t *Tunnel) GetConnector(network string) *Connector {
t.mu.RLock()
defer t.mu.RUnlock()
rw := t.rw
rw.Reset()
found := false
var connectors []*Connector
for _, c := range t.connectors {
if c.Session().IsClosed() {
continue
}
weight := c.ID().Weight()
if weight == 0 {
weight = 1
}
if network == "udp" && c.id.IsUDP() ||
network != "udp" && !c.id.IsUDP() {
if weight == MaxWeight && !found {
connectors = nil
found = true
}
if weight == MaxWeight || !found {
connectors = append(connectors, c)
rw.Add(c, int(weight))
}
}
}
if len(connectors) == 0 {
return nil
}
return rw.Next()
}
func (t *Tunnel) CloseOnIdle() bool {
t.mu.RLock()
defer t.mu.RUnlock()
select {
case <-t.close:
default:
if len(t.connectors) == 0 {
close(t.close)
return true
}
}
return false
}
func (t *Tunnel) clean() {
ticker := time.NewTicker(t.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.mu.Lock()
if len(t.connectors) == 0 {
t.mu.Unlock()
break
}
var connectors []*Connector
for _, c := range t.connectors {
if c.Session().IsClosed() {
logger.Default().Debugf("remove tunnel: %s, connector: %s", t.id, c.id)
if t.sd != nil {
t.sd.Deregister(context.Background(), &sd.Service{
ID: c.id.String(),
Name: t.id.String(),
Node: t.node,
})
}
continue
}
connectors = append(connectors, c)
if t.sd != nil {
t.sd.Renew(context.Background(), &sd.Service{
ID: c.id.String(),
Name: t.id.String(),
Node: t.node,
})
}
}
if len(connectors) != len(t.connectors) {
t.connectors = connectors
}
t.mu.Unlock()
case <-t.close:
return
}
}
}
type ConnectorPool struct {
node string
sd sd.SD
tunnels map[string]*Tunnel
mu sync.RWMutex
}
func NewConnectorPool(node string, sd sd.SD) *ConnectorPool {
p := &ConnectorPool{
node: node,
sd: sd,
tunnels: make(map[string]*Tunnel),
}
go p.closeIdles()
return p
}
func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector, ttl time.Duration) {
p.mu.Lock()
defer p.mu.Unlock()
s := tid.String()
t := p.tunnels[s]
if t == nil {
t = NewTunnel(p.node, tid, ttl)
t.WithSD(p.sd)
p.tunnels[s] = t
}
t.AddConnector(c)
}
func (p *ConnectorPool) Get(network string, tid string) *Connector {
if p == nil {
return nil
}
p.mu.RLock()
defer p.mu.RUnlock()
t := p.tunnels[tid]
if t == nil {
return nil
}
return t.GetConnector(network)
}
func (p *ConnectorPool) closeIdles() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
p.mu.Lock()
for k, v := range p.tunnels {
if v.CloseOnIdle() {
delete(p.tunnels, k)
logger.Default().Debugf("remove idle tunnel: %s", k)
}
}
p.mu.Unlock()
}
}
func parseTunnelID(s string) (tid relay.TunnelID) {
if s == "" {
return
}
private := false
if s[0] == '$' {
private = true
s = s[1:]
}
uuid, _ := uuid.Parse(s)
if private {
return relay.NewPrivateTunnelID(uuid[:])
}
return relay.NewTunnelID(uuid[:])
}