add keepalive for websocket
This commit is contained in:
@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/go-gost/gost/pkg/dialer"
|
||||
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -26,7 +25,6 @@ type mwsDialer struct {
|
||||
sessions map[string]*muxSession
|
||||
sessionMutex sync.Mutex
|
||||
tlsEnabled bool
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options dialer.Options
|
||||
}
|
||||
@ -39,7 +37,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
|
||||
return &mwsDialer{
|
||||
sessions: make(map[string]*muxSession),
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -53,7 +50,6 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
return &mwsDialer{
|
||||
tlsEnabled: true,
|
||||
sessions: make(map[string]*muxSession),
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -125,7 +121,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
||||
}
|
||||
s, err := d.initSession(ctx, host, conn)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
d.options.Logger.Error(err)
|
||||
conn.Close()
|
||||
delete(d.sessions, opts.Addr)
|
||||
return nil, err
|
||||
@ -160,12 +156,21 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
||||
dialer.TLSClientConfig = d.options.TLSConfig
|
||||
}
|
||||
|
||||
c, resp, err := dialer.Dial(url.String(), d.md.header)
|
||||
c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if d.md.keepAlive > 0 {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
c.SetPongHandler(func(string) error {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
return nil
|
||||
})
|
||||
go d.keepAlive(c)
|
||||
}
|
||||
|
||||
conn = ws_util.Conn(c)
|
||||
|
||||
// stream multiplex
|
||||
@ -193,3 +198,15 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
||||
}
|
||||
return &muxSession{conn: conn, session: session}, nil
|
||||
}
|
||||
|
||||
func (d *mwsDialer) keepAlive(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(d.md.keepAlive)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -28,7 +28,8 @@ type metadata struct {
|
||||
muxMaxReceiveBuffer int
|
||||
muxMaxStreamBuffer int
|
||||
|
||||
header http.Header
|
||||
header http.Header
|
||||
keepAlive time.Duration
|
||||
}
|
||||
|
||||
func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
@ -42,7 +43,8 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
|
||||
header = "header"
|
||||
header = "header"
|
||||
keepAlive = "keepAlive"
|
||||
|
||||
muxKeepAliveDisabled = "muxKeepAliveDisabled"
|
||||
muxKeepAliveInterval = "muxKeepAliveInterval"
|
||||
@ -79,5 +81,7 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
}
|
||||
d.md.header = h
|
||||
}
|
||||
d.md.keepAlive = mdata.GetDuration(md, keepAlive)
|
||||
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user