next-terminal/pkg/api/session.go
dushixiang bc9daf2b01 修改docker默认时区为上海
修复了记住登录无效的问题
修复了ssh下载文件名称不正确的问题
授权凭证增加了密钥类型
2021-01-08 23:01:41 +08:00

512 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package api
import (
"bytes"
"errors"
"fmt"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
"io"
"io/ioutil"
"net/http"
"next-terminal/pkg/global"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"os"
"path"
"strconv"
"strings"
"time"
)
func SessionPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
status := c.QueryParam("status")
userId := c.QueryParam("userId")
clientIp := c.QueryParam("clientIp")
assetId := c.QueryParam("assetId")
protocol := c.QueryParam("protocol")
items, total, err := model.FindPageSession(pageIndex, pageSize, status, userId, clientIp, assetId, protocol)
if err != nil {
return err
}
for i := 0; i < len(items); i++ {
if len(items[i].Recording) > 0 {
recording := items[i].Recording + "/recording"
if utils.FileExists(recording) {
logrus.Debugf("检测到录屏文件[%v]存在", recording)
items[i].Recording = "1"
} else {
logrus.Warnf("检测到录屏文件[%v]不存在", recording)
items[i].Recording = "0"
}
} else {
items[i].Recording = "0"
}
}
return Success(c, H{
"total": total,
"items": items,
})
}
func SessionDeleteEndpoint(c echo.Context) error {
sessionIds := c.Param("id")
split := strings.Split(sessionIds, ",")
for i := range split {
drivePath, err := model.GetRecordingPath()
if err != nil {
continue
}
if err := os.RemoveAll(path.Join(drivePath, split[i])); err != nil {
return err
}
model.DeleteSessionById(split[i])
}
return Success(c, nil)
}
func SessionContentEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session := model.Session{}
session.ID = sessionId
session.Status = model.Connected
session.ConnectedTime = utils.NowJsonTime()
model.UpdateSessionById(&session, sessionId)
return Success(c, nil)
}
func SessionDiscontentEndpoint(c echo.Context) error {
sessionIds := c.Param("id")
split := strings.Split(sessionIds, ",")
for i := range split {
CloseSessionById(split[i], ForcedDisconnect, "管理员强制关闭了此次接入。")
}
return Success(c, nil)
}
func CloseSessionById(sessionId string, code int, reason string) {
tun, _ := global.Store.Get(sessionId)
if tun != nil {
_ = tun.Tun.Close()
CloseSessionByWebSocket(tun.WebSocket, code, reason)
}
global.Store.Del(sessionId)
if code == Normal {
return
}
session := model.Session{}
session.ID = sessionId
session.Status = model.Disconnected
session.DisconnectedTime = utils.NowJsonTime()
session.Code = code
session.Message = reason
model.UpdateSessionById(&session, sessionId)
}
func CloseSessionByWebSocket(ws *websocket.Conn, c int, t string) {
if ws == nil {
return
}
ws.SetCloseHandler(func(code int, text string) error {
var message []byte
if code != websocket.CloseNoStatusReceived {
message = websocket.FormatCloseMessage(c, t)
}
_ = ws.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
return nil
})
defer ws.Close()
}
func SessionResizeEndpoint(c echo.Context) error {
width := c.QueryParam("width")
height := c.QueryParam("height")
sessionId := c.Param("id")
if len(width) == 0 || len(height) == 0 {
panic("参数异常")
}
intWidth, _ := strconv.Atoi(width)
intHeight, _ := strconv.Atoi(height)
session := model.Session{}
session.ID = sessionId
session.Width = intWidth
session.Height = intHeight
model.UpdateSessionById(&session, sessionId)
return Success(c, session)
}
func SessionCreateEndpoint(c echo.Context) error {
assetId := c.QueryParam("assetId")
user, _ := GetCurrentAccount(c)
asset, err := model.FindAssetById(assetId)
if err != nil {
return err
}
session := &model.Session{
ID: utils.UUID(),
AssetId: asset.ID,
Username: asset.Username,
Password: asset.Password,
PrivateKey: asset.PrivateKey,
Passphrase: asset.Passphrase,
Protocol: asset.Protocol,
IP: asset.IP,
Port: asset.Port,
Status: model.NoConnect,
Creator: user.ID,
ClientIP: c.RealIP(),
}
if asset.AccountType == "credential" {
credential, err := model.FindCredentialById(asset.CredentialId)
if err != nil {
return err
}
if credential.Type == model.Custom {
session.Username = credential.Username
session.Password = credential.Password
} else {
session.Username = credential.Username
session.PrivateKey = credential.PrivateKey
session.Passphrase = credential.Passphrase
}
}
if err := model.CreateNewSession(session); err != nil {
return err
}
return Success(c, session)
}
func SessionUploadEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
file, err := c.FormFile("file")
if err != nil {
return err
}
filename := file.Filename
src, err := file.Open()
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
remoteFile := path.Join(remoteDir, filename)
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
dstFile, err := tun.SftpClient.Create(remoteFile)
defer dstFile.Close()
if err != nil {
return err
}
buf := make([]byte, 1024)
for {
n, _ := src.Read(buf)
if n == 0 {
break
}
_, _ = dstFile.Write(buf)
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
// Destination
dst, err := os.Create(path.Join(drivePath, remoteFile))
if err != nil {
return err
}
defer dst.Close()
// Copy
if _, err = io.Copy(dst, src); err != nil {
return err
}
return Success(c, nil)
}
return err
}
func SessionDownloadEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
//remoteDir := c.Query("dir")
remoteFile := c.QueryParam("file")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
dstFile, err := tun.SftpClient.Open(remoteFile)
if err != nil {
return err
}
defer dstFile.Close()
// 获取带后缀的文件名称
filenameWithSuffix := path.Base(remoteFile)
c.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filenameWithSuffix))
var buff bytes.Buffer
if _, err := dstFile.WriteTo(&buff); err != nil {
return err
}
return c.Stream(http.StatusOK, echo.MIMEOctetStream, bytes.NewReader(buff.Bytes()))
} else if "rdp" == session.Protocol {
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
return c.File(path.Join(drivePath, remoteFile))
}
return err
}
type File struct {
Name string `json:"name"`
Path string `json:"path"`
IsDir bool `json:"isDir"`
Mode string `json:"mode"`
IsLink bool `json:"isLink"`
}
func SessionLsEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
if tun.SftpClient == nil {
sftpClient, err := CreateSftpClient(session.AssetId)
if err != nil {
logrus.Errorf("创建sftp客户端失败%v", err.Error())
return err
}
tun.SftpClient = sftpClient
}
fileInfos, err := tun.SftpClient.ReadDir(remoteDir)
if err != nil {
return err
}
var files = make([]File, 0)
for i := range fileInfos {
file := File{
Name: fileInfos[i].Name(),
Path: path.Join(remoteDir, fileInfos[i].Name()),
IsDir: fileInfos[i].IsDir(),
Mode: fileInfos[i].Mode().String(),
IsLink: fileInfos[i].Mode()&os.ModeSymlink == os.ModeSymlink,
}
files = append(files, file)
}
return Success(c, files)
} else if "rdp" == session.Protocol {
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
fileInfos, err := ioutil.ReadDir(path.Join(drivePath, remoteDir))
if err != nil {
return err
}
var files = make([]File, 0)
for i := range fileInfos {
file := File{
Name: fileInfos[i].Name(),
Path: path.Join(remoteDir, fileInfos[i].Name()),
IsDir: fileInfos[i].IsDir(),
Mode: fileInfos[i].Mode().String(),
IsLink: fileInfos[i].Mode()&os.ModeSymlink == os.ModeSymlink,
}
files = append(files, file)
}
return Success(c, files)
}
return err
}
func SessionMkDirEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
if err := tun.SftpClient.Mkdir(remoteDir); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
if err := os.MkdirAll(path.Join(drivePath, remoteDir), os.ModePerm); err != nil {
return err
}
return Success(c, nil)
}
return nil
}
func SessionRmDirEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
fileInfos, err := tun.SftpClient.ReadDir(remoteDir)
if err != nil {
return err
}
for i := range fileInfos {
if err := tun.SftpClient.Remove(path.Join(remoteDir, fileInfos[i].Name())); err != nil {
return err
}
}
if err := tun.SftpClient.RemoveDirectory(remoteDir); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
if err := os.RemoveAll(path.Join(drivePath, remoteDir)); err != nil {
return err
}
return Success(c, nil)
}
return nil
}
func SessionRmEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
remoteFile := c.QueryParam("file")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
if err := tun.SftpClient.Remove(remoteFile); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
if err := os.Remove(path.Join(drivePath, remoteFile)); err != nil {
return err
}
return Success(c, nil)
}
return nil
}
func SessionRecordingEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
recording := path.Join(session.Recording, "recording")
logrus.Debugf("读取录屏文件:%s", recording)
return c.File(recording)
}
func SessionGetEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
return Success(c, session)
}