diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index d0ac0a8..1099e07 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -54,15 +54,15 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, return } - h.pool.Add(tunnelID, NewConnector(connectorID, session)) + h.pool.Add(tunnelID, NewConnector(connectorID, tunnelID, h.id, session, h.md.sd), h.md.tunnelTTL) if h.md.ingress != nil { h.md.ingress.Set(ctx, addr, tunnelID.String()) } if h.md.sd != nil { err := h.md.sd.Register(ctx, &sd.Service{ - ID: connectorID.String(), - Name: tunnelID.String(), - Node: h.id, + ID: connectorID.String(), + Name: tunnelID.String(), + Node: h.id, Network: network, Address: h.md.entryPoint, }) diff --git a/handler/tunnel/dialer.go b/handler/tunnel/dialer.go index b8e14b3..34d9e76 100644 --- a/handler/tunnel/dialer.go +++ b/handler/tunnel/dialer.go @@ -20,7 +20,9 @@ type Dialer struct { func (d *Dialer) Dial(ctx context.Context, network string, tid string) (conn net.Conn, node string, cid string, err error) { retry := d.retry - retry = 1 + if retry <= 0 { + retry = 1 + } for i := 0; i < retry; i++ { c := d.pool.Get(network, tid) diff --git a/handler/tunnel/metadata.go b/handler/tunnel/metadata.go index 67d3040..62ca919 100644 --- a/handler/tunnel/metadata.go +++ b/handler/tunnel/metadata.go @@ -15,12 +15,17 @@ import ( "github.com/go-gost/x/registry" ) +const ( + defaultTTL = 15 * time.Second +) + type metadata struct { readTimeout time.Duration entryPoint string entryPointID relay.TunnelID entryPointProxyProtocol int directTunnel bool + tunnelTTL time.Duration ingress ingress.Ingress sd sd.SD muxCfg *mux.Config @@ -29,6 +34,10 @@ type metadata struct { func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.readTimeout = mdutil.GetDuration(md, "readTimeout") + h.md.tunnelTTL = mdutil.GetDuration(md, "tunnel.ttl") + if h.md.tunnelTTL <= 0 { + h.md.tunnelTTL = defaultTTL + } h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") h.md.entryPoint = mdutil.GetString(md, "entrypoint") h.md.entryPointID = parseTunnelID(mdutil.GetString(md, "entrypoint.id")) diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go index e12d1fd..b71f6eb 100644 --- a/handler/tunnel/tunnel.go +++ b/handler/tunnel/tunnel.go @@ -14,16 +14,22 @@ import ( ) type Connector struct { - id relay.ConnectorID - t time.Time - s *mux.Session + id relay.ConnectorID + tid relay.TunnelID + node string + sd sd.SD + t time.Time + s *mux.Session } -func NewConnector(id relay.ConnectorID, s *mux.Session) *Connector { +func NewConnector(id relay.ConnectorID, tid relay.TunnelID, node string, s *mux.Session, sd sd.SD) *Connector { c := &Connector{ - id: id, - t: time.Now(), - s: s, + id: id, + tid: tid, + node: node, + sd: sd, + t: time.Now(), + s: s, } go c.accept() return c @@ -35,6 +41,13 @@ func (c *Connector) accept() { if err != nil { logger.Default().Errorf("connector %s: %v", c.id, err) c.s.Close() + if c.sd != nil { + c.sd.Deregister(context.Background(), &sd.Service{ + ID: c.id.String(), + Name: c.tid.String(), + Node: c.node, + }) + } return } conn.Close() @@ -58,14 +71,19 @@ type Tunnel struct { close chan struct{} mu sync.RWMutex sd sd.SD + ttl time.Duration } -func NewTunnel(node string, tid relay.TunnelID) *Tunnel { +func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel { t := &Tunnel{ node: node, id: tid, t: time.Now(), close: make(chan struct{}), + ttl: ttl, + } + if t.ttl <= 0 { + t.ttl = defaultTTL } go t.clean() return t @@ -127,7 +145,7 @@ func (t *Tunnel) CloseOnIdle() bool { } func (t *Tunnel) clean() { - ticker := time.NewTicker(1 * time.Minute) + ticker := time.NewTicker(t.ttl) defer ticker.Stop() for { @@ -188,7 +206,7 @@ func NewConnectorPool(node string, sd sd.SD) *ConnectorPool { return p } -func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { +func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector, ttl time.Duration) { p.mu.Lock() defer p.mu.Unlock() @@ -196,7 +214,7 @@ func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { t := p.tunnels[s] if t == nil { - t = NewTunnel(p.node, tid) + t = NewTunnel(p.node, tid, ttl) t.WithSD(p.sd) p.tunnels[s] = t diff --git a/sd/plugin.go b/sd/plugin.go index 7795c4d..e41b16b 100644 --- a/sd/plugin.go +++ b/sd/plugin.go @@ -146,7 +146,7 @@ type sdService struct { } type httpGetResponse struct { - Services []*sdService + Services []*sdService `json:"services"` } type httpPlugin struct { @@ -327,8 +327,9 @@ func (p *httpPlugin) Get(ctx context.Context, name string) (services []*sd.Servi continue } services = append(services, &sd.Service{ - Node: v.Node, + ID: v.ID, Name: v.Name, + Node: v.Node, Network: v.Network, Address: v.Address, })