add keepalive for websocket
This commit is contained in:
@ -39,7 +39,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, e
|
||||
if d.DialFunc != nil {
|
||||
return d.DialFunc(ctx, network, addr)
|
||||
}
|
||||
logger.Default().Infof("interface: %s %s/%s", ifceName, ifAddr, network)
|
||||
logger.Default().Infof("interface: %s %v/%s", ifceName, ifAddr, network)
|
||||
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/go-gost/gost/pkg/dialer"
|
||||
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -21,7 +20,6 @@ func init() {
|
||||
|
||||
type wsDialer struct {
|
||||
tlsEnabled bool
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options dialer.Options
|
||||
}
|
||||
@ -33,7 +31,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
}
|
||||
|
||||
return &wsDialer{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -46,7 +43,6 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
|
||||
return &wsDialer{
|
||||
tlsEnabled: true,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -63,7 +59,7 @@ func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOpt
|
||||
|
||||
conn, err := options.NetDialer.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
d.options.Logger.Error(err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
@ -101,11 +97,34 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
|
||||
dialer.TLSClientConfig = d.options.TLSConfig
|
||||
}
|
||||
|
||||
c, resp, err := dialer.Dial(url.String(), d.md.header)
|
||||
c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if d.md.keepAlive > 0 {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
c.SetPongHandler(func(string) error {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
d.options.Logger.Infof("pong: set read deadline: %v", d.md.keepAlive*2)
|
||||
return nil
|
||||
})
|
||||
go d.keepAlive(c)
|
||||
}
|
||||
|
||||
return ws_util.Conn(c), nil
|
||||
}
|
||||
|
||||
func (d *wsDialer) keepAlive(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(d.md.keepAlive)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
d.options.Logger.Infof("send ping")
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,8 @@ type metadata struct {
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
|
||||
header http.Header
|
||||
header http.Header
|
||||
keepAlive time.Duration
|
||||
}
|
||||
|
||||
func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
@ -35,7 +36,8 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
|
||||
header = "header"
|
||||
header = "header"
|
||||
keepAlive = "keepAlive"
|
||||
)
|
||||
|
||||
d.md.host = mdata.GetString(md, host)
|
||||
@ -58,6 +60,7 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
}
|
||||
d.md.header = h
|
||||
}
|
||||
d.md.keepAlive = mdata.GetDuration(md, keepAlive)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/go-gost/gost/pkg/dialer"
|
||||
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -26,7 +25,6 @@ type mwsDialer struct {
|
||||
sessions map[string]*muxSession
|
||||
sessionMutex sync.Mutex
|
||||
tlsEnabled bool
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options dialer.Options
|
||||
}
|
||||
@ -39,7 +37,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
|
||||
return &mwsDialer{
|
||||
sessions: make(map[string]*muxSession),
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -53,7 +50,6 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
return &mwsDialer{
|
||||
tlsEnabled: true,
|
||||
sessions: make(map[string]*muxSession),
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
@ -125,7 +121,7 @@ func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
|
||||
}
|
||||
s, err := d.initSession(ctx, host, conn)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
d.options.Logger.Error(err)
|
||||
conn.Close()
|
||||
delete(d.sessions, opts.Addr)
|
||||
return nil, err
|
||||
@ -160,12 +156,21 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
||||
dialer.TLSClientConfig = d.options.TLSConfig
|
||||
}
|
||||
|
||||
c, resp, err := dialer.Dial(url.String(), d.md.header)
|
||||
c, resp, err := dialer.DialContext(ctx, url.String(), d.md.header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if d.md.keepAlive > 0 {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
c.SetPongHandler(func(string) error {
|
||||
c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2))
|
||||
return nil
|
||||
})
|
||||
go d.keepAlive(c)
|
||||
}
|
||||
|
||||
conn = ws_util.Conn(c)
|
||||
|
||||
// stream multiplex
|
||||
@ -193,3 +198,15 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
|
||||
}
|
||||
return &muxSession{conn: conn, session: session}, nil
|
||||
}
|
||||
|
||||
func (d *mwsDialer) keepAlive(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(d.md.keepAlive)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -28,7 +28,8 @@ type metadata struct {
|
||||
muxMaxReceiveBuffer int
|
||||
muxMaxStreamBuffer int
|
||||
|
||||
header http.Header
|
||||
header http.Header
|
||||
keepAlive time.Duration
|
||||
}
|
||||
|
||||
func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
@ -42,7 +43,8 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
writeBufferSize = "writeBufferSize"
|
||||
enableCompression = "enableCompression"
|
||||
|
||||
header = "header"
|
||||
header = "header"
|
||||
keepAlive = "keepAlive"
|
||||
|
||||
muxKeepAliveDisabled = "muxKeepAliveDisabled"
|
||||
muxKeepAliveInterval = "muxKeepAliveInterval"
|
||||
@ -79,5 +81,7 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
}
|
||||
d.md.header = h
|
||||
}
|
||||
d.md.keepAlive = mdata.GetDuration(md, keepAlive)
|
||||
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user