fix http traffic forwarding

This commit is contained in:
ginuerzh 2023-10-18 14:32:42 +08:00
parent 5d57852c8a
commit 3f3deb98b8
3 changed files with 101 additions and 69 deletions

View File

@ -2,6 +2,7 @@ package local
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
@ -182,6 +183,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) {
br := bufio.NewReader(rw) br := bufio.NewReader(rw)
var cc net.Conn
for { for {
resp := &http.Response{ resp := &http.Response{
ProtoMajor: 1, ProtoMajor: 1,
@ -193,6 +195,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
err = func() error { err = func() error {
req, err := http.ReadRequest(br) req, err := http.ReadRequest(br)
if err != nil { if err != nil {
log.Errorf("read http request: %v", err)
return err return err
} }
@ -229,8 +232,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
} }
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
} }
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
cc, err := h.router.Dial(ctx, "tcp", target.Addr) cc, err = h.router.Dial(ctx, "tcp", target.Addr)
if err != nil { if err != nil {
// TODO: the router itself may be failed due to the failed node in the router, // TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation. // the dead marker may be a wrong operation.
@ -243,9 +258,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
if marker := target.Marker(); marker != nil { if marker := target.Marker(); marker != nil {
marker.Reset() marker.Reset()
} }
defer cc.Close()
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) log.Debugf("connection to node %s(%s)", target.Name, target.Addr)
if tlsSettings := target.Options().TLS; tlsSettings != nil { if tlsSettings := target.Options().TLS; tlsSettings != nil {
cc = tls.Client(cc, &tls.Config{ cc = tls.Client(cc, &tls.Config{
@ -254,47 +268,49 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
}) })
} }
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
}
if req.Header.Get("Upgrade") == "websocket" { if req.Header.Get("Upgrade") == "websocket" {
err := xnet.Transport(cc, xio.NewReadWriter(br, rw)) var buf bytes.Buffer
req.Write(&buf)
err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw))
if err == nil { if err == nil {
err = io.EOF err = io.EOF
} }
return err return err
} }
res, err := http.ReadResponse(bufio.NewReader(cc), req) go func() {
if err != nil { defer cc.Close()
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
}
if log.IsLevelEnabled(logger.TraceLevel) { if err := req.Write(cc); err != nil {
dump, _ := httputil.DumpResponse(res, false) log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
log.Trace(string(dump)) resp.Write(rw)
} return
}
return res.Write(rw) res, err := http.ReadResponse(bufio.NewReader(cc), req)
if err != nil {
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
resp.Write(rw)
return
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpResponse(res, false)
log.Trace(string(dump))
}
if err = res.Write(rw); err != nil {
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
}
}()
return nil
}() }()
if err != nil { if err != nil {
// log.Error(err) if cc != nil {
cc.Close()
}
break break
} }
} }

View File

@ -2,6 +2,7 @@ package remote
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
@ -183,6 +184,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) { func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
br := bufio.NewReader(rw) br := bufio.NewReader(rw)
var cc net.Conn
for { for {
resp := &http.Response{ resp := &http.Response{
@ -195,6 +197,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
err = func() error { err = func() error {
req, err := http.ReadRequest(br) req, err := http.ReadRequest(br)
if err != nil { if err != nil {
log.Errorf("read http request: %v", err)
return err return err
} }
@ -231,8 +234,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
} }
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
} }
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
cc, err := h.router.Dial(ctx, "tcp", target.Addr) cc, err = h.router.Dial(ctx, "tcp", target.Addr)
if err != nil { if err != nil {
// TODO: the router itself may be failed due to the failed node in the router, // TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation. // the dead marker may be a wrong operation.
@ -245,7 +260,6 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
if marker := target.Marker(); marker != nil { if marker := target.Marker(); marker != nil {
marker.Reset() marker.Reset()
} }
defer cc.Close()
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)
@ -256,49 +270,51 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
}) })
} }
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc) cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc)
if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
}
if req.Header.Get("Upgrade") == "websocket" { if req.Header.Get("Upgrade") == "websocket" {
err := xnet.Transport(cc, xio.NewReadWriter(br, rw)) var buf bytes.Buffer
req.Write(&buf)
err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw))
if err == nil { if err == nil {
err = io.EOF err = io.EOF
} }
return err return err
} }
res, err := http.ReadResponse(bufio.NewReader(cc), req) go func() {
if err != nil { defer cc.Close()
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
}
if log.IsLevelEnabled(logger.TraceLevel) { if err := req.Write(cc); err != nil {
dump, _ := httputil.DumpResponse(res, false) log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
log.Trace(string(dump)) resp.Write(rw)
} return
}
return res.Write(rw) res, err := http.ReadResponse(bufio.NewReader(cc), req)
if err != nil {
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
resp.Write(rw)
return
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpResponse(res, false)
log.Trace(string(dump))
}
if err = res.Write(rw); err != nil {
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
}
}()
return nil
}() }()
if err != nil { if err != nil {
if cc != nil {
cc.Close()
}
break break
} }
} }

View File

@ -43,8 +43,8 @@ func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn {
func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter { func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
now := time.Now().UnixNano() now := time.Now().UnixNano()
// cache the limiter for 1s // cache the limiter for 60s
if c.limiter != nil && time.Duration(now-c.expIn) > time.Second { if c.limiter != nil && time.Duration(now-c.expIn) > 60*time.Second {
c.limiterIn = c.limiter.In(addr.String()) c.limiterIn = c.limiter.In(addr.String())
c.expIn = now c.expIn = now
} }
@ -53,8 +53,8 @@ func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter { func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter {
now := time.Now().UnixNano() now := time.Now().UnixNano()
// cache the limiter for 1s // cache the limiter for 60s
if c.limiter != nil && time.Duration(now-c.expOut) > time.Second { if c.limiter != nil && time.Duration(now-c.expOut) > 60*time.Second {
c.limiterOut = c.limiter.Out(addr.String()) c.limiterOut = c.limiter.Out(addr.String())
c.expOut = now c.expOut = now
} }