fix http traffic forwarding
This commit is contained in:
parent
5d57852c8a
commit
3f3deb98b8
@ -2,6 +2,7 @@ package local
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
@ -182,6 +183,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
|
|
||||||
|
var cc net.Conn
|
||||||
for {
|
for {
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
@ -193,6 +195,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
|
|||||||
err = func() error {
|
err = func() error {
|
||||||
req, err := http.ReadRequest(br)
|
req, err := http.ReadRequest(br)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("read http request: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,8 +232,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
|
|||||||
}
|
}
|
||||||
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
|
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
|
||||||
}
|
}
|
||||||
|
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||||
|
if httpSettings.Host != "" {
|
||||||
|
req.Host = httpSettings.Host
|
||||||
|
}
|
||||||
|
for k, v := range httpSettings.Header {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if log.IsLevelEnabled(logger.TraceLevel) {
|
||||||
|
dump, _ := httputil.DumpRequest(req, false)
|
||||||
|
log.Trace(string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
cc, err := h.router.Dial(ctx, "tcp", target.Addr)
|
cc, err = h.router.Dial(ctx, "tcp", target.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: the router itself may be failed due to the failed node in the router,
|
// TODO: the router itself may be failed due to the failed node in the router,
|
||||||
// the dead marker may be a wrong operation.
|
// the dead marker may be a wrong operation.
|
||||||
@ -243,9 +258,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
|
|||||||
if marker := target.Marker(); marker != nil {
|
if marker := target.Marker(); marker != nil {
|
||||||
marker.Reset()
|
marker.Reset()
|
||||||
}
|
}
|
||||||
defer cc.Close()
|
|
||||||
|
|
||||||
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)
|
log.Debugf("connection to node %s(%s)", target.Name, target.Addr)
|
||||||
|
|
||||||
if tlsSettings := target.Options().TLS; tlsSettings != nil {
|
if tlsSettings := target.Options().TLS; tlsSettings != nil {
|
||||||
cc = tls.Client(cc, &tls.Config{
|
cc = tls.Client(cc, &tls.Config{
|
||||||
@ -254,47 +268,49 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
|
||||||
if httpSettings.Host != "" {
|
|
||||||
req.Host = httpSettings.Host
|
|
||||||
}
|
|
||||||
for k, v := range httpSettings.Header {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if log.IsLevelEnabled(logger.TraceLevel) {
|
|
||||||
dump, _ := httputil.DumpRequest(req, false)
|
|
||||||
log.Trace(string(dump))
|
|
||||||
}
|
|
||||||
if err := req.Write(cc); err != nil {
|
|
||||||
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
|
|
||||||
return resp.Write(rw)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Header.Get("Upgrade") == "websocket" {
|
if req.Header.Get("Upgrade") == "websocket" {
|
||||||
err := xnet.Transport(cc, xio.NewReadWriter(br, rw))
|
var buf bytes.Buffer
|
||||||
|
req.Write(&buf)
|
||||||
|
err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = io.EOF
|
err = io.EOF
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := http.ReadResponse(bufio.NewReader(cc), req)
|
go func() {
|
||||||
if err != nil {
|
defer cc.Close()
|
||||||
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
|
|
||||||
return resp.Write(rw)
|
|
||||||
}
|
|
||||||
|
|
||||||
if log.IsLevelEnabled(logger.TraceLevel) {
|
if err := req.Write(cc); err != nil {
|
||||||
dump, _ := httputil.DumpResponse(res, false)
|
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
log.Trace(string(dump))
|
resp.Write(rw)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return res.Write(rw)
|
res, err := http.ReadResponse(bufio.NewReader(cc), req)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
|
resp.Write(rw)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if log.IsLevelEnabled(logger.TraceLevel) {
|
||||||
|
dump, _ := httputil.DumpResponse(res, false)
|
||||||
|
log.Trace(string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = res.Write(rw); err != nil {
|
||||||
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// log.Error(err)
|
if cc != nil {
|
||||||
|
cc.Close()
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package remote
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
@ -183,6 +184,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
|
|
||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
|
var cc net.Conn
|
||||||
|
|
||||||
for {
|
for {
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
@ -195,6 +197,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
err = func() error {
|
err = func() error {
|
||||||
req, err := http.ReadRequest(br)
|
req, err := http.ReadRequest(br)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Errorf("read http request: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,8 +234,20 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
}
|
}
|
||||||
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
|
ctx = auth_util.ContextWithID(ctx, auth_util.ID(id))
|
||||||
}
|
}
|
||||||
|
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||||
|
if httpSettings.Host != "" {
|
||||||
|
req.Host = httpSettings.Host
|
||||||
|
}
|
||||||
|
for k, v := range httpSettings.Header {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if log.IsLevelEnabled(logger.TraceLevel) {
|
||||||
|
dump, _ := httputil.DumpRequest(req, false)
|
||||||
|
log.Trace(string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
cc, err := h.router.Dial(ctx, "tcp", target.Addr)
|
cc, err = h.router.Dial(ctx, "tcp", target.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: the router itself may be failed due to the failed node in the router,
|
// TODO: the router itself may be failed due to the failed node in the router,
|
||||||
// the dead marker may be a wrong operation.
|
// the dead marker may be a wrong operation.
|
||||||
@ -245,7 +260,6 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
if marker := target.Marker(); marker != nil {
|
if marker := target.Marker(); marker != nil {
|
||||||
marker.Reset()
|
marker.Reset()
|
||||||
}
|
}
|
||||||
defer cc.Close()
|
|
||||||
|
|
||||||
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)
|
log.Debugf("new connection to node %s(%s)", target.Name, target.Addr)
|
||||||
|
|
||||||
@ -256,49 +270,51 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
|
||||||
if httpSettings.Host != "" {
|
|
||||||
req.Host = httpSettings.Host
|
|
||||||
}
|
|
||||||
for k, v := range httpSettings.Header {
|
|
||||||
req.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if log.IsLevelEnabled(logger.TraceLevel) {
|
|
||||||
dump, _ := httputil.DumpRequest(req, false)
|
|
||||||
log.Trace(string(dump))
|
|
||||||
}
|
|
||||||
|
|
||||||
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc)
|
cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc)
|
||||||
|
|
||||||
if err := req.Write(cc); err != nil {
|
|
||||||
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
|
|
||||||
return resp.Write(rw)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Header.Get("Upgrade") == "websocket" {
|
if req.Header.Get("Upgrade") == "websocket" {
|
||||||
err := xnet.Transport(cc, xio.NewReadWriter(br, rw))
|
var buf bytes.Buffer
|
||||||
|
req.Write(&buf)
|
||||||
|
err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = io.EOF
|
err = io.EOF
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := http.ReadResponse(bufio.NewReader(cc), req)
|
go func() {
|
||||||
if err != nil {
|
defer cc.Close()
|
||||||
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
|
|
||||||
return resp.Write(rw)
|
|
||||||
}
|
|
||||||
|
|
||||||
if log.IsLevelEnabled(logger.TraceLevel) {
|
if err := req.Write(cc); err != nil {
|
||||||
dump, _ := httputil.DumpResponse(res, false)
|
log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
log.Trace(string(dump))
|
resp.Write(rw)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return res.Write(rw)
|
res, err := http.ReadResponse(bufio.NewReader(cc), req)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
|
resp.Write(rw)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if log.IsLevelEnabled(logger.TraceLevel) {
|
||||||
|
dump, _ := httputil.DumpResponse(res, false)
|
||||||
|
log.Trace(string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = res.Write(rw); err != nil {
|
||||||
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if cc != nil {
|
||||||
|
cc.Close()
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,8 +43,8 @@ func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn {
|
|||||||
|
|
||||||
func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
|
func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
|
||||||
now := time.Now().UnixNano()
|
now := time.Now().UnixNano()
|
||||||
// cache the limiter for 1s
|
// cache the limiter for 60s
|
||||||
if c.limiter != nil && time.Duration(now-c.expIn) > time.Second {
|
if c.limiter != nil && time.Duration(now-c.expIn) > 60*time.Second {
|
||||||
c.limiterIn = c.limiter.In(addr.String())
|
c.limiterIn = c.limiter.In(addr.String())
|
||||||
c.expIn = now
|
c.expIn = now
|
||||||
}
|
}
|
||||||
@ -53,8 +53,8 @@ func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
|
|||||||
|
|
||||||
func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter {
|
func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter {
|
||||||
now := time.Now().UnixNano()
|
now := time.Now().UnixNano()
|
||||||
// cache the limiter for 1s
|
// cache the limiter for 60s
|
||||||
if c.limiter != nil && time.Duration(now-c.expOut) > time.Second {
|
if c.limiter != nil && time.Duration(now-c.expOut) > 60*time.Second {
|
||||||
c.limiterOut = c.limiter.Out(addr.String())
|
c.limiterOut = c.limiter.Out(addr.String())
|
||||||
c.expOut = now
|
c.expOut = now
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user