296 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			296 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package tunnel
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/go-gost/core/logger"
 | |
| 	"github.com/go-gost/core/sd"
 | |
| 	"github.com/go-gost/relay"
 | |
| 	"github.com/go-gost/x/internal/util/mux"
 | |
| 	"github.com/go-gost/x/selector"
 | |
| 	"github.com/google/uuid"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	MaxWeight uint8 = 0xff
 | |
| )
 | |
| 
 | |
| type Connector struct {
 | |
| 	id   relay.ConnectorID
 | |
| 	tid  relay.TunnelID
 | |
| 	node string
 | |
| 	sd   sd.SD
 | |
| 	t    time.Time
 | |
| 	s    *mux.Session
 | |
| }
 | |
| 
 | |
| func NewConnector(id relay.ConnectorID, tid relay.TunnelID, node string, s *mux.Session, sd sd.SD) *Connector {
 | |
| 	c := &Connector{
 | |
| 		id:   id,
 | |
| 		tid:  tid,
 | |
| 		node: node,
 | |
| 		sd:   sd,
 | |
| 		t:    time.Now(),
 | |
| 		s:    s,
 | |
| 	}
 | |
| 	go c.accept()
 | |
| 	return c
 | |
| }
 | |
| 
 | |
| func (c *Connector) accept() {
 | |
| 	for {
 | |
| 		conn, err := c.s.Accept()
 | |
| 		if err != nil {
 | |
| 			logger.Default().Errorf("connector %s: %v", c.id, err)
 | |
| 			c.s.Close()
 | |
| 			if c.sd != nil {
 | |
| 				c.sd.Deregister(context.Background(), &sd.Service{
 | |
| 					ID:   c.id.String(),
 | |
| 					Name: c.tid.String(),
 | |
| 					Node: c.node,
 | |
| 				})
 | |
| 			}
 | |
| 			return
 | |
| 		}
 | |
| 		conn.Close()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c *Connector) ID() relay.ConnectorID {
 | |
| 	return c.id
 | |
| }
 | |
| 
 | |
| func (c *Connector) Session() *mux.Session {
 | |
| 	return c.s
 | |
| }
 | |
| 
 | |
| type Tunnel struct {
 | |
| 	node       string
 | |
| 	id         relay.TunnelID
 | |
| 	connectors []*Connector
 | |
| 	t          time.Time
 | |
| 	close      chan struct{}
 | |
| 	mu         sync.RWMutex
 | |
| 	sd         sd.SD
 | |
| 	ttl        time.Duration
 | |
| 	rw         *selector.RandomWeighted[*Connector]
 | |
| }
 | |
| 
 | |
| func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
 | |
| 	t := &Tunnel{
 | |
| 		node:  node,
 | |
| 		id:    tid,
 | |
| 		t:     time.Now(),
 | |
| 		close: make(chan struct{}),
 | |
| 		ttl:   ttl,
 | |
| 		rw:    selector.NewRandomWeighted[*Connector](),
 | |
| 	}
 | |
| 	if t.ttl <= 0 {
 | |
| 		t.ttl = defaultTTL
 | |
| 	}
 | |
| 	go t.clean()
 | |
| 	return t
 | |
| }
 | |
| 
 | |
| func (t *Tunnel) WithSD(sd sd.SD) {
 | |
| 	t.sd = sd
 | |
| }
 | |
| 
 | |
| func (t *Tunnel) ID() relay.TunnelID {
 | |
| 	return t.id
 | |
| }
 | |
| 
 | |
| func (t *Tunnel) AddConnector(c *Connector) {
 | |
| 	if c == nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	t.mu.Lock()
 | |
| 	defer t.mu.Unlock()
 | |
| 
 | |
| 	t.connectors = append(t.connectors, c)
 | |
| }
 | |
| 
 | |
| func (t *Tunnel) GetConnector(network string) *Connector {
 | |
| 	t.mu.RLock()
 | |
| 	defer t.mu.RUnlock()
 | |
| 
 | |
| 	rw := t.rw
 | |
| 	rw.Reset()
 | |
| 
 | |
| 	found := false
 | |
| 	var connectors []*Connector
 | |
| 	for _, c := range t.connectors {
 | |
| 		if c.Session().IsClosed() {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		weight := c.ID().Weight()
 | |
| 		if weight == 0 {
 | |
| 			weight = 1
 | |
| 		}
 | |
| 
 | |
| 		if network == "udp" && c.id.IsUDP() ||
 | |
| 			network != "udp" && !c.id.IsUDP() {
 | |
| 			if weight == MaxWeight && !found {
 | |
| 				connectors = nil
 | |
| 				found = true
 | |
| 			}
 | |
| 
 | |
| 			if weight == MaxWeight || !found {
 | |
| 				connectors = append(connectors, c)
 | |
| 				rw.Add(c, int(weight))
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	if len(connectors) == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	return rw.Next()
 | |
| }
 | |
| 
 | |
| func (t *Tunnel) CloseOnIdle() bool {
 | |
| 	t.mu.RLock()
 | |
| 	defer t.mu.RUnlock()
 | |
| 
 | |
| 	select {
 | |
| 	case <-t.close:
 | |
| 	default:
 | |
| 		if len(t.connectors) == 0 {
 | |
| 			close(t.close)
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (t *Tunnel) clean() {
 | |
| 	ticker := time.NewTicker(t.ttl)
 | |
| 	defer ticker.Stop()
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-ticker.C:
 | |
| 			t.mu.Lock()
 | |
| 			if len(t.connectors) == 0 {
 | |
| 				t.mu.Unlock()
 | |
| 				break
 | |
| 			}
 | |
| 			var connectors []*Connector
 | |
| 			for _, c := range t.connectors {
 | |
| 				if c.Session().IsClosed() {
 | |
| 					logger.Default().Debugf("remove tunnel: %s, connector: %s", t.id, c.id)
 | |
| 					if t.sd != nil {
 | |
| 						t.sd.Deregister(context.Background(), &sd.Service{
 | |
| 							ID:   c.id.String(),
 | |
| 							Name: t.id.String(),
 | |
| 							Node: t.node,
 | |
| 						})
 | |
| 					}
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				connectors = append(connectors, c)
 | |
| 				if t.sd != nil {
 | |
| 					t.sd.Renew(context.Background(), &sd.Service{
 | |
| 						ID:   c.id.String(),
 | |
| 						Name: t.id.String(),
 | |
| 						Node: t.node,
 | |
| 					})
 | |
| 				}
 | |
| 			}
 | |
| 			if len(connectors) != len(t.connectors) {
 | |
| 				t.connectors = connectors
 | |
| 			}
 | |
| 			t.mu.Unlock()
 | |
| 		case <-t.close:
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type ConnectorPool struct {
 | |
| 	node    string
 | |
| 	sd      sd.SD
 | |
| 	tunnels map[string]*Tunnel
 | |
| 	mu      sync.RWMutex
 | |
| }
 | |
| 
 | |
| func NewConnectorPool(node string, sd sd.SD) *ConnectorPool {
 | |
| 	p := &ConnectorPool{
 | |
| 		node:    node,
 | |
| 		sd:      sd,
 | |
| 		tunnels: make(map[string]*Tunnel),
 | |
| 	}
 | |
| 	go p.closeIdles()
 | |
| 	return p
 | |
| }
 | |
| 
 | |
| func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector, ttl time.Duration) {
 | |
| 	p.mu.Lock()
 | |
| 	defer p.mu.Unlock()
 | |
| 
 | |
| 	s := tid.String()
 | |
| 
 | |
| 	t := p.tunnels[s]
 | |
| 	if t == nil {
 | |
| 		t = NewTunnel(p.node, tid, ttl)
 | |
| 		t.WithSD(p.sd)
 | |
| 
 | |
| 		p.tunnels[s] = t
 | |
| 	}
 | |
| 	t.AddConnector(c)
 | |
| }
 | |
| 
 | |
| func (p *ConnectorPool) Get(network string, tid string) *Connector {
 | |
| 	if p == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	p.mu.RLock()
 | |
| 	defer p.mu.RUnlock()
 | |
| 
 | |
| 	t := p.tunnels[tid]
 | |
| 	if t == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	return t.GetConnector(network)
 | |
| }
 | |
| 
 | |
| func (p *ConnectorPool) closeIdles() {
 | |
| 	ticker := time.NewTicker(1 * time.Hour)
 | |
| 	defer ticker.Stop()
 | |
| 
 | |
| 	for range ticker.C {
 | |
| 		p.mu.Lock()
 | |
| 		for k, v := range p.tunnels {
 | |
| 			if v.CloseOnIdle() {
 | |
| 				delete(p.tunnels, k)
 | |
| 				logger.Default().Debugf("remove idle tunnel: %s", k)
 | |
| 			}
 | |
| 		}
 | |
| 		p.mu.Unlock()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func parseTunnelID(s string) (tid relay.TunnelID) {
 | |
| 	if s == "" {
 | |
| 		return
 | |
| 	}
 | |
| 	private := false
 | |
| 	if s[0] == '$' {
 | |
| 		private = true
 | |
| 		s = s[1:]
 | |
| 	}
 | |
| 	uuid, _ := uuid.Parse(s)
 | |
| 
 | |
| 	if private {
 | |
| 		return relay.NewPrivateTunnelID(uuid[:])
 | |
| 	}
 | |
| 	return relay.NewTunnelID(uuid[:])
 | |
| }
 |