From b8027864a3507244da917a763d042c1b485bf6dc Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 12 Feb 2023 20:42:46 +0800 Subject: [PATCH] relay: fix routing for http traffic from entrypoint --- handler/relay/connect.go | 2 +- handler/relay/entrypoint.go | 123 +++++++++++++++++++++++++++++++++++- handler/relay/tunnel.go | 3 +- 3 files changed, 125 insertions(+), 3 deletions(-) diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 2598771..dd9ad84 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -131,7 +131,7 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n return err } - cc, err := getTunnelConn(network, h.pool, tunnelID, 3, log) + cc, _, err := getTunnelConn(network, h.pool, tunnelID, 3, log) if err != nil { log.Error(err) return err diff --git a/handler/relay/entrypoint.go b/handler/relay/entrypoint.go index 7c7701d..f8fbab6 100644 --- a/handler/relay/entrypoint.go +++ b/handler/relay/entrypoint.go @@ -1,15 +1,20 @@ package relay import ( + "bufio" "context" "fmt" "io" "net" + "net/http" + "net/http/httputil" + "sync" "time" "github.com/go-gost/core/handler" "github.com/go-gost/core/ingress" "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/relay" admission "github.com/go-gost/x/admission/wrapper" @@ -173,6 +178,11 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl rw, host, protocol, _ = forward.Sniffing(ctx, conn) h.options.Logger.Debugf("sniffing: host=%s, protocol=%s", host, protocol) + if protocol == forward.ProtoHTTP { + h.handleHTTP(ctx, conn.RemoteAddr(), rw, log) + return nil + } + var tunnelID relay.TunnelID if h.ingress != nil { tunnelID = parseTunnelID(h.ingress.Get(host)) @@ -191,7 +201,7 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl "tunnel": tunnelID.String(), }) - cc, err := getTunnelConn("tcp", h.pool, tunnelID, 3, log) + cc, _, err := getTunnelConn("tcp", h.pool, tunnelID, 3, log) if err != nil { log.Error(err) return err @@ -218,3 +228,114 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl return nil } + +func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.ReadWriter, log logger.Logger) (err error) { + br := bufio.NewReader(rw) + var connPool sync.Map + + for { + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: http.StatusServiceUnavailable, + } + + err = func() error { + req, err := http.ReadRequest(br) + if err != nil { + return err + } + + var tunnelID relay.TunnelID + if h.ingress != nil { + tunnelID = parseTunnelID(h.ingress.Get(req.Host)) + } + if tunnelID.IsZero() { + err := fmt.Errorf("no route to host %s", req.Host) + log.Error(err) + resp.StatusCode = http.StatusBadGateway + return resp.Write(rw) + } + if tunnelID.IsPrivate() { + err := fmt.Errorf("access denied: tunnel %s is private for host %s", tunnelID, req.Host) + log.Error(err) + resp.StatusCode = http.StatusBadGateway + return resp.Write(rw) + } + + log = log.WithFields(map[string]any{ + "host": req.Host, + "tunnel": tunnelID.String(), + }) + + var cc net.Conn + if v, ok := connPool.Load(tunnelID); ok { + cc = v.(net.Conn) + log.Debugf("connection to tunnel %s found in pool", tunnelID) + } + if cc == nil { + var cid relay.ConnectorID + cc, cid, err = getTunnelConn("tcp", h.pool, tunnelID, 3, log) + if err != nil { + log.Error(err) + return resp.Write(rw) + } + + connPool.Store(tunnelID, cc) + log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid) + + af := &relay.AddrFeature{} + af.ParseFrom(raddr.String()) + (&relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + Features: []relay.Feature{af}, + }).WriteTo(cc) + + go func() { + defer cc.Close() + err := xnet.CopyBuffer(rw, cc, 8192) + if err != nil { + resp.Write(rw) + } + log.Debugf("close connection to tunnel %s(connector %s), reason: %v", tunnelID, cid, err) + connPool.Delete(tunnelID) + }() + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + if err := req.Write(cc); err != nil { + log.Warnf("send request to tunnel %s failed: %v", tunnelID, err) + return resp.Write(rw) + } + + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.CopyBuffer(cc, br, 8192) + if err == nil { + err = io.EOF + } + return err + } + + // cc.SetReadDeadline(time.Now().Add(10 * time.Second)) + + return nil + }() + if err != nil { + // log.Error(err) + break + } + } + + connPool.Range(func(key, value any) bool { + if value != nil { + value.(net.Conn).Close() + } + return true + }) + + return +} diff --git a/handler/relay/tunnel.go b/handler/relay/tunnel.go index 1cc9ada..18d1bfa 100644 --- a/handler/relay/tunnel.go +++ b/handler/relay/tunnel.go @@ -176,7 +176,7 @@ func parseTunnelID(s string) (tid relay.TunnelID) { return relay.NewTunnelID(uuid[:]) } -func getTunnelConn(network string, pool *ConnectorPool, tid relay.TunnelID, retry int, log logger.Logger) (conn net.Conn, err error) { +func getTunnelConn(network string, pool *ConnectorPool, tid relay.TunnelID, retry int, log logger.Logger) (conn net.Conn, cid relay.ConnectorID, err error) { if retry <= 0 { retry = 1 } @@ -192,6 +192,7 @@ func getTunnelConn(network string, pool *ConnectorPool, tid relay.TunnelID, retr log.Error(err) continue } + cid = c.id break }