diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index a760b24..847f2e3 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -185,6 +185,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l resp := &http.Response{ ProtoMajor: 1, ProtoMinor: 1, + Header: http.Header{}, StatusCode: http.StatusServiceUnavailable, } diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index dec8343..233c57a 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -11,7 +11,6 @@ import ( "net/http" "net/http/httputil" "strconv" - "sync" "time" "github.com/go-gost/core/chain" @@ -183,7 +182,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) { br := bufio.NewReader(rw) - var connPool sync.Map for { resp := &http.Response{ @@ -233,47 +231,28 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } - var cc net.Conn - if v, ok := connPool.Load(target); ok { - cc = v.(net.Conn) - log.Debugf("reuse connection to node %s(%s)", target.Name, target.Addr) - } - if cc == nil { - cc, err = h.router.Dial(ctx, "tcp", target.Addr) - if err != nil { - // TODO: the router itself may be failed due to the failed node in the router, - // the dead marker may be a wrong operation. - if marker := target.Marker(); marker != nil { - marker.Mark() - } - log.Warnf("connect to node %s(%s) failed: %v", target.Name, target.Addr, err) - return resp.Write(rw) - } + cc, err := h.router.Dial(ctx, "tcp", target.Addr) + if err != nil { + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. if marker := target.Marker(); marker != nil { - marker.Reset() + marker.Mark() } + log.Warnf("connect to node %s(%s) failed: %v", target.Name, target.Addr, err) + return resp.Write(rw) + } + if marker := target.Marker(); marker != nil { + marker.Reset() + } + defer cc.Close() - if tlsSettings := target.Options().TLS; tlsSettings != nil { - cc = tls.Client(cc, &tls.Config{ - ServerName: tlsSettings.ServerName, - InsecureSkipVerify: !tlsSettings.Secure, - }) - } + log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) - cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc) - - connPool.Store(target, cc) - log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) - - go func() { - defer cc.Close() - err := xnet.CopyBuffer(rw, cc, 8192) - if err != nil { - resp.Write(rw) - } - log.Debugf("close connection to node %s(%s), reason: %v", target.Name, target.Addr, err) - connPool.Delete(target) - }() + if tlsSettings := target.Options().TLS; tlsSettings != nil { + cc = tls.Client(cc, &tls.Config{ + ServerName: tlsSettings.ServerName, + InsecureSkipVerify: !tlsSettings.Secure, + }) } if httpSettings := target.Options().HTTP; httpSettings != nil { @@ -289,35 +268,28 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot dump, _ := httputil.DumpRequest(req, false) log.Trace(string(dump)) } + + cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc) + if err := req.Write(cc); err != nil { - log.Warnf("send request to node %s(%s) failed: %v", target.Name, target.Addr, err) + log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) return resp.Write(rw) } - if req.Header.Get("Upgrade") == "websocket" { - err := xnet.CopyBuffer(cc, br, 8192) - if err == nil { - err = io.EOF - } - return err + res, err := http.ReadResponse(bufio.NewReader(cc), req) + if err != nil { + log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) + return resp.Write(rw) } + defer res.Body.Close() - // cc.SetReadDeadline(time.Now().Add(10 * time.Second)) - - return nil + return res.Write(rw) }() if err != nil { break } } - connPool.Range(func(key, value any) bool { - if value != nil { - value.(net.Conn).Close() - } - return true - }) - return } diff --git a/handler/relay/entrypoint.go b/handler/relay/entrypoint.go index d69001e..c77c67c 100644 --- a/handler/relay/entrypoint.go +++ b/handler/relay/entrypoint.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/http/httputil" - "sync" "time" "github.com/go-gost/core/handler" @@ -179,8 +178,7 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl h.options.Logger.Debugf("sniffing: host=%s, protocol=%s", host, protocol) if protocol == forward.ProtoHTTP { - h.handleHTTP(ctx, conn.RemoteAddr(), rw, log) - return nil + return h.handleHTTP(ctx, conn.RemoteAddr(), rw, log) } var tunnelID relay.TunnelID @@ -241,12 +239,12 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.ReadWriter, log logger.Logger) (err error) { br := bufio.NewReader(rw) - var connPool sync.Map for { resp := &http.Response{ ProtoMajor: 1, ProtoMinor: 1, + Header: http.Header{}, StatusCode: http.StatusServiceUnavailable, } @@ -278,73 +276,52 @@ func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.Re "tunnel": tunnelID.String(), }) - var cc net.Conn - if v, ok := connPool.Load(tunnelID); ok { - cc = v.(net.Conn) - log.Debugf("connection to tunnel %s found in pool", tunnelID) + cc, cid, err := getTunnelConn("tcp", h.pool, tunnelID, 3, log) + if err != nil { + log.Error(err) + return resp.Write(rw) } - if cc == nil { - var cid relay.ConnectorID - cc, cid, err = getTunnelConn("tcp", h.pool, tunnelID, 3, log) - if err != nil { - log.Error(err) - return resp.Write(rw) + defer cc.Close() + + log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid) + + var features []relay.Feature + af := &relay.AddrFeature{} + af.ParseFrom(raddr.String()) + features = append(features, af) + + if host := req.Host; host != "" { + if h, _, _ := net.SplitHostPort(host); h == "" { + host = net.JoinHostPort(host, "80") } - - connPool.Store(tunnelID, cc) - log.Debugf("new connection to tunnel %s(connector %s)", tunnelID, cid) - - var features []relay.Feature af := &relay.AddrFeature{} - af.ParseFrom(raddr.String()) + af.ParseFrom(host) features = append(features, af) - - if host := req.Host; host != "" { - if h, _, _ := net.SplitHostPort(host); h == "" { - host = net.JoinHostPort(host, "80") - } - af := &relay.AddrFeature{} - af.ParseFrom(host) - features = append(features, af) - } - - (&relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - Features: features, - }).WriteTo(cc) - - go func() { - defer cc.Close() - err := xnet.CopyBuffer(rw, cc, 8192) - if err != nil { - resp.Write(rw) - } - log.Debugf("close connection to tunnel %s(connector %s), reason: %v", tunnelID, cid, err) - connPool.Delete(tunnelID) - }() } + (&relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + Features: features, + }).WriteTo(cc) + if log.IsLevelEnabled(logger.TraceLevel) { dump, _ := httputil.DumpRequest(req, false) log.Trace(string(dump)) } if err := req.Write(cc); err != nil { - log.Warnf("send request to tunnel %s failed: %v", tunnelID, err) + log.Warnf("send request to tunnel %s: %v", tunnelID, err) return resp.Write(rw) } - if req.Header.Get("Upgrade") == "websocket" { - err := xnet.CopyBuffer(cc, br, 8192) - if err == nil { - err = io.EOF - } - return err + res, err := http.ReadResponse(bufio.NewReader(cc), req) + if err != nil { + log.Warnf("read response from tunnel %s: %v", tunnelID, err) + return resp.Write(rw) } + defer res.Body.Close() - // cc.SetReadDeadline(time.Now().Add(10 * time.Second)) - - return nil + return res.Write(rw) }() if err != nil { // log.Error(err) @@ -352,12 +329,5 @@ func (h *tunnelHandler) handleHTTP(ctx context.Context, raddr net.Addr, rw io.Re } } - connPool.Range(func(key, value any) bool { - if value != nil { - value.(net.Conn).Close() - } - return true - }) - return }