add keepalive for websocket

This commit is contained in:
ginuerzh
2022-03-03 22:21:38 +08:00
parent b96d37d4cc
commit 8d8785f534
7 changed files with 61 additions and 31 deletions

View File

@ -8,7 +8,6 @@ import (
"github.com/go-gost/gost/pkg/dialer"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/gorilla/websocket"
@ -21,7 +20,6 @@ func init() {
type wsDialer struct {
tlsEnabled bool
logger logger.Logger
md metadata
options dialer.Options
}
@ -33,7 +31,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
}
return &wsDialer{
logger: options.Logger,
options: options,
}
}
@ -46,7 +43,6 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
return &wsDialer{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
@ -63,7 +59,7 @@ func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOpt
conn, err := options.NetDialer.Dial(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
d.options.Logger.Error(err)
}
return conn, err
}
@ -101,11 +97,34 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)
c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header)
if err != nil {
return nil, err
}
resp.Body.Close()
if d.md.keepAlive > 0 {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
d.options.Logger.Infof("pong: set read deadline: %v", d.md.keepAlive*2)
return nil
})
go d.keepAlive(c)
}
return ws_util.Conn(c), nil
}
func (d *wsDialer) keepAlive(conn *websocket.Conn) {
ticker := time.NewTicker(d.md.keepAlive)
defer ticker.Stop()
for range ticker.C {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
d.options.Logger.Infof("send ping")
}
}

View File

@ -21,7 +21,8 @@ type metadata struct {
writeBufferSize int
enableCompression bool
header http.Header
header http.Header
keepAlive time.Duration
}
func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
@ -35,7 +36,8 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
header = "header"
header = "header"
keepAlive = "keepAlive"
)
d.md.host = mdata.GetString(md, host)
@ -58,6 +60,7 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
}
d.md.header = h
}
d.md.keepAlive = mdata.GetDuration(md, keepAlive)
return
}

View File

@ -10,7 +10,6 @@ import (
"github.com/go-gost/gost/pkg/dialer"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/gorilla/websocket"
@ -26,7 +25,6 @@ type mwsDialer struct {
sessions map[string]*muxSession
sessionMutex sync.Mutex
tlsEnabled bool
logger logger.Logger
md metadata
options dialer.Options
}
@ -39,7 +37,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
options: options,
}
}
@ -53,7 +50,6 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
return &mwsDialer{
tlsEnabled: true,
sessions: make(map[string]*muxSession),
logger: options.Logger,
options: options,
}
}
@ -125,7 +121,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
}
s, err := d.initSession(ctx, host, conn)
if err != nil {
d.logger.Error(err)
d.options.Logger.Error(err)
conn.Close()
delete(d.sessions, opts.Addr)
return nil, err
@ -160,12 +156,21 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)
c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header)
if err != nil {
return nil, err
}
resp.Body.Close()
if d.md.keepAlive > 0 {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
return nil
})
go d.keepAlive(c)
}
conn = ws_util.Conn(c)
// stream multiplex
@ -193,3 +198,15 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
}
return &muxSession{conn: conn, session: session}, nil
}
func (d *mwsDialer) keepAlive(conn *websocket.Conn) {
ticker := time.NewTicker(d.md.keepAlive)
defer ticker.Stop()
for range ticker.C {
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}

View File

@ -28,7 +28,8 @@ type metadata struct {
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
header http.Header
header http.Header
keepAlive time.Duration
}
func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
@ -42,7 +43,8 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
header = "header"
header = "header"
keepAlive = "keepAlive"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAliveInterval = "muxKeepAliveInterval"
@ -79,5 +81,7 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
}
d.md.header = h
}
d.md.keepAlive = mdata.GetDuration(md, keepAlive)
return
}