fix mws dialer
This commit is contained in:
parent
5b1183661f
commit
33adbb9027
@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/core/dialer"
|
"github.com/go-gost/core/dialer"
|
||||||
|
"github.com/go-gost/core/logger"
|
||||||
md "github.com/go-gost/core/metadata"
|
md "github.com/go-gost/core/metadata"
|
||||||
"github.com/go-gost/x/internal/util/mux"
|
"github.com/go-gost/x/internal/util/mux"
|
||||||
ws_util "github.com/go-gost/x/internal/util/ws"
|
ws_util "github.com/go-gost/x/internal/util/ws"
|
||||||
@ -100,13 +101,20 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
|||||||
option(opts)
|
option(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log := d.options.Logger.WithFields(map[string]any{
|
||||||
|
"local": conn.LocalAddr().String(),
|
||||||
|
"remote": conn.RemoteAddr().String(),
|
||||||
|
})
|
||||||
|
|
||||||
d.sessionMutex.Lock()
|
d.sessionMutex.Lock()
|
||||||
defer d.sessionMutex.Unlock()
|
defer d.sessionMutex.Unlock()
|
||||||
|
|
||||||
session, ok := d.sessions[opts.Addr]
|
session, ok := d.sessions[opts.Addr]
|
||||||
if session != nil && session.conn != conn {
|
if session != nil && session.conn != conn {
|
||||||
|
err := errors.New("mws: unrecognized connection")
|
||||||
|
log.Error(err)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, errors.New("mtls: unrecognized connection")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ok || session.session == nil {
|
if !ok || session.session == nil {
|
||||||
@ -114,9 +122,9 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
|||||||
if host == "" {
|
if host == "" {
|
||||||
host = opts.Addr
|
host = opts.Addr
|
||||||
}
|
}
|
||||||
s, err := d.initSession(ctx, host, conn)
|
s, err := d.initSession(ctx, host, conn, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.options.Logger.Error(err)
|
log.Error(err)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
delete(d.sessions, opts.Addr)
|
delete(d.sessions, opts.Addr)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -126,6 +134,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
|||||||
}
|
}
|
||||||
cc, err := session.GetConn()
|
cc, err := session.GetConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
session.Close()
|
session.Close()
|
||||||
delete(d.sessions, opts.Addr)
|
delete(d.sessions, opts.Addr)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -134,7 +143,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
|||||||
return cc, nil
|
return cc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) {
|
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn, log logger.Logger) (*muxSession, error) {
|
||||||
dialer := websocket.Dialer{
|
dialer := websocket.Dialer{
|
||||||
HandshakeTimeout: d.md.handshakeTimeout,
|
HandshakeTimeout: d.md.handshakeTimeout,
|
||||||
ReadBufferSize: d.md.readBufferSize,
|
ReadBufferSize: d.md.readBufferSize,
|
||||||
@ -168,7 +177,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
|||||||
cc := ws_util.Conn(c)
|
cc := ws_util.Conn(c)
|
||||||
|
|
||||||
if d.md.keepaliveInterval > 0 {
|
if d.md.keepaliveInterval > 0 {
|
||||||
d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
|
log.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval)
|
||||||
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
|
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
|
||||||
c.SetPongHandler(func(string) error {
|
c.SetPongHandler(func(string) error {
|
||||||
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
|
c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2))
|
||||||
@ -178,8 +187,9 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stream multiplex
|
// stream multiplex
|
||||||
session, err := mux.ClientSession(conn, d.md.muxCfg)
|
session, err := mux.ClientSession(cc, d.md.muxCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &muxSession{conn: cc, session: session}, nil
|
return &muxSession{conn: cc, session: session}, nil
|
||||||
|
@ -149,30 +149,30 @@ func (l *mwsListener) Addr() net.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log := l.logger.WithFields(map[string]any{
|
||||||
|
"local": l.addr.String(),
|
||||||
|
"remote": r.RemoteAddr,
|
||||||
|
})
|
||||||
if l.logger.IsLevelEnabled(logger.TraceLevel) {
|
if l.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||||
log := l.logger.WithFields(map[string]any{
|
|
||||||
"local": l.addr.String(),
|
|
||||||
"remote": r.RemoteAddr,
|
|
||||||
})
|
|
||||||
dump, _ := httputil.DumpRequest(r, false)
|
dump, _ := httputil.DumpRequest(r, false)
|
||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
|
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error(err)
|
log.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l.mux(ws_util.Conn(conn))
|
l.mux(ws_util.Conn(conn), log)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *mwsListener) mux(conn net.Conn) {
|
func (l *mwsListener) mux(conn net.Conn, log logger.Logger) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
session, err := mux.ServerSession(conn, l.md.muxCfg)
|
session, err := mux.ServerSession(conn, l.md.muxCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error(err)
|
log.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
@ -180,7 +180,7 @@ func (l *mwsListener) mux(conn net.Conn) {
|
|||||||
for {
|
for {
|
||||||
stream, err := session.Accept()
|
stream, err := session.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.logger.Error("accept stream: ", err)
|
log.Error("accept stream: ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,7 +188,7 @@ func (l *mwsListener) mux(conn net.Conn) {
|
|||||||
case l.cqueue <- stream:
|
case l.cqueue <- stream:
|
||||||
default:
|
default:
|
||||||
stream.Close()
|
stream.Close()
|
||||||
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
|
log.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user