add keepalive for websocket

This commit is contained in:
ginuerzh
2022-03-03 22:21:38 +08:00
parent b96d37d4cc
commit 8d8785f534
7 changed files with 61 additions and 31 deletions

View File

@ -8,7 +8,6 @@ import (
"github.com/go-gost/gost/pkg/dialer"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/gorilla/websocket"
@ -21,7 +20,6 @@ func init() {
type wsDialer struct {
tlsEnabled bool
logger logger.Logger
md metadata
options dialer.Options
}
@ -33,7 +31,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
}
return &wsDialer{
logger: options.Logger,
options: options,
}
}
@ -46,7 +43,6 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
return &wsDialer{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
@ -63,7 +59,7 @@ func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOpt
conn, err := options.NetDialer.Dial(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
d.options.Logger.Error(err)
}
return conn, err
}
@ -101,11 +97,34 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)
c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header)
if err != nil {
return nil, err
}
resp.Body.Close()
if d.md.keepAlive > 0 {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
d.options.Logger.Infof("pong: set read deadline: %v", d.md.keepAlive*2)
return nil
})
go d.keepAlive(c)
}
return ws_util.Conn(c), nil
}
func (d *wsDialer) keepAlive(conn *websocket.Conn) {
ticker := time.NewTicker(d.md.keepAlive)
defer ticker.Stop()
for range ticker.C {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
d.options.Logger.Infof("send ping")
}
}