next-terminal/pkg/api/ssh.go
dushixiang bc9daf2b01 修改docker默认时区为上海
修复了记住登录无效的问题
修复了ssh下载文件名称不正确的问题
授权凭证增加了密钥类型
2021-01-08 23:01:41 +08:00

217 lines
4.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package api
import (
"bytes"
"fmt"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/pkg/sftp"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"net/http"
"next-terminal/pkg/model"
"strconv"
"sync"
"time"
)
var UpGrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
Subprotocols: []string{"guacamole"},
}
type NextWriter struct {
b bytes.Buffer
mu sync.Mutex
}
func (w *NextWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Write(p)
}
func (w *NextWriter) Read() ([]byte, int, error) {
w.mu.Lock()
defer w.mu.Unlock()
p := w.b.Bytes()
buf := make([]byte, len(p))
read, err := w.b.Read(buf)
w.b.Reset()
if err != nil {
return nil, 0, err
}
return buf, read, err
}
func SSHEndpoint(c echo.Context) error {
ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil)
if err != nil {
logrus.Errorf("升级为WebSocket协议失败%v", err.Error())
return err
}
assetId := c.QueryParam("assetId")
width, _ := strconv.Atoi(c.QueryParam("width"))
height, _ := strconv.Atoi(c.QueryParam("height"))
sshClient, err := CreateSshClient(assetId)
if err != nil {
logrus.Errorf("创建SSH客户端失败%v", err.Error())
return err
}
session, err := sshClient.NewSession()
if err != nil {
logrus.Errorf("创建SSH会话失败%v", err.Error())
return err
}
defer session.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := session.RequestPty("xterm", height, width, modes); err != nil {
return err
}
var b NextWriter
session.Stdout = &b
session.Stderr = &b
stdinPipe, err := session.StdinPipe()
if err != nil {
return err
}
if err := session.Shell(); err != nil {
return err
}
go func() {
for true {
p, n, err := b.Read()
if err != nil {
continue
}
if n > 0 {
WriteByteMessage(ws, p)
}
time.Sleep(time.Duration(100) * time.Millisecond)
}
}()
for true {
_, message, err := ws.ReadMessage()
if err != nil {
continue
}
_, err = stdinPipe.Write(message)
if err != nil {
logrus.Debugf("Tunnel write: %v", err)
}
}
return err
}
func CreateSshClient(assetId string) (*ssh.Client, error) {
asset, err := model.FindAssetById(assetId)
if err != nil {
return nil, err
}
var (
accountType = asset.AccountType
username = asset.Username
password = asset.Password
privateKey = asset.PrivateKey
passphrase = asset.Passphrase
)
var authMethod ssh.AuthMethod
if accountType == "credential" {
credential, err := model.FindCredentialById(asset.CredentialId)
if err != nil {
return nil, err
}
accountType = credential.Type
username = credential.Username
password = credential.Password
privateKey = credential.PrivateKey
passphrase = credential.Passphrase
}
if username == "-" {
username = ""
}
if password == "-" {
password = ""
}
if privateKey == "-" {
privateKey = ""
}
if passphrase == "-" {
passphrase = ""
}
if accountType == model.PrivateKey {
var key ssh.Signer
if len(passphrase) > 0 {
key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase))
if err != nil {
return nil, err
}
} else {
key, err = ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, err
}
}
authMethod = ssh.PublicKeys(key)
} else {
authMethod = ssh.Password(password)
}
config := &ssh.ClientConfig{
Timeout: 1 * time.Second,
User: username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", asset.IP, asset.Port)
sshClient, err := ssh.Dial("tcp", addr, config)
if err != nil {
return nil, err
}
return sshClient, nil
}
func WriteMessage(ws *websocket.Conn, message string) {
WriteByteMessage(ws, []byte(message))
}
func WriteByteMessage(ws *websocket.Conn, p []byte) {
err := ws.WriteMessage(websocket.TextMessage, p)
if err != nil {
logrus.Debugf("write: %v", err)
}
}
func CreateSftpClient(assetId string) (sftpClient *sftp.Client, err error) {
sshClient, err := CreateSshClient(assetId)
if err != nil {
return nil, err
}
return sftp.NewClient(sshClient)
}