add proxy protocol support for relay tunnel and rtcp

This commit is contained in:
ginuerzh 2023-10-08 19:49:03 +08:00
parent 07db20c9a8
commit 6a36ebcc9f
9 changed files with 83 additions and 33 deletions

View File

@ -48,6 +48,7 @@ func (p *bindListener) getPeerConn(conn net.Conn) (net.Conn, error) {
} }
var address, host string var address, host string
// the first addr is the client address, the optional second addr is the target host address.
for _, f := range resp.Features { for _, f := range resp.Features {
if f.Type() == relay.FeatureAddr { if f.Type() == relay.FeatureAddr {
if fa, ok := f.(*relay.AddrFeature); ok { if fa, ok := f.(*relay.AddrFeature); ok {

2
go.mod
View File

@ -21,7 +21,7 @@ require (
github.com/miekg/dns v1.1.56 github.com/miekg/dns v1.1.56
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/dtls/v2 v2.2.6 github.com/pion/dtls/v2 v2.2.6
github.com/pires/go-proxyproto v0.6.2 github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.14.0 github.com/prometheus/client_golang v1.14.0
github.com/quic-go/quic-go v0.38.1 github.com/quic-go/quic-go v0.38.1
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0

4
go.sum
View File

@ -300,8 +300,8 @@ github.com/pion/transport/v2 v2.0.2 h1:St+8o+1PEzPT51O9bv+tH/KYYLMNR5Vwm5Z3Qkjsy
github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0= github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0=
github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54= github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54=
github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8= github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8=
github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8= github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs=
github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=

View File

@ -20,6 +20,7 @@ import (
mdata "github.com/go-gost/core/metadata" mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
xnet "github.com/go-gost/x/internal/net" xnet "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/internal/net/proxyproto"
auth_util "github.com/go-gost/x/internal/util/auth" auth_util "github.com/go-gost/x/internal/util/auth"
"github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/internal/util/forward"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
@ -107,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
} }
} }
if protocol == forward.ProtoHTTP { if protocol == forward.ProtoHTTP {
h.handleHTTP(ctx, rw, log) h.handleHTTP(ctx, rw, conn.RemoteAddr(), conn.LocalAddr(), log)
return nil return nil
} }
@ -157,6 +158,13 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
marker.Reset() marker.Reset()
} }
if dst, ok := conn.LocalAddr().(*net.TCPAddr); ok {
if dst.IP.Equal(net.IPv6zero) {
dst.IP = net.IPv4zero
}
}
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, conn.RemoteAddr(), conn.LocalAddr(), cc)
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
xnet.Transport(rw, cc) xnet.Transport(rw, cc)
@ -167,7 +175,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil return nil
} }
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
br := bufio.NewReader(rw) br := bufio.NewReader(rw)
var connPool sync.Map var connPool sync.Map
@ -246,6 +254,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
}) })
} }
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc)
connPool.Store(target, cc) connPool.Store(target, cc)
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)

View File

@ -11,16 +11,19 @@ type metadata struct {
readTimeout time.Duration readTimeout time.Duration
sniffing bool sniffing bool
sniffingTimeout time.Duration sniffingTimeout time.Duration
proxyProtocol int
} }
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
const ( const (
readTimeout = "readTimeout" readTimeout = "readTimeout"
sniffing = "sniffing" sniffing = "sniffing"
proxyProtocol = "proxyProtocol"
) )
h.md.readTimeout = mdutil.GetDuration(md, readTimeout) h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing) h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout") h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
h.md.proxyProtocol = mdutil.GetInt(md, proxyProtocol)
return return
} }

View File

@ -142,16 +142,22 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n
if ingress := h.md.ingress; ingress != nil { if ingress := h.md.ingress; ingress != nil {
tid = parseTunnelID(ingress.Get(ctx, host)) tid = parseTunnelID(ingress.Get(ctx, host))
} }
// client is not an public entrypoint.
if h.md.entryPointID.IsZero() || !tunnelID.Equal(h.md.entryPointID) {
if !tid.Equal(tunnelID) && !h.md.directTunnel { if !tid.Equal(tunnelID) && !h.md.directTunnel {
resp.Status = relay.StatusBadRequest resp.Status = relay.StatusHostUnreachable
resp.WriteTo(conn) resp.WriteTo(conn)
err := fmt.Errorf("not route to host %s", host) err := fmt.Errorf("no route to host %s", host)
log.Error(err) log.Error(err)
return err return err
} }
}
cc, _, err := getTunnelConn(network, h.pool, tunnelID, 3, log) cc, _, err := getTunnelConn(network, h.pool, tid, 3, log)
if err != nil { if err != nil {
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
log.Error(err) log.Error(err)
return err return err
} }

View File

@ -210,12 +210,22 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl
log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr())
var features []relay.Feature
af := &relay.AddrFeature{} af := &relay.AddrFeature{}
af.ParseFrom(conn.RemoteAddr().String()) af.ParseFrom(conn.RemoteAddr().String()) // client address
features = append(features, af)
if host != "" {
// target host
af := &relay.AddrFeature{}
af.ParseFrom(host)
features = append(features, af)
}
resp := relay.Response{ resp := relay.Response{
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
Features: []relay.Feature{af}, Features: features,
} }
resp.WriteTo(cc) resp.WriteTo(cc)
@ -284,12 +294,24 @@ func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.Re
connPool.Store(tunnelID, cc) connPool.Store(tunnelID, cc)
log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid) log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid)
var features []relay.Feature
af := &relay.AddrFeature{} af := &relay.AddrFeature{}
af.ParseFrom(raddr.String()) af.ParseFrom(raddr.String())
features = append(features, af)
if host := req.Host; host != "" {
if h, _, _ := net.SplitHostPort(host); h == "" {
host = net.JoinHostPort(host, "80")
}
af := &relay.AddrFeature{}
af.ParseFrom(host)
features = append(features, af)
}
(&relay.Response{ (&relay.Response{
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
Features: []relay.Feature{af}, Features: features,
}).WriteTo(cc) }).WriteTo(cc)
go func() { go func() {

View File

@ -95,6 +95,7 @@ func (h *relayHandler) initEntryPoint() (err error) {
epListener := newTCPListener(ln, epListener := newTCPListener(ln,
listener.AddrOption(h.md.entryPoint), listener.AddrOption(h.md.entryPoint),
listener.ServiceOption(serviceName), listener.ServiceOption(serviceName),
listener.ProxyProtocolOption(h.md.entryPointProxyProtocol),
listener.LoggerOption(log.WithFields(map[string]any{ listener.LoggerOption(log.WithFields(map[string]any{
"kind": "listener", "kind": "listener",
})), })),

View File

@ -9,6 +9,7 @@ import (
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
mdata "github.com/go-gost/core/metadata" mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/relay"
xingress "github.com/go-gost/x/ingress" xingress "github.com/go-gost/x/ingress"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
) )
@ -19,9 +20,11 @@ type metadata struct {
udpBufferSize int udpBufferSize int
noDelay bool noDelay bool
hash string hash string
entryPoint string
ingress ingress.Ingress
directTunnel bool directTunnel bool
entryPoint string
entryPointID relay.TunnelID
entryPointProxyProtocol int
ingress ingress.Ingress
} }
func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -32,6 +35,8 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
noDelay = "nodelay" noDelay = "nodelay"
hash = "hash" hash = "hash"
entryPoint = "entryPoint" entryPoint = "entryPoint"
entryPointID = "entryPoint.id"
entryPointProxyProtocol = "entryPoint.proxyProtocol"
) )
h.md.readTimeout = mdutil.GetDuration(md, readTimeout) h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
@ -46,10 +51,12 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.hash = mdutil.GetString(md, hash) h.md.hash = mdutil.GetString(md, hash)
h.md.entryPoint = mdutil.GetString(md, entryPoint)
h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress"))
h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct")
h.md.entryPoint = mdutil.GetString(md, entryPoint)
h.md.entryPointID = parseTunnelID(mdutil.GetString(md, entryPointID))
h.md.entryPointProxyProtocol = mdutil.GetInt(md, entryPointProxyProtocol)
h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress"))
if h.md.ingress == nil { if h.md.ingress == nil {
var rules []xingress.Rule var rules []xingress.Rule
for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") { for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") {