From 34d6e393a1d64e41495ab32b5e9138438978fa9c Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 19 Dec 2021 17:24:51 +0800 Subject: [PATCH] add ssh tunnel --- cmd/gost/register.go | 2 + pkg/dialer/forward/ssh/dialer.go | 6 +- pkg/dialer/ssh/conn.go | 31 +++++ pkg/dialer/ssh/dialer.go | 193 ++++++++++++++++++++++++++++ pkg/dialer/ssh/metadata.go | 56 ++++++++ pkg/handler/forward/ssh/metadata.go | 22 +--- pkg/internal/util/ssh/conn.go | 24 ++++ pkg/internal/util/ssh/ssh.go | 29 +++++ pkg/listener/ssh/listener.go | 136 ++++++++++++++++++++ pkg/listener/ssh/metadata.go | 87 +++++++++++++ pkg/listener/tcp/listener.go | 9 -- pkg/listener/tcp/metadata.go | 16 --- 12 files changed, 561 insertions(+), 50 deletions(-) create mode 100644 pkg/dialer/ssh/conn.go create mode 100644 pkg/dialer/ssh/dialer.go create mode 100644 pkg/dialer/ssh/metadata.go create mode 100644 pkg/listener/ssh/listener.go create mode 100644 pkg/listener/ssh/metadata.go diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 3a3f5a1..f66d971 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -22,6 +22,7 @@ import ( _ "github.com/go-gost/gost/pkg/dialer/obfs/http" _ "github.com/go-gost/gost/pkg/dialer/obfs/tls" _ "github.com/go-gost/gost/pkg/dialer/quic" + _ "github.com/go-gost/gost/pkg/dialer/ssh" _ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/tls" _ "github.com/go-gost/gost/pkg/dialer/tls/mux" @@ -55,6 +56,7 @@ import ( _ "github.com/go-gost/gost/pkg/listener/redirect/udp" _ "github.com/go-gost/gost/pkg/listener/rtcp" _ "github.com/go-gost/gost/pkg/listener/rudp" + _ "github.com/go-gost/gost/pkg/listener/ssh" _ "github.com/go-gost/gost/pkg/listener/tcp" _ "github.com/go-gost/gost/pkg/listener/tls" _ "github.com/go-gost/gost/pkg/listener/tls/mux" diff --git a/pkg/dialer/forward/ssh/dialer.go b/pkg/dialer/forward/ssh/dialer.go index ea7ff0b..c4c3891 100644 --- a/pkg/dialer/forward/ssh/dialer.go +++ b/pkg/dialer/forward/ssh/dialer.go @@ -15,10 +15,6 @@ import ( "golang.org/x/crypto/ssh" ) -var ( - ErrSessionDead = errors.New("ssh: session is dead") -) - func init() { registry.RegisterDialer("sshd", NewDialer) } @@ -126,7 +122,7 @@ func (d *forwardDialer) Handshake(ctx context.Context, conn net.Conn, options .. } if session.IsClosed() { delete(d.sessions, opts.Addr) - return nil, ErrSessionDead + return nil, ssh_util.ErrSessionDead } return ssh_util.NewClientConn(session.conn, session.client), nil diff --git a/pkg/dialer/ssh/conn.go b/pkg/dialer/ssh/conn.go new file mode 100644 index 0000000..fb14c6f --- /dev/null +++ b/pkg/dialer/ssh/conn.go @@ -0,0 +1,31 @@ +package ssh + +import ( + "net" + + "golang.org/x/crypto/ssh" +) + +type sshSession struct { + addr string + conn net.Conn + client *ssh.Client + closed chan struct{} + dead chan struct{} +} + +func (s *sshSession) IsClosed() bool { + select { + case <-s.dead: + return true + case <-s.closed: + return true + default: + } + return false +} + +func (s *sshSession) wait() error { + defer close(s.closed) + return s.client.Wait() +} diff --git a/pkg/dialer/ssh/dialer.go b/pkg/dialer/ssh/dialer.go new file mode 100644 index 0000000..23d21b0 --- /dev/null +++ b/pkg/dialer/ssh/dialer.go @@ -0,0 +1,193 @@ +package ssh + +import ( + "context" + "errors" + "net" + "sync" + "time" + + "github.com/go-gost/gost/pkg/dialer" + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "golang.org/x/crypto/ssh" +) + +func init() { + registry.RegisterDialer("ssh", NewDialer) +} + +type sshDialer struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex + logger logger.Logger + md metadata +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &sshDialer{ + sessions: make(map[string]*sshSession), + logger: options.Logger, + } +} + +func (d *sshDialer) Init(md md.Metadata) (err error) { + if err = d.parseMetadata(md); err != nil { + return + } + + return nil +} + +// Multiplex implements dialer.Multiplexer interface. +func (d *sshDialer) Multiplex() bool { + return true +} + +func (d *sshDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + session, ok := d.sessions[addr] + if session != nil && session.IsClosed() { + delete(d.sessions, addr) // session is dead + ok = false + } + if !ok { + conn, err = d.dial(ctx, "tcp", addr, &options) + if err != nil { + return + } + + session = &sshSession{ + addr: addr, + conn: conn, + } + d.sessions[addr] = session + } + + return session.conn, err +} + +// Handshake implements dialer.Handshaker +func (d *sshDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { + opts := &dialer.HandshakeOptions{} + for _, option := range options { + option(opts) + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + if d.md.handshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + session, ok := d.sessions[opts.Addr] + if session != nil && session.conn != conn { + err := errors.New("ssh: unrecognized connection") + d.logger.Error(err) + conn.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + + if !ok || session.client == nil { + s, err := d.initSession(ctx, opts.Addr, conn) + if err != nil { + d.logger.Error(err) + conn.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + session = s + go func() { + s.wait() + d.logger.Debug("session closed") + }() + d.sessions[opts.Addr] = session + } + if session.IsClosed() { + delete(d.sessions, opts.Addr) + return nil, ssh_util.ErrSessionDead + } + + channel, reqs, err := session.client.OpenChannel(ssh_util.GostSSHTunnelRequest, nil) + if err != nil { + return nil, err + } + go ssh.DiscardRequests(reqs) + + return ssh_util.NewConn(conn, channel), nil +} + +func (d *sshDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { + dial := opts.DialFunc + if dial != nil { + conn, err := dial(ctx, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial with dial func") + } + return conn, err + } + + var netd net.Dialer + conn, err := netd.DialContext(ctx, network, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debugf("dial direct %s/%s", addr, network) + } + return conn, err +} + +func (d *sshDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) { + config := ssh.ClientConfig{ + // Timeout: timeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + if d.md.user != nil { + config.User = d.md.user.Username() + if password, _ := d.md.user.Password(); password != "" { + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + } + if d.md.signer != nil { + config.Auth = append(config.Auth, ssh.PublicKeys(d.md.signer)) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, &config) + if err != nil { + return nil, err + } + + return &sshSession{ + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + dead: make(chan struct{}), + }, nil +} diff --git a/pkg/dialer/ssh/metadata.go b/pkg/dialer/ssh/metadata.go new file mode 100644 index 0000000..986e5d2 --- /dev/null +++ b/pkg/dialer/ssh/metadata.go @@ -0,0 +1,56 @@ +package ssh + +import ( + "io/ioutil" + "net/url" + "strings" + "time" + + md "github.com/go-gost/gost/pkg/metadata" + "golang.org/x/crypto/ssh" +) + +type metadata struct { + handshakeTimeout time.Duration + user *url.Userinfo + signer ssh.Signer +} + +func (d *sshDialer) parseMetadata(md md.Metadata) (err error) { + const ( + handshakeTimeout = "handshakeTimeout" + user = "user" + privateKeyFile = "privateKeyFile" + passphrase = "passphrase" + ) + + if v := md.GetString(user); v != "" { + ss := strings.SplitN(v, ":", 2) + if len(ss) == 1 { + d.md.user = url.User(ss[0]) + } else { + d.md.user = url.UserPassword(ss[0], ss[1]) + } + } + + if key := md.GetString(privateKeyFile); key != "" { + data, err := ioutil.ReadFile(key) + if err != nil { + return err + } + + pp := md.GetString(passphrase) + if pp == "" { + d.md.signer, err = ssh.ParsePrivateKey(data) + } else { + d.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp)) + } + if err != nil { + return err + } + } + + d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + + return +} diff --git a/pkg/handler/forward/ssh/metadata.go b/pkg/handler/forward/ssh/metadata.go index 3f705a1..c904ae6 100644 --- a/pkg/handler/forward/ssh/metadata.go +++ b/pkg/handler/forward/ssh/metadata.go @@ -6,6 +6,7 @@ import ( "github.com/go-gost/gost/pkg/auth" tls_util "github.com/go-gost/gost/pkg/common/util/tls" + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" md "github.com/go-gost/gost/pkg/metadata" "golang.org/x/crypto/ssh" ) @@ -64,7 +65,7 @@ func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { } if name := md.GetString(authorizedKeys); name != "" { - m, err := parseAuthorizedKeysFile(name) + m, err := ssh_util.ParseAuthorizedKeysFile(name) if err != nil { return err } @@ -73,22 +74,3 @@ func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { return } - -// parseSSHAuthorizedKeysFile parses ssh authorized keys file. -func parseAuthorizedKeysFile(name string) (map[string]bool, error) { - authorizedKeysBytes, err := ioutil.ReadFile(name) - if err != nil { - return nil, err - } - authorizedKeysMap := make(map[string]bool) - for len(authorizedKeysBytes) > 0 { - pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) - if err != nil { - return nil, err - } - authorizedKeysMap[string(pubKey.Marshal())] = true - authorizedKeysBytes = rest - } - - return authorizedKeysMap, nil -} diff --git a/pkg/internal/util/ssh/conn.go b/pkg/internal/util/ssh/conn.go index fbc689b..7cb9382 100644 --- a/pkg/internal/util/ssh/conn.go +++ b/pkg/internal/util/ssh/conn.go @@ -22,3 +22,27 @@ func NewClientConn(conn net.Conn, client *ssh.Client) net.Conn { func (c *ClientConn) Client() *ssh.Client { return c.client } + +type sshConn struct { + channel ssh.Channel + net.Conn +} + +func NewConn(conn net.Conn, channel ssh.Channel) net.Conn { + return &sshConn{ + Conn: conn, + channel: channel, + } +} + +func (c *sshConn) Read(b []byte) (n int, err error) { + return c.channel.Read(b) +} + +func (c *sshConn) Write(b []byte) (n int, err error) { + return c.channel.Write(b) +} + +func (c *sshConn) Close() error { + return c.channel.Close() +} diff --git a/pkg/internal/util/ssh/ssh.go b/pkg/internal/util/ssh/ssh.go index 22e7ba7..4b835a8 100644 --- a/pkg/internal/util/ssh/ssh.go +++ b/pkg/internal/util/ssh/ssh.go @@ -1,12 +1,22 @@ package ssh import ( + "errors" "fmt" + "io/ioutil" "github.com/go-gost/gost/pkg/auth" "golang.org/x/crypto/ssh" ) +const ( + GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel +) + +var ( + ErrSessionDead = errors.New("session is dead") +) + // PasswordCallbackFunc is a callback function used by SSH server. // It authenticates user using a password. type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) @@ -44,3 +54,22 @@ func PublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc { return nil, fmt.Errorf("unknown public key for %q", c.User()) } } + +// ParseSSHAuthorizedKeysFile parses ssh authorized keys file. +func ParseAuthorizedKeysFile(name string) (map[string]bool, error) { + authorizedKeysBytes, err := ioutil.ReadFile(name) + if err != nil { + return nil, err + } + authorizedKeysMap := make(map[string]bool) + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, err + } + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + return authorizedKeysMap, nil +} diff --git a/pkg/listener/ssh/listener.go b/pkg/listener/ssh/listener.go new file mode 100644 index 0000000..e57c67b --- /dev/null +++ b/pkg/listener/ssh/listener.go @@ -0,0 +1,136 @@ +package ssh + +import ( + "fmt" + "net" + + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "golang.org/x/crypto/ssh" +) + +func init() { + registry.RegisterListener("ssh", NewListener) +} + +type sshListener struct { + addr string + net.Listener + config *ssh.ServerConfig + cqueue chan net.Conn + errChan chan error + logger logger.Logger + md metadata +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &sshListener{ + addr: options.Addr, + logger: options.Logger, + } +} + +func (l *sshListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + ln, err := net.Listen("tcp", l.addr) + if err != nil { + return err + } + + l.Listener = ln + + config := &ssh.ServerConfig{ + PasswordCallback: ssh_util.PasswordCallback(l.md.authenticator), + PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys), + } + + config.AddHostKey(l.md.signer) + + if l.md.authenticator == nil && len(l.md.authorizedKeys) == 0 { + config.NoClientAuth = true + } + + l.config = config + l.cqueue = make(chan net.Conn, l.md.backlog) + l.errChan = make(chan error, 1) + + go l.listenLoop() + + return +} + +func (l *sshListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.cqueue: + case err, ok = <-l.errChan: + if !ok { + err = listener.ErrClosed + } + } + return +} + +func (l *sshListener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + l.logger.Error("accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.serveConn(conn) + } +} + +func (l *sshListener) serveConn(conn net.Conn) { + sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) + if err != nil { + l.logger.Error(err) + conn.Close() + return + } + defer sc.Close() + + go ssh.DiscardRequests(reqs) + go func() { + for newChannel := range chans { + // Check the type of channel + t := newChannel.ChannelType() + switch t { + case ssh_util.GostSSHTunnelRequest: + channel, requests, err := newChannel.Accept() + if err != nil { + l.logger.Warnf("could not accept channel: %s", err.Error()) + continue + } + + go ssh.DiscardRequests(requests) + cc := ssh_util.NewConn(conn, channel) + select { + case l.cqueue <- cc: + default: + cc.Close() + l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr()) + } + + default: + l.logger.Warnf("unsupported channel type: %s", t) + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unsupported channel type: %s", t)) + } + } + }() + + sc.Wait() +} diff --git a/pkg/listener/ssh/metadata.go b/pkg/listener/ssh/metadata.go new file mode 100644 index 0000000..96e7800 --- /dev/null +++ b/pkg/listener/ssh/metadata.go @@ -0,0 +1,87 @@ +package ssh + +import ( + "io/ioutil" + "strings" + + "github.com/go-gost/gost/pkg/auth" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" + md "github.com/go-gost/gost/pkg/metadata" + "golang.org/x/crypto/ssh" +) + +const ( + defaultBacklog = 128 +) + +type metadata struct { + authenticator auth.Authenticator + signer ssh.Signer + authorizedKeys map[string]bool + backlog int +} + +func (l *sshListener) parseMetadata(md md.Metadata) (err error) { + const ( + users = "users" + authorizedKeys = "authorizedKeys" + privateKeyFile = "privateKeyFile" + passphrase = "passphrase" + backlog = "backlog" + ) + + if v, _ := md.Get(users).([]interface{}); len(v) > 0 { + authenticator := auth.NewLocalAuthenticator(nil) + for _, auth := range v { + if s, _ := auth.(string); s != "" { + ss := strings.SplitN(s, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) + } + } + } + l.md.authenticator = authenticator + } + + if key := md.GetString(privateKeyFile); key != "" { + data, err := ioutil.ReadFile(key) + if err != nil { + return err + } + + pp := md.GetString(passphrase) + if pp == "" { + l.md.signer, err = ssh.ParsePrivateKey(data) + } else { + l.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp)) + } + if err != nil { + return err + } + } + if l.md.signer == nil { + signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey) + if err != nil { + return err + } + l.md.signer = signer + } + + if name := md.GetString(authorizedKeys); name != "" { + m, err := ssh_util.ParseAuthorizedKeysFile(name) + if err != nil { + return err + } + l.md.authorizedKeys = m + } + + l.md.backlog = md.GetInt(backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + + return +} diff --git a/pkg/listener/tcp/listener.go b/pkg/listener/tcp/listener.go index 2fd2368..0af1139 100644 --- a/pkg/listener/tcp/listener.go +++ b/pkg/listener/tcp/listener.go @@ -3,7 +3,6 @@ package tcp import ( "net" - util "github.com/go-gost/gost/pkg/common/util" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -46,14 +45,6 @@ func (l *tcpListener) Init(md md.Metadata) (err error) { return } - if l.md.keepAlive { - l.Listener = &util.TCPKeepAliveListener{ - TCPListener: ln, - KeepAlivePeriod: l.md.keepAlivePeriod, - } - return - } - l.Listener = ln return } diff --git a/pkg/listener/tcp/metadata.go b/pkg/listener/tcp/metadata.go index b93af21..ea71797 100644 --- a/pkg/listener/tcp/metadata.go +++ b/pkg/listener/tcp/metadata.go @@ -1,28 +1,12 @@ package tcp import ( - "time" - md "github.com/go-gost/gost/pkg/metadata" ) -const ( - defaultKeepAlivePeriod = 180 * time.Second -) - type metadata struct { - keepAlive bool - keepAlivePeriod time.Duration } func (l *tcpListener) parseMetadata(md md.Metadata) (err error) { - const ( - keepAlive = "keepAlive" - keepAlivePeriod = "keepAlivePeriod" - ) - - l.md.keepAlive = md.GetBool(keepAlive) - l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) - return }