diff --git a/handler/tunnel/conn.go b/handler/tunnel/conn.go deleted file mode 100644 index 91b7666..0000000 --- a/handler/tunnel/conn.go +++ /dev/null @@ -1,26 +0,0 @@ -package tunnel - -import ( - "bytes" - "net" -) - -type tcpConn struct { - net.Conn - wbuf bytes.Buffer -} - -func (c *tcpConn) Read(b []byte) (n int, err error) { - return c.Conn.Read(b) -} - -func (c *tcpConn) Write(b []byte) (n int, err error) { - n = len(b) // force byte length consistent - if c.wbuf.Len() > 0 { - c.wbuf.Write(b) // append the data to the cached header - _, err = c.wbuf.WriteTo(c.Conn) - return - } - _, err = c.Conn.Write(b) - return -} diff --git a/handler/tunnel/connect.go b/handler/tunnel/connect.go index 0157381..78ecc39 100644 --- a/handler/tunnel/connect.go +++ b/handler/tunnel/connect.go @@ -32,39 +32,26 @@ func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, networ host, _, _ := net.SplitHostPort(dstAddr) + // client is a public entrypoint. + if tunnelID.Equal(h.md.entryPointID) && !h.md.entryPointID.IsZero() { + resp.WriteTo(conn) + return h.ep.handle(ctx, conn) + } + var tid relay.TunnelID if ingress := h.md.ingress; ingress != nil && host != "" { tid = parseTunnelID(ingress.Get(ctx, host)) } - // client is a public entrypoint. - if tunnelID.Equal(h.md.entryPointID) && !h.md.entryPointID.IsZero() { - if tid.IsZero() { - resp.Status = relay.StatusNetworkUnreachable - resp.WriteTo(conn) - err := fmt.Errorf("no route to host %s", host) - log.Error(err) - return err - } - - if tid.IsPrivate() { - resp.Status = relay.StatusHostUnreachable - resp.WriteTo(conn) - err := fmt.Errorf("access denied: tunnel %s is private for host %s", tunnelID, host) - log.Error(err) - return err - } - } else { - // direct routing - if h.md.directTunnel { - tid = tunnelID - } else if !tid.Equal(tunnelID) { - resp.Status = relay.StatusHostUnreachable - resp.WriteTo(conn) - err := fmt.Errorf("no route to host %s", host) - log.Error(err) - return err - } + // direct routing + if h.md.directTunnel { + tid = tunnelID + } else if !tid.Equal(tunnelID) { + resp.Status = relay.StatusHostUnreachable + resp.WriteTo(conn) + err := fmt.Errorf("no route to host %s", host) + log.Error(err) + return err } cc, _, err := getTunnelConn(network, h.pool, tid, 3, log) @@ -78,20 +65,9 @@ func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, networ log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) - if h.md.noDelay { - if _, err := resp.WriteTo(conn); err != nil { - log.Error(err) - return err - } - } else { - rc := &tcpConn{ - Conn: conn, - } - // cache the header - if _, err := resp.WriteTo(&rc.wbuf); err != nil { - return err - } - conn = rc + if _, err := resp.WriteTo(conn); err != nil { + log.Error(err) + return err } resp = relay.Response{ diff --git a/handler/tunnel/entrypoint.go b/handler/tunnel/entrypoint.go index 7f22503..a6e3f63 100644 --- a/handler/tunnel/entrypoint.go +++ b/handler/tunnel/entrypoint.go @@ -19,16 +19,204 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/relay" admission "github.com/go-gost/x/admission/wrapper" + xio "github.com/go-gost/x/internal/io" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - auth_util "github.com/go-gost/x/internal/util/auth" - "github.com/go-gost/x/internal/util/forward" - "github.com/go-gost/x/internal/util/mux" climiter "github.com/go-gost/x/limiter/conn/wrapper" limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" ) +type entrypoint struct { + pool *ConnectorPool + ingress ingress.Ingress + log logger.Logger +} + +func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { + defer conn.Close() + + start := time.Now() + log := ep.log.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + br := bufio.NewReader(conn) + + var cc net.Conn + for { + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + StatusCode: http.StatusServiceUnavailable, + } + + err := func() error { + req, err := http.ReadRequest(br) + if err != nil { + // log.Errorf("read http request: %v", err) + return err + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + + var tunnelID relay.TunnelID + if ep.ingress != nil { + tunnelID = parseTunnelID(ep.ingress.Get(ctx, req.Host)) + } + if tunnelID.IsZero() { + err := fmt.Errorf("no route to host %s", req.Host) + log.Error(err) + resp.StatusCode = http.StatusBadGateway + return resp.Write(conn) + } + if tunnelID.IsPrivate() { + err := fmt.Errorf("access denied: tunnel %s is private for host %s", tunnelID, req.Host) + log.Error(err) + resp.StatusCode = http.StatusBadGateway + return resp.Write(conn) + } + + log = log.WithFields(map[string]any{ + "host": req.Host, + "tunnel": tunnelID.String(), + }) + + remoteAddr := conn.RemoteAddr() + if addr := ep.getRealClientAddr(req, remoteAddr); addr != remoteAddr { + log = log.WithFields(map[string]any{ + "src": addr.String(), + }) + remoteAddr = addr + } + + cc, cid, err := getTunnelConn("tcp", ep.pool, tunnelID, 3, log) + if err != nil { + log.Error(err) + return resp.Write(conn) + } + + log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid) + + var features []relay.Feature + af := &relay.AddrFeature{} + af.ParseFrom(remoteAddr.String()) + features = append(features, af) // src address + + host := req.Host + if h, _, _ := net.SplitHostPort(host); h == "" { + host = net.JoinHostPort(host, "80") + } + af = &relay.AddrFeature{} + af.ParseFrom(host) + features = append(features, af) // dst address + + (&relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + Features: features, + }).WriteTo(cc) + + if err := req.Write(cc); err != nil { + cc.Close() + log.Errorf("send request: %v", err) + return resp.Write(conn) + } + + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.Transport(cc, xio.NewReadWriter(br, conn)) + if err == nil { + err = io.EOF + } + return err + } + + go func() { + defer cc.Close() + + t := time.Now() + log.Debugf("%s <-> %s", remoteAddr, host) + + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Debugf("%s >-< %s", remoteAddr, host) + }() + + res, err := http.ReadResponse(bufio.NewReader(cc), req) + if err != nil { + log.Errorf("read response: %v", err) + resp.Write(conn) + return + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpResponse(res, false) + log.Trace(string(dump)) + } + + if err = res.Write(conn); err != nil { + log.Errorf("write response: %v", err) + } + }() + + return nil + }() + + if err != nil { + if cc != nil { + cc.Close() + } + break + } + } + + return nil +} + +func (ep *entrypoint) getRealClientAddr(req *http.Request, raddr net.Addr) net.Addr { + if req == nil { + return nil + } + // cloudflare CDN + sip := req.Header.Get("CF-Connecting-IP") + if sip == "" { + ss := strings.Split(req.Header.Get("X-Forwarded-For"), ",") + if len(ss) > 0 && ss[0] != "" { + sip = ss[0] + } + } + if sip == "" { + sip = req.Header.Get("X-Real-Ip") + } + + ip := net.ParseIP(sip) + if ip == nil { + return raddr + } + + _, sp, _ := net.SplitHostPort(raddr.String()) + + port, _ := strconv.Atoi(sp) + + return &net.TCPAddr{ + IP: ip, + Port: port, + } +} + type tcpListener struct { ln net.Listener options listener.Options @@ -70,87 +258,8 @@ func (l *tcpListener) Close() error { return l.ln.Close() } -type tcpHandler struct { - session *mux.Session - options handler.Options -} - -func newTCPHandler(session *mux.Session, opts ...handler.Option) handler.Handler { - options := handler.Options{} - for _, opt := range opts { - opt(&options) - } - - return &tcpHandler{ - session: session, - options: options, - } -} - -func (h *tcpHandler) Init(md md.Metadata) (err error) { - return -} - -func (h *tcpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { - defer conn.Close() - - start := time.Now() - log := h.options.Logger.WithFields(map[string]any{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), - }) - - log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) - defer func() { - log.WithFields(map[string]any{ - "duration": time.Since(start), - }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) - }() - - cc, err := h.session.GetConn() - if err != nil { - log.Error(err) - return err - } - defer cc.Close() - - af := &relay.AddrFeature{} - af.ParseFrom(conn.RemoteAddr().String()) - resp := relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - Features: []relay.Feature{af}, - } - if _, err := resp.WriteTo(cc); err != nil { - log.Error(err) - return err - } - - t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - xnet.Transport(conn, cc) - log.WithFields(map[string]any{"duration": time.Since(t)}). - Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) - return nil -} - type entrypointHandler struct { - pool *ConnectorPool - ingress ingress.Ingress - options handler.Options -} - -func newEntrypointHandler(pool *ConnectorPool, ingress ingress.Ingress, opts ...handler.Option) handler.Handler { - options := handler.Options{} - for _, opt := range opts { - opt(&options) - } - - return &entrypointHandler{ - pool: pool, - ingress: ingress, - options: options, - } + ep *entrypoint } func (h *entrypointHandler) Init(md md.Metadata) (err error) { @@ -158,219 +267,5 @@ func (h *entrypointHandler) Init(md md.Metadata) (err error) { } func (h *entrypointHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { - defer conn.Close() - - start := time.Now() - log := h.options.Logger.WithFields(map[string]any{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), - }) - - log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) - defer func() { - log.WithFields(map[string]any{ - "duration": time.Since(start), - }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) - }() - - var rw io.ReadWriter = conn - var host string - var protocol string - rw, host, protocol, _ = forward.Sniffing(ctx, conn) - h.options.Logger.Debugf("sniffing: host=%s, protocol=%s", host, protocol) - - if protocol == forward.ProtoHTTP { - return h.handleHTTP(ctx, conn.RemoteAddr(), rw, log) - } - - var tunnelID relay.TunnelID - if h.ingress != nil { - tunnelID = parseTunnelID(h.ingress.Get(ctx, host)) - } - if tunnelID.IsZero() { - err := fmt.Errorf("no route to host %s", host) - log.Error(err) - return err - } - if tunnelID.IsPrivate() { - err := fmt.Errorf("access denied: tunnel %s is private for host %s", tunnelID, host) - log.Error(err) - return err - } - log = log.WithFields(map[string]any{ - "tunnel": tunnelID.String(), - }) - - cc, _, err := getTunnelConn("tcp", h.pool, tunnelID, 3, log) - if err != nil { - log.Error(err) - return err - } - defer cc.Close() - - log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) - - var features []relay.Feature - af := &relay.AddrFeature{} - af.ParseFrom(conn.RemoteAddr().String()) // client address - features = append(features, af) - - if host != "" { - // target host - af := &relay.AddrFeature{} - af.ParseFrom(host) - features = append(features, af) - } - - resp := relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - Features: features, - } - resp.WriteTo(cc) - - t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - xnet.Transport(rw, cc) - log.WithFields(map[string]any{ - "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) - - return nil -} - -func (h *entrypointHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.ReadWriter, log logger.Logger) (err error) { - br := bufio.NewReader(rw) - - for { - resp := &http.Response{ - ProtoMajor: 1, - ProtoMinor: 1, - Header: http.Header{}, - StatusCode: http.StatusServiceUnavailable, - } - - err = func() error { - req, err := http.ReadRequest(br) - if err != nil { - return err - } - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) - log.Trace(string(dump)) - } - - var tunnelID relay.TunnelID - if h.ingress != nil { - tunnelID = parseTunnelID(h.ingress.Get(ctx, req.Host)) - } - if tunnelID.IsZero() { - err := fmt.Errorf("no route to host %s", req.Host) - log.Error(err) - resp.StatusCode = http.StatusBadGateway - return resp.Write(rw) - } - if tunnelID.IsPrivate() { - err := fmt.Errorf("access denied: tunnel %s is private for host %s", tunnelID, req.Host) - log.Error(err) - resp.StatusCode = http.StatusBadGateway - return resp.Write(rw) - } - - log = log.WithFields(map[string]any{ - "host": req.Host, - "tunnel": tunnelID.String(), - }) - - if addr := getRealClientAddr(req, raddr); addr != raddr { - log = log.WithFields(map[string]any{ - "src": addr.String(), - }) - raddr = addr - } - - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(raddr.String())) - - cc, cid, err := getTunnelConn("tcp", h.pool, tunnelID, 3, log) - if err != nil { - log.Error(err) - return resp.Write(rw) - } - defer cc.Close() - - log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid) - - var features []relay.Feature - af := &relay.AddrFeature{} - af.ParseFrom(raddr.String()) // src address - features = append(features, af) - - if host := req.Host; host != "" { - if h, _, _ := net.SplitHostPort(host); h == "" { - host = net.JoinHostPort(host, "80") - } - af := &relay.AddrFeature{} - af.ParseFrom(host) // dst address - features = append(features, af) - } - - (&relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - Features: features, - }).WriteTo(cc) - - if err := req.Write(cc); err != nil { - log.Warnf("send request to tunnel %s: %v", tunnelID, err) - return resp.Write(rw) - } - - res, err := http.ReadResponse(bufio.NewReader(cc), req) - if err != nil { - log.Warnf("read response from tunnel %s: %v", tunnelID, err) - return resp.Write(rw) - } - defer res.Body.Close() - - return res.Write(rw) - }() - if err != nil { - // log.Error(err) - break - } - } - - return -} - -func getRealClientAddr(req *http.Request, raddr net.Addr) net.Addr { - if req == nil { - return nil - } - // cloudflare CDN - sip := req.Header.Get("CF-Connecting-IP") - if sip == "" { - ss := strings.Split(req.Header.Get("X-Forwarded-For"), ",") - if len(ss) > 0 && ss[0] != "" { - sip = ss[0] - } - } - if sip == "" { - sip = req.Header.Get("X-Real-Ip") - } - - ip := net.ParseIP(sip) - if ip == nil { - return raddr - } - - _, sp, _ := net.SplitHostPort(raddr.String()) - - port, _ := strconv.Atoi(sp) - - return &net.TCPAddr{ - IP: ip, - Port: port, - } + return h.ep.handle(ctx, conn) } diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go index 16838b4..4ebfe52 100644 --- a/handler/tunnel/handler.go +++ b/handler/tunnel/handler.go @@ -38,9 +38,10 @@ type tunnelHandler struct { router *chain.Router md metadata options handler.Options - ep service.Service pool *ConnectorPool recorder recorder.RecorderObject + epSvc service.Service + ep *entrypoint } func NewHandler(opts ...handler.Option) handler.Handler { @@ -74,6 +75,13 @@ func (h *tunnelHandler) Init(md md.Metadata) (err error) { } } + h.ep = &entrypoint{ + pool: h.pool, + ingress: h.md.ingress, + log: h.options.Logger.WithFields(map[string]any{ + "kind": "entrypoint", + }), + } if err = h.initEntrypoint(); err != nil { return } @@ -115,24 +123,19 @@ func (h *tunnelHandler) initEntrypoint() (err error) { if err = epListener.Init(nil); err != nil { return } - epHandler := newEntrypointHandler( - h.pool, - h.md.ingress, - handler.ServiceOption(serviceName), - handler.LoggerOption(log.WithFields(map[string]any{ - "kind": "handler", - })), - ) + epHandler := &entrypointHandler{ + ep: h.ep, + } if err = epHandler.Init(nil); err != nil { return } - h.ep = xservice.NewService( + h.epSvc = xservice.NewService( serviceName, epListener, epHandler, xservice.LoggerOption(log), ) - go h.ep.Serve() - log.Infof("entrypoint: %s", h.ep.Addr()) + go h.epSvc.Serve() + log.Infof("entrypoint: %s", h.epSvc.Addr()) return } diff --git a/handler/tunnel/metadata.go b/handler/tunnel/metadata.go index 05f9d62..4d5868e 100644 --- a/handler/tunnel/metadata.go +++ b/handler/tunnel/metadata.go @@ -16,19 +16,16 @@ import ( type metadata struct { readTimeout time.Duration - noDelay bool - hash string - directTunnel bool entryPoint string - entryPointProxyProtocol int entryPointID relay.TunnelID + entryPointProxyProtocol int + directTunnel bool ingress ingress.Ingress muxCfg *mux.Config } func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.readTimeout = mdutil.GetDuration(md, "readTimeout") - h.md.noDelay = mdutil.GetBool(md, "nodelay") h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") h.md.entryPoint = mdutil.GetString(md, "entrypoint") @@ -71,7 +68,5 @@ func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.muxCfg.Version = 2 } - h.md.hash = mdutil.GetString(md, "hash") - return }