From 33adbb90276493ad1534b5d9c40d315a8e3fc20b Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 25 Oct 2023 23:11:55 +0800 Subject: [PATCH] fix mws dialer --- dialer/mws/dialer.go | 22 ++++++++++++++++------ listener/mws/listener.go | 20 ++++++++++---------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/dialer/mws/dialer.go b/dialer/mws/dialer.go index 08e526e..cdd1105 100644 --- a/dialer/mws/dialer.go +++ b/dialer/mws/dialer.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-gost/core/dialer" + "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/x/internal/util/mux" ws_util "github.com/go-gost/x/internal/util/ws" @@ -100,13 +101,20 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia option(opts) } + log := d.options.Logger.WithFields(map[string]any{ + "local": conn.LocalAddr().String(), + "remote": conn.RemoteAddr().String(), + }) + d.sessionMutex.Lock() defer d.sessionMutex.Unlock() session, ok := d.sessions[opts.Addr] if session != nil && session.conn != conn { + err := errors.New("mws: unrecognized connection") + log.Error(err) conn.Close() - return nil, errors.New("mtls: unrecognized connection") + return nil, err } if !ok || session.session == nil { @@ -114,9 +122,9 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia if host == "" { host = opts.Addr } - s, err := d.initSession(ctx, host, conn) + s, err := d.initSession(ctx, host, conn, log) if err != nil { - d.options.Logger.Error(err) + log.Error(err) conn.Close() delete(d.sessions, opts.Addr) return nil, err @@ -126,6 +134,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia } cc, err := session.GetConn() if err != nil { + log.Error(err) session.Close() delete(d.sessions, opts.Addr) return nil, err @@ -134,7 +143,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia return cc, nil } -func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) { +func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn, log logger.Logger) (*muxSession, error) { dialer := websocket.Dialer{ HandshakeTimeout: d.md.handshakeTimeout, ReadBufferSize: d.md.readBufferSize, @@ -168,7 +177,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) cc := ws_util.Conn(c) if d.md.keepaliveInterval > 0 { - d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval) + log.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval) c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) @@ -178,8 +187,9 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) } // stream multiplex - session, err := mux.ClientSession(conn, d.md.muxCfg) + session, err := mux.ClientSession(cc, d.md.muxCfg) if err != nil { + log.Error(err) return nil, err } return &muxSession{conn: cc, session: session}, nil diff --git a/listener/mws/listener.go b/listener/mws/listener.go index 120364a..a4c9402 100644 --- a/listener/mws/listener.go +++ b/listener/mws/listener.go @@ -149,30 +149,30 @@ func (l *mwsListener) Addr() net.Addr { } func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) { + log := l.logger.WithFields(map[string]any{ + "local": l.addr.String(), + "remote": r.RemoteAddr, + }) if l.logger.IsLevelEnabled(logger.TraceLevel) { - log := l.logger.WithFields(map[string]any{ - "local": l.addr.String(), - "remote": r.RemoteAddr, - }) dump, _ := httputil.DumpRequest(r, false) log.Trace(string(dump)) } conn, err := l.upgrader.Upgrade(w, r, l.md.header) if err != nil { - l.logger.Error(err) + log.Error(err) return } - l.mux(ws_util.Conn(conn)) + l.mux(ws_util.Conn(conn), log) } -func (l *mwsListener) mux(conn net.Conn) { +func (l *mwsListener) mux(conn net.Conn, log logger.Logger) { defer conn.Close() session, err := mux.ServerSession(conn, l.md.muxCfg) if err != nil { - l.logger.Error(err) + log.Error(err) return } defer session.Close() @@ -180,7 +180,7 @@ func (l *mwsListener) mux(conn net.Conn) { for { stream, err := session.Accept() if err != nil { - l.logger.Error("accept stream: ", err) + log.Error("accept stream: ", err) return } @@ -188,7 +188,7 @@ func (l *mwsListener) mux(conn net.Conn) { case l.cqueue <- stream: default: stream.Close() - l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr()) + log.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr()) } } }