优化原生ssh连接代码

This commit is contained in:
dushixiang 2021-02-06 00:25:48 +08:00 committed by dushixiang
parent 248815538d
commit d72ab4e21e
20 changed files with 896 additions and 369 deletions

View File

@ -10,7 +10,6 @@ import (
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"io"
"next-terminal/pkg/api"
"next-terminal/pkg/config"
@ -23,7 +22,7 @@ import (
"time"
)
const Version = "v0.1.1"
const Version = "v0.2.0"
func main() {
log.Fatal(Run())
@ -71,7 +70,7 @@ func Run() error {
global.Config.Mysql.Database,
)
global.DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
//Logger: logger.Default.LogMode(logger.Info),
})
} else {
global.DB, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{})

View File

@ -1,101 +0,0 @@
package api
import (
"encoding/json"
"next-terminal/pkg/utils"
"os"
"path"
"time"
)
type Env struct {
Shell string `json:"SHELL"`
Term string `json:"TERM"`
}
type Header struct {
Title string `json:"title"`
Version int `json:"version"`
Height int `json:"height"`
Width int `json:"width"`
Env Env `json:"env"`
Timestamp int `json:"timestamp"`
}
type Recorder struct {
file *os.File
timestamp int
}
func NewRecorder(dir string) (recorder *Recorder, filename string, err error) {
recorder = &Recorder{}
if utils.FileExists(dir) {
if err := os.RemoveAll(dir); err != nil {
return nil, "", err
}
}
if err = os.MkdirAll(dir, 0777); err != nil {
return
}
filename = path.Join(dir, "recording.cast")
var file *os.File
file, err = os.Create(filename)
if err != nil {
return nil, "", err
}
recorder.file = file
return recorder, filename, nil
}
func (recorder *Recorder) Close() {
if recorder.file != nil {
recorder.file.Close()
}
}
func (recorder *Recorder) WriteHeader(header *Header) (err error) {
var p []byte
if p, err = json.Marshal(header); err != nil {
return
}
if _, err := recorder.file.Write(p); err != nil {
return err
}
if _, err := recorder.file.Write([]byte("\n")); err != nil {
return err
}
recorder.timestamp = header.Timestamp
return
}
func (recorder *Recorder) WriteData(data string) (err error) {
now := int(time.Now().UnixNano())
delta := float64(now-recorder.timestamp*1000*1000*1000) / 1000 / 1000 / 1000
row := make([]interface{}, 0)
row = append(row, delta)
row = append(row, "o")
row = append(row, data)
var s []byte
if s, err = json.Marshal(row); err != nil {
return
}
if _, err := recorder.file.Write(s); err != nil {
return err
}
if _, err := recorder.file.Write([]byte("\n")); err != nil {
return err
}
return
}

View File

@ -3,6 +3,7 @@ package api
import (
"net/http"
"next-terminal/pkg/global"
"next-terminal/pkg/log"
"next-terminal/pkg/model"
"github.com/labstack/echo/v4"
@ -15,6 +16,7 @@ func SetupRoutes() *echo.Echo {
e := echo.New()
e.HideBanner = true
e.Logger = log.GetEchoLogger()
e.File("/", "web/build/index.html")
e.File("/logo.svg", "web/build/logo.svg")

View File

@ -4,14 +4,12 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
"io"
"io/ioutil"
"net/http"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"os"
@ -100,7 +98,7 @@ func SessionDisconnectEndpoint(c echo.Context) error {
split := strings.Split(sessionIds, ",")
for i := range split {
CloseSessionById(split[i], ForcedDisconnect, "forced disconnect")
CloseSessionById(split[i], ForcedDisconnect, "管理员强制关闭了此会话")
}
return Success(c, nil)
}
@ -112,16 +110,13 @@ func CloseSessionById(sessionId string, code int, reason string) {
defer mutex.Unlock()
observable, _ := global.Store.Get(sessionId)
if observable != nil {
logrus.Debugf("会话%v创建者退出", observable.Subject.Tunnel.UUID)
observable.Subject.Close()
logrus.Debugf("会话%v创建者退出", sessionId)
observable.Subject.Close(code, reason)
for i := 0; i < len(observable.Observers); i++ {
observable.Observers[i].Close()
CloseWebSocket(observable.Observers[i].WebSocket, code, reason)
logrus.Debugf("强制踢出会话%v的观察者", observable.Observers[i].Tunnel.UUID)
observable.Observers[i].Close(code, reason)
logrus.Debugf("强制踢出会话%v的观察者", sessionId)
}
CloseWebSocket(observable.Subject.WebSocket, code, reason)
}
global.Store.Del(sessionId)
@ -150,17 +145,6 @@ func CloseSessionById(sessionId string, code int, reason string) {
_ = model.UpdateSessionById(&session, sessionId)
}
func CloseWebSocket(ws *websocket.Conn, c int, t string) {
if ws == nil {
return
}
err := guacd.NewInstruction("error", "", strconv.Itoa(c))
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.String()))
disconnect := guacd.NewInstruction("disconnect")
_ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String()))
//defer ws.Close()
}
func SessionResizeEndpoint(c echo.Context) error {
width := c.QueryParam("width")
height := c.QueryParam("height")
@ -274,11 +258,11 @@ func SessionUploadEndpoint(c echo.Context) error {
return errors.New("获取sftp客户端失败")
}
dstFile, err := tun.Subject.SftpClient.Create(remoteFile)
defer dstFile.Close()
dstFile, err := tun.Subject.NextTerminal.SftpClient.Create(remoteFile)
if err != nil {
return err
}
defer dstFile.Close()
buf := make([]byte, 1024)
for {
@ -327,7 +311,7 @@ func SessionDownloadEndpoint(c echo.Context) error {
return errors.New("获取sftp客户端失败")
}
dstFile, err := tun.Subject.SftpClient.Open(remoteFile)
dstFile, err := tun.Subject.NextTerminal.SftpClient.Open(remoteFile)
if err != nil {
return err
}
@ -378,16 +362,16 @@ func SessionLsEndpoint(c echo.Context) error {
return errors.New("获取sftp客户端失败")
}
if tun.Subject.SftpClient == nil {
if tun.Subject.NextTerminal.SftpClient == nil {
sftpClient, err := CreateSftpClient(session)
if err != nil {
logrus.Errorf("创建sftp客户端失败%v", err.Error())
return err
}
tun.Subject.SftpClient = sftpClient
tun.Subject.NextTerminal.SftpClient = sftpClient
}
fileInfos, err := tun.Subject.SftpClient.ReadDir(remoteDir)
fileInfos, err := tun.Subject.NextTerminal.SftpClient.ReadDir(remoteDir)
if err != nil {
return err
}
@ -457,7 +441,7 @@ func SessionMkDirEndpoint(c echo.Context) error {
if !ok {
return errors.New("获取sftp客户端失败")
}
if err := tun.Subject.SftpClient.Mkdir(remoteDir); err != nil {
if err := tun.Subject.NextTerminal.SftpClient.Mkdir(remoteDir); err != nil {
return err
}
return Success(c, nil)
@ -489,7 +473,7 @@ func SessionRmEndpoint(c echo.Context) error {
return errors.New("获取sftp客户端失败")
}
sftpClient := tun.Subject.SftpClient
sftpClient := tun.Subject.NextTerminal.SftpClient
stat, err := sftpClient.Stat(key)
if err != nil {
@ -548,7 +532,7 @@ func SessionRenameEndpoint(c echo.Context) error {
return errors.New("获取sftp客户端失败")
}
sftpClient := tun.Subject.SftpClient
sftpClient := tun.Subject.NextTerminal.SftpClient
if err := sftpClient.Rename(oldName, newName); err != nil {
return err

View File

@ -1,7 +1,6 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
@ -10,12 +9,13 @@ import (
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"net/http"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/pkg/model"
"next-terminal/pkg/term"
"next-terminal/pkg/utils"
"path"
"strconv"
"sync"
"time"
)
@ -26,30 +26,6 @@ var UpGrader = websocket.Upgrader{
Subprotocols: []string{"guacamole"},
}
type NextWriter struct {
b bytes.Buffer
mu sync.Mutex
}
func (w *NextWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Write(p)
}
func (w *NextWriter) Read() ([]byte, int, error) {
w.mu.Lock()
defer w.mu.Unlock()
p := w.b.Bytes()
buf := make([]byte, len(p))
read, err := w.b.Read(buf)
w.b.Reset()
if err != nil {
return nil, 0, err
}
return buf, read, err
}
const (
Connected = "connected"
Data = "data"
@ -75,14 +51,14 @@ func SSHEndpoint(c echo.Context) (err error) {
}
sessionId := c.QueryParam("sessionId")
width, _ := strconv.Atoi(c.QueryParam("width"))
height, _ := strconv.Atoi(c.QueryParam("height"))
cols, _ := strconv.Atoi(c.QueryParam("cols"))
rows, _ := strconv.Atoi(c.QueryParam("rows"))
aSession, err := model.FindSessionById(sessionId)
session, err := model.FindSessionById(sessionId)
if err != nil {
msg := Message{
Type: Closed,
Content: "get session error." + err.Error(),
Content: "get sshSession error." + err.Error(),
}
_ = WriteMessage(ws, msg)
return err
@ -96,7 +72,7 @@ func SSHEndpoint(c echo.Context) (err error) {
return err
}
if !utils.Contains(assetIds, aSession.AssetId) {
if !utils.Contains(assetIds, session.AssetId) {
msg := Message{
Type: Closed,
Content: "您没有权限访问此资产",
@ -105,7 +81,41 @@ func SSHEndpoint(c echo.Context) (err error) {
}
}
sshClient, err := CreateSshClientBySession(aSession)
var (
username = session.Username
password = session.Password
privateKey = session.PrivateKey
passphrase = session.Passphrase
ip = session.IP
port = session.Port
)
recording := ""
propertyMap := model.FindAllPropertiesMap()
if propertyMap[guacd.EnableRecording] == "true" {
recording = path.Join(propertyMap[guacd.RecordingPath], sessionId, "recording.cast")
}
tun := global.Tun{
Protocol: session.Protocol,
WebSocket: ws,
}
if session.ConnectionId != "" {
// 监控会话
observable, ok := global.Store.Get(sessionId)
if ok {
observers := append(observable.Observers, tun)
observable.Observers = observers
global.Store.Set(sessionId, observable)
logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers))
}
return err
}
nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording)
if err != nil {
logrus.Errorf("创建SSH客户端失败%v", err.Error())
msg := Message{
@ -115,129 +125,50 @@ func SSHEndpoint(c echo.Context) (err error) {
err := WriteMessage(ws, msg)
return err
}
tun.NextTerminal = nextTerminal
session, err := sshClient.NewSession()
if err != nil {
logrus.Errorf("创建SSH会话失败%v", err.Error())
msg := Message{
Type: Closed,
Content: err.Error(),
}
err := WriteMessage(ws, msg)
var observers []global.Tun
observable := global.Observable{
Subject: &tun,
Observers: observers,
}
global.Store.Set(sessionId, &observable)
sess := model.Session{
ConnectionId: sessionId,
Width: cols,
Height: rows,
Status: model.Connecting,
Recording: recording,
}
// 创建新会话
logrus.Debugf("创建新会话 %v", sess.ConnectionId)
if err := model.UpdateSessionById(&sess, sessionId); err != nil {
return err
}
defer session.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := session.RequestPty("xterm", height, width, modes); err != nil {
return err
}
var b NextWriter
session.Stdout = &b
session.Stderr = &b
stdinPipe, err := session.StdinPipe()
if err != nil {
return err
}
if err := session.Shell(); err != nil {
return err
}
var recorder *Recorder
var recording string
property, _ := model.FindPropertyByName(guacd.RecordingPath)
if property.Value != "" {
dir := path.Join(property.Value, sessionId)
recorder, recording, err = NewRecorder(dir)
if err != nil {
msg := Message{
Type: Closed,
Content: "创建录屏文件失败 :( " + err.Error(),
}
return WriteMessage(ws, msg)
}
header := &Header{
Title: "test",
Version: 2,
Height: height,
Width: width,
Env: Env{Shell: "/bin/bash", Term: "xterm-256color"},
Timestamp: int(time.Now().Unix()),
}
if err := recorder.WriteHeader(header); err != nil {
return err
}
if err := model.UpdateSessionById(&model.Session{Recording: recording}, sessionId); err != nil {
return err
}
}
msg := Message{
Type: Connected,
Content: "Connect to server successfully.\r\n",
Content: "",
}
_ = WriteMessage(ws, msg)
var mut sync.Mutex
var active = true
quitChan := make(chan bool)
recordingChan := make(chan string, 1024)
go func() {
for true {
mut.Lock()
if !active {
logrus.Debugf("会话: %v -> %v 关闭", sshClient.LocalAddr().String(), sshClient.RemoteAddr().String())
if recorder != nil {
recorder.Close()
}
CloseSessionById(sessionId, Normal, "正常退出")
break
}
mut.Unlock()
go ReadMessage(nextTerminal, recordingChan, quitChan, ws)
p, n, err := b.Read()
if err != nil {
continue
}
if n > 0 {
s := string(p)
if recorder != nil {
// 录屏
_ = recorder.WriteData(s)
}
msg := Message{
Type: Data,
Content: s,
}
message, err := json.Marshal(msg)
if err != nil {
logrus.Warnf("生成Json失败 %v", err)
continue
}
WriteByteMessage(ws, message)
}
time.Sleep(time.Duration(100) * time.Millisecond)
}
}()
go Recoding(nextTerminal, recordingChan, quitChan)
go Monitor(sessionId, recordingChan, quitChan)
for true {
for {
_, message, err := ws.ReadMessage()
if err != nil {
// web socket会话关闭后主动关闭ssh会话
_ = session.Close()
mut.Lock()
active = false
mut.Unlock()
CloseSessionById(sessionId, Normal, "正常退出")
quitChan <- true
quitChan <- true
break
}
@ -256,12 +187,12 @@ func SSHEndpoint(c echo.Context) (err error) {
logrus.Warnf("解析SSH会话窗口大小失败: %v", err)
continue
}
if err := session.WindowChange(winSize.Rows, winSize.Cols); err != nil {
if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil {
logrus.Warnf("更改SSH会话窗口大小失败: %v", err)
continue
}
case Data:
_, err = stdinPipe.Write([]byte(msg.Content))
_, err = nextTerminal.Write([]byte(msg.Content))
if err != nil {
logrus.Debugf("SSH会话写入失败: %v", err)
msg := Message{
@ -276,16 +207,96 @@ func SSHEndpoint(c echo.Context) (err error) {
return err
}
func ReadMessage(nextTerminal *term.NextTerminal, stdoutChan chan string, quitChan chan bool, ws *websocket.Conn) {
var quit bool
for {
select {
case quit = <-quitChan:
if quit {
return
}
default:
p, n, err := nextTerminal.Read()
if err != nil {
msg := Message{
Type: Closed,
Content: err.Error(),
}
_ = WriteMessage(ws, msg)
}
if n > 0 {
s := string(p)
// 发送一份数据到队列中
stdoutChan <- s
msg := Message{
Type: Data,
Content: s,
}
_ = WriteMessage(ws, msg)
}
time.Sleep(time.Duration(10) * time.Millisecond)
}
}
}
func Recoding(nextTerminal *term.NextTerminal, recordingChan chan string, quitChan chan bool) {
var quit bool
var s string
for {
select {
case quit = <-quitChan:
if quit {
fmt.Println("退出录屏")
return
}
case s = <-recordingChan:
_ = nextTerminal.Recorder.WriteData(s)
default:
}
}
}
func Monitor(sessionId string, recordingChan chan string, quitChan chan bool) {
var quit bool
var s string
for {
select {
case quit = <-quitChan:
if quit {
fmt.Println("退出监控")
return
}
case s = <-recordingChan:
msg := Message{
Type: Data,
Content: s,
}
observable, ok := global.Store.Get(sessionId)
if ok {
for i := 0; i < len(observable.Observers); i++ {
_ = WriteMessage(observable.Observers[i].WebSocket, msg)
}
}
default:
}
}
}
func WriteMessage(ws *websocket.Conn, msg Message) error {
message, err := json.Marshal(msg)
if err != nil {
logrus.Warnf("生成Json失败 %v", err)
return err
}
WriteByteMessage(ws, message)
return err
}
func CreateSshClientBySession(session model.Session) (sshClient *ssh.Client, err error) {
var (
username = session.Username
password = session.Password
@ -293,52 +304,7 @@ func CreateSshClientBySession(session model.Session) (sshClient *ssh.Client, err
passphrase = session.Passphrase
)
var authMethod ssh.AuthMethod
if username == "-" || username == "" {
username = "root"
}
if password == "-" {
password = ""
}
if privateKey == "-" {
privateKey = ""
}
if passphrase == "-" {
passphrase = ""
}
if privateKey != "" {
var key ssh.Signer
if len(passphrase) > 0 {
key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase))
if err != nil {
return nil, err
}
} else {
key, err = ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, err
}
}
authMethod = ssh.PublicKeys(key)
} else {
authMethod = ssh.Password(password)
}
config := &ssh.ClientConfig{
Timeout: 1 * time.Second,
User: username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", session.IP, session.Port)
sshClient, err = ssh.Dial("tcp", addr, config)
if err != nil {
return nil, err
}
return sshClient, nil
return term.NewSshClient(session.IP, session.Port, username, password, privateKey, passphrase)
}
func WriteByteMessage(ws *websocket.Conn, p []byte) {

View File

@ -46,11 +46,11 @@ func TunEndpoint(c echo.Context) error {
if len(connectionId) > 0 {
session, err = model.FindSessionByConnectionId(connectionId)
if err != nil {
CloseWebSocket(ws, NotFoundSession, "会话不存在")
logrus.Warnf("会话不存在")
return err
}
if session.Status != model.Connected {
CloseWebSocket(ws, NotFoundSession, "会话未在线")
logrus.Warnf("会话未在线")
return errors.New("会话未在线")
}
configuration.ConnectionID = connectionId

View File

@ -2,30 +2,37 @@ package global
import (
"github.com/gorilla/websocket"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"next-terminal/pkg/guacd"
"next-terminal/pkg/term"
"strconv"
"sync"
)
type Tun struct {
Protocol string
Tunnel *guacd.Tunnel
SshClient *ssh.Client
SftpClient *sftp.Client
WebSocket *websocket.Conn
Protocol string
WebSocket *websocket.Conn
Tunnel *guacd.Tunnel
NextTerminal *term.NextTerminal
}
func (r *Tun) Close() {
if r.Protocol == "rdp" || r.Protocol == "vnc" {
func (r *Tun) Close(code int, reason string) {
if r.Tunnel != nil {
_ = r.Tunnel.Close()
} else {
if r.SshClient != nil {
_ = r.SshClient.Close()
}
}
if r.NextTerminal != nil {
_ = r.NextTerminal.Close()
}
if r.SftpClient != nil {
_ = r.SftpClient.Close()
ws := r.WebSocket
if ws != nil {
if r.Protocol == "rdp" || r.Protocol == "vnc" {
err := guacd.NewInstruction("error", reason, strconv.Itoa(code))
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.String()))
disconnect := guacd.NewInstruction("disconnect")
_ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String()))
} else {
msg := `{"type":"closed","content":"` + reason + `"}`
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg))
}
}
}

193
pkg/log/logger.go Normal file
View File

@ -0,0 +1,193 @@
package log
import (
"io"
"strconv"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/sirupsen/logrus"
)
// Logrus : implement Logger
type Logrus struct {
*logrus.Logger
}
// Logger ...
var Logger = logrus.New()
// GetEchoLogger for e.Logger
func GetEchoLogger() Logrus {
return Logrus{Logger}
}
// Level returns logger level
func (l Logrus) Level() log.Lvl {
switch l.Logger.Level {
case logrus.DebugLevel:
return log.DEBUG
case logrus.WarnLevel:
return log.WARN
case logrus.ErrorLevel:
return log.ERROR
case logrus.InfoLevel:
return log.INFO
default:
l.Panic("Invalid level")
}
return log.OFF
}
// SetHeader is a stub to satisfy interface
// It's controlled by Logger
func (l Logrus) SetHeader(_ string) {}
// SetPrefix It's controlled by Logger
func (l Logrus) SetPrefix(s string) {}
// Prefix It's controlled by Logger
func (l Logrus) Prefix() string {
return ""
}
// SetLevel set level to logger from given log.Lvl
func (l Logrus) SetLevel(lvl log.Lvl) {
switch lvl {
case log.DEBUG:
Logger.SetLevel(logrus.DebugLevel)
case log.WARN:
Logger.SetLevel(logrus.WarnLevel)
case log.ERROR:
Logger.SetLevel(logrus.ErrorLevel)
case log.INFO:
Logger.SetLevel(logrus.InfoLevel)
default:
l.Panic("Invalid level")
}
}
// Output logger output func
func (l Logrus) Output() io.Writer {
return l.Out
}
// SetOutput change output, default os.Stdout
func (l Logrus) SetOutput(w io.Writer) {
Logger.SetOutput(w)
}
// Printj print json log
func (l Logrus) Printj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Print()
}
// Debugj debug json log
func (l Logrus) Debugj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Debug()
}
// Infoj info json log
func (l Logrus) Infoj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Info()
}
// Warnj warning json log
func (l Logrus) Warnj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Warn()
}
// Errorj error json log
func (l Logrus) Errorj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Error()
}
// Fatalj fatal json log
func (l Logrus) Fatalj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Fatal()
}
// Panicj panic json log
func (l Logrus) Panicj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Panic()
}
// Print string log
func (l Logrus) Print(i ...interface{}) {
Logger.Print(i[0].(string))
}
// Debug string log
func (l Logrus) Debug(i ...interface{}) {
Logger.Debug(i[0].(string))
}
// Info string log
func (l Logrus) Info(i ...interface{}) {
Logger.Info(i[0].(string))
}
// Warn string log
func (l Logrus) Warn(i ...interface{}) {
Logger.Warn(i[0].(string))
}
// Error string log
func (l Logrus) Error(i ...interface{}) {
Logger.Error(i[0].(string))
}
// Fatal string log
func (l Logrus) Fatal(i ...interface{}) {
Logger.Fatal(i[0].(string))
}
// Panic string log
func (l Logrus) Panic(i ...interface{}) {
Logger.Panic(i[0].(string))
}
func logrusMiddlewareHandler(c echo.Context, next echo.HandlerFunc) error {
req := c.Request()
res := c.Response()
start := time.Now()
if err := next(c); err != nil {
c.Error(err)
}
stop := time.Now()
p := req.URL.Path
bytesIn := req.Header.Get(echo.HeaderContentLength)
Logger.WithFields(map[string]interface{}{
"time_rfc3339": time.Now().Format(time.RFC3339),
"remote_ip": c.RealIP(),
"host": req.Host,
"uri": req.RequestURI,
"method": req.Method,
"path": p,
"referer": req.Referer(),
"user_agent": req.UserAgent(),
"status": res.Status,
"latency": strconv.FormatInt(stop.Sub(start).Nanoseconds()/1000, 10),
"latency_human": stop.Sub(start).String(),
"bytes_in": bytesIn,
"bytes_out": strconv.FormatInt(res.Size, 10),
}).Info("Handled request")
return nil
}
func logger(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
return logrusMiddlewareHandler(c, next)
}
}
// Hook is a function to process log.
func Hook() echo.MiddlewareFunc {
return logger
}

99
pkg/term/next_terminal.go Normal file
View File

@ -0,0 +1,99 @@
package term
import (
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"io"
)
type NextTerminal struct {
SshClient *ssh.Client
SshSession *ssh.Session
StdinPipe io.WriteCloser
SftpClient *sftp.Client
Recorder *Recorder
NextWriter *NextWriter
}
func NewNextTerminal(ip string, port int, username, password, privateKey, passphrase string, rows, cols int, recording string) (*NextTerminal, error) {
sshClient, err := NewSshClient(ip, port, username, password, privateKey, passphrase)
if err != nil {
return nil, err
}
sshSession, err := sshClient.NewSession()
if err != nil {
return nil, err
}
//defer sshSession.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := sshSession.RequestPty("xterm-256color", rows, cols, modes); err != nil {
return nil, err
}
var nextWriter NextWriter
sshSession.Stdout = &nextWriter
sshSession.Stderr = &nextWriter
stdinPipe, err := sshSession.StdinPipe()
if err != nil {
return nil, err
}
if err := sshSession.Shell(); err != nil {
return nil, err
}
var recorder *Recorder
if recording != "" {
recorder, err = CreateRecording(recording, rows, cols)
if err != nil {
return nil, err
}
}
terminal := NextTerminal{
SshClient: sshClient,
SshSession: sshSession,
Recorder: recorder,
StdinPipe: stdinPipe,
NextWriter: &nextWriter,
}
return &terminal, nil
}
func (ret *NextTerminal) Write(p []byte) (int, error) {
return ret.StdinPipe.Write(p)
}
func (ret *NextTerminal) Read() ([]byte, int, error) {
return ret.NextWriter.Read()
}
func (ret *NextTerminal) Close() error {
if ret.SshSession != nil {
return ret.SshSession.Close()
}
if ret.SshClient != nil {
return ret.SshClient.Close()
}
if ret.Recorder != nil {
return ret.Close()
}
return nil
}
func (ret *NextTerminal) WindowChange(h int, w int) error {
return ret.SshSession.WindowChange(h, w)
}

30
pkg/term/next_writer.go Normal file
View File

@ -0,0 +1,30 @@
package term
import (
"bytes"
"sync"
)
type NextWriter struct {
b bytes.Buffer
mu sync.Mutex
}
func (w *NextWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Write(p)
}
func (w *NextWriter) Read() ([]byte, int, error) {
w.mu.Lock()
defer w.mu.Unlock()
p := w.b.Bytes()
buf := make([]byte, len(p))
read, err := w.b.Read(buf)
w.b.Reset()
if err != nil {
return nil, 0, err
}
return buf, read, err
}

122
pkg/term/recording.go Normal file
View File

@ -0,0 +1,122 @@
package term
import (
"encoding/json"
"next-terminal/pkg/utils"
"os"
"time"
)
type Env struct {
Shell string `json:"SHELL"`
Term string `json:"TERM"`
}
type Header struct {
Title string `json:"title"`
Version int `json:"version"`
Height int `json:"height"`
Width int `json:"width"`
Env Env `json:"env"`
Timestamp int `json:"Timestamp"`
}
type Recorder struct {
File *os.File
Timestamp int
}
func NewRecorder(recoding string) (recorder *Recorder, err error) {
recorder = &Recorder{}
parentDirectory := utils.GetParentDirectory(recoding)
if utils.FileExists(parentDirectory) {
if err := os.RemoveAll(parentDirectory); err != nil {
return nil, err
}
}
if err = os.MkdirAll(parentDirectory, 0777); err != nil {
return
}
var file *os.File
file, err = os.Create(recoding)
if err != nil {
return nil, err
}
recorder.File = file
return recorder, nil
}
func (recorder *Recorder) Close() {
if recorder.File != nil {
recorder.File.Close()
}
}
func (recorder *Recorder) WriteHeader(header *Header) (err error) {
var p []byte
if p, err = json.Marshal(header); err != nil {
return
}
if _, err := recorder.File.Write(p); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
recorder.Timestamp = header.Timestamp
return
}
func (recorder *Recorder) WriteData(data string) (err error) {
now := int(time.Now().UnixNano())
delta := float64(now-recorder.Timestamp*1000*1000*1000) / 1000 / 1000 / 1000
row := make([]interface{}, 0)
row = append(row, delta)
row = append(row, "o")
row = append(row, data)
var s []byte
if s, err = json.Marshal(row); err != nil {
return
}
if _, err := recorder.File.Write(s); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
return
}
func CreateRecording(recordingPath string, h int, w int) (*Recorder, error) {
recorder, err := NewRecorder(recordingPath)
if err != nil {
return nil, err
}
header := &Header{
Title: "",
Version: 2,
Height: h,
Width: w,
Env: Env{Shell: "/bin/bash", Term: "xterm-256color"},
Timestamp: int(time.Now().Unix()),
}
if err := recorder.WriteHeader(header); err != nil {
return nil, err
}
return recorder, nil
}

53
pkg/term/ssh.go Normal file
View File

@ -0,0 +1,53 @@
package term
import (
"fmt"
"golang.org/x/crypto/ssh"
"time"
)
func NewSshClient(ip string, port int, username, password, privateKey, passphrase string) (*ssh.Client, error) {
var authMethod ssh.AuthMethod
if username == "-" || username == "" {
username = "root"
}
if password == "-" {
password = ""
}
if privateKey == "-" {
privateKey = ""
}
if passphrase == "-" {
passphrase = ""
}
var err error
if privateKey != "" {
var key ssh.Signer
if len(passphrase) > 0 {
key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase))
if err != nil {
return nil, err
}
} else {
key, err = ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, err
}
}
authMethod = ssh.PublicKeys(key)
} else {
authMethod = ssh.Password(password)
}
config := &ssh.ClientConfig{
Timeout: 1 * time.Second,
User: username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", ip, port)
return ssh.Dial("tcp", addr, config)
}

174
pkg/term/test/test_ssh.go Normal file
View File

@ -0,0 +1,174 @@
package main
import (
"fmt"
"io"
"os"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
type SSHTerminal struct {
Session *ssh.Session
exitMsg string
stdout io.Reader
stdin io.Writer
stderr io.Reader
}
func main() {
sshConfig := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{
ssh.Password("root"),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
client, err := ssh.Dial("tcp", "172.16.101.32:22", sshConfig)
if err != nil {
fmt.Println(err)
}
defer client.Close()
err = New(client)
if err != nil {
fmt.Println(err)
}
}
func (t *SSHTerminal) updateTerminalSize() {
go func() {
// SIGWINCH is sent to the process when the window size of the terminal has
// changed.
sigwinchCh := make(chan os.Signal, 1)
//signal.Notify(sigwinchCh, syscall.SIN)
fd := int(os.Stdin.Fd())
termWidth, termHeight, err := terminal.GetSize(fd)
if err != nil {
fmt.Println(err)
}
for {
select {
// The client updated the size of the local PTY. This change needs to occur
// on the server side PTY as well.
case sigwinch := <-sigwinchCh:
if sigwinch == nil {
return
}
currTermWidth, currTermHeight, err := terminal.GetSize(fd)
// Terminal size has not changed, don't do anything.
if currTermHeight == termHeight && currTermWidth == termWidth {
continue
}
t.Session.WindowChange(currTermHeight, currTermWidth)
if err != nil {
fmt.Printf("Unable to send window-change reqest: %s.", err)
continue
}
termWidth, termHeight = currTermWidth, currTermHeight
}
}
}()
}
func (t *SSHTerminal) interactiveSession() error {
defer func() {
if t.exitMsg == "" {
fmt.Fprintln(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822))
} else {
fmt.Fprintln(os.Stdout, t.exitMsg)
}
}()
fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd)
if err != nil {
return err
}
defer terminal.Restore(fd, state)
termWidth, termHeight, err := terminal.GetSize(fd)
if err != nil {
return err
}
termType := os.Getenv("TERM")
if termType == "" {
termType = "xterm-256color"
}
err = t.Session.RequestPty(termType, termHeight, termWidth, ssh.TerminalModes{})
if err != nil {
return err
}
t.updateTerminalSize()
t.stdin, err = t.Session.StdinPipe()
if err != nil {
return err
}
t.stdout, err = t.Session.StdoutPipe()
if err != nil {
return err
}
t.stderr, err = t.Session.StderrPipe()
go io.Copy(os.Stderr, t.stderr)
go io.Copy(os.Stdout, t.stdout)
go func() {
buf := make([]byte, 128)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
fmt.Println(err)
return
}
if n > 0 {
_, err = t.stdin.Write(buf[:n])
if err != nil {
fmt.Println(err)
t.exitMsg = err.Error()
return
}
}
}
}()
err = t.Session.Shell()
if err != nil {
return err
}
err = t.Session.Wait()
if err != nil {
return err
}
return nil
}
func New(client *ssh.Client) error {
session, err := client.NewSession()
if err != nil {
return err
}
defer session.Close()
s := SSHTerminal{
Session: session,
}
return s.interactiveSession()
}

View File

@ -10,6 +10,7 @@ import (
"image/png"
"net"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
@ -123,6 +124,10 @@ func IsFile(path string) bool {
return !IsDir(path)
}
func GetParentDirectory(directory string) string {
return filepath.Dir(directory)
}
// 去除重复元素
func Distinct(a []string) []string {
result := make([]string, 0, len(a))

View File

@ -1,6 +1,6 @@
{
"name": "next-terminal",
"version": "0.1.1",
"version": "0.2.0",
"private": true,
"dependencies": {
"@ant-design/icons": "^4.3.0",

View File

@ -40,7 +40,7 @@ import {isEmpty, NT_PACKAGE} from "./utils/utils";
import {isAdmin} from "./service/permission";
import UserGroup from "./components/user/UserGroup";
import LoginLog from "./components/session/LoginLog";
import AccessNaive from "./components/access/AccessNaive";
import Term from "./components/access/Term";
const {Footer, Sider} = Layout;
@ -114,7 +114,7 @@ class App extends Component {
<Switch>
<Route path="/access" component={Access}/>
<Route path="/access-naive" component={AccessNaive}/>
<Route path="/term" component={Term}/>
<Route path="/login"><Login updateUser={this.updateUser}/></Route>
<Route path="/">

View File

@ -240,7 +240,7 @@ class Access extends Component {
this.showMessage('创建隧道失败');
break;
case 802:
this.showMessage('管理员强制断开了此会话');
this.showMessage('管理员强制关闭了此会话');
break;
default:
this.showMessage('未知错误。');

View File

@ -9,7 +9,7 @@ import "./Access.css"
import request from "../../common/request";
import {message} from "antd";
class AccessNaive extends Component {
class Term extends Component {
state = {
width: window.innerWidth,
@ -30,17 +30,9 @@ class AccessNaive extends Component {
return;
}
let params = {
'width': this.state.width,
'height': this.state.height,
'sessionId': sessionId
};
let paramStr = qs.stringify(params);
let term = new Terminal({
fontFamily: 'monaco, Consolas, "Lucida Console", monospace',
fontSize: 14,
fontSize: 15,
// theme: {
// background: '#1b1b1b',
// lineHeight: 17
@ -81,20 +73,22 @@ class AccessNaive extends Component {
});
let token = getToken();
let params = {
'cols': term.cols,
'rows': term.rows,
'sessionId': sessionId,
'X-Auth-Token': token
};
let webSocket = new WebSocket(wsServer + '/ssh?X-Auth-Token=' + token + '&' + paramStr);
let paramStr = qs.stringify(params);
let webSocket = new WebSocket(wsServer + '/ssh?' + paramStr);
let pingInterval;
webSocket.onopen = (e => {
pingInterval = setInterval(() => {
webSocket.send(JSON.stringify({type: 'ping'}))
}, 5000);
let terminalSize = {
cols: term.cols,
rows: term.rows
}
webSocket.send(JSON.stringify({type: 'resize', content: JSON.stringify(terminalSize)}));
});
webSocket.onerror = (e) => {
@ -111,8 +105,8 @@ class AccessNaive extends Component {
let msg = JSON.parse(e.data);
switch (msg['type']) {
case 'connected':
// term.write(msg['content'])
term.clear();
this.onWindowResize();
this.updateSessionStatus(sessionId);
break;
case 'data':
@ -202,4 +196,4 @@ class AccessNaive extends Component {
}
}
export default AccessNaive;
export default Term;

View File

@ -323,12 +323,12 @@ class Asset extends Component {
if (result.code === 1) {
if (result.data === true) {
message.success({content: '检测完成,您访问的资产在线,即将打开窗口进行访问。', key: id, duration: 3});
window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`);
// if (protocol === 'ssh') {
// window.open(`#/access-naive?assetId=${id}&assetName=${name}`);
// } else {
// window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`);
// }
// window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`);
if (protocol === 'ssh') {
window.open(`#/term?assetId=${id}&assetName=${name}`);
} else {
window.open(`#/access?assetId=${id}&assetName=${name}&protocol=${protocol}`);
}
} else {
message.warn('您访问的资产未在线,请确认网络状态。', 10);
}