diff --git a/handler/http/handler.go b/handler/http/handler.go index 149e349..b6887be 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "hash/crc32" + "io" "net" "net/http" "net/http/httputil" @@ -148,7 +149,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt fields := map[string]any{ "dst": addr, } - if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log); u != "" { + if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" { fields["user"] = u } log = log.WithFields(fields) @@ -222,26 +223,6 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } defer cc.Close() - if req.Method == http.MethodConnect { - resp.StatusCode = http.StatusOK - resp.Status = "200 Connection established" - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpResponse(resp, false) - log.Trace(string(dump)) - } - if err = resp.Write(conn); err != nil { - log.Error(err) - return err - } - } else { - req.Header.Del("Proxy-Connection") - if err = req.Write(cc); err != nil { - log.Error(err) - return err - } - } - rw := traffic_wrapper.WrapReadWriter(h.options.Limiter, conn, traffic.NetworkOption(network), traffic.AddrOption(addr), @@ -256,6 +237,22 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt rw = stats_wrapper.WrapReadWriter(rw, pstats) } + if req.Method != http.MethodConnect { + return h.handleProxy(rw, cc, req, log) + } + + resp.StatusCode = http.StatusOK + resp.Status = "200 Connection established" + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Trace(string(dump)) + } + if err = resp.Write(rw); err != nil { + log.Error(err) + return err + } + start := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) netpkg.Transport(rw, cc) @@ -266,6 +263,49 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt return nil } +func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) { + req.Header.Del("Proxy-Connection") + + if err = req.Write(cc); err != nil { + log.Error(err) + return err + } + + ch := make(chan error, 1) + + go func() { + ch <- netpkg.CopyBuffer(rw, cc, 32*1024) + }() + + for { + err := func() error { + req, err := http.ReadRequest(bufio.NewReader(rw)) + if err != nil { + return err + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + + req.Header.Del("Proxy-Connection") + + if err = req.Write(cc); err != nil { + return err + } + return nil + }() + ch <- err + + if err != nil { + break + } + } + + return <-ch +} + func (h *httpHandler) decodeServerName(s string) (string, error) { b, err := base64.RawURLEncoding.DecodeString(s) if err != nil { @@ -284,7 +324,7 @@ func (h *httpHandler) decodeServerName(s string) (string, error) { return string(v), nil } -func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (username, password string, ok bool) { +func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) { if proxyAuth == "" { return } @@ -306,7 +346,7 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (usern } func (h *httpHandler) authenticate(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (id string, ok bool) { - u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log) + u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")) if h.options.Auther == nil { return "", true }