From d72ab4e21ee94d066645003e471807689d889e25 Mon Sep 17 00:00:00 2001 From: dushixiang Date: Sat, 6 Feb 2021 00:25:48 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=8E=9F=E7=94=9Fssh?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 5 +- pkg/api/recording.go | 101 ----- pkg/api/routes.go | 2 + pkg/api/session.go | 44 +-- pkg/api/ssh.go | 344 ++++++++---------- pkg/api/tunnel.go | 4 +- pkg/global/store.go | 37 +- pkg/log/logger.go | 193 ++++++++++ pkg/term/next_terminal.go | 99 +++++ pkg/term/next_writer.go | 30 ++ pkg/term/recording.go | 122 +++++++ pkg/term/ssh.go | 53 +++ pkg/term/test/test_ssh.go | 174 +++++++++ pkg/utils/utils.go | 5 + web/package.json | 2 +- web/src/App.js | 4 +- web/src/components/access/Access.js | 2 +- .../access/{AccessNaive.css => Term.css} | 0 .../access/{AccessNaive.js => Term.js} | 32 +- web/src/components/asset/Asset.js | 12 +- 20 files changed, 896 insertions(+), 369 deletions(-) delete mode 100644 pkg/api/recording.go create mode 100644 pkg/log/logger.go create mode 100644 pkg/term/next_terminal.go create mode 100644 pkg/term/next_writer.go create mode 100644 pkg/term/recording.go create mode 100644 pkg/term/ssh.go create mode 100644 pkg/term/test/test_ssh.go rename web/src/components/access/{AccessNaive.css => Term.css} (100%) rename web/src/components/access/{AccessNaive.js => Term.js} (90%) diff --git a/main.go b/main.go index 69795b4..994fe36 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,6 @@ import ( "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" - "gorm.io/gorm/logger" "io" "next-terminal/pkg/api" "next-terminal/pkg/config" @@ -23,7 +22,7 @@ import ( "time" ) -const Version = "v0.1.1" +const Version = "v0.2.0" func main() { log.Fatal(Run()) @@ -71,7 +70,7 @@ func Run() error { global.Config.Mysql.Database, ) global.DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Info), + //Logger: logger.Default.LogMode(logger.Info), }) } else { global.DB, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{}) diff --git a/pkg/api/recording.go b/pkg/api/recording.go deleted file mode 100644 index 68f50ec..0000000 --- a/pkg/api/recording.go +++ /dev/null @@ -1,101 +0,0 @@ -package api - -import ( - "encoding/json" - "next-terminal/pkg/utils" - "os" - "path" - "time" -) - -type Env struct { - Shell string `json:"SHELL"` - Term string `json:"TERM"` -} - -type Header struct { - Title string `json:"title"` - Version int `json:"version"` - Height int `json:"height"` - Width int `json:"width"` - Env Env `json:"env"` - Timestamp int `json:"timestamp"` -} - -type Recorder struct { - file *os.File - timestamp int -} - -func NewRecorder(dir string) (recorder *Recorder, filename string, err error) { - recorder = &Recorder{} - - if utils.FileExists(dir) { - if err := os.RemoveAll(dir); err != nil { - return nil, "", err - } - } - - if err = os.MkdirAll(dir, 0777); err != nil { - return - } - - filename = path.Join(dir, "recording.cast") - - var file *os.File - file, err = os.Create(filename) - if err != nil { - return nil, "", err - } - - recorder.file = file - return recorder, filename, nil -} - -func (recorder *Recorder) Close() { - if recorder.file != nil { - recorder.file.Close() - } -} - -func (recorder *Recorder) WriteHeader(header *Header) (err error) { - var p []byte - - if p, err = json.Marshal(header); err != nil { - return - } - - if _, err := recorder.file.Write(p); err != nil { - return err - } - if _, err := recorder.file.Write([]byte("\n")); err != nil { - return err - } - - recorder.timestamp = header.Timestamp - - return -} - -func (recorder *Recorder) WriteData(data string) (err error) { - now := int(time.Now().UnixNano()) - - delta := float64(now-recorder.timestamp*1000*1000*1000) / 1000 / 1000 / 1000 - - row := make([]interface{}, 0) - row = append(row, delta) - row = append(row, "o") - row = append(row, data) - - var s []byte - if s, err = json.Marshal(row); err != nil { - return - } - if _, err := recorder.file.Write(s); err != nil { - return err - } - if _, err := recorder.file.Write([]byte("\n")); err != nil { - return err - } - return -} diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 1fb9c66..1936864 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -3,6 +3,7 @@ package api import ( "net/http" "next-terminal/pkg/global" + "next-terminal/pkg/log" "next-terminal/pkg/model" "github.com/labstack/echo/v4" @@ -15,6 +16,7 @@ func SetupRoutes() *echo.Echo { e := echo.New() e.HideBanner = true + e.Logger = log.GetEchoLogger() e.File("/", "web/build/index.html") e.File("/logo.svg", "web/build/logo.svg") diff --git a/pkg/api/session.go b/pkg/api/session.go index 0c145ee..44138b5 100644 --- a/pkg/api/session.go +++ b/pkg/api/session.go @@ -4,14 +4,12 @@ import ( "bytes" "errors" "fmt" - "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/sirupsen/logrus" "io" "io/ioutil" "net/http" "next-terminal/pkg/global" - "next-terminal/pkg/guacd" "next-terminal/pkg/model" "next-terminal/pkg/utils" "os" @@ -100,7 +98,7 @@ func SessionDisconnectEndpoint(c echo.Context) error { split := strings.Split(sessionIds, ",") for i := range split { - CloseSessionById(split[i], ForcedDisconnect, "forced disconnect") + CloseSessionById(split[i], ForcedDisconnect, "管理员强制关闭了此会话") } return Success(c, nil) } @@ -112,16 +110,13 @@ func CloseSessionById(sessionId string, code int, reason string) { defer mutex.Unlock() observable, _ := global.Store.Get(sessionId) if observable != nil { - logrus.Debugf("会话%v创建者退出", observable.Subject.Tunnel.UUID) - observable.Subject.Close() + logrus.Debugf("会话%v创建者退出", sessionId) + observable.Subject.Close(code, reason) for i := 0; i < len(observable.Observers); i++ { - observable.Observers[i].Close() - CloseWebSocket(observable.Observers[i].WebSocket, code, reason) - logrus.Debugf("强制踢出会话%v的观察者", observable.Observers[i].Tunnel.UUID) + observable.Observers[i].Close(code, reason) + logrus.Debugf("强制踢出会话%v的观察者", sessionId) } - - CloseWebSocket(observable.Subject.WebSocket, code, reason) } global.Store.Del(sessionId) @@ -150,17 +145,6 @@ func CloseSessionById(sessionId string, code int, reason string) { _ = model.UpdateSessionById(&session, sessionId) } -func CloseWebSocket(ws *websocket.Conn, c int, t string) { - if ws == nil { - return - } - err := guacd.NewInstruction("error", "", strconv.Itoa(c)) - _ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) - disconnect := guacd.NewInstruction("disconnect") - _ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) - //defer ws.Close() -} - func SessionResizeEndpoint(c echo.Context) error { width := c.QueryParam("width") height := c.QueryParam("height") @@ -274,11 +258,11 @@ func SessionUploadEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - dstFile, err := tun.Subject.SftpClient.Create(remoteFile) - defer dstFile.Close() + dstFile, err := tun.Subject.NextTerminal.SftpClient.Create(remoteFile) if err != nil { return err } + defer dstFile.Close() buf := make([]byte, 1024) for { @@ -327,7 +311,7 @@ func SessionDownloadEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - dstFile, err := tun.Subject.SftpClient.Open(remoteFile) + dstFile, err := tun.Subject.NextTerminal.SftpClient.Open(remoteFile) if err != nil { return err } @@ -378,16 +362,16 @@ func SessionLsEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - if tun.Subject.SftpClient == nil { + if tun.Subject.NextTerminal.SftpClient == nil { sftpClient, err := CreateSftpClient(session) if err != nil { logrus.Errorf("创建sftp客户端失败:%v", err.Error()) return err } - tun.Subject.SftpClient = sftpClient + tun.Subject.NextTerminal.SftpClient = sftpClient } - fileInfos, err := tun.Subject.SftpClient.ReadDir(remoteDir) + fileInfos, err := tun.Subject.NextTerminal.SftpClient.ReadDir(remoteDir) if err != nil { return err } @@ -457,7 +441,7 @@ func SessionMkDirEndpoint(c echo.Context) error { if !ok { return errors.New("获取sftp客户端失败") } - if err := tun.Subject.SftpClient.Mkdir(remoteDir); err != nil { + if err := tun.Subject.NextTerminal.SftpClient.Mkdir(remoteDir); err != nil { return err } return Success(c, nil) @@ -489,7 +473,7 @@ func SessionRmEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - sftpClient := tun.Subject.SftpClient + sftpClient := tun.Subject.NextTerminal.SftpClient stat, err := sftpClient.Stat(key) if err != nil { @@ -548,7 +532,7 @@ func SessionRenameEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - sftpClient := tun.Subject.SftpClient + sftpClient := tun.Subject.NextTerminal.SftpClient if err := sftpClient.Rename(oldName, newName); err != nil { return err diff --git a/pkg/api/ssh.go b/pkg/api/ssh.go index b035564..0f74eca 100644 --- a/pkg/api/ssh.go +++ b/pkg/api/ssh.go @@ -1,7 +1,6 @@ package api import ( - "bytes" "encoding/json" "fmt" "github.com/gorilla/websocket" @@ -10,12 +9,13 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "net/http" + "next-terminal/pkg/global" "next-terminal/pkg/guacd" "next-terminal/pkg/model" + "next-terminal/pkg/term" "next-terminal/pkg/utils" "path" "strconv" - "sync" "time" ) @@ -26,30 +26,6 @@ var UpGrader = websocket.Upgrader{ Subprotocols: []string{"guacamole"}, } -type NextWriter struct { - b bytes.Buffer - mu sync.Mutex -} - -func (w *NextWriter) Write(p []byte) (int, error) { - w.mu.Lock() - defer w.mu.Unlock() - return w.b.Write(p) -} - -func (w *NextWriter) Read() ([]byte, int, error) { - w.mu.Lock() - defer w.mu.Unlock() - p := w.b.Bytes() - buf := make([]byte, len(p)) - read, err := w.b.Read(buf) - w.b.Reset() - if err != nil { - return nil, 0, err - } - return buf, read, err -} - const ( Connected = "connected" Data = "data" @@ -75,14 +51,14 @@ func SSHEndpoint(c echo.Context) (err error) { } sessionId := c.QueryParam("sessionId") - width, _ := strconv.Atoi(c.QueryParam("width")) - height, _ := strconv.Atoi(c.QueryParam("height")) + cols, _ := strconv.Atoi(c.QueryParam("cols")) + rows, _ := strconv.Atoi(c.QueryParam("rows")) - aSession, err := model.FindSessionById(sessionId) + session, err := model.FindSessionById(sessionId) if err != nil { msg := Message{ Type: Closed, - Content: "get session error." + err.Error(), + Content: "get sshSession error." + err.Error(), } _ = WriteMessage(ws, msg) return err @@ -96,7 +72,7 @@ func SSHEndpoint(c echo.Context) (err error) { return err } - if !utils.Contains(assetIds, aSession.AssetId) { + if !utils.Contains(assetIds, session.AssetId) { msg := Message{ Type: Closed, Content: "您没有权限访问此资产", @@ -105,7 +81,41 @@ func SSHEndpoint(c echo.Context) (err error) { } } - sshClient, err := CreateSshClientBySession(aSession) + var ( + username = session.Username + password = session.Password + privateKey = session.PrivateKey + passphrase = session.Passphrase + ip = session.IP + port = session.Port + ) + + recording := "" + propertyMap := model.FindAllPropertiesMap() + if propertyMap[guacd.EnableRecording] == "true" { + recording = path.Join(propertyMap[guacd.RecordingPath], sessionId, "recording.cast") + } + + tun := global.Tun{ + Protocol: session.Protocol, + WebSocket: ws, + } + + if session.ConnectionId != "" { + // 监控会话 + observable, ok := global.Store.Get(sessionId) + if ok { + observers := append(observable.Observers, tun) + observable.Observers = observers + global.Store.Set(sessionId, observable) + logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers)) + } + + return err + } + + nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording) + if err != nil { logrus.Errorf("创建SSH客户端失败:%v", err.Error()) msg := Message{ @@ -115,129 +125,50 @@ func SSHEndpoint(c echo.Context) (err error) { err := WriteMessage(ws, msg) return err } + tun.NextTerminal = nextTerminal - session, err := sshClient.NewSession() - if err != nil { - logrus.Errorf("创建SSH会话失败:%v", err.Error()) - msg := Message{ - Type: Closed, - Content: err.Error(), - } - err := WriteMessage(ws, msg) + var observers []global.Tun + observable := global.Observable{ + Subject: &tun, + Observers: observers, + } + + global.Store.Set(sessionId, &observable) + + sess := model.Session{ + ConnectionId: sessionId, + Width: cols, + Height: rows, + Status: model.Connecting, + Recording: recording, + } + // 创建新会话 + logrus.Debugf("创建新会话 %v", sess.ConnectionId) + if err := model.UpdateSessionById(&sess, sessionId); err != nil { return err } - defer session.Close() - - modes := ssh.TerminalModes{ - ssh.ECHO: 1, - ssh.TTY_OP_ISPEED: 14400, - ssh.TTY_OP_OSPEED: 14400, - } - - if err := session.RequestPty("xterm", height, width, modes); err != nil { - return err - } - - var b NextWriter - session.Stdout = &b - session.Stderr = &b - - stdinPipe, err := session.StdinPipe() - if err != nil { - return err - } - - if err := session.Shell(); err != nil { - return err - } - - var recorder *Recorder - var recording string - property, _ := model.FindPropertyByName(guacd.RecordingPath) - if property.Value != "" { - dir := path.Join(property.Value, sessionId) - recorder, recording, err = NewRecorder(dir) - if err != nil { - msg := Message{ - Type: Closed, - Content: "创建录屏文件失败 :( " + err.Error(), - } - return WriteMessage(ws, msg) - } - - header := &Header{ - Title: "test", - Version: 2, - Height: height, - Width: width, - Env: Env{Shell: "/bin/bash", Term: "xterm-256color"}, - Timestamp: int(time.Now().Unix()), - } - - if err := recorder.WriteHeader(header); err != nil { - return err - } - - if err := model.UpdateSessionById(&model.Session{Recording: recording}, sessionId); err != nil { - return err - } - } msg := Message{ Type: Connected, - Content: "Connect to server successfully.\r\n", + Content: "", } _ = WriteMessage(ws, msg) - var mut sync.Mutex - var active = true + quitChan := make(chan bool) + recordingChan := make(chan string, 1024) - go func() { - for true { - mut.Lock() - if !active { - logrus.Debugf("会话: %v -> %v 关闭", sshClient.LocalAddr().String(), sshClient.RemoteAddr().String()) - if recorder != nil { - recorder.Close() - } - CloseSessionById(sessionId, Normal, "正常退出") - break - } - mut.Unlock() + go ReadMessage(nextTerminal, recordingChan, quitChan, ws) - p, n, err := b.Read() - if err != nil { - continue - } - if n > 0 { - s := string(p) - if recorder != nil { - // 录屏 - _ = recorder.WriteData(s) - } - msg := Message{ - Type: Data, - Content: s, - } - message, err := json.Marshal(msg) - if err != nil { - logrus.Warnf("生成Json失败 %v", err) - continue - } - WriteByteMessage(ws, message) - } - time.Sleep(time.Duration(100) * time.Millisecond) - } - }() + go Recoding(nextTerminal, recordingChan, quitChan) + go Monitor(sessionId, recordingChan, quitChan) - for true { + for { _, message, err := ws.ReadMessage() if err != nil { // web socket会话关闭后主动关闭ssh会话 - _ = session.Close() - mut.Lock() - active = false - mut.Unlock() + CloseSessionById(sessionId, Normal, "正常退出") + quitChan <- true + quitChan <- true break } @@ -256,12 +187,12 @@ func SSHEndpoint(c echo.Context) (err error) { logrus.Warnf("解析SSH会话窗口大小失败: %v", err) continue } - if err := session.WindowChange(winSize.Rows, winSize.Cols); err != nil { + if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil { logrus.Warnf("更改SSH会话窗口大小失败: %v", err) continue } case Data: - _, err = stdinPipe.Write([]byte(msg.Content)) + _, err = nextTerminal.Write([]byte(msg.Content)) if err != nil { logrus.Debugf("SSH会话写入失败: %v", err) msg := Message{ @@ -276,16 +207,96 @@ func SSHEndpoint(c echo.Context) (err error) { return err } +func ReadMessage(nextTerminal *term.NextTerminal, stdoutChan chan string, quitChan chan bool, ws *websocket.Conn) { + + var quit bool + for { + select { + case quit = <-quitChan: + if quit { + return + } + default: + p, n, err := nextTerminal.Read() + if err != nil { + msg := Message{ + Type: Closed, + Content: err.Error(), + } + _ = WriteMessage(ws, msg) + } + if n > 0 { + s := string(p) + // 发送一份数据到队列中 + stdoutChan <- s + msg := Message{ + Type: Data, + Content: s, + } + _ = WriteMessage(ws, msg) + } + time.Sleep(time.Duration(10) * time.Millisecond) + } + } +} + +func Recoding(nextTerminal *term.NextTerminal, recordingChan chan string, quitChan chan bool) { + var quit bool + var s string + for { + select { + case quit = <-quitChan: + if quit { + fmt.Println("退出录屏") + return + } + case s = <-recordingChan: + _ = nextTerminal.Recorder.WriteData(s) + default: + + } + } +} + +func Monitor(sessionId string, recordingChan chan string, quitChan chan bool) { + var quit bool + var s string + for { + select { + case quit = <-quitChan: + if quit { + fmt.Println("退出监控") + return + } + case s = <-recordingChan: + msg := Message{ + Type: Data, + Content: s, + } + + observable, ok := global.Store.Get(sessionId) + if ok { + for i := 0; i < len(observable.Observers); i++ { + _ = WriteMessage(observable.Observers[i].WebSocket, msg) + } + } + default: + + } + } +} + func WriteMessage(ws *websocket.Conn, msg Message) error { message, err := json.Marshal(msg) if err != nil { - logrus.Warnf("生成Json失败 %v", err) + return err } WriteByteMessage(ws, message) return err } func CreateSshClientBySession(session model.Session) (sshClient *ssh.Client, err error) { + var ( username = session.Username password = session.Password @@ -293,52 +304,7 @@ func CreateSshClientBySession(session model.Session) (sshClient *ssh.Client, err passphrase = session.Passphrase ) - var authMethod ssh.AuthMethod - if username == "-" || username == "" { - username = "root" - } - if password == "-" { - password = "" - } - if privateKey == "-" { - privateKey = "" - } - if passphrase == "-" { - passphrase = "" - } - - if privateKey != "" { - var key ssh.Signer - if len(passphrase) > 0 { - key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase)) - if err != nil { - return nil, err - } - } else { - key, err = ssh.ParsePrivateKey([]byte(privateKey)) - if err != nil { - return nil, err - } - } - authMethod = ssh.PublicKeys(key) - } else { - authMethod = ssh.Password(password) - } - - config := &ssh.ClientConfig{ - Timeout: 1 * time.Second, - User: username, - Auth: []ssh.AuthMethod{authMethod}, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - - addr := fmt.Sprintf("%s:%d", session.IP, session.Port) - - sshClient, err = ssh.Dial("tcp", addr, config) - if err != nil { - return nil, err - } - return sshClient, nil + return term.NewSshClient(session.IP, session.Port, username, password, privateKey, passphrase) } func WriteByteMessage(ws *websocket.Conn, p []byte) { diff --git a/pkg/api/tunnel.go b/pkg/api/tunnel.go index 75b047e..496d8ab 100644 --- a/pkg/api/tunnel.go +++ b/pkg/api/tunnel.go @@ -46,11 +46,11 @@ func TunEndpoint(c echo.Context) error { if len(connectionId) > 0 { session, err = model.FindSessionByConnectionId(connectionId) if err != nil { - CloseWebSocket(ws, NotFoundSession, "会话不存在") + logrus.Warnf("会话不存在") return err } if session.Status != model.Connected { - CloseWebSocket(ws, NotFoundSession, "会话未在线") + logrus.Warnf("会话未在线") return errors.New("会话未在线") } configuration.ConnectionID = connectionId diff --git a/pkg/global/store.go b/pkg/global/store.go index 9e0cf99..a418503 100644 --- a/pkg/global/store.go +++ b/pkg/global/store.go @@ -2,30 +2,37 @@ package global import ( "github.com/gorilla/websocket" - "github.com/pkg/sftp" - "golang.org/x/crypto/ssh" "next-terminal/pkg/guacd" + "next-terminal/pkg/term" + "strconv" "sync" ) type Tun struct { - Protocol string - Tunnel *guacd.Tunnel - SshClient *ssh.Client - SftpClient *sftp.Client - WebSocket *websocket.Conn + Protocol string + WebSocket *websocket.Conn + Tunnel *guacd.Tunnel + NextTerminal *term.NextTerminal } -func (r *Tun) Close() { - if r.Protocol == "rdp" || r.Protocol == "vnc" { +func (r *Tun) Close(code int, reason string) { + if r.Tunnel != nil { _ = r.Tunnel.Close() - } else { - if r.SshClient != nil { - _ = r.SshClient.Close() - } + } + if r.NextTerminal != nil { + _ = r.NextTerminal.Close() + } - if r.SftpClient != nil { - _ = r.SftpClient.Close() + ws := r.WebSocket + if ws != nil { + if r.Protocol == "rdp" || r.Protocol == "vnc" { + err := guacd.NewInstruction("error", reason, strconv.Itoa(code)) + _ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) + disconnect := guacd.NewInstruction("disconnect") + _ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) + } else { + msg := `{"type":"closed","content":"` + reason + `"}` + _ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) } } } diff --git a/pkg/log/logger.go b/pkg/log/logger.go new file mode 100644 index 0000000..ed3c485 --- /dev/null +++ b/pkg/log/logger.go @@ -0,0 +1,193 @@ +package log + +import ( + "io" + "strconv" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" + "github.com/sirupsen/logrus" +) + +// Logrus : implement Logger +type Logrus struct { + *logrus.Logger +} + +// Logger ... +var Logger = logrus.New() + +// GetEchoLogger for e.Logger +func GetEchoLogger() Logrus { + return Logrus{Logger} +} + +// Level returns logger level +func (l Logrus) Level() log.Lvl { + switch l.Logger.Level { + case logrus.DebugLevel: + return log.DEBUG + case logrus.WarnLevel: + return log.WARN + case logrus.ErrorLevel: + return log.ERROR + case logrus.InfoLevel: + return log.INFO + default: + l.Panic("Invalid level") + } + + return log.OFF +} + +// SetHeader is a stub to satisfy interface +// It's controlled by Logger +func (l Logrus) SetHeader(_ string) {} + +// SetPrefix It's controlled by Logger +func (l Logrus) SetPrefix(s string) {} + +// Prefix It's controlled by Logger +func (l Logrus) Prefix() string { + return "" +} + +// SetLevel set level to logger from given log.Lvl +func (l Logrus) SetLevel(lvl log.Lvl) { + switch lvl { + case log.DEBUG: + Logger.SetLevel(logrus.DebugLevel) + case log.WARN: + Logger.SetLevel(logrus.WarnLevel) + case log.ERROR: + Logger.SetLevel(logrus.ErrorLevel) + case log.INFO: + Logger.SetLevel(logrus.InfoLevel) + default: + l.Panic("Invalid level") + } +} + +// Output logger output func +func (l Logrus) Output() io.Writer { + return l.Out +} + +// SetOutput change output, default os.Stdout +func (l Logrus) SetOutput(w io.Writer) { + Logger.SetOutput(w) +} + +// Printj print json log +func (l Logrus) Printj(j log.JSON) { + Logger.WithFields(logrus.Fields(j)).Print() +} + +// Debugj debug json log +func (l Logrus) Debugj(j log.JSON) { + Logger.WithFields(logrus.Fields(j)).Debug() +} + +// Infoj info json log +func (l Logrus) Infoj(j log.JSON) { + Logger.WithFields(logrus.Fields(j)).Info() +} + +// Warnj warning json log +func (l Logrus) Warnj(j log.JSON) { + Logger.WithFields(logrus.Fields(j)).Warn() +} + +// Errorj error json log +func (l Logrus) Errorj(j log.JSON) { + Logger.WithFields(logrus.Fields(j)).Error() +} + +// Fatalj fatal json log +func (l Logrus) Fatalj(j log.JSON) { + Logger.WithFields(logrus.Fields(j)).Fatal() +} + +// Panicj panic json log +func (l Logrus) Panicj(j log.JSON) { + Logger.WithFields(logrus.Fields(j)).Panic() +} + +// Print string log +func (l Logrus) Print(i ...interface{}) { + Logger.Print(i[0].(string)) +} + +// Debug string log +func (l Logrus) Debug(i ...interface{}) { + Logger.Debug(i[0].(string)) +} + +// Info string log +func (l Logrus) Info(i ...interface{}) { + Logger.Info(i[0].(string)) +} + +// Warn string log +func (l Logrus) Warn(i ...interface{}) { + Logger.Warn(i[0].(string)) +} + +// Error string log +func (l Logrus) Error(i ...interface{}) { + Logger.Error(i[0].(string)) +} + +// Fatal string log +func (l Logrus) Fatal(i ...interface{}) { + Logger.Fatal(i[0].(string)) +} + +// Panic string log +func (l Logrus) Panic(i ...interface{}) { + Logger.Panic(i[0].(string)) +} + +func logrusMiddlewareHandler(c echo.Context, next echo.HandlerFunc) error { + req := c.Request() + res := c.Response() + start := time.Now() + if err := next(c); err != nil { + c.Error(err) + } + stop := time.Now() + + p := req.URL.Path + + bytesIn := req.Header.Get(echo.HeaderContentLength) + + Logger.WithFields(map[string]interface{}{ + "time_rfc3339": time.Now().Format(time.RFC3339), + "remote_ip": c.RealIP(), + "host": req.Host, + "uri": req.RequestURI, + "method": req.Method, + "path": p, + "referer": req.Referer(), + "user_agent": req.UserAgent(), + "status": res.Status, + "latency": strconv.FormatInt(stop.Sub(start).Nanoseconds()/1000, 10), + "latency_human": stop.Sub(start).String(), + "bytes_in": bytesIn, + "bytes_out": strconv.FormatInt(res.Size, 10), + }).Info("Handled request") + + return nil +} + +func logger(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + return logrusMiddlewareHandler(c, next) + } +} + +// Hook is a function to process log. +func Hook() echo.MiddlewareFunc { + return logger +} diff --git a/pkg/term/next_terminal.go b/pkg/term/next_terminal.go new file mode 100644 index 0000000..55c5241 --- /dev/null +++ b/pkg/term/next_terminal.go @@ -0,0 +1,99 @@ +package term + +import ( + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" + "io" +) + +type NextTerminal struct { + SshClient *ssh.Client + SshSession *ssh.Session + StdinPipe io.WriteCloser + SftpClient *sftp.Client + Recorder *Recorder + NextWriter *NextWriter +} + +func NewNextTerminal(ip string, port int, username, password, privateKey, passphrase string, rows, cols int, recording string) (*NextTerminal, error) { + + sshClient, err := NewSshClient(ip, port, username, password, privateKey, passphrase) + if err != nil { + return nil, err + } + + sshSession, err := sshClient.NewSession() + if err != nil { + return nil, err + } + //defer sshSession.Close() + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } + + if err := sshSession.RequestPty("xterm-256color", rows, cols, modes); err != nil { + return nil, err + } + + var nextWriter NextWriter + sshSession.Stdout = &nextWriter + sshSession.Stderr = &nextWriter + + stdinPipe, err := sshSession.StdinPipe() + if err != nil { + return nil, err + } + + if err := sshSession.Shell(); err != nil { + return nil, err + } + + var recorder *Recorder + if recording != "" { + recorder, err = CreateRecording(recording, rows, cols) + if err != nil { + return nil, err + } + } + + terminal := NextTerminal{ + SshClient: sshClient, + SshSession: sshSession, + Recorder: recorder, + StdinPipe: stdinPipe, + NextWriter: &nextWriter, + } + + return &terminal, nil +} + +func (ret *NextTerminal) Write(p []byte) (int, error) { + return ret.StdinPipe.Write(p) +} + +func (ret *NextTerminal) Read() ([]byte, int, error) { + return ret.NextWriter.Read() +} + +func (ret *NextTerminal) Close() error { + if ret.SshSession != nil { + return ret.SshSession.Close() + } + + if ret.SshClient != nil { + return ret.SshClient.Close() + } + + if ret.Recorder != nil { + return ret.Close() + } + + return nil +} + +func (ret *NextTerminal) WindowChange(h int, w int) error { + return ret.SshSession.WindowChange(h, w) +} diff --git a/pkg/term/next_writer.go b/pkg/term/next_writer.go new file mode 100644 index 0000000..cb472ad --- /dev/null +++ b/pkg/term/next_writer.go @@ -0,0 +1,30 @@ +package term + +import ( + "bytes" + "sync" +) + +type NextWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *NextWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +func (w *NextWriter) Read() ([]byte, int, error) { + w.mu.Lock() + defer w.mu.Unlock() + p := w.b.Bytes() + buf := make([]byte, len(p)) + read, err := w.b.Read(buf) + w.b.Reset() + if err != nil { + return nil, 0, err + } + return buf, read, err +} diff --git a/pkg/term/recording.go b/pkg/term/recording.go new file mode 100644 index 0000000..4eda1b2 --- /dev/null +++ b/pkg/term/recording.go @@ -0,0 +1,122 @@ +package term + +import ( + "encoding/json" + "next-terminal/pkg/utils" + "os" + "time" +) + +type Env struct { + Shell string `json:"SHELL"` + Term string `json:"TERM"` +} + +type Header struct { + Title string `json:"title"` + Version int `json:"version"` + Height int `json:"height"` + Width int `json:"width"` + Env Env `json:"env"` + Timestamp int `json:"Timestamp"` +} + +type Recorder struct { + File *os.File + Timestamp int +} + +func NewRecorder(recoding string) (recorder *Recorder, err error) { + recorder = &Recorder{} + + parentDirectory := utils.GetParentDirectory(recoding) + + if utils.FileExists(parentDirectory) { + if err := os.RemoveAll(parentDirectory); err != nil { + return nil, err + } + } + + if err = os.MkdirAll(parentDirectory, 0777); err != nil { + return + } + + var file *os.File + file, err = os.Create(recoding) + if err != nil { + return nil, err + } + + recorder.File = file + return recorder, nil +} + +func (recorder *Recorder) Close() { + if recorder.File != nil { + recorder.File.Close() + } +} + +func (recorder *Recorder) WriteHeader(header *Header) (err error) { + var p []byte + + if p, err = json.Marshal(header); err != nil { + return + } + + if _, err := recorder.File.Write(p); err != nil { + return err + } + if _, err := recorder.File.Write([]byte("\n")); err != nil { + return err + } + + recorder.Timestamp = header.Timestamp + + return +} + +func (recorder *Recorder) WriteData(data string) (err error) { + now := int(time.Now().UnixNano()) + + delta := float64(now-recorder.Timestamp*1000*1000*1000) / 1000 / 1000 / 1000 + + row := make([]interface{}, 0) + row = append(row, delta) + row = append(row, "o") + row = append(row, data) + + var s []byte + if s, err = json.Marshal(row); err != nil { + return + } + if _, err := recorder.File.Write(s); err != nil { + return err + } + if _, err := recorder.File.Write([]byte("\n")); err != nil { + return err + } + return +} + +func CreateRecording(recordingPath string, h int, w int) (*Recorder, error) { + recorder, err := NewRecorder(recordingPath) + if err != nil { + return nil, err + } + + header := &Header{ + Title: "", + Version: 2, + Height: h, + Width: w, + Env: Env{Shell: "/bin/bash", Term: "xterm-256color"}, + Timestamp: int(time.Now().Unix()), + } + + if err := recorder.WriteHeader(header); err != nil { + return nil, err + } + + return recorder, nil +} diff --git a/pkg/term/ssh.go b/pkg/term/ssh.go new file mode 100644 index 0000000..f27b6d2 --- /dev/null +++ b/pkg/term/ssh.go @@ -0,0 +1,53 @@ +package term + +import ( + "fmt" + "golang.org/x/crypto/ssh" + "time" +) + +func NewSshClient(ip string, port int, username, password, privateKey, passphrase string) (*ssh.Client, error) { + var authMethod ssh.AuthMethod + if username == "-" || username == "" { + username = "root" + } + if password == "-" { + password = "" + } + if privateKey == "-" { + privateKey = "" + } + if passphrase == "-" { + passphrase = "" + } + + var err error + if privateKey != "" { + var key ssh.Signer + if len(passphrase) > 0 { + key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase)) + if err != nil { + return nil, err + } + } else { + key, err = ssh.ParsePrivateKey([]byte(privateKey)) + if err != nil { + return nil, err + } + } + authMethod = ssh.PublicKeys(key) + } else { + authMethod = ssh.Password(password) + } + + config := &ssh.ClientConfig{ + Timeout: 1 * time.Second, + User: username, + Auth: []ssh.AuthMethod{authMethod}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + addr := fmt.Sprintf("%s:%d", ip, port) + + return ssh.Dial("tcp", addr, config) +} diff --git a/pkg/term/test/test_ssh.go b/pkg/term/test/test_ssh.go new file mode 100644 index 0000000..fabefdd --- /dev/null +++ b/pkg/term/test/test_ssh.go @@ -0,0 +1,174 @@ +package main + +import ( + "fmt" + "io" + "os" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +type SSHTerminal struct { + Session *ssh.Session + exitMsg string + stdout io.Reader + stdin io.Writer + stderr io.Reader +} + +func main() { + sshConfig := &ssh.ClientConfig{ + User: "root", + Auth: []ssh.AuthMethod{ + ssh.Password("root"), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + client, err := ssh.Dial("tcp", "172.16.101.32:22", sshConfig) + if err != nil { + fmt.Println(err) + } + defer client.Close() + + err = New(client) + if err != nil { + fmt.Println(err) + } +} + +func (t *SSHTerminal) updateTerminalSize() { + + go func() { + // SIGWINCH is sent to the process when the window size of the terminal has + // changed. + sigwinchCh := make(chan os.Signal, 1) + //signal.Notify(sigwinchCh, syscall.SIN) + + fd := int(os.Stdin.Fd()) + termWidth, termHeight, err := terminal.GetSize(fd) + if err != nil { + fmt.Println(err) + } + + for { + select { + // The client updated the size of the local PTY. This change needs to occur + // on the server side PTY as well. + case sigwinch := <-sigwinchCh: + if sigwinch == nil { + return + } + currTermWidth, currTermHeight, err := terminal.GetSize(fd) + + // Terminal size has not changed, don't do anything. + if currTermHeight == termHeight && currTermWidth == termWidth { + continue + } + + t.Session.WindowChange(currTermHeight, currTermWidth) + if err != nil { + fmt.Printf("Unable to send window-change reqest: %s.", err) + continue + } + + termWidth, termHeight = currTermWidth, currTermHeight + + } + } + }() + +} + +func (t *SSHTerminal) interactiveSession() error { + + defer func() { + if t.exitMsg == "" { + fmt.Fprintln(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822)) + } else { + fmt.Fprintln(os.Stdout, t.exitMsg) + } + }() + + fd := int(os.Stdin.Fd()) + state, err := terminal.MakeRaw(fd) + if err != nil { + return err + } + defer terminal.Restore(fd, state) + + termWidth, termHeight, err := terminal.GetSize(fd) + if err != nil { + return err + } + + termType := os.Getenv("TERM") + if termType == "" { + termType = "xterm-256color" + } + + err = t.Session.RequestPty(termType, termHeight, termWidth, ssh.TerminalModes{}) + if err != nil { + return err + } + + t.updateTerminalSize() + + t.stdin, err = t.Session.StdinPipe() + if err != nil { + return err + } + t.stdout, err = t.Session.StdoutPipe() + if err != nil { + return err + } + t.stderr, err = t.Session.StderrPipe() + + go io.Copy(os.Stderr, t.stderr) + go io.Copy(os.Stdout, t.stdout) + go func() { + buf := make([]byte, 128) + for { + n, err := os.Stdin.Read(buf) + if err != nil { + fmt.Println(err) + return + } + if n > 0 { + _, err = t.stdin.Write(buf[:n]) + if err != nil { + fmt.Println(err) + t.exitMsg = err.Error() + return + } + } + } + }() + + err = t.Session.Shell() + if err != nil { + return err + } + err = t.Session.Wait() + if err != nil { + return err + } + return nil +} + +func New(client *ssh.Client) error { + + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + s := SSHTerminal{ + Session: session, + } + + return s.interactiveSession() +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 8d0b22f..930b842 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -10,6 +10,7 @@ import ( "image/png" "net" "os" + "path/filepath" "sort" "strconv" "strings" @@ -123,6 +124,10 @@ func IsFile(path string) bool { return !IsDir(path) } +func GetParentDirectory(directory string) string { + return filepath.Dir(directory) +} + // 去除重复元素 func Distinct(a []string) []string { result := make([]string, 0, len(a)) diff --git a/web/package.json b/web/package.json index 7961bd2..d771c9f 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "next-terminal", - "version": "0.1.1", + "version": "0.2.0", "private": true, "dependencies": { "@ant-design/icons": "^4.3.0", diff --git a/web/src/App.js b/web/src/App.js index aa63f55..9be13bd 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -40,7 +40,7 @@ import {isEmpty, NT_PACKAGE} from "./utils/utils"; import {isAdmin} from "./service/permission"; import UserGroup from "./components/user/UserGroup"; import LoginLog from "./components/session/LoginLog"; -import AccessNaive from "./components/access/AccessNaive"; +import Term from "./components/access/Term"; const {Footer, Sider} = Layout; @@ -114,7 +114,7 @@ class App extends Component { - + diff --git a/web/src/components/access/Access.js b/web/src/components/access/Access.js index 9a39a08..907cdb6 100644 --- a/web/src/components/access/Access.js +++ b/web/src/components/access/Access.js @@ -240,7 +240,7 @@ class Access extends Component { this.showMessage('创建隧道失败'); break; case 802: - this.showMessage('管理员强制断开了此会话'); + this.showMessage('管理员强制关闭了此会话'); break; default: this.showMessage('未知错误。'); diff --git a/web/src/components/access/AccessNaive.css b/web/src/components/access/Term.css similarity index 100% rename from web/src/components/access/AccessNaive.css rename to web/src/components/access/Term.css diff --git a/web/src/components/access/AccessNaive.js b/web/src/components/access/Term.js similarity index 90% rename from web/src/components/access/AccessNaive.js rename to web/src/components/access/Term.js index 9307d74..20d9998 100644 --- a/web/src/components/access/AccessNaive.js +++ b/web/src/components/access/Term.js @@ -9,7 +9,7 @@ import "./Access.css" import request from "../../common/request"; import {message} from "antd"; -class AccessNaive extends Component { +class Term extends Component { state = { width: window.innerWidth, @@ -30,17 +30,9 @@ class AccessNaive extends Component { return; } - let params = { - 'width': this.state.width, - 'height': this.state.height, - 'sessionId': sessionId - }; - - let paramStr = qs.stringify(params); - let term = new Terminal({ fontFamily: 'monaco, Consolas, "Lucida Console", monospace', - fontSize: 14, + fontSize: 15, // theme: { // background: '#1b1b1b', // lineHeight: 17 @@ -81,20 +73,22 @@ class AccessNaive extends Component { }); let token = getToken(); + let params = { + 'cols': term.cols, + 'rows': term.rows, + 'sessionId': sessionId, + 'X-Auth-Token': token + }; - let webSocket = new WebSocket(wsServer + '/ssh?X-Auth-Token=' + token + '&' + paramStr); + let paramStr = qs.stringify(params); + + let webSocket = new WebSocket(wsServer + '/ssh?' + paramStr); let pingInterval; webSocket.onopen = (e => { pingInterval = setInterval(() => { webSocket.send(JSON.stringify({type: 'ping'})) }, 5000); - - let terminalSize = { - cols: term.cols, - rows: term.rows - } - webSocket.send(JSON.stringify({type: 'resize', content: JSON.stringify(terminalSize)})); }); webSocket.onerror = (e) => { @@ -111,8 +105,8 @@ class AccessNaive extends Component { let msg = JSON.parse(e.data); switch (msg['type']) { case 'connected': + // term.write(msg['content']) term.clear(); - this.onWindowResize(); this.updateSessionStatus(sessionId); break; case 'data': @@ -202,4 +196,4 @@ class AccessNaive extends Component { } } -export default AccessNaive; +export default Term; diff --git a/web/src/components/asset/Asset.js b/web/src/components/asset/Asset.js index 85eb187..cf6f16a 100644 --- a/web/src/components/asset/Asset.js +++ b/web/src/components/asset/Asset.js @@ -323,12 +323,12 @@ class Asset extends Component { if (result.code === 1) { if (result.data === true) { message.success({content: '检测完成,您访问的资产在线,即将打开窗口进行访问。', key: id, duration: 3}); - window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`); - // if (protocol === 'ssh') { - // window.open(`#/access-naive?assetId=${id}&assetName=${name}`); - // } else { - // window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`); - // } + // window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`); + if (protocol === 'ssh') { + window.open(`#/term?assetId=${id}&assetName=${name}`); + } else { + window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`); + } } else { message.warn('您访问的资产未在线,请确认网络状态。', 10); }