add proxy protocol support for relay tunnel and rtcp

This commit is contained in:
ginuerzh 2023-10-08 19:49:03 +08:00
parent 07db20c9a8
commit 6a36ebcc9f
9 changed files with 83 additions and 33 deletions

View File

@ -48,6 +48,7 @@ func (p *bindListener) getPeerConn(conn net.Conn) (net.Conn, error) {
} }
var address, host string var address, host string
// the first addr is the client address, the optional second addr is the target host address.
for _, f := range resp.Features { for _, f := range resp.Features {
if f.Type() == relay.FeatureAddr { if f.Type() == relay.FeatureAddr {
if fa, ok := f.(*relay.AddrFeature); ok { if fa, ok := f.(*relay.AddrFeature); ok {

2
go.mod
View File

@ -21,7 +21,7 @@ require (
github.com/miekg/dns v1.1.56 github.com/miekg/dns v1.1.56
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/dtls/v2 v2.2.6 github.com/pion/dtls/v2 v2.2.6
github.com/pires/go-proxyproto v0.6.2 github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.14.0 github.com/prometheus/client_golang v1.14.0
github.com/quic-go/quic-go v0.38.1 github.com/quic-go/quic-go v0.38.1
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0

4
go.sum
View File

@ -300,8 +300,8 @@ github.com/pion/transport/v2 v2.0.2 h1:St+8o+1PEzPT51O9bv+tH/KYYLMNR5Vwm5Z3Qkjsy
github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0= github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0=
github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54= github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54=
github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8= github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8=
github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8= github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs=
github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=

View File

@ -20,6 +20,7 @@ import (
mdata "github.com/go-gost/core/metadata" mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
xnet "github.com/go-gost/x/internal/net" xnet "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/internal/net/proxyproto"
auth_util "github.com/go-gost/x/internal/util/auth" auth_util "github.com/go-gost/x/internal/util/auth"
"github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/internal/util/forward"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
@ -107,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
} }
} }
if protocol == forward.ProtoHTTP { if protocol == forward.ProtoHTTP {
h.handleHTTP(ctx, rw, log) h.handleHTTP(ctx, rw, conn.RemoteAddr(), conn.LocalAddr(), log)
return nil return nil
} }
@ -157,6 +158,13 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
marker.Reset() marker.Reset()
} }
if dst, ok := conn.LocalAddr().(*net.TCPAddr); ok {
if dst.IP.Equal(net.IPv6zero) {
dst.IP = net.IPv4zero
}
}
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, conn.RemoteAddr(), conn.LocalAddr(), cc)
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
xnet.Transport(rw, cc) xnet.Transport(rw, cc)
@ -167,7 +175,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil return nil
} }
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
br := bufio.NewReader(rw) br := bufio.NewReader(rw)
var connPool sync.Map var connPool sync.Map
@ -246,6 +254,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
}) })
} }
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc)
connPool.Store(target, cc) connPool.Store(target, cc)
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)

View File

@ -11,16 +11,19 @@ type metadata struct {
readTimeout time.Duration readTimeout time.Duration
sniffing bool sniffing bool
sniffingTimeout time.Duration sniffingTimeout time.Duration
proxyProtocol int
} }
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
const ( const (
readTimeout = "readTimeout" readTimeout = "readTimeout"
sniffing = "sniffing" sniffing = "sniffing"
proxyProtocol = "proxyProtocol"
) )
h.md.readTimeout = mdutil.GetDuration(md, readTimeout) h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing) h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout") h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
h.md.proxyProtocol = mdutil.GetInt(md, proxyProtocol)
return return
} }

View File

@ -142,16 +142,22 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n
if ingress := h.md.ingress; ingress != nil { if ingress := h.md.ingress; ingress != nil {
tid = parseTunnelID(ingress.Get(ctx, host)) tid = parseTunnelID(ingress.Get(ctx, host))
} }
if !tid.Equal(tunnelID) && !h.md.directTunnel {
resp.Status = relay.StatusBadRequest // client is not an public entrypoint.
resp.WriteTo(conn) if h.md.entryPointID.IsZero() || !tunnelID.Equal(h.md.entryPointID) {
err := fmt.Errorf("not route to host %s", host) if !tid.Equal(tunnelID) && !h.md.directTunnel {
log.Error(err) resp.Status = relay.StatusHostUnreachable
return err resp.WriteTo(conn)
err := fmt.Errorf("no route to host %s", host)
log.Error(err)
return err
}
} }
cc, _, err := getTunnelConn(network, h.pool, tunnelID, 3, log) cc, _, err := getTunnelConn(network, h.pool, tid, 3, log)
if err != nil { if err != nil {
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
log.Error(err) log.Error(err)
return err return err
} }

View File

@ -210,12 +210,22 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl
log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr())
var features []relay.Feature
af := &relay.AddrFeature{} af := &relay.AddrFeature{}
af.ParseFrom(conn.RemoteAddr().String()) af.ParseFrom(conn.RemoteAddr().String()) // client address
features = append(features, af)
if host != "" {
// target host
af := &relay.AddrFeature{}
af.ParseFrom(host)
features = append(features, af)
}
resp := relay.Response{ resp := relay.Response{
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
Features: []relay.Feature{af}, Features: features,
} }
resp.WriteTo(cc) resp.WriteTo(cc)
@ -284,12 +294,24 @@ func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.Re
connPool.Store(tunnelID, cc) connPool.Store(tunnelID, cc)
log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid) log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid)
var features []relay.Feature
af := &relay.AddrFeature{} af := &relay.AddrFeature{}
af.ParseFrom(raddr.String()) af.ParseFrom(raddr.String())
features = append(features, af)
if host := req.Host; host != "" {
if h, _, _ := net.SplitHostPort(host); h == "" {
host = net.JoinHostPort(host, "80")
}
af := &relay.AddrFeature{}
af.ParseFrom(host)
features = append(features, af)
}
(&relay.Response{ (&relay.Response{
Version: relay.Version1, Version: relay.Version1,
Status: relay.StatusOK, Status: relay.StatusOK,
Features: []relay.Feature{af}, Features: features,
}).WriteTo(cc) }).WriteTo(cc)
go func() { go func() {

View File

@ -95,6 +95,7 @@ func (h *relayHandler) initEntryPoint() (err error) {
epListener := newTCPListener(ln, epListener := newTCPListener(ln,
listener.AddrOption(h.md.entryPoint), listener.AddrOption(h.md.entryPoint),
listener.ServiceOption(serviceName), listener.ServiceOption(serviceName),
listener.ProxyProtocolOption(h.md.entryPointProxyProtocol),
listener.LoggerOption(log.WithFields(map[string]any{ listener.LoggerOption(log.WithFields(map[string]any{
"kind": "listener", "kind": "listener",
})), })),

View File

@ -9,29 +9,34 @@ import (
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
mdata "github.com/go-gost/core/metadata" mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/relay"
xingress "github.com/go-gost/x/ingress" xingress "github.com/go-gost/x/ingress"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
) )
type metadata struct { type metadata struct {
readTimeout time.Duration readTimeout time.Duration
enableBind bool enableBind bool
udpBufferSize int udpBufferSize int
noDelay bool noDelay bool
hash string hash string
entryPoint string directTunnel bool
ingress ingress.Ingress entryPoint string
directTunnel bool entryPointID relay.TunnelID
entryPointProxyProtocol int
ingress ingress.Ingress
} }
func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
const ( const (
readTimeout = "readTimeout" readTimeout = "readTimeout"
enableBind = "bind" enableBind = "bind"
udpBufferSize = "udpBufferSize" udpBufferSize = "udpBufferSize"
noDelay = "nodelay" noDelay = "nodelay"
hash = "hash" hash = "hash"
entryPoint = "entryPoint" entryPoint = "entryPoint"
entryPointID = "entryPoint.id"
entryPointProxyProtocol = "entryPoint.proxyProtocol"
) )
h.md.readTimeout = mdutil.GetDuration(md, readTimeout) h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
@ -46,10 +51,12 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.hash = mdutil.GetString(md, hash) h.md.hash = mdutil.GetString(md, hash)
h.md.entryPoint = mdutil.GetString(md, entryPoint)
h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress"))
h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct")
h.md.entryPoint = mdutil.GetString(md, entryPoint)
h.md.entryPointID = parseTunnelID(mdutil.GetString(md, entryPointID))
h.md.entryPointProxyProtocol = mdutil.GetInt(md, entryPointProxyProtocol)
h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress"))
if h.md.ingress == nil { if h.md.ingress == nil {
var rules []xingress.Rule var rules []xingress.Rule
for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") { for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") {