add proxy protocol support for relay tunnel and rtcp

This commit is contained in:
ginuerzh 2023-10-08 19:49:03 +08:00
parent 07db20c9a8
commit 6a36ebcc9f
9 changed files with 83 additions and 33 deletions

View File

@ -48,6 +48,7 @@ func (p *bindListener) getPeerConn(conn net.Conn) (net.Conn, error) {
}
var address, host string
// the first addr is the client address, the optional second addr is the target host address.
for _, f := range resp.Features {
if f.Type() == relay.FeatureAddr {
if fa, ok := f.(*relay.AddrFeature); ok {

2
go.mod
View File

@ -21,7 +21,7 @@ require (
github.com/miekg/dns v1.1.56
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/dtls/v2 v2.2.6
github.com/pires/go-proxyproto v0.6.2
github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.14.0
github.com/quic-go/quic-go v0.38.1
github.com/rs/xid v1.3.0

4
go.sum
View File

@ -300,8 +300,8 @@ github.com/pion/transport/v2 v2.0.2 h1:St+8o+1PEzPT51O9bv+tH/KYYLMNR5Vwm5Z3Qkjsy
github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0=
github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54=
github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8=
github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8=
github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs=
github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=

View File

@ -20,6 +20,7 @@ import (
mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
xnet "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/internal/net/proxyproto"
auth_util "github.com/go-gost/x/internal/util/auth"
"github.com/go-gost/x/internal/util/forward"
"github.com/go-gost/x/registry"
@ -107,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}
}
if protocol == forward.ProtoHTTP {
h.handleHTTP(ctx, rw, log)
h.handleHTTP(ctx, rw, conn.RemoteAddr(), conn.LocalAddr(), log)
return nil
}
@ -157,6 +158,13 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
marker.Reset()
}
if dst, ok := conn.LocalAddr().(*net.TCPAddr); ok {
if dst.IP.Equal(net.IPv6zero) {
dst.IP = net.IPv4zero
}
}
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, conn.RemoteAddr(), conn.LocalAddr(), cc)
t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
xnet.Transport(rw, cc)
@ -167,7 +175,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil
}
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) {
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
br := bufio.NewReader(rw)
var connPool sync.Map
@ -246,6 +254,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
})
}
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc)
connPool.Store(target, cc)
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)

View File

@ -11,16 +11,19 @@ type metadata struct {
readTimeout time.Duration
sniffing bool
sniffingTimeout time.Duration
proxyProtocol int
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
sniffing = "sniffing"
readTimeout = "readTimeout"
sniffing = "sniffing"
proxyProtocol = "proxyProtocol"
)
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
h.md.proxyProtocol = mdutil.GetInt(md, proxyProtocol)
return
}

View File

@ -142,16 +142,22 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n
if ingress := h.md.ingress; ingress != nil {
tid = parseTunnelID(ingress.Get(ctx, host))
}
if !tid.Equal(tunnelID) && !h.md.directTunnel {
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
err := fmt.Errorf("not route to host %s", host)
log.Error(err)
return err
// client is not an public entrypoint.
if h.md.entryPointID.IsZero() || !tunnelID.Equal(h.md.entryPointID) {
if !tid.Equal(tunnelID) && !h.md.directTunnel {
resp.Status = relay.StatusHostUnreachable
resp.WriteTo(conn)
err := fmt.Errorf("no route to host %s", host)
log.Error(err)
return err
}
}
cc, _, err := getTunnelConn(network, h.pool, tunnelID, 3, log)
cc, _, err := getTunnelConn(network, h.pool, tid, 3, log)
if err != nil {
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
log.Error(err)
return err
}

View File

@ -210,12 +210,22 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl
log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr())
var features []relay.Feature
af := &relay.AddrFeature{}
af.ParseFrom(conn.RemoteAddr().String())
af.ParseFrom(conn.RemoteAddr().String()) // client address
features = append(features, af)
if host != "" {
// target host
af := &relay.AddrFeature{}
af.ParseFrom(host)
features = append(features, af)
}
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
Features: []relay.Feature{af},
Features: features,
}
resp.WriteTo(cc)
@ -284,12 +294,24 @@ func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.Re
connPool.Store(tunnelID, cc)
log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid)
var features []relay.Feature
af := &relay.AddrFeature{}
af.ParseFrom(raddr.String())
features = append(features, af)
if host := req.Host; host != "" {
if h, _, _ := net.SplitHostPort(host); h == "" {
host = net.JoinHostPort(host, "80")
}
af := &relay.AddrFeature{}
af.ParseFrom(host)
features = append(features, af)
}
(&relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
Features: []relay.Feature{af},
Features: features,
}).WriteTo(cc)
go func() {

View File

@ -95,6 +95,7 @@ func (h *relayHandler) initEntryPoint() (err error) {
epListener := newTCPListener(ln,
listener.AddrOption(h.md.entryPoint),
listener.ServiceOption(serviceName),
listener.ProxyProtocolOption(h.md.entryPointProxyProtocol),
listener.LoggerOption(log.WithFields(map[string]any{
"kind": "listener",
})),

View File

@ -9,29 +9,34 @@ import (
"github.com/go-gost/core/logger"
mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/relay"
xingress "github.com/go-gost/x/ingress"
"github.com/go-gost/x/registry"
)
type metadata struct {
readTimeout time.Duration
enableBind bool
udpBufferSize int
noDelay bool
hash string
entryPoint string
ingress ingress.Ingress
directTunnel bool
readTimeout time.Duration
enableBind bool
udpBufferSize int
noDelay bool
hash string
directTunnel bool
entryPoint string
entryPointID relay.TunnelID
entryPointProxyProtocol int
ingress ingress.Ingress
}
func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
enableBind = "bind"
udpBufferSize = "udpBufferSize"
noDelay = "nodelay"
hash = "hash"
entryPoint = "entryPoint"
readTimeout = "readTimeout"
enableBind = "bind"
udpBufferSize = "udpBufferSize"
noDelay = "nodelay"
hash = "hash"
entryPoint = "entryPoint"
entryPointID = "entryPoint.id"
entryPointProxyProtocol = "entryPoint.proxyProtocol"
)
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
@ -46,10 +51,12 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.hash = mdutil.GetString(md, hash)
h.md.entryPoint = mdutil.GetString(md, entryPoint)
h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress"))
h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct")
h.md.entryPoint = mdutil.GetString(md, entryPoint)
h.md.entryPointID = parseTunnelID(mdutil.GetString(md, entryPointID))
h.md.entryPointProxyProtocol = mdutil.GetInt(md, entryPointProxyProtocol)
h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress"))
if h.md.ingress == nil {
var rules []xingress.Rule
for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") {