diff --git a/admission/plugin.go b/admission/plugin.go index 91ea9d9..6942ab0 100644 --- a/admission/plugin.go +++ b/admission/plugin.go @@ -115,7 +115,7 @@ func (p *httpPlugin) Admit(ctx context.Context, addr string, opts ...admission.O return } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return } diff --git a/auth/plugin.go b/auth/plugin.go index d3246f5..72fe797 100644 --- a/auth/plugin.go +++ b/auth/plugin.go @@ -125,7 +125,7 @@ func (p *httpPlugin) Authenticate(ctx context.Context, user, password string, op return } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return } diff --git a/bypass/plugin.go b/bypass/plugin.go index b19ce6a..4e20703 100644 --- a/bypass/plugin.go +++ b/bypass/plugin.go @@ -138,7 +138,7 @@ func (p *httpPlugin) Contains(ctx context.Context, network, addr string, opts .. return } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return } diff --git a/go.mod b/go.mod index d7f97b5..6e4661f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20231020111249-6431cd8bb957 + github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 github.com/go-gost/plugin v0.0.0-20231020155519-e190e1c74d78 diff --git a/go.sum b/go.sum index d150015..ee210c5 100644 --- a/go.sum +++ b/go.sum @@ -91,8 +91,10 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20231020111249-6431cd8bb957 h1:Ch7m/rplsCHjpGuOzgXV+OrXmGxNa/UVLUGV2yUFGhQ= -github.com/go-gost/core v0.0.0-20231020111249-6431cd8bb957/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20231026142046-9e767d674527 h1:BLhpnK+J9A3vugXCJrC+BNjz2Q4qdEE8IWIlWr7VOaw= +github.com/go-gost/core v0.0.0-20231026142046-9e767d674527/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477 h1:a49XfrB4mgbw7z7oN/WTovx0X7SbxdfoANsEDTy9CqI= +github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index fc0d6dc..6b0eb01 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -4,9 +4,11 @@ import ( "context" "crypto/md5" "encoding/hex" + "fmt" "net" "github.com/go-gost/core/logger" + "github.com/go-gost/core/recorder" "github.com/go-gost/relay" "github.com/go-gost/x/internal/util/mux" "github.com/google/uuid" @@ -57,8 +59,15 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, if h.md.ingress != nil { h.md.ingress.Set(ctx, addr, tunnelID.String()) } - if h.recorder.Recorder != nil { - h.recorder.Recorder.Record(ctx, tunnelID[:]) + if h.recorder != nil { + h.recorder.Record(ctx, + []byte(fmt.Sprintf("%s:%s", tunnelID, connectorID)), + recorder.MetadataReocrdOption(connectorMetadata{ + Op: "add", + Network: network, + Server: conn.LocalAddr().String(), + }), + ) } log.Debugf("%s/%s: tunnel=%s, connector=%s established", addr, network, tunnelID, connectorID) diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go index 4ebfe52..ad8d491 100644 --- a/handler/tunnel/handler.go +++ b/handler/tunnel/handler.go @@ -39,7 +39,7 @@ type tunnelHandler struct { md metadata options handler.Options pool *ConnectorPool - recorder recorder.RecorderObject + recorder recorder.Recorder epSvc service.Service ep *entrypoint } @@ -52,7 +52,6 @@ func NewHandler(opts ...handler.Option) handler.Handler { return &tunnelHandler{ options: options, - pool: NewConnectorPool(), } } @@ -68,12 +67,14 @@ func (h *tunnelHandler) Init(md md.Metadata) (err error) { if opts := h.router.Options(); opts != nil { for _, ro := range opts.Recorders { - if ro.Record == xrecorder.RecorderServiceHandlerTunnelEndpoint { - h.recorder = ro + if ro.Record == xrecorder.RecorderServiceHandlerTunnelConnector { + h.recorder = ro.Recorder break } } } + h.pool = NewConnectorPool() + h.pool.WithRecorder(h.recorder) h.ep = &entrypoint{ pool: h.pool, diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go index a45392a..134db6d 100644 --- a/handler/tunnel/tunnel.go +++ b/handler/tunnel/tunnel.go @@ -1,6 +1,7 @@ package tunnel import ( + "context" "fmt" "net" "sync" @@ -8,11 +9,18 @@ import ( "time" "github.com/go-gost/core/logger" + "github.com/go-gost/core/recorder" "github.com/go-gost/relay" "github.com/go-gost/x/internal/util/mux" "github.com/google/uuid" ) +type connectorMetadata struct { + Op string + Network string + Server string +} + type Connector struct { id relay.ConnectorID t time.Time @@ -54,18 +62,25 @@ type Tunnel struct { connectors []*Connector t time.Time n uint64 + close chan struct{} mu sync.RWMutex + recorder recorder.Recorder } func NewTunnel(id relay.TunnelID) *Tunnel { t := &Tunnel{ - id: id, - t: time.Now(), + id: id, + t: time.Now(), + close: make(chan struct{}), } go t.clean() return t } +func (t *Tunnel) WithRecorder(recorder recorder.Recorder) { + t.recorder = recorder +} + func (t *Tunnel) ID() relay.TunnelID { return t.id } @@ -87,6 +102,9 @@ func (t *Tunnel) GetConnector(network string) *Connector { var connectors []*Connector for _, c := range t.connectors { + if c.Session().IsClosed() { + continue + } if network == "udp" && c.id.IsUDP() || network != "udp" && !c.id.IsUDP() { connectors = append(connectors, c) @@ -99,34 +117,83 @@ func (t *Tunnel) GetConnector(network string) *Connector { return connectors[n%uint64(len(connectors))] } +func (t *Tunnel) CloseOnIdle() bool { + t.mu.RLock() + defer t.mu.RUnlock() + + select { + case <-t.close: + default: + if len(t.connectors) == 0 { + close(t.close) + return true + } + } + return false +} + func (t *Tunnel) clean() { - ticker := time.NewTicker(30 * time.Second) - for range ticker.C { - t.mu.Lock() - var connectors []*Connector - for _, c := range t.connectors { - if c.Session().IsClosed() { - logger.Default().Debugf("remove tunnel %s connector %s", t.id, c.id) - continue + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + t.mu.Lock() + if len(t.connectors) == 0 { + t.mu.Unlock() } - connectors = append(connectors, c) + var connectors []*Connector + for _, c := range t.connectors { + if c.Session().IsClosed() { + logger.Default().Debugf("remove tunnel: %s, connector: %s", t.id, c.id) + if t.recorder != nil { + t.recorder.Record(context.Background(), + []byte(fmt.Sprintf("%s:%s", t.id, c.id)), + recorder.MetadataReocrdOption(connectorMetadata{ + Op: "del", + }), + ) + } + continue + } + + connectors = append(connectors, c) + if t.recorder != nil { + t.recorder.Record(context.Background(), + []byte(fmt.Sprintf("%s:%s", t.id, c.id)), + recorder.MetadataReocrdOption(connectorMetadata{ + Op: "set", + }), + ) + } + } + if len(connectors) != len(t.connectors) { + t.connectors = connectors + } + t.mu.Unlock() + case <-t.close: + return } - if len(connectors) != len(t.connectors) { - t.connectors = connectors - } - t.mu.Unlock() } } type ConnectorPool struct { - tunnels map[string]*Tunnel - mu sync.RWMutex + tunnels map[string]*Tunnel + mu sync.RWMutex + recorder recorder.Recorder } func NewConnectorPool() *ConnectorPool { - return &ConnectorPool{ + p := &ConnectorPool{ tunnels: make(map[string]*Tunnel), } + go p.closeIdles() + return p +} + +func (p *ConnectorPool) WithRecorder(recorder recorder.Recorder) { + p.recorder = recorder } func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { @@ -138,6 +205,8 @@ func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { t := p.tunnels[s] if t == nil { t = NewTunnel(tid) + t.WithRecorder(p.recorder) + p.tunnels[s] = t } t.AddConnector(c) @@ -159,6 +228,22 @@ func (p *ConnectorPool) Get(network string, tid relay.TunnelID) *Connector { return t.GetConnector(network) } +func (p *ConnectorPool) closeIdles() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for range ticker.C { + p.mu.Lock() + for k, v := range p.tunnels { + if v.CloseOnIdle() { + delete(p.tunnels, k) + logger.Default().Debugf("remove idle tunnel: %s", k) + } + } + p.mu.Unlock() + } +} + func parseTunnelID(s string) (tid relay.TunnelID) { if s == "" { return diff --git a/hop/plugin.go b/hop/plugin.go index 65e09f9..882ca18 100644 --- a/hop/plugin.go +++ b/hop/plugin.go @@ -159,7 +159,7 @@ func (p *httpPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai return nil } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { p.log.Error(err) return nil diff --git a/hosts/plugin.go b/hosts/plugin.go index 1752d28..07f15d6 100644 --- a/hosts/plugin.go +++ b/hosts/plugin.go @@ -11,8 +11,8 @@ import ( "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/hosts/proto" - auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/plugin" + auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -133,7 +133,7 @@ func (p *httpPlugin) Lookup(ctx context.Context, network, host string, opts ...h return } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return } diff --git a/ingress/plugin.go b/ingress/plugin.go index b41a0ae..7097a23 100644 --- a/ingress/plugin.go +++ b/ingress/plugin.go @@ -134,7 +134,7 @@ func (p *httpPlugin) Get(ctx context.Context, host string, opts ...ingress.GetOp return } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return } @@ -174,7 +174,7 @@ func (p *httpPlugin) Set(ctx context.Context, host, endpoint string, opts ...ing return } - req, err := http.NewRequest(http.MethodPut, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, p.url, bytes.NewReader(v)) if err != nil { return } diff --git a/recorder/plugin.go b/recorder/plugin.go index 9b74327..a0b6a5d 100644 --- a/recorder/plugin.go +++ b/recorder/plugin.go @@ -53,9 +53,17 @@ func (p *grpcPlugin) Record(ctx context.Context, b []byte, opts ...recorder.Reco return nil } - _, err := p.client.Record(context.Background(), + var options recorder.RecordOptions + for _, opt := range opts { + opt(&options) + } + + md, _ := json.Marshal(options.Metadata) + + _, err := p.client.Record(ctx, &proto.RecordRequest{ - Data: b, + Data: b, + Metadata: md, }) if err != nil { p.log.Error(err) @@ -72,7 +80,8 @@ func (p *grpcPlugin) Close() error { } type httpPluginRequest struct { - Data []byte `json:"data"` + Data []byte `json:"data"` + Metadata []byte `json:"metadata"` } type httpPluginResponse struct { @@ -109,15 +118,23 @@ func (p *httpPlugin) Record(ctx context.Context, b []byte, opts ...recorder.Reco return nil } + var options recorder.RecordOptions + for _, opt := range opts { + opt(&options) + } + + md, _ := json.Marshal(options.Metadata) + rb := httpPluginRequest{ - Data: b, + Data: b, + Metadata: md, } v, err := json.Marshal(&rb) if err != nil { return err } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return err } diff --git a/recorder/recorder.go b/recorder/recorder.go index 857a608..9d1eb8b 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -1,6 +1,6 @@ package recorder const ( - RecorderServiceHandlerSerial = "recorder.service.handler.serial" - RecorderServiceHandlerTunnelEndpoint = "recorder.service.handler.tunnel.endpoint" + RecorderServiceHandlerSerial = "recorder.service.handler.serial" + RecorderServiceHandlerTunnelConnector = "recorder.service.handler.tunnel.connector" ) diff --git a/resolver/plugin.go b/resolver/plugin.go index dacb293..1c5d47f 100644 --- a/resolver/plugin.go +++ b/resolver/plugin.go @@ -56,7 +56,7 @@ func (p *grpcPlugin) Resolve(ctx context.Context, network, host string, opts ... return } - r, err := p.client.Resolve(context.Background(), + r, err := p.client.Resolve(ctx, &proto.ResolveRequest{ Network: network, Host: host, @@ -134,7 +134,7 @@ func (p *httpPlugin) Resolve(ctx context.Context, network, host string, opts ... return } - req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) if err != nil { return }