From 24037aba7bb5ba1336fb3bc93ef4434d2951d150 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 29 Jan 2023 11:55:26 +0800 Subject: [PATCH] relay: added private tunnel --- connector/relay/bind.go | 4 +- connector/relay/connector.go | 10 ++-- connector/relay/metadata.go | 2 +- go.mod | 2 +- go.sum | 4 +- handler/forward/local/handler.go | 25 +++++----- handler/relay/bind.go | 3 +- handler/relay/connect.go | 83 +++++++++++++++++++++++++++++++- handler/relay/entrypoint.go | 27 +++-------- handler/relay/handler.go | 19 +++++--- handler/relay/tunnel.go | 54 +++++++++++++++++++-- 11 files changed, 174 insertions(+), 59 deletions(-) diff --git a/connector/relay/bind.go b/connector/relay/bind.go index f99d33f..1980a53 100644 --- a/connector/relay/bind.go +++ b/connector/relay/bind.go @@ -48,7 +48,7 @@ func (c *relayConnector) tunnel(ctx context.Context, conn net.Conn, log logger.L if err != nil { return nil, err } - log.Debugf("create tunnel %s connector %s OK", c.md.tunnelID, cid) + log.Debugf("create tunnel %s connector %s OK", c.md.tunnelID.String(), cid) session, err := mux.ServerSession(conn) if err != nil { @@ -90,7 +90,7 @@ func (c *relayConnector) initTunnel(conn net.Conn) (addr net.Addr, cid relay.Con } if resp.Status != relay.StatusOK { - err = fmt.Errorf("%d: create tunnel %s failed", resp.Status, c.md.tunnelID) + err = fmt.Errorf("%d: create tunnel %s failed", resp.Status, c.md.tunnelID.String()) return } diff --git a/connector/relay/connector.go b/connector/relay/connector.go index 1e6f540..f752d5c 100644 --- a/connector/relay/connector.go +++ b/connector/relay/connector.go @@ -83,11 +83,13 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad if err := af.ParseFrom(address); err != nil { return nil, err } + req.Features = append(req.Features, af) + } - // forward mode if port is 0. - if af.Port > 0 { - req.Features = append(req.Features, af) - } + if !c.md.tunnelID.IsZero() { + req.Features = append(req.Features, &relay.TunnelFeature{ + ID: c.md.tunnelID, + }) } if c.md.noDelay { diff --git a/connector/relay/metadata.go b/connector/relay/metadata.go index 910fd3d..29fc927 100644 --- a/connector/relay/metadata.go +++ b/connector/relay/metadata.go @@ -30,7 +30,7 @@ func (c *relayConnector) parseMetadata(md mdata.Metadata) (err error) { if err != nil { return err } - copy(c.md.tunnelID[:], uuid[:]) + c.md.tunnelID = relay.NewTunnelID(uuid[:]) } return diff --git a/go.mod b/go.mod index 603f587..1145672 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/go-gost/core v0.0.0-20230114050924-1a8c1ccb1dc5 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 - github.com/go-gost/relay v0.2.0 + github.com/go-gost/relay v0.3.1 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 github.com/gobwas/glob v0.2.3 diff --git a/go.sum b/go.sum index e6e38a0..c1b497e 100644 --- a/go.sum +++ b/go.sum @@ -97,8 +97,8 @@ github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2 github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/relay v0.2.0 h1:8udTweykgDUdOY1j1U90fApNuG7Sp7pvKoiIp3eV6ME= -github.com/go-gost/relay v0.2.0/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= +github.com/go-gost/relay v0.3.1 h1:mkKtvMT5n3mTSHbQo//DXXLxTsIUJKRQ4Fn4atma+Ds= +github.com/go-gost/relay v0.3.1/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 63ca7e8..0c87273 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -11,7 +11,6 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" - xchain "github.com/go-gost/x/chain" netpkg "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/registry" @@ -46,13 +45,6 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { return } - if h.hop == nil { - // dummy node used by relay connector. - h.hop = xchain.NewChainHop([]*chain.Node{ - {Name: "dummy", Addr: ":0"}, - }) - } - h.router = h.options.Router if h.router == nil { h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) @@ -101,17 +93,22 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } } - var target *chain.Node + if _, _, err := net.SplitHostPort(host); err != nil { + host = net.JoinHostPort(host, "0") + } + target := &chain.Node{ + Addr: host, + } if h.hop != nil { target = h.hop.Select(ctx, chain.HostSelectOption(host), chain.ProtocolSelectOption(protocol), ) - } - if target == nil { - err := errors.New("target not available") - log.Error(err) - return err + if target == nil { + err := errors.New("target not available") + log.Error(err) + return err + } } log = log.WithFields(map[string]any{ diff --git a/handler/relay/bind.go b/handler/relay/bind.go index 1b03466..bde6f9f 100644 --- a/handler/relay/bind.go +++ b/handler/relay/bind.go @@ -212,8 +212,7 @@ func (h *relayHandler) handleTunnel(ctx context.Context, conn net.Conn, tunnelID return } - var connectorID relay.ConnectorID - copy(connectorID[:], uuid[:]) + connectorID := relay.NewTunnelID(uuid[:]) af := &relay.AddrFeature{} err = af.ParseFrom(h.ep.Addr().String()) diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 45f82c6..5d8eb53 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/relay" netpkg "github.com/go-gost/x/internal/net" + xnet "github.com/go-gost/x/internal/net" sx "github.com/go-gost/x/internal/util/selector" ) @@ -19,7 +20,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network "cmd": "connect", }) - log.Debugf("%s >> %s", conn.RemoteAddr(), address) + log.Debugf("%s >> %s/%s", conn.RemoteAddr(), address, network) resp := relay.Response{ Version: relay.Version1, @@ -94,3 +95,83 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network return nil } + +func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, network, address string, tunnelID relay.TunnelID, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "dst": fmt.Sprintf("%s/%s", address, network), + "cmd": "connect", + "tunnel": tunnelID.String(), + }) + + log.Debugf("%s >> %s/%s", conn.RemoteAddr(), address, network) + + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + host, _, _ := net.SplitHostPort(address) + + if h.options.Bypass != nil && h.options.Bypass.Contains(address) { + log.Debug("bypass: ", address) + resp.Status = relay.StatusForbidden + _, err := resp.WriteTo(conn) + return err + } + + var tid relay.TunnelID + if ingress := h.md.ingress; ingress != nil { + tid = parseTunnelID(ingress.Get(host)) + } + + if !tid.Equal(tunnelID) { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + err := fmt.Errorf("tunnel %s not found", tunnelID.String()) + log.Error(err) + return err + } + + cc, err := getTunnelConn(h.pool, tunnelID, 3, log) + if err != nil { + log.Error(err) + return err + } + defer cc.Close() + + log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) + + if h.md.noDelay { + if _, err := resp.WriteTo(conn); err != nil { + log.Error(err) + return err + } + } else { + rc := &tcpConn{ + Conn: conn, + } + // cache the header + if _, err := resp.WriteTo(&rc.wbuf); err != nil { + return err + } + conn = rc + } + + af := &relay.AddrFeature{} + af.ParseFrom(conn.RemoteAddr().String()) + resp = relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + Features: []relay.Feature{af}, + } + resp.WriteTo(cc) + + t := time.Now() + log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + xnet.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) + + return nil +} diff --git a/handler/relay/entrypoint.go b/handler/relay/entrypoint.go index 22e677f..a8e6e31 100644 --- a/handler/relay/entrypoint.go +++ b/handler/relay/entrypoint.go @@ -19,7 +19,6 @@ import ( climiter "github.com/go-gost/x/limiter/conn/wrapper" limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" - "github.com/google/uuid" ) type epListener struct { @@ -118,30 +117,18 @@ func (h *epHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H var tunnelID relay.TunnelID if h.ingress != nil { - v := h.ingress.Get(host) - uuid, _ := uuid.Parse(v) - copy(tunnelID[:], uuid[:]) + tunnelID = parseTunnelID(h.ingress.Get(host)) + } + if tunnelID.IsPrivate() { + err := fmt.Errorf("tunnel %s is private", tunnelID) + log.Error(err) + return err } log = log.WithFields(map[string]any{ "tunnel": tunnelID.String(), }) - var cc net.Conn - var err error - for i := 0; i < 3; i++ { - c := h.pool.Get(tunnelID) - if c == nil { - err = fmt.Errorf("tunnel %s not available", tunnelID.String()) - break - } - - cc, err = c.Session().GetConn() - if err != nil { - log.Error(err) - continue - } - break - } + cc, err := getTunnelConn(h.pool, tunnelID, 3, log) if err != nil { log.Error(err) return err diff --git a/handler/relay/handler.go b/handler/relay/handler.go index ec361ac..7540ed9 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -195,18 +195,23 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle } if h.hop != nil { - if address != "" { - resp.Status = relay.StatusForbidden - log.Error("forward mode, connect is forbidden") - _, err := resp.WriteTo(conn) - return err - } + /* + if address != "" { + resp.Status = relay.StatusForbidden + log.Error("forward mode, CONNECT method is forbidden") + _, err := resp.WriteTo(conn) + return err + } + */ // forward mode return h.handleForward(ctx, conn, network, log) } switch req.Cmd & relay.CmdMask { - case 0, relay.CmdConnect: + case relay.CmdConnect: + if !tunnelID.IsZero() { + return h.handleConnectTunnel(ctx, conn, network, address, tunnelID, log) + } return h.handleConnect(ctx, conn, network, address, log) case relay.CmdBind: if !tunnelID.IsZero() { diff --git a/handler/relay/tunnel.go b/handler/relay/tunnel.go index 27e9745..3a37a15 100644 --- a/handler/relay/tunnel.go +++ b/handler/relay/tunnel.go @@ -1,6 +1,8 @@ package relay import ( + "fmt" + "net" "sync" "sync/atomic" "time" @@ -8,6 +10,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/relay" "github.com/go-gost/x/internal/util/mux" + "github.com/google/uuid" ) type Connector struct { @@ -110,13 +113,13 @@ func (t *Tunnel) clean() { } type ConnectorPool struct { - tunnels map[relay.TunnelID]*Tunnel + tunnels map[string]*Tunnel mu sync.RWMutex } func NewConnectorPool() *ConnectorPool { return &ConnectorPool{ - tunnels: make(map[relay.TunnelID]*Tunnel), + tunnels: make(map[string]*Tunnel), } } @@ -124,10 +127,12 @@ func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { p.mu.Lock() defer p.mu.Unlock() - t := p.tunnels[tid] + s := tid.String() + + t := p.tunnels[s] if t == nil { t = NewTunnel(tid) - p.tunnels[tid] = t + p.tunnels[s] = t } t.AddConnector(c) } @@ -140,10 +145,49 @@ func (p *ConnectorPool) Get(tid relay.TunnelID) *Connector { p.mu.RLock() defer p.mu.RUnlock() - t := p.tunnels[tid] + t := p.tunnels[tid.String()] if t == nil { return nil } return t.GetConnector() } + +func parseTunnelID(s string) (tid relay.TunnelID) { + if s == "" { + return + } + private := false + if s[0] == '$' { + private = true + s = s[1:] + } + uuid, _ := uuid.Parse(s) + + if private { + return relay.NewPrivateTunnelID(uuid[:]) + } + return relay.NewTunnelID(uuid[:]) +} + +func getTunnelConn(pool *ConnectorPool, tunnelID relay.TunnelID, retry int, log logger.Logger) (conn net.Conn, err error) { + if retry <= 0 { + retry = 1 + } + for i := 0; i < retry; i++ { + c := pool.Get(tunnelID) + if c == nil { + err = fmt.Errorf("tunnel %s not available", tunnelID.String()) + break + } + + conn, err = c.Session().GetConn() + if err != nil { + log.Error(err) + continue + } + break + } + + return +}