fix mws dialer

This commit is contained in:
ginuerzh 2023-10-25 23:11:55 +08:00
parent 5b1183661f
commit 33adbb9027
2 changed files with 26 additions and 16 deletions

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/go-gost/core/dialer"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/x/internal/util/mux"
ws_util "github.com/go-gost/x/internal/util/ws"
@ -100,13 +101,20 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
option(opts)
}
log := d.options.Logger.WithFields(map[string]any{
"local": conn.LocalAddr().String(),
"remote": conn.RemoteAddr().String(),
})
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
session, ok := d.sessions[opts.Addr]
if session != nil && session.conn != conn {
err := errors.New("mws: unrecognized connection")
log.Error(err)
conn.Close()
return nil, errors.New("mtls: unrecognized connection")
return nil, err
}
if !ok || session.session == nil {
@ -114,9 +122,9 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
if host == "" {
host = opts.Addr
}
s, err := d.initSession(ctx, host, conn)
s, err := d.initSession(ctx, host, conn, log)
if err != nil {
d.options.Logger.Error(err)
log.Error(err)
conn.Close()
delete(d.sessions, opts.Addr)
return nil, err
@ -126,6 +134,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
}
cc, err := session.GetConn()
if err != nil {
log.Error(err)
session.Close()
delete(d.sessions, opts.Addr)
return nil, err
@ -134,7 +143,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
return cc, nil
}
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) {
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn, log logger.Logger) (*muxSession, error) {
dialer := websocket.Dialer{
HandshakeTimeout: d.md.handshakeTimeout,
ReadBufferSize: d.md.readBufferSize,
@ -168,7 +177,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
cc := ws_util.Conn(c)
if d.md.keepaliveInterval > 0 {
d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
log.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
@ -178,8 +187,9 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
}
// stream multiplex
session, err := mux.ClientSession(conn, d.md.muxCfg)
session, err := mux.ClientSession(cc, d.md.muxCfg)
if err != nil {
log.Error(err)
return nil, err
}
return &muxSession{conn: cc, session: session}, nil

View File

@ -149,30 +149,30 @@ func (l *mwsListener) Addr() net.Addr {
}
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
if l.logger.IsLevelEnabled(logger.TraceLevel) {
log := l.logger.WithFields(map[string]any{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
if l.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false)
log.Trace(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil {
l.logger.Error(err)
log.Error(err)
return
}
l.mux(ws_util.Conn(conn))
l.mux(ws_util.Conn(conn), log)
}
func (l *mwsListener) mux(conn net.Conn) {
func (l *mwsListener) mux(conn net.Conn, log logger.Logger) {
defer conn.Close()
session, err := mux.ServerSession(conn, l.md.muxCfg)
if err != nil {
l.logger.Error(err)
log.Error(err)
return
}
defer session.Close()
@ -180,7 +180,7 @@ func (l *mwsListener) mux(conn net.Conn) {
for {
stream, err := session.Accept()
if err != nil {
l.logger.Error("accept stream: ", err)
log.Error("accept stream: ", err)
return
}
@ -188,7 +188,7 @@ func (l *mwsListener) mux(conn net.Conn) {
case l.cqueue <- stream:
default:
stream.Close()
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
log.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
}
}
}