add keepalive for websocket
This commit is contained in:
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/go-gost/gost/pkg/dialer"
|
||||
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -21,7 +20,6 @@ func init() {
|
||||
|
||||
type wsDialer struct {
|
||||
tlsEnabled bool
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options dialer.Options
|
||||
}
|
||||
@ -33,7 +31,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
}
|
||||
|
||||
return &wsDialer{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -46,7 +43,6 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
|
||||
return &wsDialer{
|
||||
tlsEnabled: true,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -63,7 +59,7 @@ func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOpt
|
||||
|
||||
conn, err := options.NetDialer.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
d.options.Logger.Error(err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
@ -101,11 +97,34 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
|
||||
dialer.TLSClientConfig = d.options.TLSConfig
|
||||
}
|
||||
|
||||
c, resp, err := dialer.Dial(url.String(), d.md.header)
|
||||
c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if d.md.keepAlive > 0 {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
c.SetPongHandler(func(string) error {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
d.options.Logger.Infof("pong: set read deadline: %v", d.md.keepAlive*2)
|
||||
return nil
|
||||
})
|
||||
go d.keepAlive(c)
|
||||
}
|
||||
|
||||
return ws_util.Conn(c), nil
|
||||
}
|
||||
|
||||
func (d *wsDialer) keepAlive(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(d.md.keepAlive)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
d.options.Logger.Infof("send ping")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user