fix http traffic forwarding

This commit is contained in:
ginuerzh 2023-10-18 14:32:42 +08:00
parent 5d57852c8a
commit 3f3deb98b8
3 changed files with 101 additions and 69 deletions

View File

@ -2,6 +2,7 @@ package local
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
@ -182,6 +183,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) {
br := bufio.NewReader(rw)
var cc net.Conn
for {
resp := &http.Response{
ProtoMajor: 1,
@ -193,6 +195,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
err = func() error {
req, err := http.ReadRequest(br)
if err != nil {
log.Errorf("read http request: %v", err)
return err
}
@ -229,8 +232,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
}
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
}
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
cc, err := h.router.Dial(ctx, "tcp", target.Addr)
cc, err = h.router.Dial(ctx, "tcp", target.Addr)
if err != nil {
// TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation.
@ -243,9 +258,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
if marker := target.Marker(); marker != nil {
marker.Reset()
}
defer cc.Close()
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)
log.Debugf("connection to node %s(%s)", target.Name, target.Addr)
if tlsSettings := target.Options().TLS; tlsSettings != nil {
cc = tls.Client(cc, &tls.Config{
@ -254,36 +268,30 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
})
}
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
}
if req.Header.Get("Upgrade") == "websocket" {
err := xnet.Transport(cc, xio.NewReadWriter(br, rw))
var buf bytes.Buffer
req.Write(&buf)
err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw))
if err == nil {
err = io.EOF
}
return err
}
go func() {
defer cc.Close()
if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
resp.Write(rw)
return
}
res, err := http.ReadResponse(bufio.NewReader(cc), req)
if err != nil {
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
resp.Write(rw)
return
}
if log.IsLevelEnabled(logger.TraceLevel) {
@ -291,10 +299,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
log.Trace(string(dump))
}
return res.Write(rw)
if err = res.Write(rw); err != nil {
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
}
}()
return nil
}()
if err != nil {
// log.Error(err)
if cc != nil {
cc.Close()
}
break
}
}

View File

@ -2,6 +2,7 @@ package remote
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
@ -183,6 +184,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
br := bufio.NewReader(rw)
var cc net.Conn
for {
resp := &http.Response{
@ -195,6 +197,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
err = func() error {
req, err := http.ReadRequest(br)
if err != nil {
log.Errorf("read http request: %v", err)
return err
}
@ -231,8 +234,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
}
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
}
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
cc, err := h.router.Dial(ctx, "tcp", target.Addr)
cc, err = h.router.Dial(ctx, "tcp", target.Addr)
if err != nil {
// TODO: the router itself may be failed due to the failed node in the router,
// the dead marker may be a wrong operation.
@ -245,7 +260,6 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
if marker := target.Marker(); marker != nil {
marker.Reset()
}
defer cc.Close()
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)
@ -256,39 +270,32 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
})
}
if httpSettings := target.Options().HTTP; httpSettings != nil {
if httpSettings.Host != "" {
req.Host = httpSettings.Host
}
for k, v := range httpSettings.Header {
req.Header.Set(k, v)
}
}
if log.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(req, false)
log.Trace(string(dump))
}
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc)
if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
}
if req.Header.Get("Upgrade") == "websocket" {
err := xnet.Transport(cc, xio.NewReadWriter(br, rw))
var buf bytes.Buffer
req.Write(&buf)
err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw))
if err == nil {
err = io.EOF
}
return err
}
go func() {
defer cc.Close()
if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
resp.Write(rw)
return
}
res, err := http.ReadResponse(bufio.NewReader(cc), req)
if err != nil {
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
return resp.Write(rw)
resp.Write(rw)
return
}
if log.IsLevelEnabled(logger.TraceLevel) {
@ -296,9 +303,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
log.Trace(string(dump))
}
return res.Write(rw)
if err = res.Write(rw); err != nil {
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
}
}()
return nil
}()
if err != nil {
if cc != nil {
cc.Close()
}
break
}
}

View File

@ -43,8 +43,8 @@ func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn {
func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
now := time.Now().UnixNano()
// cache the limiter for 1s
if c.limiter != nil && time.Duration(now-c.expIn) > time.Second {
// cache the limiter for 60s
if c.limiter != nil && time.Duration(now-c.expIn) > 60*time.Second {
c.limiterIn = c.limiter.In(addr.String())
c.expIn = now
}
@ -53,8 +53,8 @@ func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter {
now := time.Now().UnixNano()
// cache the limiter for 1s
if c.limiter != nil && time.Duration(now-c.expOut) > time.Second {
// cache the limiter for 60s
if c.limiter != nil && time.Duration(now-c.expOut) > 60*time.Second {
c.limiterOut = c.limiter.Out(addr.String())
c.expOut = now
}