From fb7b827ea21a858cb1123e7510390ade1f15f50a Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 21 Mar 2023 17:58:10 +0800 Subject: [PATCH] fix keepalive for websocket --- dialer/mws/dialer.go | 11 +++++++---- dialer/mws/metadata.go | 18 ++++++++++++------ dialer/ws/dialer.go | 5 ++++- dialer/ws/metadata.go | 4 ++-- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/dialer/mws/dialer.go b/dialer/mws/dialer.go index 4bf1f79..063f337 100644 --- a/dialer/mws/dialer.go +++ b/dialer/mws/dialer.go @@ -167,10 +167,11 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) cc := ws_util.Conn(c) - if d.md.keepAlive > 0 { - c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2)) + if d.md.keepaliveInterval > 0 { + d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval) + c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) c.SetPongHandler(func(string) error { - c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2)) + c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) return nil }) go d.keepAlive(cc) @@ -203,13 +204,15 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) } func (d *mwsDialer) keepAlive(conn ws_util.WebsocketConn) { - ticker := time.NewTicker(d.md.keepAlive) + ticker := time.NewTicker(d.md.keepaliveInterval) defer ticker.Stop() for range ticker.C { + d.options.Logger.Debug("send ping") conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } + conn.SetWriteDeadline(time.Time{}) } } diff --git a/dialer/mws/metadata.go b/dialer/mws/metadata.go index c5044c7..16faec5 100644 --- a/dialer/mws/metadata.go +++ b/dialer/mws/metadata.go @@ -9,7 +9,8 @@ import ( ) const ( - defaultPath = "/ws" + defaultPath = "/ws" + defaultKeepalivePeriod = 15 * time.Second ) type metadata struct { @@ -29,8 +30,8 @@ type metadata struct { muxMaxReceiveBuffer int muxMaxStreamBuffer int - header http.Header - keepAlive time.Duration + header http.Header + keepaliveInterval time.Duration } func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -44,8 +45,7 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { writeBufferSize = "writeBufferSize" enableCompression = "enableCompression" - header = "header" - keepAlive = "keepAlive" + header = "header" muxKeepAliveDisabled = "muxKeepAliveDisabled" muxKeepAliveInterval = "muxKeepAliveInterval" @@ -82,7 +82,13 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { } d.md.header = h } - d.md.keepAlive = mdutil.GetDuration(md, keepAlive) + + if mdutil.GetBool(md, "keepalive") { + d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval") + if d.md.keepaliveInterval <= 0 { + d.md.keepaliveInterval = defaultKeepalivePeriod + } + } return } diff --git a/dialer/ws/dialer.go b/dialer/ws/dialer.go index 7f66cf9..ccc5bc5 100644 --- a/dialer/ws/dialer.go +++ b/dialer/ws/dialer.go @@ -106,6 +106,7 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial cc := ws_util.Conn(c) if d.md.keepaliveInterval > 0 { + d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval) c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) @@ -123,10 +124,12 @@ func (d *wsDialer) keepalive(conn ws_util.WebsocketConn) { defer ticker.Stop() for range ticker.C { + d.options.Logger.Debug("send ping") conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + // d.options.Logger.Error(err) return } - d.options.Logger.Debug("send ping") + conn.SetWriteDeadline(time.Time{}) } } diff --git a/dialer/ws/metadata.go b/dialer/ws/metadata.go index eaf1108..2f23eac 100644 --- a/dialer/ws/metadata.go +++ b/dialer/ws/metadata.go @@ -10,7 +10,7 @@ import ( const ( defaultPath = "/ws" - defaultKeepAlivePeriod = 15 * time.Second + defaultKeepalivePeriod = 15 * time.Second ) type metadata struct { @@ -65,7 +65,7 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { if mdutil.GetBool(md, "keepalive") { d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval") if d.md.keepaliveInterval <= 0 { - d.md.keepaliveInterval = defaultKeepAlivePeriod + d.md.keepaliveInterval = defaultKeepalivePeriod } }