优化原生ssh连接代码
This commit is contained in:
344
pkg/api/ssh.go
344
pkg/api/ssh.go
@ -1,7 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -10,12 +9,13 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net/http"
|
||||
"next-terminal/pkg/global"
|
||||
"next-terminal/pkg/guacd"
|
||||
"next-terminal/pkg/model"
|
||||
"next-terminal/pkg/term"
|
||||
"next-terminal/pkg/utils"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -26,30 +26,6 @@ var UpGrader = websocket.Upgrader{
|
||||
Subprotocols: []string{"guacamole"},
|
||||
}
|
||||
|
||||
type NextWriter struct {
|
||||
b bytes.Buffer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (w *NextWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.b.Write(p)
|
||||
}
|
||||
|
||||
func (w *NextWriter) Read() ([]byte, int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
p := w.b.Bytes()
|
||||
buf := make([]byte, len(p))
|
||||
read, err := w.b.Read(buf)
|
||||
w.b.Reset()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return buf, read, err
|
||||
}
|
||||
|
||||
const (
|
||||
Connected = "connected"
|
||||
Data = "data"
|
||||
@ -75,14 +51,14 @@ func SSHEndpoint(c echo.Context) (err error) {
|
||||
}
|
||||
|
||||
sessionId := c.QueryParam("sessionId")
|
||||
width, _ := strconv.Atoi(c.QueryParam("width"))
|
||||
height, _ := strconv.Atoi(c.QueryParam("height"))
|
||||
cols, _ := strconv.Atoi(c.QueryParam("cols"))
|
||||
rows, _ := strconv.Atoi(c.QueryParam("rows"))
|
||||
|
||||
aSession, err := model.FindSessionById(sessionId)
|
||||
session, err := model.FindSessionById(sessionId)
|
||||
if err != nil {
|
||||
msg := Message{
|
||||
Type: Closed,
|
||||
Content: "get session error." + err.Error(),
|
||||
Content: "get sshSession error." + err.Error(),
|
||||
}
|
||||
_ = WriteMessage(ws, msg)
|
||||
return err
|
||||
@ -96,7 +72,7 @@ func SSHEndpoint(c echo.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if !utils.Contains(assetIds, aSession.AssetId) {
|
||||
if !utils.Contains(assetIds, session.AssetId) {
|
||||
msg := Message{
|
||||
Type: Closed,
|
||||
Content: "您没有权限访问此资产",
|
||||
@ -105,7 +81,41 @@ func SSHEndpoint(c echo.Context) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
sshClient, err := CreateSshClientBySession(aSession)
|
||||
var (
|
||||
username = session.Username
|
||||
password = session.Password
|
||||
privateKey = session.PrivateKey
|
||||
passphrase = session.Passphrase
|
||||
ip = session.IP
|
||||
port = session.Port
|
||||
)
|
||||
|
||||
recording := ""
|
||||
propertyMap := model.FindAllPropertiesMap()
|
||||
if propertyMap[guacd.EnableRecording] == "true" {
|
||||
recording = path.Join(propertyMap[guacd.RecordingPath], sessionId, "recording.cast")
|
||||
}
|
||||
|
||||
tun := global.Tun{
|
||||
Protocol: session.Protocol,
|
||||
WebSocket: ws,
|
||||
}
|
||||
|
||||
if session.ConnectionId != "" {
|
||||
// 监控会话
|
||||
observable, ok := global.Store.Get(sessionId)
|
||||
if ok {
|
||||
observers := append(observable.Observers, tun)
|
||||
observable.Observers = observers
|
||||
global.Store.Set(sessionId, observable)
|
||||
logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording)
|
||||
|
||||
if err != nil {
|
||||
logrus.Errorf("创建SSH客户端失败:%v", err.Error())
|
||||
msg := Message{
|
||||
@ -115,129 +125,50 @@ func SSHEndpoint(c echo.Context) (err error) {
|
||||
err := WriteMessage(ws, msg)
|
||||
return err
|
||||
}
|
||||
tun.NextTerminal = nextTerminal
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
logrus.Errorf("创建SSH会话失败:%v", err.Error())
|
||||
msg := Message{
|
||||
Type: Closed,
|
||||
Content: err.Error(),
|
||||
}
|
||||
err := WriteMessage(ws, msg)
|
||||
var observers []global.Tun
|
||||
observable := global.Observable{
|
||||
Subject: &tun,
|
||||
Observers: observers,
|
||||
}
|
||||
|
||||
global.Store.Set(sessionId, &observable)
|
||||
|
||||
sess := model.Session{
|
||||
ConnectionId: sessionId,
|
||||
Width: cols,
|
||||
Height: rows,
|
||||
Status: model.Connecting,
|
||||
Recording: recording,
|
||||
}
|
||||
// 创建新会话
|
||||
logrus.Debugf("创建新会话 %v", sess.ConnectionId)
|
||||
if err := model.UpdateSessionById(&sess, sessionId); err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm", height, width, modes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var b NextWriter
|
||||
session.Stdout = &b
|
||||
session.Stderr = &b
|
||||
|
||||
stdinPipe, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var recorder *Recorder
|
||||
var recording string
|
||||
property, _ := model.FindPropertyByName(guacd.RecordingPath)
|
||||
if property.Value != "" {
|
||||
dir := path.Join(property.Value, sessionId)
|
||||
recorder, recording, err = NewRecorder(dir)
|
||||
if err != nil {
|
||||
msg := Message{
|
||||
Type: Closed,
|
||||
Content: "创建录屏文件失败 :( " + err.Error(),
|
||||
}
|
||||
return WriteMessage(ws, msg)
|
||||
}
|
||||
|
||||
header := &Header{
|
||||
Title: "test",
|
||||
Version: 2,
|
||||
Height: height,
|
||||
Width: width,
|
||||
Env: Env{Shell: "/bin/bash", Term: "xterm-256color"},
|
||||
Timestamp: int(time.Now().Unix()),
|
||||
}
|
||||
|
||||
if err := recorder.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := model.UpdateSessionById(&model.Session{Recording: recording}, sessionId); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
msg := Message{
|
||||
Type: Connected,
|
||||
Content: "Connect to server successfully.\r\n",
|
||||
Content: "",
|
||||
}
|
||||
_ = WriteMessage(ws, msg)
|
||||
|
||||
var mut sync.Mutex
|
||||
var active = true
|
||||
quitChan := make(chan bool)
|
||||
recordingChan := make(chan string, 1024)
|
||||
|
||||
go func() {
|
||||
for true {
|
||||
mut.Lock()
|
||||
if !active {
|
||||
logrus.Debugf("会话: %v -> %v 关闭", sshClient.LocalAddr().String(), sshClient.RemoteAddr().String())
|
||||
if recorder != nil {
|
||||
recorder.Close()
|
||||
}
|
||||
CloseSessionById(sessionId, Normal, "正常退出")
|
||||
break
|
||||
}
|
||||
mut.Unlock()
|
||||
go ReadMessage(nextTerminal, recordingChan, quitChan, ws)
|
||||
|
||||
p, n, err := b.Read()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
s := string(p)
|
||||
if recorder != nil {
|
||||
// 录屏
|
||||
_ = recorder.WriteData(s)
|
||||
}
|
||||
msg := Message{
|
||||
Type: Data,
|
||||
Content: s,
|
||||
}
|
||||
message, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logrus.Warnf("生成Json失败 %v", err)
|
||||
continue
|
||||
}
|
||||
WriteByteMessage(ws, message)
|
||||
}
|
||||
time.Sleep(time.Duration(100) * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
go Recoding(nextTerminal, recordingChan, quitChan)
|
||||
go Monitor(sessionId, recordingChan, quitChan)
|
||||
|
||||
for true {
|
||||
for {
|
||||
_, message, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
// web socket会话关闭后主动关闭ssh会话
|
||||
_ = session.Close()
|
||||
mut.Lock()
|
||||
active = false
|
||||
mut.Unlock()
|
||||
CloseSessionById(sessionId, Normal, "正常退出")
|
||||
quitChan <- true
|
||||
quitChan <- true
|
||||
break
|
||||
}
|
||||
|
||||
@ -256,12 +187,12 @@ func SSHEndpoint(c echo.Context) (err error) {
|
||||
logrus.Warnf("解析SSH会话窗口大小失败: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := session.WindowChange(winSize.Rows, winSize.Cols); err != nil {
|
||||
if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil {
|
||||
logrus.Warnf("更改SSH会话窗口大小失败: %v", err)
|
||||
continue
|
||||
}
|
||||
case Data:
|
||||
_, err = stdinPipe.Write([]byte(msg.Content))
|
||||
_, err = nextTerminal.Write([]byte(msg.Content))
|
||||
if err != nil {
|
||||
logrus.Debugf("SSH会话写入失败: %v", err)
|
||||
msg := Message{
|
||||
@ -276,16 +207,96 @@ func SSHEndpoint(c echo.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func ReadMessage(nextTerminal *term.NextTerminal, stdoutChan chan string, quitChan chan bool, ws *websocket.Conn) {
|
||||
|
||||
var quit bool
|
||||
for {
|
||||
select {
|
||||
case quit = <-quitChan:
|
||||
if quit {
|
||||
return
|
||||
}
|
||||
default:
|
||||
p, n, err := nextTerminal.Read()
|
||||
if err != nil {
|
||||
msg := Message{
|
||||
Type: Closed,
|
||||
Content: err.Error(),
|
||||
}
|
||||
_ = WriteMessage(ws, msg)
|
||||
}
|
||||
if n > 0 {
|
||||
s := string(p)
|
||||
// 发送一份数据到队列中
|
||||
stdoutChan <- s
|
||||
msg := Message{
|
||||
Type: Data,
|
||||
Content: s,
|
||||
}
|
||||
_ = WriteMessage(ws, msg)
|
||||
}
|
||||
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Recoding(nextTerminal *term.NextTerminal, recordingChan chan string, quitChan chan bool) {
|
||||
var quit bool
|
||||
var s string
|
||||
for {
|
||||
select {
|
||||
case quit = <-quitChan:
|
||||
if quit {
|
||||
fmt.Println("退出录屏")
|
||||
return
|
||||
}
|
||||
case s = <-recordingChan:
|
||||
_ = nextTerminal.Recorder.WriteData(s)
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Monitor(sessionId string, recordingChan chan string, quitChan chan bool) {
|
||||
var quit bool
|
||||
var s string
|
||||
for {
|
||||
select {
|
||||
case quit = <-quitChan:
|
||||
if quit {
|
||||
fmt.Println("退出监控")
|
||||
return
|
||||
}
|
||||
case s = <-recordingChan:
|
||||
msg := Message{
|
||||
Type: Data,
|
||||
Content: s,
|
||||
}
|
||||
|
||||
observable, ok := global.Store.Get(sessionId)
|
||||
if ok {
|
||||
for i := 0; i < len(observable.Observers); i++ {
|
||||
_ = WriteMessage(observable.Observers[i].WebSocket, msg)
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WriteMessage(ws *websocket.Conn, msg Message) error {
|
||||
message, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logrus.Warnf("生成Json失败 %v", err)
|
||||
return err
|
||||
}
|
||||
WriteByteMessage(ws, message)
|
||||
return err
|
||||
}
|
||||
|
||||
func CreateSshClientBySession(session model.Session) (sshClient *ssh.Client, err error) {
|
||||
|
||||
var (
|
||||
username = session.Username
|
||||
password = session.Password
|
||||
@ -293,52 +304,7 @@ func CreateSshClientBySession(session model.Session) (sshClient *ssh.Client, err
|
||||
passphrase = session.Passphrase
|
||||
)
|
||||
|
||||
var authMethod ssh.AuthMethod
|
||||
if username == "-" || username == "" {
|
||||
username = "root"
|
||||
}
|
||||
if password == "-" {
|
||||
password = ""
|
||||
}
|
||||
if privateKey == "-" {
|
||||
privateKey = ""
|
||||
}
|
||||
if passphrase == "-" {
|
||||
passphrase = ""
|
||||
}
|
||||
|
||||
if privateKey != "" {
|
||||
var key ssh.Signer
|
||||
if len(passphrase) > 0 {
|
||||
key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
key, err = ssh.ParsePrivateKey([]byte(privateKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
authMethod = ssh.PublicKeys(key)
|
||||
} else {
|
||||
authMethod = ssh.Password(password)
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
Timeout: 1 * time.Second,
|
||||
User: username,
|
||||
Auth: []ssh.AuthMethod{authMethod},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", session.IP, session.Port)
|
||||
|
||||
sshClient, err = ssh.Dial("tcp", addr, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sshClient, nil
|
||||
return term.NewSshClient(session.IP, session.Port, username, password, privateKey, passphrase)
|
||||
}
|
||||
|
||||
func WriteByteMessage(ws *websocket.Conn, p []byte) {
|
||||
|
Reference in New Issue
Block a user