From 54b56df2141ad44c465b0dc5c30567cbeb63d756 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 18 Oct 2023 19:19:43 +0800 Subject: [PATCH] fix race condition --- handler/forward/local/handler.go | 19 +++++++--------- handler/forward/remote/handler.go | 19 +++++++--------- listener/rtcp/listener.go | 37 ++++++++++++++++++++++--------- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 9a86b6e..26f266e 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -2,7 +2,6 @@ package local import ( "bufio" - "bytes" "context" "crypto/tls" "errors" @@ -195,7 +194,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l err = func() error { req, err := http.ReadRequest(br) if err != nil { - log.Errorf("read http request: %v", err) + // log.Errorf("read http request: %v", err) return err } @@ -268,10 +267,14 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l }) } + if err := req.Write(cc); err != nil { + cc.Close() + log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) + return resp.Write(rw) + } + if req.Header.Get("Upgrade") == "websocket" { - var buf bytes.Buffer - req.Write(&buf) - err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw)) + err := xnet.Transport(cc, xio.NewReadWriter(br, rw)) if err == nil { err = io.EOF } @@ -281,12 +284,6 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l go func() { defer cc.Close() - if err := req.Write(cc); err != nil { - log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) - resp.Write(rw) - return - } - res, err := http.ReadResponse(bufio.NewReader(cc), req) if err != nil { log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 33c6f53..eb39da1 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -2,7 +2,6 @@ package remote import ( "bufio" - "bytes" "context" "crypto/tls" "errors" @@ -197,7 +196,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot err = func() error { req, err := http.ReadRequest(br) if err != nil { - log.Errorf("read http request: %v", err) + // log.Errorf("read http request: %v", err) return err } @@ -272,10 +271,14 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot cc = proxyproto.WrapClientConn(h.md.proxyProtocol, remoteAddr, localAddr, cc) + if err := req.Write(cc); err != nil { + cc.Close() + log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) + return resp.Write(rw) + } + if req.Header.Get("Upgrade") == "websocket" { - var buf bytes.Buffer - req.Write(&buf) - err := xnet.Transport(cc, xio.NewReadWriter(io.MultiReader(&buf, br), rw)) + err := xnet.Transport(cc, xio.NewReadWriter(br, rw)) if err == nil { err = io.EOF } @@ -285,12 +288,6 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot go func() { defer cc.Close() - if err := req.Write(cc); err != nil { - log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) - resp.Write(rw) - return - } - res, err := http.ReadResponse(bufio.NewReader(cc), req) if err != nil { log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) diff --git a/listener/rtcp/listener.go b/listener/rtcp/listener.go index bbbc7cf..95ba00a 100644 --- a/listener/rtcp/listener.go +++ b/listener/rtcp/listener.go @@ -3,6 +3,7 @@ package rtcp import ( "context" "net" + "sync" "github.com/go-gost/core/chain" "github.com/go-gost/core/listener" @@ -27,6 +28,7 @@ type rtcpListener struct { logger logger.Logger closed chan struct{} options listener.Options + mu sync.Mutex } func NewListener(opts ...listener.Option) listener.Listener { @@ -72,23 +74,25 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) { default: } - if l.ln == nil { - l.ln, err = l.router.Bind( + ln := l.getListener() + if ln == nil { + ln, err = l.router.Bind( context.Background(), "tcp", l.laddr.String(), chain.MuxBindOption(true), ) if err != nil { return nil, listener.NewAcceptError(err) } - l.ln = metrics.WrapListener(l.options.Service, l.ln) - l.ln = admission.WrapListener(l.options.Admission, l.ln) - l.ln = limiter.WrapListener(l.options.TrafficLimiter, l.ln) - l.ln = climiter.WrapListener(l.options.ConnLimiter, l.ln) + ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) + l.setListener(ln) } conn, err = l.ln.Accept() if err != nil { - l.ln.Close() - l.ln = nil + ln.Close() + l.setListener(nil) return nil, listener.NewAcceptError(err) } return @@ -103,8 +107,9 @@ func (l *rtcpListener) Close() error { case <-l.closed: default: close(l.closed) - if l.ln != nil { - l.ln.Close() + ln := l.getListener() + if ln != nil { + ln.Close() // l.ln = nil } } @@ -112,6 +117,18 @@ func (l *rtcpListener) Close() error { return nil } +func (l *rtcpListener) setListener(ln net.Listener) { + l.mu.Lock() + defer l.mu.Unlock() + l.ln = ln +} + +func (l *rtcpListener) getListener() net.Listener { + l.mu.Lock() + defer l.mu.Unlock() + return l.ln +} + type bindAddr struct { addr string }