From 7a21c7eb6fe775f01130dccdf215c4cf0ae57b5c Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 3 Mar 2023 13:06:20 +0800 Subject: [PATCH] add keepalive for ssh --- dialer/ssh/conn.go | 31 --------- dialer/ssh/dialer.go | 82 +++++------------------ dialer/ssh/metadata.go | 21 ++++-- dialer/sshd/conn.go | 31 --------- dialer/sshd/dialer.go | 85 ++++++----------------- dialer/sshd/metadata.go | 13 +++- dialer/ws/dialer.go | 14 ++-- dialer/ws/metadata.go | 16 ++--- go.mod | 2 +- go.sum | 3 +- internal/util/ssh/conn.go | 6 +- internal/util/ssh/session.go | 126 +++++++++++++++++++++++++++++++++++ listener/sshd/listener.go | 2 + 13 files changed, 213 insertions(+), 219 deletions(-) delete mode 100644 dialer/ssh/conn.go delete mode 100644 dialer/sshd/conn.go create mode 100644 internal/util/ssh/session.go diff --git a/dialer/ssh/conn.go b/dialer/ssh/conn.go deleted file mode 100644 index fb14c6f..0000000 --- a/dialer/ssh/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -package ssh - -import ( - "net" - - "golang.org/x/crypto/ssh" -) - -type sshSession struct { - addr string - conn net.Conn - client *ssh.Client - closed chan struct{} - dead chan struct{} -} - -func (s *sshSession) IsClosed() bool { - select { - case <-s.dead: - return true - case <-s.closed: - return true - default: - } - return false -} - -func (s *sshSession) wait() error { - defer close(s.closed) - return s.client.Wait() -} diff --git a/dialer/ssh/dialer.go b/dialer/ssh/dialer.go index f91c788..509c3cb 100644 --- a/dialer/ssh/dialer.go +++ b/dialer/ssh/dialer.go @@ -2,13 +2,11 @@ package ssh import ( "context" - "errors" "net" "sync" "time" "github.com/go-gost/core/dialer" - "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" ssh_util "github.com/go-gost/x/internal/util/ssh" "github.com/go-gost/x/registry" @@ -20,9 +18,8 @@ func init() { } type sshDialer struct { - sessions map[string]*sshSession + sessions map[string]*ssh_util.Session sessionMutex sync.Mutex - logger logger.Logger md metadata options dialer.Options } @@ -34,8 +31,7 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } return &sshDialer{ - sessions: make(map[string]*sshSession), - logger: options.Logger, + sessions: make(map[string]*ssh_util.Session), options: options, } } @@ -72,73 +68,36 @@ func (d *sshDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp if err != nil { return } - - session = &sshSession{ - addr: addr, - conn: conn, + if d.md.handshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetDeadline(time.Time{}) } - d.sessions[addr] = session - } - return session.conn, err -} - -// Handshake implements dialer.Handshaker -func (d *sshDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { - opts := &dialer.HandshakeOptions{} - for _, option := range options { - option(opts) - } - - d.sessionMutex.Lock() - defer d.sessionMutex.Unlock() - - if d.md.handshakeTimeout > 0 { - conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) - defer conn.SetDeadline(time.Time{}) - } - - session, ok := d.sessions[opts.Addr] - if session != nil && session.conn != conn { - err := errors.New("ssh: unrecognized connection") - d.logger.Error(err) - conn.Close() - delete(d.sessions, opts.Addr) - return nil, err - } - - if !ok || session.client == nil { - s, err := d.initSession(ctx, opts.Addr, conn) + session, err = d.initSession(ctx, addr, conn) if err != nil { - d.logger.Error(err) conn.Close() - delete(d.sessions, opts.Addr) return nil, err } - session = s - go func() { - s.wait() - d.logger.Debug("session closed") - }() - d.sessions[opts.Addr] = session - } - if session.IsClosed() { - delete(d.sessions, opts.Addr) - return nil, ssh_util.ErrSessionDead - } + if d.md.keepalive { + go session.Keepalive(d.md.keepaliveInterval, d.md.keepaliveTimeout, d.md.keepaliveRetries) + } + go session.Wait() + go session.WaitClose() - channel, reqs, err := session.client.OpenChannel(ssh_util.GostSSHTunnelRequest, nil) + d.sessions[addr] = session + } + channel, reqs, err := session.OpenChannel(ssh_util.GostSSHTunnelRequest) if err != nil { return nil, err } go ssh.DiscardRequests(reqs) - return ssh_util.NewConn(conn, channel), nil + return ssh_util.NewConn(session, channel), nil } -func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) { +func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*ssh_util.Session, error) { config := ssh.ClientConfig{ - Timeout: 30 * time.Second, + Timeout: d.md.handshakeTimeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } if d.options.Auth != nil { @@ -158,10 +117,5 @@ func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) return nil, err } - return &sshSession{ - conn: conn, - client: ssh.NewClient(sshConn, chans, reqs), - closed: make(chan struct{}), - dead: make(chan struct{}), - }, nil + return ssh_util.NewSession(conn, ssh.NewClient(sshConn, chans, reqs), d.options.Logger), nil } diff --git a/dialer/ssh/metadata.go b/dialer/ssh/metadata.go index 725fd3b..3ed5fca 100644 --- a/dialer/ssh/metadata.go +++ b/dialer/ssh/metadata.go @@ -10,8 +10,12 @@ import ( ) type metadata struct { - handshakeTimeout time.Duration - signer ssh.Signer + handshakeTimeout time.Duration + signer ssh.Signer + keepalive bool + keepaliveInterval time.Duration + keepaliveTimeout time.Duration + keepaliveRetries int } func (d *sshDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -27,11 +31,10 @@ func (d *sshDialer) parseMetadata(md mdata.Metadata) (err error) { return err } - pp := mdutil.GetString(md, passphrase) - if pp == "" { - d.md.signer, err = ssh.ParsePrivateKey(data) - } else { + if pp := mdutil.GetString(md, passphrase); pp != "" { d.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp)) + } else { + d.md.signer, err = ssh.ParsePrivateKey(data) } if err != nil { return err @@ -40,5 +43,11 @@ func (d *sshDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) + if d.md.keepalive = mdutil.GetBool(md, "keepalive"); d.md.keepalive { + d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval") + d.md.keepaliveTimeout = mdutil.GetDuration(md, "keepalive.timeout") + d.md.keepaliveRetries = mdutil.GetInt(md, "keepalive.retries") + } + return } diff --git a/dialer/sshd/conn.go b/dialer/sshd/conn.go deleted file mode 100644 index 120ffb5..0000000 --- a/dialer/sshd/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -package sshd - -import ( - "net" - - "golang.org/x/crypto/ssh" -) - -type sshSession struct { - addr string - conn net.Conn - client *ssh.Client - closed chan struct{} - dead chan struct{} -} - -func (s *sshSession) IsClosed() bool { - select { - case <-s.dead: - return true - case <-s.closed: - return true - default: - } - return false -} - -func (s *sshSession) wait() error { - defer close(s.closed) - return s.client.Wait() -} diff --git a/dialer/sshd/dialer.go b/dialer/sshd/dialer.go index 2282007..c09c36b 100644 --- a/dialer/sshd/dialer.go +++ b/dialer/sshd/dialer.go @@ -2,7 +2,6 @@ package sshd import ( "context" - "errors" "net" "sync" "time" @@ -19,7 +18,7 @@ func init() { } type sshdDialer struct { - sessions map[string]*sshSession + sessions map[string]*ssh_util.Session sessionMutex sync.Mutex md metadata options dialer.Options @@ -32,7 +31,7 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } return &sshdDialer{ - sessions: make(map[string]*sshSession), + sessions: make(map[string]*ssh_util.Session), options: options, } } @@ -70,68 +69,31 @@ func (d *sshdDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO return } - session = &sshSession{ - addr: addr, - conn: conn, + if d.md.handshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetDeadline(time.Time{}) } + + session, err = d.initSession(ctx, addr, conn) + if err != nil { + conn.Close() + return nil, err + } + if d.md.keepalive { + go session.Keepalive(d.md.keepaliveInterval, d.md.keepaliveTimeout, d.md.keepaliveRetries) + } + go session.Wait() + go session.WaitClose() + d.sessions[addr] = session } - return session.conn, err + return ssh_util.NewClientConn(session), nil } -// Handshake implements dialer.Handshaker -func (d *sshdDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { - opts := &dialer.HandshakeOptions{} - for _, option := range options { - option(opts) - } - - d.sessionMutex.Lock() - defer d.sessionMutex.Unlock() - - if d.md.handshakeTimeout > 0 { - conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) - defer conn.SetDeadline(time.Time{}) - } - - log := d.options.Logger - - session, ok := d.sessions[opts.Addr] - if session != nil && session.conn != conn { - err := errors.New("ssh: unrecognized connection") - log.Error(err) - conn.Close() - delete(d.sessions, opts.Addr) - return nil, err - } - - if !ok || session.client == nil { - s, err := d.initSession(ctx, opts.Addr, conn) - if err != nil { - log.Error(err) - conn.Close() - delete(d.sessions, opts.Addr) - return nil, err - } - session = s - go func() { - s.wait() - log.Debug("session closed") - }() - d.sessions[opts.Addr] = session - } - if session.IsClosed() { - delete(d.sessions, opts.Addr) - return nil, ssh_util.ErrSessionDead - } - - return ssh_util.NewClientConn(session.conn, session.client), nil -} - -func (d *sshdDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) { +func (d *sshdDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*ssh_util.Session, error) { config := ssh.ClientConfig{ - // Timeout: timeout, + Timeout: d.md.handshakeTimeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } if d.options.Auth != nil { @@ -151,10 +113,5 @@ func (d *sshdDialer) initSession(ctx context.Context, addr string, conn net.Conn return nil, err } - return &sshSession{ - conn: conn, - client: ssh.NewClient(sshConn, chans, reqs), - closed: make(chan struct{}), - dead: make(chan struct{}), - }, nil + return ssh_util.NewSession(conn, ssh.NewClient(sshConn, chans, reqs), d.options.Logger), nil } diff --git a/dialer/sshd/metadata.go b/dialer/sshd/metadata.go index 77cdd1f..62fef81 100644 --- a/dialer/sshd/metadata.go +++ b/dialer/sshd/metadata.go @@ -10,8 +10,12 @@ import ( ) type metadata struct { - handshakeTimeout time.Duration - signer ssh.Signer + handshakeTimeout time.Duration + signer ssh.Signer + keepalive bool + keepaliveInterval time.Duration + keepaliveTimeout time.Duration + keepaliveRetries int } func (d *sshdDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -40,5 +44,10 @@ func (d *sshdDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) + if d.md.keepalive = mdutil.GetBool(md, "keepalive"); d.md.keepalive { + d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval") + d.md.keepaliveTimeout = mdutil.GetDuration(md, "keepalive.timeout") + d.md.keepaliveRetries = mdutil.GetInt(md, "keepalive.retries") + } return } diff --git a/dialer/ws/dialer.go b/dialer/ws/dialer.go index abfac69..7f66cf9 100644 --- a/dialer/ws/dialer.go +++ b/dialer/ws/dialer.go @@ -105,21 +105,21 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial cc := ws_util.Conn(c) - if d.md.keepAlive > 0 { - c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2)) + if d.md.keepaliveInterval > 0 { + c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) c.SetPongHandler(func(string) error { - c.SetReadDeadline(time.Now().Add(d.md.keepAlive * 2)) - d.options.Logger.Debugf("pong: set read deadline: %v", d.md.keepAlive*2) + c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) + d.options.Logger.Debugf("pong: set read deadline: %v", d.md.keepaliveInterval*2) return nil }) - go d.keepAlive(cc) + go d.keepalive(cc) } return cc, nil } -func (d *wsDialer) keepAlive(conn ws_util.WebsocketConn) { - ticker := time.NewTicker(d.md.keepAlive) +func (d *wsDialer) keepalive(conn ws_util.WebsocketConn) { + ticker := time.NewTicker(d.md.keepaliveInterval) defer ticker.Stop() for range ticker.C { diff --git a/dialer/ws/metadata.go b/dialer/ws/metadata.go index e00ad9f..eaf1108 100644 --- a/dialer/ws/metadata.go +++ b/dialer/ws/metadata.go @@ -23,8 +23,8 @@ type metadata struct { writeBufferSize int enableCompression bool - header http.Header - keepAlive time.Duration + header http.Header + keepaliveInterval time.Duration } func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -38,9 +38,7 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { writeBufferSize = "writeBufferSize" enableCompression = "enableCompression" - header = "header" - keepAlive = "keepAlive" - keepAlivePeriod = "ttl" + header = "header" ) d.md.host = mdutil.GetString(md, host) @@ -64,10 +62,10 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.header = h } - if mdutil.GetBool(md, keepAlive) { - d.md.keepAlive = mdutil.GetDuration(md, keepAlivePeriod) - if d.md.keepAlive <= 0 { - d.md.keepAlive = defaultKeepAlivePeriod + if mdutil.GetBool(md, "keepalive") { + d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval") + if d.md.keepaliveInterval <= 0 { + d.md.keepaliveInterval = defaultKeepAlivePeriod } } diff --git a/go.mod b/go.mod index f9f2e37..2accaca 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/xtaci/smux v1.5.16 github.com/xtaci/tcpraw v1.2.25 github.com/yl2chen/cidranger v1.0.2 - golang.org/x/crypto v0.5.0 + golang.org/x/crypto v0.6.0 golang.org/x/net v0.7.0 golang.org/x/sys v0.5.0 golang.org/x/time v0.3.0 diff --git a/go.sum b/go.sum index e70a44c..b0e2800 100644 --- a/go.sum +++ b/go.sum @@ -424,8 +424,9 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= diff --git a/internal/util/ssh/conn.go b/internal/util/ssh/conn.go index 7cb9382..f09b0e3 100644 --- a/internal/util/ssh/conn.go +++ b/internal/util/ssh/conn.go @@ -12,10 +12,10 @@ type ClientConn struct { client *ssh.Client } -func NewClientConn(conn net.Conn, client *ssh.Client) net.Conn { +func NewClientConn(session *Session) net.Conn { return &ClientConn{ - Conn: conn, - client: client, + Conn: session.Conn, + client: session.client, } } diff --git a/internal/util/ssh/session.go b/internal/util/ssh/session.go new file mode 100644 index 0000000..0922f2b --- /dev/null +++ b/internal/util/ssh/session.go @@ -0,0 +1,126 @@ +package ssh + +import ( + "context" + "net" + "time" + + "github.com/go-gost/core/logger" + "golang.org/x/crypto/ssh" +) + +const ( + defaultKeepaliveInterval = 30 * time.Second + defaultKeepaliveTimeout = 15 * time.Second + defaultkeepaliveRetries = 1 +) + +type Session struct { + net.Conn + client *ssh.Client + closed chan struct{} + dead chan struct{} + log logger.Logger +} + +func NewSession(c net.Conn, client *ssh.Client, log logger.Logger) *Session { + return &Session{ + Conn: c, + client: client, + closed: make(chan struct{}), + dead: make(chan struct{}), + log: log, + } +} + +func (s *Session) OpenChannel(name string) (ssh.Channel, <-chan *ssh.Request, error) { + return s.client.OpenChannel(name, nil) +} + +func (s *Session) IsClosed() bool { + select { + case <-s.dead: + return true + case <-s.closed: + return true + default: + } + return false +} + +func (s *Session) Wait() error { + defer close(s.closed) + + return s.client.Wait() +} + +func (s *Session) WaitClose() { + defer s.client.Close() + + select { + case <-s.dead: + s.log.Debugf("session is dead") + case <-s.closed: + s.log.Debugf("session is closed") + } +} + +func (s *Session) Keepalive(interval, timeout time.Duration, retries int) { + if interval <= 0 { + interval = defaultKeepaliveInterval + } + if timeout <= 0 { + timeout = defaultKeepaliveTimeout + } + if retries <= 0 { + retries = defaultkeepaliveRetries + } + + s.log.Debugf("keepalive is enabled, interval: %v, timeout: %v, retries: %d", interval, timeout, retries) + defer close(s.dead) + + t := time.NewTicker(interval) + defer t.Stop() + + count := retries + for { + select { + case <-t.C: + start := time.Now() + err := func() error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + select { + case err := <-s.ping(): + return err + case <-ctx.Done(): + return ctx.Err() + } + }() + if err != nil { + s.log.Debugf("ssh ping: %v", err) + count-- + if count == 0 { + return + } + continue + } + s.log.Debugf("ssh ping OK, RTT: %v", time.Since(start)) + count = retries + case <-s.closed: + return + } + } +} + +func (s *Session) ping() <-chan error { + ch := make(chan error, 1) + go func() { + defer close(ch) + if _, _, err := s.client.SendRequest("ping", true, nil); err != nil { + ch <- err + } + }() + return ch +} diff --git a/listener/sshd/listener.go b/listener/sshd/listener.go index 23ea87f..70a1640 100644 --- a/listener/sshd/listener.go +++ b/listener/sshd/listener.go @@ -188,6 +188,8 @@ func (l *sshdListener) serveConn(conn net.Conn) { req.Reply(false, []byte("connection queue is full")) cc.Close() } + case "ping": + req.Reply(true, []byte("pong")) default: l.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply) req.Reply(false, nil)