diff --git a/chain/hop.go b/chain/hop.go index 1e3ef2b..b9d13c3 100644 --- a/chain/hop.go +++ b/chain/hop.go @@ -97,7 +97,7 @@ func (p *chainHop) Select(ctx context.Context, opts ...chain.SelectOption) *chai if vhost == host || vhost[0] == '.' && strings.HasSuffix(host, vhost[1:]) { filters = append(filters, node) - p.options.logger.Debugf("find node for host: %s(match %s)", host, vhost) + p.options.logger.Debugf("find node for host: %s -> %s(%s)", host, node.Name, node.Addr) } } if len(filters) == 0 { diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 84eb9b7..e668874 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/http/httputil" + "sync" "time" "github.com/go-gost/core/chain" @@ -97,6 +98,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } } + if protocol == forward.ProtoHTTP { + h.handleHTTP(ctx, rw, log) + return nil + } + if _, _, err := net.SplitHostPort(host); err != nil { host = net.JoinHostPort(host, "0") } @@ -138,14 +144,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) - - if protocol == forward.ProtoHTTP && - target.Options().HTTP != nil { - h.handleHTTP(ctx, rw, cc, target.Options().HTTP, log) - } else { - xnet.Transport(rw, cc) - } - + xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) @@ -153,6 +152,110 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } +func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { + br := bufio.NewReader(rw) + var connPool sync.Map + + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: http.StatusServiceUnavailable, + } + + for { + err = func() error { + req, err := http.ReadRequest(br) + if err != nil { + return err + } + + var target *chain.Node + if h.hop != nil { + target = h.hop.Select(ctx, + chain.HostSelectOption(req.Host), + chain.ProtocolSelectOption(forward.ProtoHTTP), + ) + } + if target == nil { + return resp.Write(rw) + } + + log = log.WithFields(map[string]any{ + "dst": target.Addr, + }) + + // log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) + + var cc net.Conn + if v, ok := connPool.Load(target); ok { + cc = v.(net.Conn) + log.Debugf("connection to node %s(%s) found in pool", target.Name, target.Addr) + } + if cc == nil { + cc, err = h.router.Dial(ctx, "tcp", target.Addr) + if err != nil { + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. + if marker := target.Marker(); marker != nil { + marker.Mark() + } + return resp.Write(rw) + } + if marker := target.Marker(); marker != nil { + marker.Reset() + } + connPool.Store(target, cc) + log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) + + go func() { + defer cc.Close() + xnet.CopyBuffer(rw, cc, 8192) + connPool.Delete(target) + }() + } + + if httpSettings := target.Options().HTTP; httpSettings != nil { + if httpSettings.Host != "" { + req.Host = httpSettings.Host + } + for k, v := range httpSettings.Header { + req.Header.Set(k, v) + } + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + if err := req.Write(cc); err != nil { + return resp.Write(rw) + } + + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.CopyBuffer(cc, br, 8192) + if err == nil { + err = io.EOF + } + return err + } + + return nil + }() + if err != nil { + break + } + } + + connPool.Range(func(key, value any) bool { + if value != nil { + value.(net.Conn).Close() + } + return true + }) + + return +} + func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { if h.options.RateLimiter == nil { return true @@ -164,56 +267,3 @@ func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { return true } - -func (h *forwardHandler) handleHTTP(ctx context.Context, src, dst io.ReadWriter, httpSettings *chain.HTTPNodeSettings, log logger.Logger) error { - errc := make(chan error, 1) - go func() { - errc <- xnet.CopyBuffer(src, dst, 8192) - }() - - go func() { - br := bufio.NewReader(src) - for { - err := func() error { - req, err := http.ReadRequest(br) - if err != nil { - return err - } - - if httpSettings.Host != "" { - req.Host = httpSettings.Host - } - for k, v := range httpSettings.Header { - req.Header.Set(k, v) - } - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) - log.Trace(string(dump)) - } - if err := req.Write(dst); err != nil { - return err - } - - if req.Header.Get("Upgrade") == "websocket" { - err := xnet.CopyBuffer(dst, src, 8192) - if err == nil { - err = io.EOF - } - return err - } - return nil - }() - if err != nil { - errc <- err - break - } - } - }() - - if err := <-errc; err != nil && err != io.EOF { - return err - } - - return nil -} diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 4760306..b13edb8 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/http/httputil" + "sync" "time" "github.com/go-gost/core/chain" @@ -94,6 +95,12 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand rw, host, protocol, _ = forward.Sniffing(ctx, conn) } } + + if protocol == forward.ProtoHTTP { + h.handleHTTP(ctx, rw, log) + return nil + } + var target *chain.Node if h.hop != nil { target = h.hop.Select(ctx, @@ -130,14 +137,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) - - if protocol == forward.ProtoHTTP && - target.Options().HTTP != nil { - h.handleHTTP(ctx, rw, cc, target.Options().HTTP, log) - } else { - xnet.Transport(rw, cc) - } - + xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) @@ -145,6 +145,110 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } +func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { + br := bufio.NewReader(rw) + var connPool sync.Map + + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: http.StatusServiceUnavailable, + } + + for { + err = func() error { + req, err := http.ReadRequest(br) + if err != nil { + return err + } + + var target *chain.Node + if h.hop != nil { + target = h.hop.Select(ctx, + chain.HostSelectOption(req.Host), + chain.ProtocolSelectOption(forward.ProtoHTTP), + ) + } + if target == nil { + return resp.Write(rw) + } + + log = log.WithFields(map[string]any{ + "dst": target.Addr, + }) + + // log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) + + var cc net.Conn + if v, ok := connPool.Load(target); ok { + cc = v.(net.Conn) + log.Debugf("reuse connection to node %s(%s)", target.Name, target.Addr) + } + if cc == nil { + cc, err = h.router.Dial(ctx, "tcp", target.Addr) + if err != nil { + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. + if marker := target.Marker(); marker != nil { + marker.Mark() + } + return resp.Write(rw) + } + if marker := target.Marker(); marker != nil { + marker.Reset() + } + connPool.Store(target, cc) + log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) + + go func() { + defer cc.Close() + xnet.CopyBuffer(rw, cc, 8192) + connPool.Delete(target) + }() + } + + if httpSettings := target.Options().HTTP; httpSettings != nil { + if httpSettings.Host != "" { + req.Host = httpSettings.Host + } + for k, v := range httpSettings.Header { + req.Header.Set(k, v) + } + } + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Trace(string(dump)) + } + if err := req.Write(cc); err != nil { + return resp.Write(rw) + } + + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.CopyBuffer(cc, br, 8192) + if err == nil { + err = io.EOF + } + return err + } + + return nil + }() + if err != nil { + break + } + } + + connPool.Range(func(key, value any) bool { + if value != nil { + value.(net.Conn).Close() + } + return true + }) + + return +} + func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { if h.options.RateLimiter == nil { return true @@ -156,56 +260,3 @@ func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { return true } - -func (h *forwardHandler) handleHTTP(ctx context.Context, src, dst io.ReadWriter, httpSettings *chain.HTTPNodeSettings, log logger.Logger) error { - errc := make(chan error, 1) - go func() { - errc <- xnet.CopyBuffer(src, dst, 8192) - }() - - go func() { - br := bufio.NewReader(src) - for { - err := func() error { - req, err := http.ReadRequest(br) - if err != nil { - return err - } - - if httpSettings.Host != "" { - req.Host = httpSettings.Host - } - for k, v := range httpSettings.Header { - req.Header.Set(k, v) - } - - if log.IsLevelEnabled(logger.TraceLevel) { - dump, _ := httputil.DumpRequest(req, false) - log.Trace(string(dump)) - } - if err := req.Write(dst); err != nil { - return err - } - - if req.Header.Get("Upgrade") == "websocket" { - err := xnet.CopyBuffer(dst, src, 8192) - if err == nil { - err = io.EOF - } - return err - } - return nil - }() - if err != nil { - errc <- err - break - } - } - }() - - if err := <-errc; err != nil && err != io.EOF { - return err - } - - return nil -}