fix mws dialer
This commit is contained in:
parent
5b1183661f
commit
33adbb9027
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/dialer"
|
||||
"github.com/go-gost/core/logger"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
"github.com/go-gost/x/internal/util/mux"
|
||||
ws_util "github.com/go-gost/x/internal/util/ws"
|
||||
@ -100,13 +101,20 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
||||
option(opts)
|
||||
}
|
||||
|
||||
log := d.options.Logger.WithFields(map[string]any{
|
||||
"local": conn.LocalAddr().String(),
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
})
|
||||
|
||||
d.sessionMutex.Lock()
|
||||
defer d.sessionMutex.Unlock()
|
||||
|
||||
session, ok := d.sessions[opts.Addr]
|
||||
if session != nil && session.conn != conn {
|
||||
err := errors.New("mws: unrecognized connection")
|
||||
log.Error(err)
|
||||
conn.Close()
|
||||
return nil, errors.New("mtls: unrecognized connection")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !ok || session.session == nil {
|
||||
@ -114,9 +122,9 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
||||
if host == "" {
|
||||
host = opts.Addr
|
||||
}
|
||||
s, err := d.initSession(ctx, host, conn)
|
||||
s, err := d.initSession(ctx, host, conn, log)
|
||||
if err != nil {
|
||||
d.options.Logger.Error(err)
|
||||
log.Error(err)
|
||||
conn.Close()
|
||||
delete(d.sessions, opts.Addr)
|
||||
return nil, err
|
||||
@ -126,6 +134,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
||||
}
|
||||
cc, err := session.GetConn()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
session.Close()
|
||||
delete(d.sessions, opts.Addr)
|
||||
return nil, err
|
||||
@ -134,7 +143,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) {
|
||||
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn, log logger.Logger) (*muxSession, error) {
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: d.md.handshakeTimeout,
|
||||
ReadBufferSize: d.md.readBufferSize,
|
||||
@ -168,7 +177,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
||||
cc := ws_util.Conn(c)
|
||||
|
||||
if d.md.keepaliveInterval > 0 {
|
||||
d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
|
||||
log.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
|
||||
c.SetPongHandler(func(string) error {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
|
||||
@ -178,8 +187,9 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
||||
}
|
||||
|
||||
// stream multiplex
|
||||
session, err := mux.ClientSession(conn, d.md.muxCfg)
|
||||
session, err := mux.ClientSession(cc, d.md.muxCfg)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
return &muxSession{conn: cc, session: session}, nil
|
||||
|
@ -149,30 +149,30 @@ func (l *mwsListener) Addr() net.Addr {
|
||||
}
|
||||
|
||||
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||
if l.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||
log := l.logger.WithFields(map[string]any{
|
||||
"local": l.addr.String(),
|
||||
"remote": r.RemoteAddr,
|
||||
})
|
||||
if l.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
log.Trace(string(dump))
|
||||
}
|
||||
|
||||
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
l.mux(ws_util.Conn(conn))
|
||||
l.mux(ws_util.Conn(conn), log)
|
||||
}
|
||||
|
||||
func (l *mwsListener) mux(conn net.Conn) {
|
||||
func (l *mwsListener) mux(conn net.Conn, log logger.Logger) {
|
||||
defer conn.Close()
|
||||
|
||||
session, err := mux.ServerSession(conn, l.md.muxCfg)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
@ -180,7 +180,7 @@ func (l *mwsListener) mux(conn net.Conn) {
|
||||
for {
|
||||
stream, err := session.Accept()
|
||||
if err != nil {
|
||||
l.logger.Error("accept stream: ", err)
|
||||
log.Error("accept stream: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -188,7 +188,7 @@ func (l *mwsListener) mux(conn net.Conn) {
|
||||
case l.cqueue <- stream:
|
||||
default:
|
||||
stream.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
|
||||
log.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user