From 43036f8e644274099fbb2b1396f43b208b851631 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 29 Jan 2023 23:32:13 +0800 Subject: [PATCH] fix http traffic forwarding --- handler/forward/local/handler.go | 89 +++++++++++++++++++++---------- handler/forward/remote/handler.go | 88 +++++++++++++++++++++--------- internal/net/transport.go | 8 +-- 3 files changed, 128 insertions(+), 57 deletions(-) diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 786754d..84eb9b7 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -15,7 +15,7 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - netpkg "github.com/go-gost/x/internal/net" + xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/registry" ) @@ -136,34 +136,16 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand marker.Reset() } - if protocol == forward.ProtoHTTP && - target.Options().HTTP != nil { - req, err := http.ReadRequest(bufio.NewReader(rw)) - if err != nil { - return err - } - - httpSettings := target.Options().HTTP - if httpSettings.Host != "" { - req.Host = httpSettings.Host - } - for k, v := range httpSettings.Header { - req.Header.Set(k, v) - } - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) - log.Trace(string(dump)) - } - - if err := req.Write(cc); err != nil { - return err - } - } - t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) - netpkg.Transport(rw, cc) + + if protocol == forward.ProtoHTTP && + target.Options().HTTP != nil { + h.handleHTTP(ctx, rw, cc, target.Options().HTTP, log) + } else { + xnet.Transport(rw, cc) + } + log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) @@ -182,3 +164,56 @@ func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { return true } + +func (h *forwardHandler) handleHTTP(ctx context.Context, src, dst io.ReadWriter, httpSettings *chain.HTTPNodeSettings, log logger.Logger) error { + errc := make(chan error, 1) + go func() { + errc <- xnet.CopyBuffer(src, dst, 8192) + }() + + go func() { + br := bufio.NewReader(src) + for { + err := func() error { + req, err := http.ReadRequest(br) + if err != nil { + return err + } + + if httpSettings.Host != "" { + req.Host = httpSettings.Host + } + for k, v := range httpSettings.Header { + req.Header.Set(k, v) + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + if err := req.Write(dst); err != nil { + return err + } + + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.CopyBuffer(dst, src, 8192) + if err == nil { + err = io.EOF + } + return err + } + return nil + }() + if err != nil { + errc <- err + break + } + } + }() + + if err := <-errc; err != nil && err != io.EOF { + return err + } + + return nil +} diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 8a1bdee..4760306 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -15,7 +15,7 @@ import ( "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - netpkg "github.com/go-gost/x/internal/net" + xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/registry" ) @@ -128,33 +128,16 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand marker.Reset() } - if protocol == forward.ProtoHTTP && - target.Options().HTTP != nil { - req, err := http.ReadRequest(bufio.NewReader(rw)) - if err != nil { - return err - } - - httpSettings := target.Options().HTTP - if httpSettings.Host != "" { - req.Host = httpSettings.Host - } - for k, v := range httpSettings.Header { - req.Header.Set(k, v) - } - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) - log.Trace(string(dump)) - } - if err := req.Write(cc); err != nil { - return err - } - } - t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) - netpkg.Transport(rw, cc) + + if protocol == forward.ProtoHTTP && + target.Options().HTTP != nil { + h.handleHTTP(ctx, rw, cc, target.Options().HTTP, log) + } else { + xnet.Transport(rw, cc) + } + log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) @@ -173,3 +156,56 @@ func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { return true } + +func (h *forwardHandler) handleHTTP(ctx context.Context, src, dst io.ReadWriter, httpSettings *chain.HTTPNodeSettings, log logger.Logger) error { + errc := make(chan error, 1) + go func() { + errc <- xnet.CopyBuffer(src, dst, 8192) + }() + + go func() { + br := bufio.NewReader(src) + for { + err := func() error { + req, err := http.ReadRequest(br) + if err != nil { + return err + } + + if httpSettings.Host != "" { + req.Host = httpSettings.Host + } + for k, v := range httpSettings.Header { + req.Header.Set(k, v) + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + if err := req.Write(dst); err != nil { + return err + } + + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.CopyBuffer(dst, src, 8192) + if err == nil { + err = io.EOF + } + return err + } + return nil + }() + if err != nil { + errc <- err + break + } + } + }() + + if err := <-errc; err != nil && err != io.EOF { + return err + } + + return nil +} diff --git a/internal/net/transport.go b/internal/net/transport.go index 9b76542..375cd36 100644 --- a/internal/net/transport.go +++ b/internal/net/transport.go @@ -11,11 +11,11 @@ import ( func Transport(rw1, rw2 io.ReadWriter) error { errc := make(chan error, 1) go func() { - errc <- copyBuffer(rw1, rw2) + errc <- CopyBuffer(rw1, rw2, 8192) }() go func() { - errc <- copyBuffer(rw2, rw1) + errc <- CopyBuffer(rw2, rw1, 8192) }() if err := <-errc; err != nil && err != io.EOF { @@ -25,8 +25,8 @@ func Transport(rw1, rw2 io.ReadWriter) error { return nil } -func copyBuffer(dst io.Writer, src io.Reader) error { - buf := bufpool.Get(4 * 1024) +func CopyBuffer(dst io.Writer, src io.Reader, bufSize int) error { + buf := bufpool.Get(8192) defer bufpool.Put(buf) _, err := io.CopyBuffer(dst, src, *buf)