From a9f0dda80573232e3f578de98fa76959cf069156 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 17 Oct 2023 23:04:30 +0800 Subject: [PATCH] fix websocket for tunnel --- dialer/mtls/conn.go | 8 +++---- dialer/mtls/dialer.go | 22 ++--------------- dialer/mtls/metadata.go | 38 ++++++++++------------------- dialer/mws/conn.go | 8 +++---- dialer/mws/dialer.go | 22 ++--------------- dialer/mws/metadata.go | 31 +++++++++--------------- handler/forward/local/handler.go | 8 +++++++ handler/forward/remote/handler.go | 8 +++++++ internal/util/forward/forward.go | 11 +++------ listener/mtls/listener.go | 25 +++---------------- listener/mtls/metadata.go | 40 ++++++++++--------------------- listener/mws/listener.go | 25 +++---------------- listener/mws/metadata.go | 23 +++++++++--------- 13 files changed, 83 insertions(+), 186 deletions(-) diff --git a/dialer/mtls/conn.go b/dialer/mtls/conn.go index 4b5f8c2..78d6cde 100644 --- a/dialer/mtls/conn.go +++ b/dialer/mtls/conn.go @@ -3,20 +3,20 @@ package mtls import ( "net" - "github.com/xtaci/smux" + "github.com/go-gost/x/internal/util/mux" ) type muxSession struct { conn net.Conn - session *smux.Session + session *mux.Session } func (session *muxSession) GetConn() (net.Conn, error) { - return session.session.OpenStream() + return session.session.GetConn() } func (session *muxSession) Accept() (net.Conn, error) { - return session.session.AcceptStream() + return session.session.Accept() } func (session *muxSession) Close() error { diff --git a/dialer/mtls/dialer.go b/dialer/mtls/dialer.go index 6835fea..f7052a6 100644 --- a/dialer/mtls/dialer.go +++ b/dialer/mtls/dialer.go @@ -11,8 +11,8 @@ import ( "github.com/go-gost/core/dialer" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/internal/util/mux" "github.com/go-gost/x/registry" - "github.com/xtaci/smux" ) func init() { @@ -130,25 +130,7 @@ func (d *mtlsDialer) initSession(ctx context.Context, conn net.Conn) (*muxSessio conn = tlsConn // stream multiplex - smuxConfig := smux.DefaultConfig() - smuxConfig.KeepAliveDisabled = d.md.muxKeepAliveDisabled - if d.md.muxKeepAliveInterval > 0 { - smuxConfig.KeepAliveInterval = d.md.muxKeepAliveInterval - } - if d.md.muxKeepAliveTimeout > 0 { - smuxConfig.KeepAliveTimeout = d.md.muxKeepAliveTimeout - } - if d.md.muxMaxFrameSize > 0 { - smuxConfig.MaxFrameSize = d.md.muxMaxFrameSize - } - if d.md.muxMaxReceiveBuffer > 0 { - smuxConfig.MaxReceiveBuffer = d.md.muxMaxReceiveBuffer - } - if d.md.muxMaxStreamBuffer > 0 { - smuxConfig.MaxStreamBuffer = d.md.muxMaxStreamBuffer - } - - session, err := smux.Client(conn, smuxConfig) + session, err := mux.ClientSession(conn, d.md.muxCfg) if err != nil { return nil, err } diff --git a/dialer/mtls/metadata.go b/dialer/mtls/metadata.go index d7e3093..0ba7e70 100644 --- a/dialer/mtls/metadata.go +++ b/dialer/mtls/metadata.go @@ -5,39 +5,25 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/x/internal/util/mux" ) type metadata struct { handshakeTimeout time.Duration - - muxKeepAliveDisabled bool - muxKeepAliveInterval time.Duration - muxKeepAliveTimeout time.Duration - muxMaxFrameSize int - muxMaxReceiveBuffer int - muxMaxStreamBuffer int + muxCfg *mux.Config } func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) { - const ( - handshakeTimeout = "handshakeTimeout" - - muxKeepAliveDisabled = "muxKeepAliveDisabled" - muxKeepAliveInterval = "muxKeepAliveInterval" - muxKeepAliveTimeout = "muxKeepAliveTimeout" - muxMaxFrameSize = "muxMaxFrameSize" - muxMaxReceiveBuffer = "muxMaxReceiveBuffer" - muxMaxStreamBuffer = "muxMaxStreamBuffer" - ) - - d.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) - - d.md.muxKeepAliveDisabled = mdutil.GetBool(md, muxKeepAliveDisabled) - d.md.muxKeepAliveInterval = mdutil.GetDuration(md, muxKeepAliveInterval) - d.md.muxKeepAliveTimeout = mdutil.GetDuration(md, muxKeepAliveTimeout) - d.md.muxMaxFrameSize = mdutil.GetInt(md, muxMaxFrameSize) - d.md.muxMaxReceiveBuffer = mdutil.GetInt(md, muxMaxReceiveBuffer) - d.md.muxMaxStreamBuffer = mdutil.GetInt(md, muxMaxStreamBuffer) + d.md.handshakeTimeout = mdutil.GetDuration(md, "handshakeTimeout") + d.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } return } diff --git a/dialer/mws/conn.go b/dialer/mws/conn.go index 7c5dbe5..6b76ca5 100644 --- a/dialer/mws/conn.go +++ b/dialer/mws/conn.go @@ -3,20 +3,20 @@ package mws import ( "net" - "github.com/xtaci/smux" + "github.com/go-gost/x/internal/util/mux" ) type muxSession struct { conn net.Conn - session *smux.Session + session *mux.Session } func (session *muxSession) GetConn() (net.Conn, error) { - return session.session.OpenStream() + return session.session.GetConn() } func (session *muxSession) Accept() (net.Conn, error) { - return session.session.AcceptStream() + return session.session.Accept() } func (session *muxSession) Close() error { diff --git a/dialer/mws/dialer.go b/dialer/mws/dialer.go index 063f337..08e526e 100644 --- a/dialer/mws/dialer.go +++ b/dialer/mws/dialer.go @@ -10,10 +10,10 @@ import ( "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/internal/util/mux" ws_util "github.com/go-gost/x/internal/util/ws" "github.com/go-gost/x/registry" "github.com/gorilla/websocket" - "github.com/xtaci/smux" ) func init() { @@ -178,25 +178,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) } // stream multiplex - smuxConfig := smux.DefaultConfig() - smuxConfig.KeepAliveDisabled = d.md.muxKeepAliveDisabled - if d.md.muxKeepAliveInterval > 0 { - smuxConfig.KeepAliveInterval = d.md.muxKeepAliveInterval - } - if d.md.muxKeepAliveTimeout > 0 { - smuxConfig.KeepAliveTimeout = d.md.muxKeepAliveTimeout - } - if d.md.muxMaxFrameSize > 0 { - smuxConfig.MaxFrameSize = d.md.muxMaxFrameSize - } - if d.md.muxMaxReceiveBuffer > 0 { - smuxConfig.MaxReceiveBuffer = d.md.muxMaxReceiveBuffer - } - if d.md.muxMaxStreamBuffer > 0 { - smuxConfig.MaxStreamBuffer = d.md.muxMaxStreamBuffer - } - - session, err := smux.Client(cc, smuxConfig) + session, err := mux.ClientSession(conn, d.md.muxCfg) if err != nil { return nil, err } diff --git a/dialer/mws/metadata.go b/dialer/mws/metadata.go index 16faec5..7dc2c29 100644 --- a/dialer/mws/metadata.go +++ b/dialer/mws/metadata.go @@ -6,6 +6,7 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/x/internal/util/mux" ) const ( @@ -23,15 +24,9 @@ type metadata struct { writeBufferSize int enableCompression bool - muxKeepAliveDisabled bool - muxKeepAliveInterval time.Duration - muxKeepAliveTimeout time.Duration - muxMaxFrameSize int - muxMaxReceiveBuffer int - muxMaxStreamBuffer int - header http.Header keepaliveInterval time.Duration + muxCfg *mux.Config } func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -46,13 +41,6 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { enableCompression = "enableCompression" header = "header" - - muxKeepAliveDisabled = "muxKeepAliveDisabled" - muxKeepAliveInterval = "muxKeepAliveInterval" - muxKeepAliveTimeout = "muxKeepAliveTimeout" - muxMaxFrameSize = "muxMaxFrameSize" - muxMaxReceiveBuffer = "muxMaxReceiveBuffer" - muxMaxStreamBuffer = "muxMaxStreamBuffer" ) d.md.host = mdutil.GetString(md, host) @@ -62,12 +50,15 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.path = defaultPath } - d.md.muxKeepAliveDisabled = mdutil.GetBool(md, muxKeepAliveDisabled) - d.md.muxKeepAliveInterval = mdutil.GetDuration(md, muxKeepAliveInterval) - d.md.muxKeepAliveTimeout = mdutil.GetDuration(md, muxKeepAliveTimeout) - d.md.muxMaxFrameSize = mdutil.GetInt(md, muxMaxFrameSize) - d.md.muxMaxReceiveBuffer = mdutil.GetInt(md, muxMaxReceiveBuffer) - d.md.muxMaxStreamBuffer = mdutil.GetInt(md, muxMaxStreamBuffer) + d.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } d.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) d.md.readHeaderTimeout = mdutil.GetDuration(md, readHeaderTimeout) diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index c8fa60e..3e205c8 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -271,6 +271,14 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l return resp.Write(rw) } + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.CopyBuffer(cc, br, 8192) + if err == nil { + err = io.EOF + } + return err + } + res, err := http.ReadResponse(bufio.NewReader(cc), req) if err != nil { log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 1191a23..84589a2 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -276,6 +276,14 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot return resp.Write(rw) } + if req.Header.Get("Upgrade") == "websocket" { + err := xnet.CopyBuffer(cc, br, 8192) + if err == nil { + err = io.EOF + } + return err + } + res, err := http.ReadResponse(bufio.NewReader(cc), req) if err != nil { log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) diff --git a/internal/util/forward/forward.go b/internal/util/forward/forward.go index fe459c6..42a4b93 100644 --- a/internal/util/forward/forward.go +++ b/internal/util/forward/forward.go @@ -38,10 +38,6 @@ func Sniffing(ctx context.Context, rdw io.ReadWriter) (rw io.ReadWriter, host st if err == nil { host = r.Host protocol = ProtoHTTP - - if r.Header.Get("Upgrade") == "websocket" { - protocol = ProtoWebsocket - } return } } @@ -93,10 +89,9 @@ func isHTTP(s string) bool { } const ( - ProtoHTTP = "http" - ProtoWebsocket = "ws" - ProtoTLS = "tls" - ProtoSSHv2 = "SSH-2" + ProtoHTTP = "http" + ProtoTLS = "tls" + ProtoSSHv2 = "SSH-2" ) func sniffProtocol(hdr []byte) string { diff --git a/listener/mtls/listener.go b/listener/mtls/listener.go index 6cdd920..7e7061a 100644 --- a/listener/mtls/listener.go +++ b/listener/mtls/listener.go @@ -12,11 +12,11 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" + "github.com/go-gost/x/internal/util/mux" climiter "github.com/go-gost/x/limiter/conn/wrapper" limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" - "github.com/xtaci/smux" ) func init() { @@ -105,24 +105,7 @@ func (l *mtlsListener) listenLoop() { func (l *mtlsListener) mux(conn net.Conn) { defer conn.Close() - smuxConfig := smux.DefaultConfig() - smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled - if l.md.muxKeepAliveInterval > 0 { - smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval - } - if l.md.muxKeepAliveTimeout > 0 { - smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout - } - if l.md.muxMaxFrameSize > 0 { - smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize - } - if l.md.muxMaxReceiveBuffer > 0 { - smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer - } - if l.md.muxMaxStreamBuffer > 0 { - smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer - } - session, err := smux.Server(conn, smuxConfig) + session, err := mux.ServerSession(conn, l.md.muxCfg) if err != nil { l.logger.Error(err) return @@ -130,7 +113,7 @@ func (l *mtlsListener) mux(conn net.Conn) { defer session.Close() for { - stream, err := session.AcceptStream() + stream, err := session.Accept() if err != nil { l.logger.Error("accept stream: ", err) return @@ -138,8 +121,6 @@ func (l *mtlsListener) mux(conn net.Conn) { select { case l.cqueue <- stream: - case <-stream.GetDieCh(): - stream.Close() default: stream.Close() l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr()) diff --git a/listener/mtls/metadata.go b/listener/mtls/metadata.go index e3b7a7a..fa36de3 100644 --- a/listener/mtls/metadata.go +++ b/listener/mtls/metadata.go @@ -1,10 +1,9 @@ package mtls import ( - "time" - mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/x/internal/util/mux" ) const ( @@ -12,41 +11,26 @@ const ( ) type metadata struct { - muxKeepAliveDisabled bool - muxKeepAliveInterval time.Duration - muxKeepAliveTimeout time.Duration - muxMaxFrameSize int - muxMaxReceiveBuffer int - muxMaxStreamBuffer int - + muxCfg *mux.Config backlog int mptcp bool } func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) { - const ( - backlog = "backlog" - - muxKeepAliveDisabled = "muxKeepAliveDisabled" - muxKeepAliveInterval = "muxKeepAliveInterval" - muxKeepAliveTimeout = "muxKeepAliveTimeout" - muxMaxFrameSize = "muxMaxFrameSize" - muxMaxReceiveBuffer = "muxMaxReceiveBuffer" - muxMaxStreamBuffer = "muxMaxStreamBuffer" - ) - - l.md.backlog = mdutil.GetInt(md, backlog) + l.md.backlog = mdutil.GetInt(md, "backlog") if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.muxKeepAliveDisabled = mdutil.GetBool(md, muxKeepAliveDisabled) - l.md.muxKeepAliveInterval = mdutil.GetDuration(md, muxKeepAliveInterval) - l.md.muxKeepAliveTimeout = mdutil.GetDuration(md, muxKeepAliveTimeout) - l.md.muxMaxFrameSize = mdutil.GetInt(md, muxMaxFrameSize) - l.md.muxMaxReceiveBuffer = mdutil.GetInt(md, muxMaxReceiveBuffer) - l.md.muxMaxStreamBuffer = mdutil.GetInt(md, muxMaxStreamBuffer) - + l.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } l.md.mptcp = mdutil.GetBool(md, "mptcp") return diff --git a/listener/mws/listener.go b/listener/mws/listener.go index 14b2780..120364a 100644 --- a/listener/mws/listener.go +++ b/listener/mws/listener.go @@ -14,13 +14,13 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" + "github.com/go-gost/x/internal/util/mux" ws_util "github.com/go-gost/x/internal/util/ws" climiter "github.com/go-gost/x/limiter/conn/wrapper" limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/gorilla/websocket" - "github.com/xtaci/smux" ) func init() { @@ -170,24 +170,7 @@ func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) { func (l *mwsListener) mux(conn net.Conn) { defer conn.Close() - smuxConfig := smux.DefaultConfig() - smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled - if l.md.muxKeepAliveInterval > 0 { - smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval - } - if l.md.muxKeepAliveTimeout > 0 { - smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout - } - if l.md.muxMaxFrameSize > 0 { - smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize - } - if l.md.muxMaxReceiveBuffer > 0 { - smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer - } - if l.md.muxMaxStreamBuffer > 0 { - smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer - } - session, err := smux.Server(conn, smuxConfig) + session, err := mux.ServerSession(conn, l.md.muxCfg) if err != nil { l.logger.Error(err) return @@ -195,7 +178,7 @@ func (l *mwsListener) mux(conn net.Conn) { defer session.Close() for { - stream, err := session.AcceptStream() + stream, err := session.Accept() if err != nil { l.logger.Error("accept stream: ", err) return @@ -203,8 +186,6 @@ func (l *mwsListener) mux(conn net.Conn) { select { case l.cqueue <- stream: - case <-stream.GetDieCh(): - stream.Close() default: stream.Close() l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr()) diff --git a/listener/mws/metadata.go b/listener/mws/metadata.go index 402941c..74bd365 100644 --- a/listener/mws/metadata.go +++ b/listener/mws/metadata.go @@ -6,6 +6,7 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/x/internal/util/mux" ) const ( @@ -24,12 +25,7 @@ type metadata struct { writeBufferSize int enableCompression bool - muxKeepAliveDisabled bool - muxKeepAliveInterval time.Duration - muxKeepAliveTimeout time.Duration - muxMaxFrameSize int - muxMaxReceiveBuffer int - muxMaxStreamBuffer int + muxCfg *mux.Config mptcp bool } @@ -70,12 +66,15 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { l.md.writeBufferSize = mdutil.GetInt(md, writeBufferSize) l.md.enableCompression = mdutil.GetBool(md, enableCompression) - l.md.muxKeepAliveDisabled = mdutil.GetBool(md, muxKeepAliveDisabled) - l.md.muxKeepAliveInterval = mdutil.GetDuration(md, muxKeepAliveInterval) - l.md.muxKeepAliveTimeout = mdutil.GetDuration(md, muxKeepAliveTimeout) - l.md.muxMaxFrameSize = mdutil.GetInt(md, muxMaxFrameSize) - l.md.muxMaxReceiveBuffer = mdutil.GetInt(md, muxMaxReceiveBuffer) - l.md.muxMaxStreamBuffer = mdutil.GetInt(md, muxMaxStreamBuffer) + l.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } if mm := mdutil.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{}