http: fix non-connect method request handler

This commit is contained in:
ginuerzh 2024-06-11 21:50:11 +08:00
parent e793b2743b
commit 784e4b2b01

View File

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash/crc32" "hash/crc32"
"io"
"net" "net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@ -148,7 +149,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
fields := map[string]any{ fields := map[string]any{
"dst": addr, "dst": addr,
} }
if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log); u != "" { if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" {
fields["user"] = u fields["user"] = u
} }
log = log.WithFields(fields) log = log.WithFields(fields)
@ -222,26 +223,6 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
} }
defer cc.Close() defer cc.Close()
if req.Method == http.MethodConnect {
resp.StatusCode = http.StatusOK
resp.Status = "200 Connection established"
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpResponse(resp, false)
log.Trace(string(dump))
}
if err = resp.Write(conn); err != nil {
log.Error(err)
return err
}
} else {
req.Header.Del("Proxy-Connection")
if err = req.Write(cc); err != nil {
log.Error(err)
return err
}
}
rw := traffic_wrapper.WrapReadWriter(h.options.Limiter, conn, rw := traffic_wrapper.WrapReadWriter(h.options.Limiter, conn,
traffic.NetworkOption(network), traffic.NetworkOption(network),
traffic.AddrOption(addr), traffic.AddrOption(addr),
@ -256,6 +237,22 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
rw = stats_wrapper.WrapReadWriter(rw, pstats) rw = stats_wrapper.WrapReadWriter(rw, pstats)
} }
if req.Method != http.MethodConnect {
return h.handleProxy(rw, cc, req, log)
}
resp.StatusCode = http.StatusOK
resp.Status = "200 Connection established"
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpResponse(resp, false)
log.Trace(string(dump))
}
if err = resp.Write(rw); err != nil {
log.Error(err)
return err
}
start := time.Now() start := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), addr) log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
netpkg.Transport(rw, cc) netpkg.Transport(rw, cc)
@ -266,6 +263,49 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
return nil return nil
} }
func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) {
req.Header.Del("Proxy-Connection")
if err = req.Write(cc); err != nil {
log.Error(err)
return err
}
ch := make(chan error, 1)
go func() {
ch <- netpkg.CopyBuffer(rw, cc, 32*1024)
}()
for {
err := func() error {
req, err := http.ReadRequest(bufio.NewReader(rw))
if err != nil {
return err
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
req.Header.Del("Proxy-Connection")
if err = req.Write(cc); err != nil {
return err
}
return nil
}()
ch <- err
if err != nil {
break
}
}
return <-ch
}
func (h *httpHandler) decodeServerName(s string) (string, error) { func (h *httpHandler) decodeServerName(s string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(s) b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil { if err != nil {
@ -284,7 +324,7 @@ func (h *httpHandler) decodeServerName(s string) (string, error) {
return string(v), nil return string(v), nil
} }
func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (username, password string, ok bool) { func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) {
if proxyAuth == "" { if proxyAuth == "" {
return return
} }
@ -306,7 +346,7 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (usern
} }
func (h *httpHandler) authenticate(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (id string, ok bool) { func (h *httpHandler) authenticate(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (id string, ok bool) {
u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log) u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"))
if h.options.Auther == nil { if h.options.Auther == nil {
return "", true return "", true
} }