From 3f3deb98b816b7466098cbbbd834171d694331f8 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 18 Oct 2023 14:32:42 +0800 Subject: [PATCH] fix http traffic forwarding --- handler/forward/local/handler.go | 82 ++++++++++++++++++------------- handler/forward/remote/handler.go | 80 ++++++++++++++++++------------ limiter/traffic/wrapper/conn.go | 8 +-- 3 files changed, 101 insertions(+), 69 deletions(-) diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 891829d..9a86b6e 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -2,6 +2,7 @@ package local import ( "bufio" + "bytes" "context" "crypto/tls" "errors" @@ -182,6 +183,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { br := bufio.NewReader(rw) + var cc net.Conn for { resp := &http.Response{ ProtoMajor: 1, @@ -193,6 +195,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l err = func() error { req, err := http.ReadRequest(br) if err != nil { + log.Errorf("read http request: %v", err) return err } @@ -229,8 +232,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l } ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } + if httpSettings := target.Options().HTTP; httpSettings != nil { + if httpSettings.Host != "" { + req.Host = httpSettings.Host + } + for k, v := range httpSettings.Header { + req.Header.Set(k, v) + } + } + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } - cc, err := h.router.Dial(ctx, "tcp", target.Addr) + cc, err = h.router.Dial(ctx, "tcp", target.Addr) if err != nil { // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. @@ -243,9 +258,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l if marker := target.Marker(); marker != nil { marker.Reset() } - defer cc.Close() - log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) + log.Debugf("connection to node %s(%s)", target.Name, target.Addr) if tlsSettings := target.Options().TLS; tlsSettings != nil { cc = tls.Client(cc, &tls.Config{ @@ -254,47 +268,49 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l }) } - if httpSettings := target.Options().HTTP; httpSettings != nil { - if httpSettings.Host != "" { - req.Host = httpSettings.Host - } - for k, v := range httpSettings.Header { - req.Header.Set(k, v) - } - } - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) - log.Trace(string(dump)) - } - if err := req.Write(cc); err != nil { - log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) - return resp.Write(rw) - } - if req.Header.Get("Upgrade") == "websocket" { - err := xnet.Transport(cc, xio.NewReadWriter(br, rw)) + var buf bytes.Buffer + req.Write(&buf) + err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw)) if err == nil { err = io.EOF } return err } - res, err := http.ReadResponse(bufio.NewReader(cc), req) - if err != nil { - log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) - return resp.Write(rw) - } + go func() { + defer cc.Close() - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpResponse(res, false) - log.Trace(string(dump)) - } + if err := req.Write(cc); err != nil { + log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) + resp.Write(rw) + return + } - return res.Write(rw) + res, err := http.ReadResponse(bufio.NewReader(cc), req) + if err != nil { + log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) + resp.Write(rw) + return + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpResponse(res, false) + log.Trace(string(dump)) + } + + if err = res.Write(rw); err != nil { + log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err) + } + }() + + return nil }() + if err != nil { - // log.Error(err) + if cc != nil { + cc.Close() + } break } } diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 7c2fe75..33c6f53 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -2,6 +2,7 @@ package remote import ( "bufio" + "bytes" "context" "crypto/tls" "errors" @@ -183,6 +184,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) { br := bufio.NewReader(rw) + var cc net.Conn for { resp := &http.Response{ @@ -195,6 +197,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot err = func() error { req, err := http.ReadRequest(br) if err != nil { + log.Errorf("read http request: %v", err) return err } @@ -231,8 +234,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot } ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } + if httpSettings := target.Options().HTTP; httpSettings != nil { + if httpSettings.Host != "" { + req.Host = httpSettings.Host + } + for k, v := range httpSettings.Header { + req.Header.Set(k, v) + } + } + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } - cc, err := h.router.Dial(ctx, "tcp", target.Addr) + cc, err = h.router.Dial(ctx, "tcp", target.Addr) if err != nil { // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. @@ -245,7 +260,6 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot if marker := target.Marker(); marker != nil { marker.Reset() } - defer cc.Close() log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) @@ -256,49 +270,51 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot }) } - if httpSettings := target.Options().HTTP; httpSettings != nil { - if httpSettings.Host != "" { - req.Host = httpSettings.Host - } - for k, v := range httpSettings.Header { - req.Header.Set(k, v) - } - } - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) - log.Trace(string(dump)) - } - cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc) - if err := req.Write(cc); err != nil { - log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) - return resp.Write(rw) - } - if req.Header.Get("Upgrade") == "websocket" { - err := xnet.Transport(cc, xio.NewReadWriter(br, rw)) + var buf bytes.Buffer + req.Write(&buf) + err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw)) if err == nil { err = io.EOF } return err } - res, err := http.ReadResponse(bufio.NewReader(cc), req) - if err != nil { - log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) - return resp.Write(rw) - } + go func() { + defer cc.Close() - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpResponse(res, false) - log.Trace(string(dump)) - } + if err := req.Write(cc); err != nil { + log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) + resp.Write(rw) + return + } - return res.Write(rw) + res, err := http.ReadResponse(bufio.NewReader(cc), req) + if err != nil { + log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) + resp.Write(rw) + return + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpResponse(res, false) + log.Trace(string(dump)) + } + + if err = res.Write(rw); err != nil { + log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err) + } + }() + + return nil }() + if err != nil { + if cc != nil { + cc.Close() + } break } } diff --git a/limiter/traffic/wrapper/conn.go b/limiter/traffic/wrapper/conn.go index 4cec7e0..3671041 100644 --- a/limiter/traffic/wrapper/conn.go +++ b/limiter/traffic/wrapper/conn.go @@ -43,8 +43,8 @@ func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn { func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter { now := time.Now().UnixNano() - // cache the limiter for 1s - if c.limiter != nil && time.Duration(now-c.expIn) > time.Second { + // cache the limiter for 60s + if c.limiter != nil && time.Duration(now-c.expIn) > 60*time.Second { c.limiterIn = c.limiter.In(addr.String()) c.expIn = now } @@ -53,8 +53,8 @@ func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter { func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter { now := time.Now().UnixNano() - // cache the limiter for 1s - if c.limiter != nil && time.Duration(now-c.expOut) > time.Second { + // cache the limiter for 60s + if c.limiter != nil && time.Duration(now-c.expOut) > 60*time.Second { c.limiterOut = c.limiter.Out(addr.String()) c.expOut = now }