From 10bcc59370cb51a9072964d1971a886bc15cee53 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 18 Dec 2021 21:22:36 +0800 Subject: [PATCH] add ssh dialer --- cmd/gost/register.go | 3 + go.mod | 2 +- go.sum | 2 + pkg/connector/forward/ssh/connector.go | 79 +++++++ pkg/dialer/forward/ssh/conn.go | 31 +++ pkg/dialer/forward/ssh/dialer.go | 191 ++++++++++++++++ pkg/dialer/forward/ssh/metadata.go | 56 +++++ pkg/handler/forward/local/handler.go | 2 +- pkg/handler/forward/local/metadata.go | 2 +- pkg/handler/forward/remote/handler.go | 2 +- pkg/handler/forward/remote/metadata.go | 2 +- pkg/handler/forward/ssh/handler.go | 300 +++++++++++++++++++++++++ pkg/handler/forward/ssh/metadata.go | 94 ++++++++ pkg/internal/util/ssh/conn.go | 24 ++ pkg/internal/util/ssh/ssh.go | 46 ++++ 15 files changed, 831 insertions(+), 5 deletions(-) create mode 100644 pkg/connector/forward/ssh/connector.go create mode 100644 pkg/dialer/forward/ssh/conn.go create mode 100644 pkg/dialer/forward/ssh/dialer.go create mode 100644 pkg/dialer/forward/ssh/metadata.go create mode 100644 pkg/handler/forward/ssh/handler.go create mode 100644 pkg/handler/forward/ssh/metadata.go create mode 100644 pkg/internal/util/ssh/conn.go create mode 100644 pkg/internal/util/ssh/ssh.go diff --git a/cmd/gost/register.go b/cmd/gost/register.go index c6782a0..3a3f5a1 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -3,6 +3,7 @@ package main import ( // Register connectors _ "github.com/go-gost/gost/pkg/connector/forward" + _ "github.com/go-gost/gost/pkg/connector/forward/ssh" _ "github.com/go-gost/gost/pkg/connector/http" _ "github.com/go-gost/gost/pkg/connector/http2" _ "github.com/go-gost/gost/pkg/connector/relay" @@ -13,6 +14,7 @@ import ( _ "github.com/go-gost/gost/pkg/connector/ss/udp" // Register dialers + _ "github.com/go-gost/gost/pkg/dialer/forward/ssh" _ "github.com/go-gost/gost/pkg/dialer/ftcp" _ "github.com/go-gost/gost/pkg/dialer/http2" _ "github.com/go-gost/gost/pkg/dialer/http2/h2" @@ -31,6 +33,7 @@ import ( _ "github.com/go-gost/gost/pkg/handler/auto" _ "github.com/go-gost/gost/pkg/handler/forward/local" _ "github.com/go-gost/gost/pkg/handler/forward/remote" + _ "github.com/go-gost/gost/pkg/handler/forward/ssh" _ "github.com/go-gost/gost/pkg/handler/http" _ "github.com/go-gost/gost/pkg/handler/http2" _ "github.com/go-gost/gost/pkg/handler/redirect" diff --git a/go.mod b/go.mod index 84aa030..f74aa5d 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,7 @@ require ( github.com/xtaci/kcp-go/v5 v5.6.1 github.com/xtaci/smux v1.5.16 github.com/xtaci/tcpraw v1.2.25 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 + golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/mod v0.4.2 // indirect golang.org/x/net v0.0.0-20211209124913-491a49abca63 golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf // indirect diff --git a/go.sum b/go.sum index b99f18d..55b8a62 100644 --- a/go.sum +++ b/go.sum @@ -442,6 +442,8 @@ golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= diff --git a/pkg/connector/forward/ssh/connector.go b/pkg/connector/forward/ssh/connector.go new file mode 100644 index 0000000..ae3f38f --- /dev/null +++ b/pkg/connector/forward/ssh/connector.go @@ -0,0 +1,79 @@ +package ssh + +import ( + "context" + "errors" + "net" + + "github.com/go-gost/gost/pkg/connector" + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegiserConnector("sshd", NewConnector) +} + +type forwardConnector struct { + logger logger.Logger +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &forwardConnector{ + logger: options.Logger, + } +} + +func (c *forwardConnector) Init(md md.Metadata) (err error) { + return nil +} + +func (c *forwardConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + c.logger.Infof("connect %s/%s", address, network) + + cc, ok := conn.(*ssh_util.ClientConn) + if !ok { + return nil, errors.New("ssh: invalid connection") + } + + conn, err := cc.Client().Dial(network, address) + if err != nil { + c.logger.Error(err) + return nil, err + } + + return conn, nil +} + +// Bind implements connector.Binder. +func (c *forwardConnector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "network": network, + "address": address, + }) + c.logger.Infof("bind on %s/%s", address, network) + + cc, ok := conn.(*ssh_util.ClientConn) + if !ok { + return nil, errors.New("ssh: invalid connection") + } + + if host, port, _ := net.SplitHostPort(address); host == "" { + address = net.JoinHostPort("0.0.0.0", port) + } + + return cc.Client().Listen(network, address) +} diff --git a/pkg/dialer/forward/ssh/conn.go b/pkg/dialer/forward/ssh/conn.go new file mode 100644 index 0000000..fb14c6f --- /dev/null +++ b/pkg/dialer/forward/ssh/conn.go @@ -0,0 +1,31 @@ +package ssh + +import ( + "net" + + "golang.org/x/crypto/ssh" +) + +type sshSession struct { + addr string + conn net.Conn + client *ssh.Client + closed chan struct{} + dead chan struct{} +} + +func (s *sshSession) IsClosed() bool { + select { + case <-s.dead: + return true + case <-s.closed: + return true + default: + } + return false +} + +func (s *sshSession) wait() error { + defer close(s.closed) + return s.client.Wait() +} diff --git a/pkg/dialer/forward/ssh/dialer.go b/pkg/dialer/forward/ssh/dialer.go new file mode 100644 index 0000000..ea7ff0b --- /dev/null +++ b/pkg/dialer/forward/ssh/dialer.go @@ -0,0 +1,191 @@ +package ssh + +import ( + "context" + "errors" + "net" + "sync" + "time" + + "github.com/go-gost/gost/pkg/dialer" + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "golang.org/x/crypto/ssh" +) + +var ( + ErrSessionDead = errors.New("ssh: session is dead") +) + +func init() { + registry.RegisterDialer("sshd", NewDialer) +} + +type forwardDialer struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex + logger logger.Logger + md metadata +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &forwardDialer{ + sessions: make(map[string]*sshSession), + logger: options.Logger, + } +} + +func (d *forwardDialer) Init(md md.Metadata) (err error) { + if err = d.parseMetadata(md); err != nil { + return + } + + return nil +} + +// Multiplex implements dialer.Multiplexer interface. +func (d *forwardDialer) Multiplex() bool { + return true +} + +func (d *forwardDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + session, ok := d.sessions[addr] + if session != nil && session.IsClosed() { + delete(d.sessions, addr) // session is dead + ok = false + } + if !ok { + conn, err = d.dial(ctx, "tcp", addr, &options) + if err != nil { + return + } + + session = &sshSession{ + addr: addr, + conn: conn, + } + d.sessions[addr] = session + } + + return session.conn, err +} + +// Handshake implements dialer.Handshaker +func (d *forwardDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { + opts := &dialer.HandshakeOptions{} + for _, option := range options { + option(opts) + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + if d.md.handshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + session, ok := d.sessions[opts.Addr] + if session != nil && session.conn != conn { + err := errors.New("ssh: unrecognized connection") + d.logger.Error(err) + conn.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + + if !ok || session.client == nil { + s, err := d.initSession(ctx, opts.Addr, conn) + if err != nil { + d.logger.Error(err) + conn.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + session = s + go func() { + s.wait() + d.logger.Debug("session closed") + }() + d.sessions[opts.Addr] = session + } + if session.IsClosed() { + delete(d.sessions, opts.Addr) + return nil, ErrSessionDead + } + + return ssh_util.NewClientConn(session.conn, session.client), nil +} + +func (d *forwardDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { + dial := opts.DialFunc + if dial != nil { + conn, err := dial(ctx, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial with dial func") + } + return conn, err + } + + var netd net.Dialer + conn, err := netd.DialContext(ctx, network, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debugf("dial direct %s/%s", addr, network) + } + return conn, err +} + +func (d *forwardDialer) initSession(ctx context.Context, addr string, conn net.Conn) (*sshSession, error) { + config := ssh.ClientConfig{ + // Timeout: timeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + if d.md.user != nil { + config.User = d.md.user.Username() + if password, _ := d.md.user.Password(); password != "" { + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + } + if d.md.signer != nil { + config.Auth = append(config.Auth, ssh.PublicKeys(d.md.signer)) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, &config) + if err != nil { + return nil, err + } + + return &sshSession{ + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + dead: make(chan struct{}), + }, nil +} diff --git a/pkg/dialer/forward/ssh/metadata.go b/pkg/dialer/forward/ssh/metadata.go new file mode 100644 index 0000000..d2fcd03 --- /dev/null +++ b/pkg/dialer/forward/ssh/metadata.go @@ -0,0 +1,56 @@ +package ssh + +import ( + "io/ioutil" + "net/url" + "strings" + "time" + + md "github.com/go-gost/gost/pkg/metadata" + "golang.org/x/crypto/ssh" +) + +type metadata struct { + handshakeTimeout time.Duration + user *url.Userinfo + signer ssh.Signer +} + +func (d *forwardDialer) parseMetadata(md md.Metadata) (err error) { + const ( + handshakeTimeout = "handshakeTimeout" + user = "user" + privateKeyFile = "privateKeyFile" + passphrase = "passphrase" + ) + + if v := md.GetString(user); v != "" { + ss := strings.SplitN(v, ":", 2) + if len(ss) == 1 { + d.md.user = url.User(ss[0]) + } else { + d.md.user = url.UserPassword(ss[0], ss[1]) + } + } + + if key := md.GetString(privateKeyFile); key != "" { + data, err := ioutil.ReadFile(key) + if err != nil { + return err + } + + pp := md.GetString(passphrase) + if pp == "" { + d.md.signer, err = ssh.ParsePrivateKey(data) + } else { + d.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp)) + } + if err != nil { + return err + } + } + + d.md.handshakeTimeout = md.GetDuration(handshakeTimeout) + + return +} diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index ad6a57a..8342061 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -1,4 +1,4 @@ -package forward +package local import ( "context" diff --git a/pkg/handler/forward/local/metadata.go b/pkg/handler/forward/local/metadata.go index 9bf7df0..b54f8eb 100644 --- a/pkg/handler/forward/local/metadata.go +++ b/pkg/handler/forward/local/metadata.go @@ -1,4 +1,4 @@ -package forward +package local import ( "time" diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go index af118da..8bf529a 100644 --- a/pkg/handler/forward/remote/handler.go +++ b/pkg/handler/forward/remote/handler.go @@ -1,4 +1,4 @@ -package forward +package remote import ( "context" diff --git a/pkg/handler/forward/remote/metadata.go b/pkg/handler/forward/remote/metadata.go index 9bf7df0..50210b5 100644 --- a/pkg/handler/forward/remote/metadata.go +++ b/pkg/handler/forward/remote/metadata.go @@ -1,4 +1,4 @@ -package forward +package remote import ( "time" diff --git a/pkg/handler/forward/ssh/handler.go b/pkg/handler/forward/ssh/handler.go new file mode 100644 index 0000000..9ed9e76 --- /dev/null +++ b/pkg/handler/forward/ssh/handler.go @@ -0,0 +1,300 @@ +package ssh + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "strconv" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "golang.org/x/crypto/ssh" +) + +// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X +const ( + DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2 + RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 + ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 + CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1 +) + +func init() { + registry.RegisterHandler("sshd", NewHandler) +} + +type forwardHandler struct { + chain *chain.Chain + bypass bypass.Bypass + config *ssh.ServerConfig + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &forwardHandler{ + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *forwardHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + config := &ssh.ServerConfig{ + PasswordCallback: ssh_util.PasswordCallback(h.md.authenticator), + PublicKeyCallback: ssh_util.PublicKeyCallback(h.md.authorizedKeys), + } + + config.AddHostKey(h.md.signer) + + if h.md.authenticator == nil && len(h.md.authorizedKeys) == 0 { + config.NoClientAuth = true + } + + h.config = config + + return nil +} + +// WithChain implements chain.Chainable interface +func (h *forwardHandler) WithChain(chain *chain.Chain) { + h.chain = chain +} + +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config) + if err != nil { + h.logger.Error(err) + return + } + + h.handleForward(ctx, sshConn, chans, reqs) +} + +func (h *forwardHandler) handleForward(ctx context.Context, conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + quit := make(chan struct{}) + defer close(quit) // quit signal + + go func() { + for req := range reqs { + switch req.Type { + case RemoteForwardRequest: + go h.tcpipForwardRequest(conn, req, quit) + default: + h.logger.Warnf("unsupported request type: %s, want reply: %v", req.Type, req.WantReply) + if req.WantReply { + req.Reply(false, nil) + } + } + } + }() + + go func() { + for newChannel := range chans { + // Check the type of channel + t := newChannel.ChannelType() + switch t { + case DirectForwardRequest: + channel, requests, err := newChannel.Accept() + if err != nil { + h.logger.Warnf("could not accept channel: %s", err.Error()) + continue + } + p := directForward{} + ssh.Unmarshal(newChannel.ExtraData(), &p) + + h.logger.Debug(p.String()) + + if p.Host1 == "" { + p.Host1 = "" + } + + go ssh.DiscardRequests(requests) + go h.directPortForwardChannel(ctx, channel, net.JoinHostPort(p.Host1, strconv.Itoa(int(p.Port1)))) + default: + h.logger.Warnf("unsupported channel type: %s", t) + newChannel.Reject(ssh.Prohibited, fmt.Sprintf("unsupported channel type: %s", t)) + } + } + }() + + conn.Wait() +} + +func (h *forwardHandler) directPortForwardChannel(ctx context.Context, channel ssh.Channel, raddr string) { + defer channel.Close() + + // log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr) + + /* + if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr) + return + } + */ + + if h.bypass != nil && h.bypass.Contains(raddr) { + h.logger.Infof("bypass %s", raddr) + return + } + + r := (&chain.Router{}). + WithChain(h.chain). + // WithRetry(h.md.retryCount). + WithLogger(h.logger) + conn, err := r.Dial(ctx, "tcp", raddr) + if err != nil { + return + } + defer conn.Close() + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.LocalAddr(), conn.RemoteAddr()) + handler.Transport(conn, channel) + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.LocalAddr(), conn.RemoteAddr()) +} + +// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" +type directForward struct { + Host1 string + Port1 uint32 + Host2 string + Port2 uint32 +} + +func (p directForward) String() string { + return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) +} + +func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { + host, portString, err := net.SplitHostPort(addr.String()) + if err != nil { + return + } + port, err = strconv.Atoi(portString) + return +} + +// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request +type tcpipForward struct { + Host string + Port uint32 +} + +func (h *forwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) { + t := tcpipForward{} + ssh.Unmarshal(req.Payload, &t) + + addr := net.JoinHostPort(t.Host, strconv.Itoa(int(t.Port))) + + /* + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr) + req.Reply(false, nil) + return + } + */ + + // tie to the client connection + ln, err := net.Listen("tcp", addr) + if err != nil { + h.logger.Error(err) + req.Reply(false, nil) + return + } + defer ln.Close() + + h.logger.Debugf("bind on %s OK", ln.Addr()) + + err = func() error { + if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used + _, port, err := getHostPortFromAddr(ln.Addr()) + if err != nil { + return err + } + var b [4]byte + binary.BigEndian.PutUint32(b[:], uint32(port)) + t.Port = uint32(port) + return req.Reply(true, b[:]) + } + return req.Reply(true, nil) + }() + if err != nil { + h.logger.Error(err) + return + } + + go func() { + for { + conn, err := ln.Accept() + if err != nil { // Unable to accept new connection - listener is likely closed + return + } + + go func(conn net.Conn) { + defer conn.Close() + + p := directForward{} + var err error + + var portnum int + p.Host1 = t.Host + p.Port1 = t.Port + p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) + if err != nil { + return + } + + p.Port2 = uint32(portnum) + ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) + if err != nil { + h.logger.Error("open forwarded channel: ", err) + return + } + defer ch.Close() + go ssh.DiscardRequests(reqs) + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + handler.Transport(ch, conn) + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(t), + }).Infof("%s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) + }(conn) + } + }() + + <-quit +} diff --git a/pkg/handler/forward/ssh/metadata.go b/pkg/handler/forward/ssh/metadata.go new file mode 100644 index 0000000..3f705a1 --- /dev/null +++ b/pkg/handler/forward/ssh/metadata.go @@ -0,0 +1,94 @@ +package ssh + +import ( + "io/ioutil" + "strings" + + "github.com/go-gost/gost/pkg/auth" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" + "golang.org/x/crypto/ssh" +) + +type metadata struct { + authenticator auth.Authenticator + signer ssh.Signer + authorizedKeys map[string]bool +} + +func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { + const ( + users = "users" + authorizedKeys = "authorizedKeys" + privateKeyFile = "privateKeyFile" + passphrase = "passphrase" + ) + + if v, _ := md.Get(users).([]interface{}); len(v) > 0 { + authenticator := auth.NewLocalAuthenticator(nil) + for _, auth := range v { + if s, _ := auth.(string); s != "" { + ss := strings.SplitN(s, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) + } + } + } + h.md.authenticator = authenticator + } + + if key := md.GetString(privateKeyFile); key != "" { + data, err := ioutil.ReadFile(key) + if err != nil { + return err + } + + pp := md.GetString(passphrase) + if pp == "" { + h.md.signer, err = ssh.ParsePrivateKey(data) + } else { + h.md.signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(pp)) + } + if err != nil { + return err + } + } + if h.md.signer == nil { + signer, err := ssh.NewSignerFromKey(tls_util.DefaultConfig.Clone().Certificates[0].PrivateKey) + if err != nil { + return err + } + h.md.signer = signer + } + + if name := md.GetString(authorizedKeys); name != "" { + m, err := parseAuthorizedKeysFile(name) + if err != nil { + return err + } + h.md.authorizedKeys = m + } + + return +} + +// parseSSHAuthorizedKeysFile parses ssh authorized keys file. +func parseAuthorizedKeysFile(name string) (map[string]bool, error) { + authorizedKeysBytes, err := ioutil.ReadFile(name) + if err != nil { + return nil, err + } + authorizedKeysMap := make(map[string]bool) + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, err + } + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + return authorizedKeysMap, nil +} diff --git a/pkg/internal/util/ssh/conn.go b/pkg/internal/util/ssh/conn.go new file mode 100644 index 0000000..fbc689b --- /dev/null +++ b/pkg/internal/util/ssh/conn.go @@ -0,0 +1,24 @@ +package ssh + +import ( + "net" + + "golang.org/x/crypto/ssh" +) + +// a dummy ssh client conn used by client connector +type ClientConn struct { + net.Conn + client *ssh.Client +} + +func NewClientConn(conn net.Conn, client *ssh.Client) net.Conn { + return &ClientConn{ + Conn: conn, + client: client, + } +} + +func (c *ClientConn) Client() *ssh.Client { + return c.client +} diff --git a/pkg/internal/util/ssh/ssh.go b/pkg/internal/util/ssh/ssh.go new file mode 100644 index 0000000..22e7ba7 --- /dev/null +++ b/pkg/internal/util/ssh/ssh.go @@ -0,0 +1,46 @@ +package ssh + +import ( + "fmt" + + "github.com/go-gost/gost/pkg/auth" + "golang.org/x/crypto/ssh" +) + +// PasswordCallbackFunc is a callback function used by SSH server. +// It authenticates user using a password. +type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) + +func PasswordCallback(au auth.Authenticator) PasswordCallbackFunc { + if au == nil { + return nil + } + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if au.Authenticate(conn.User(), string(password)) { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %s", conn.User()) + } +} + +// PublicKeyCallbackFunc is a callback function used by SSH server. +// It offers a public key for authentication. +type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) + +func PublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc { + if len(keys) == 0 { + return nil + } + + return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if keys[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": ssh.FingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + } +}