From 6a36ebcc9fdeff98aab65480409fd5a2a6539e9d Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 8 Oct 2023 19:49:03 +0800 Subject: [PATCH] add proxy protocol support for relay tunnel and rtcp --- connector/relay/listener.go | 1 + go.mod | 2 +- go.sum | 4 +-- handler/forward/remote/handler.go | 14 +++++++++-- handler/forward/remote/metadata.go | 7 ++++-- handler/relay/connect.go | 20 +++++++++------ handler/relay/entrypoint.go | 28 ++++++++++++++++++--- handler/relay/handler.go | 1 + handler/relay/metadata.go | 39 ++++++++++++++++++------------ 9 files changed, 83 insertions(+), 33 deletions(-) diff --git a/connector/relay/listener.go b/connector/relay/listener.go index 3c6f8a6..f29e1fa 100644 --- a/connector/relay/listener.go +++ b/connector/relay/listener.go @@ -48,6 +48,7 @@ func (p *bindListener) getPeerConn(conn net.Conn) (net.Conn, error) { } var address, host string + // the first addr is the client address, the optional second addr is the target host address. for _, f := range resp.Features { if f.Type() == relay.FeatureAddr { if fa, ok := f.(*relay.AddrFeature); ok { diff --git a/go.mod b/go.mod index 6f76cd1..830ec49 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/miekg/dns v1.1.56 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pion/dtls/v2 v2.2.6 - github.com/pires/go-proxyproto v0.6.2 + github.com/pires/go-proxyproto v0.7.0 github.com/prometheus/client_golang v1.14.0 github.com/quic-go/quic-go v0.38.1 github.com/rs/xid v1.3.0 diff --git a/go.sum b/go.sum index 99af399..7fe0bad 100644 --- a/go.sum +++ b/go.sum @@ -300,8 +300,8 @@ github.com/pion/transport/v2 v2.0.2 h1:St+8o+1PEzPT51O9bv+tH/KYYLMNR5Vwm5Z3Qkjsy github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0= github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54= github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8= -github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8= -github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 0a85fda..d947198 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -20,6 +20,7 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" xnet "github.com/go-gost/x/internal/net" + "github.com/go-gost/x/internal/net/proxyproto" auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/registry" @@ -107,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } } if protocol == forward.ProtoHTTP { - h.handleHTTP(ctx, rw, log) + h.handleHTTP(ctx, rw, conn.RemoteAddr(), conn.LocalAddr(), log) return nil } @@ -157,6 +158,13 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand marker.Reset() } + if dst, ok := conn.LocalAddr().(*net.TCPAddr); ok { + if dst.IP.Equal(net.IPv6zero) { + dst.IP = net.IPv4zero + } + } + cc = proxyproto.WrapClientConn(h.md.proxyProtocol, conn.RemoteAddr(), conn.LocalAddr(), cc) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) xnet.Transport(rw, cc) @@ -167,7 +175,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } -func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { +func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) { br := bufio.NewReader(rw) var connPool sync.Map @@ -246,6 +254,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l }) } + cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc) + connPool.Store(target, cc) log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) diff --git a/handler/forward/remote/metadata.go b/handler/forward/remote/metadata.go index ea903d8..f811e31 100644 --- a/handler/forward/remote/metadata.go +++ b/handler/forward/remote/metadata.go @@ -11,16 +11,19 @@ type metadata struct { readTimeout time.Duration sniffing bool sniffingTimeout time.Duration + proxyProtocol int } func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { const ( - readTimeout = "readTimeout" - sniffing = "sniffing" + readTimeout = "readTimeout" + sniffing = "sniffing" + proxyProtocol = "proxyProtocol" ) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) h.md.sniffing = mdutil.GetBool(md, sniffing) h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout") + h.md.proxyProtocol = mdutil.GetInt(md, proxyProtocol) return } diff --git a/handler/relay/connect.go b/handler/relay/connect.go index c4fe48e..52d9e6a 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -142,16 +142,22 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n if ingress := h.md.ingress; ingress != nil { tid = parseTunnelID(ingress.Get(ctx, host)) } - if !tid.Equal(tunnelID) && !h.md.directTunnel { - resp.Status = relay.StatusBadRequest - resp.WriteTo(conn) - err := fmt.Errorf("not route to host %s", host) - log.Error(err) - return err + + // client is not an public entrypoint. + if h.md.entryPointID.IsZero() || !tunnelID.Equal(h.md.entryPointID) { + if !tid.Equal(tunnelID) && !h.md.directTunnel { + resp.Status = relay.StatusHostUnreachable + resp.WriteTo(conn) + err := fmt.Errorf("no route to host %s", host) + log.Error(err) + return err + } } - cc, _, err := getTunnelConn(network, h.pool, tunnelID, 3, log) + cc, _, err := getTunnelConn(network, h.pool, tid, 3, log) if err != nil { + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) log.Error(err) return err } diff --git a/handler/relay/entrypoint.go b/handler/relay/entrypoint.go index 6e872cd..d69001e 100644 --- a/handler/relay/entrypoint.go +++ b/handler/relay/entrypoint.go @@ -210,12 +210,22 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) + var features []relay.Feature af := &relay.AddrFeature{} - af.ParseFrom(conn.RemoteAddr().String()) + af.ParseFrom(conn.RemoteAddr().String()) // client address + features = append(features, af) + + if host != "" { + // target host + af := &relay.AddrFeature{} + af.ParseFrom(host) + features = append(features, af) + } + resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, - Features: []relay.Feature{af}, + Features: features, } resp.WriteTo(cc) @@ -284,12 +294,24 @@ func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.Re connPool.Store(tunnelID, cc) log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid) + var features []relay.Feature af := &relay.AddrFeature{} af.ParseFrom(raddr.String()) + features = append(features, af) + + if host := req.Host; host != "" { + if h, _, _ := net.SplitHostPort(host); h == "" { + host = net.JoinHostPort(host, "80") + } + af := &relay.AddrFeature{} + af.ParseFrom(host) + features = append(features, af) + } + (&relay.Response{ Version: relay.Version1, Status: relay.StatusOK, - Features: []relay.Feature{af}, + Features: features, }).WriteTo(cc) go func() { diff --git a/handler/relay/handler.go b/handler/relay/handler.go index fed36df..f095720 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -95,6 +95,7 @@ func (h *relayHandler) initEntryPoint() (err error) { epListener := newTCPListener(ln, listener.AddrOption(h.md.entryPoint), listener.ServiceOption(serviceName), + listener.ProxyProtocolOption(h.md.entryPointProxyProtocol), listener.LoggerOption(log.WithFields(map[string]any{ "kind": "listener", })), diff --git a/handler/relay/metadata.go b/handler/relay/metadata.go index 6f40c2b..25b1817 100644 --- a/handler/relay/metadata.go +++ b/handler/relay/metadata.go @@ -9,29 +9,34 @@ import ( "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/relay" xingress "github.com/go-gost/x/ingress" "github.com/go-gost/x/registry" ) type metadata struct { - readTimeout time.Duration - enableBind bool - udpBufferSize int - noDelay bool - hash string - entryPoint string - ingress ingress.Ingress - directTunnel bool + readTimeout time.Duration + enableBind bool + udpBufferSize int + noDelay bool + hash string + directTunnel bool + entryPoint string + entryPointID relay.TunnelID + entryPointProxyProtocol int + ingress ingress.Ingress } func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { const ( - readTimeout = "readTimeout" - enableBind = "bind" - udpBufferSize = "udpBufferSize" - noDelay = "nodelay" - hash = "hash" - entryPoint = "entryPoint" + readTimeout = "readTimeout" + enableBind = "bind" + udpBufferSize = "udpBufferSize" + noDelay = "nodelay" + hash = "hash" + entryPoint = "entryPoint" + entryPointID = "entryPoint.id" + entryPointProxyProtocol = "entryPoint.proxyProtocol" ) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) @@ -46,10 +51,12 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.hash = mdutil.GetString(md, hash) - h.md.entryPoint = mdutil.GetString(md, entryPoint) - h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress")) h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") + h.md.entryPoint = mdutil.GetString(md, entryPoint) + h.md.entryPointID = parseTunnelID(mdutil.GetString(md, entryPointID)) + h.md.entryPointProxyProtocol = mdutil.GetInt(md, entryPointProxyProtocol) + h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress")) if h.md.ingress == nil { var rules []xingress.Rule for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") {