🐶 重构部分用户数据库操作代码

This commit is contained in:
dushixiang
2021-03-18 00:07:30 +08:00
parent d6ef8aa1db
commit 805fec4a67
50 changed files with 478 additions and 453 deletions

310
server/api/account.go Normal file
View File

@ -0,0 +1,310 @@
package api
import (
"strings"
"time"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/totp"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
const (
RememberEffectiveTime = time.Hour * time.Duration(24*14)
NotRememberEffectiveTime = time.Hour * time.Duration(2)
)
type LoginAccount struct {
Username string `json:"username"`
Password string `json:"password"`
Remember bool `json:"remember"`
TOTP string `json:"totp"`
}
type ConfirmTOTP struct {
Secret string `json:"secret"`
TOTP string `json:"totp"`
}
type ChangePassword struct {
NewPassword string `json:"newPassword"`
OldPassword string `json:"oldPassword"`
}
type Authorization struct {
Token string
Remember bool
User model.User
}
func LoginEndpoint(c echo.Context) error {
var loginAccount LoginAccount
if err := c.Bind(&loginAccount); err != nil {
return err
}
user, err := userRepository.FindByUsername(loginAccount.Username)
// 存储登录失败次数信息
loginFailCountKey := loginAccount.Username
v, ok := global.Cache.Get(loginFailCountKey)
if !ok {
v = 1
}
count := v.(int)
if count >= 5 {
return Fail(c, -1, "登录失败次数过多,请稍后再试")
}
if err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
if err := utils.Encoder.Match([]byte(user.Password), []byte(loginAccount.Password)); err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
if user.TOTPSecret != "" && user.TOTPSecret != "-" {
return Fail(c, 0, "")
}
token, err := LoginSuccess(c, loginAccount, user)
if err != nil {
return err
}
return Success(c, token)
}
func LoginSuccess(c echo.Context, loginAccount LoginAccount, user model.User) (token string, err error) {
token = strings.Join([]string{utils.UUID(), utils.UUID(), utils.UUID(), utils.UUID()}, "")
authorization := Authorization{
Token: token,
Remember: loginAccount.Remember,
User: user,
}
cacheKey := BuildCacheKeyByToken(token)
if authorization.Remember {
// 记住登录有效期两周
global.Cache.Set(cacheKey, authorization, RememberEffectiveTime)
} else {
global.Cache.Set(cacheKey, authorization, NotRememberEffectiveTime)
}
// 保存登录日志
loginLog := model.LoginLog{
ID: token,
UserId: user.ID,
ClientIP: c.RealIP(),
ClientUserAgent: c.Request().UserAgent(),
LoginTime: utils.NowJsonTime(),
Remember: authorization.Remember,
}
if model.CreateNewLoginLog(&loginLog) != nil {
return "", err
}
// 修改登录状态
err = userRepository.Update(&model.User{Online: true, ID: user.ID})
return token, err
}
func BuildCacheKeyByToken(token string) string {
cacheKey := strings.Join([]string{Token, token}, ":")
return cacheKey
}
func GetTokenFormCacheKey(cacheKey string) string {
token := strings.Split(cacheKey, ":")[1]
return token
}
func loginWithTotpEndpoint(c echo.Context) error {
var loginAccount LoginAccount
if err := c.Bind(&loginAccount); err != nil {
return err
}
// 存储登录失败次数信息
loginFailCountKey := loginAccount.Username
v, ok := global.Cache.Get(loginFailCountKey)
if !ok {
v = 1
}
count := v.(int)
if count >= 5 {
return Fail(c, -1, "登录失败次数过多,请稍后再试")
}
user, err := userRepository.FindByUsername(loginAccount.Username)
if err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
if err := utils.Encoder.Match([]byte(user.Password), []byte(loginAccount.Password)); err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
if !totp.Validate(loginAccount.TOTP, user.TOTPSecret) {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
return FailWithData(c, -1, "您输入双因素认证授权码不正确", count)
}
token, err := LoginSuccess(c, loginAccount, user)
if err != nil {
return err
}
return Success(c, token)
}
func LogoutEndpoint(c echo.Context) error {
token := GetToken(c)
cacheKey := BuildCacheKeyByToken(token)
global.Cache.Delete(cacheKey)
err := model.Logout(token)
if err != nil {
return err
}
return Success(c, nil)
}
func ConfirmTOTPEndpoint(c echo.Context) error {
if global.Config.Demo {
return Fail(c, 0, "演示模式禁止开启两步验证")
}
account, _ := GetCurrentAccount(c)
var confirmTOTP ConfirmTOTP
if err := c.Bind(&confirmTOTP); err != nil {
return err
}
if !totp.Validate(confirmTOTP.TOTP, confirmTOTP.Secret) {
return Fail(c, -1, "TOTP 验证失败,请重试")
}
u := &model.User{
TOTPSecret: confirmTOTP.Secret,
ID: account.ID,
}
if err := userRepository.Update(u); err != nil {
return err
}
return Success(c, nil)
}
func ReloadTOTPEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
key, err := totp.NewTOTP(totp.GenerateOpts{
Issuer: c.Request().Host,
AccountName: account.Username,
})
if err != nil {
return Fail(c, -1, err.Error())
}
qrcode, err := key.Image(200, 200)
if err != nil {
return Fail(c, -1, err.Error())
}
qrEncode, err := utils.ImageToBase64Encode(qrcode)
if err != nil {
return Fail(c, -1, err.Error())
}
return Success(c, map[string]string{
"qr": qrEncode,
"secret": key.Secret(),
})
}
func ResetTOTPEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
u := &model.User{
TOTPSecret: "-",
ID: account.ID,
}
if err := userRepository.Update(u); err != nil {
return err
}
return Success(c, "")
}
func ChangePasswordEndpoint(c echo.Context) error {
if global.Config.Demo {
return Fail(c, 0, "演示模式禁止修改密码")
}
account, _ := GetCurrentAccount(c)
var changePassword ChangePassword
if err := c.Bind(&changePassword); err != nil {
return err
}
if err := utils.Encoder.Match([]byte(account.Password), []byte(changePassword.OldPassword)); err != nil {
return Fail(c, -1, "您输入的原密码不正确")
}
passwd, err := utils.Encoder.Encode([]byte(changePassword.NewPassword))
if err != nil {
return err
}
u := &model.User{
Password: string(passwd),
ID: account.ID,
}
if err := userRepository.Update(u); err != nil {
return err
}
return LogoutEndpoint(c)
}
type AccountInfo struct {
Id string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Type string `json:"type"`
EnableTotp bool `json:"enableTotp"`
}
func InfoEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
user, err := userRepository.FindById(account.ID)
if err != nil {
return err
}
info := AccountInfo{
Id: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Type: user.Type,
EnableTotp: user.TOTPSecret != "" && user.TOTPSecret != "-",
}
return Success(c, info)
}

324
server/api/asset.go Normal file
View File

@ -0,0 +1,324 @@
package api
import (
"bufio"
"encoding/csv"
"encoding/json"
"errors"
"strconv"
"strings"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func AssetCreateEndpoint(c echo.Context) error {
m := echo.Map{}
if err := c.Bind(&m); err != nil {
return err
}
data, _ := json.Marshal(m)
var item model.Asset
if err := json.Unmarshal(data, &item); err != nil {
return err
}
account, _ := GetCurrentAccount(c)
item.Owner = account.ID
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := model.CreateNewAsset(&item); err != nil {
return err
}
if err := model.UpdateAssetAttributes(item.ID, item.Protocol, m); err != nil {
return err
}
// 创建后自动检测资产是否存活
go func() {
active := utils.Tcping(item.IP, item.Port)
model.UpdateAssetActiveById(active, item.ID)
}()
return Success(c, item)
}
func AssetImportEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
file, err := c.FormFile("file")
if err != nil {
return err
}
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
reader := csv.NewReader(bufio.NewReader(src))
records, err := reader.ReadAll()
if err != nil {
return err
}
total := len(records)
if total == 0 {
return errors.New("csv数据为空")
}
var successCount = 0
var errorCount = 0
m := echo.Map{}
for i := 0; i < total; i++ {
record := records[i]
if len(record) >= 9 {
port, _ := strconv.Atoi(record[3])
asset := model.Asset{
ID: utils.UUID(),
Name: record[0],
Protocol: record[1],
IP: record[2],
Port: port,
AccountType: constant.Custom,
Username: record[4],
Password: record[5],
PrivateKey: record[6],
Passphrase: record[7],
Description: record[8],
Created: utils.NowJsonTime(),
Owner: account.ID,
}
err := model.CreateNewAsset(&asset)
if err != nil {
errorCount++
m[strconv.Itoa(i)] = err.Error()
} else {
successCount++
// 创建后自动检测资产是否存活
go func() {
active := utils.Tcping(asset.IP, asset.Port)
model.UpdateAssetActiveById(active, asset.ID)
}()
}
}
}
return Success(c, echo.Map{
"successCount": successCount,
"errorCount": errorCount,
"data": m,
})
}
func AssetPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
protocol := c.QueryParam("protocol")
tags := c.QueryParam("tags")
owner := c.QueryParam("owner")
sharer := c.QueryParam("sharer")
userGroupId := c.QueryParam("userGroupId")
ip := c.QueryParam("ip")
order := c.QueryParam("order")
field := c.QueryParam("field")
account, _ := GetCurrentAccount(c)
items, total, err := model.FindPageAsset(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func AssetAllEndpoint(c echo.Context) error {
protocol := c.QueryParam("protocol")
account, _ := GetCurrentAccount(c)
items, _ := model.FindAssetByConditions(protocol, account)
return Success(c, items)
}
func AssetUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
if err := PreCheckAssetPermission(c, id); err != nil {
return err
}
m := echo.Map{}
if err := c.Bind(&m); err != nil {
return err
}
data, _ := json.Marshal(m)
var item model.Asset
if err := json.Unmarshal(data, &item); err != nil {
return err
}
switch item.AccountType {
case "credential":
item.Username = "-"
item.Password = "-"
item.PrivateKey = "-"
item.Passphrase = "-"
case "private-key":
item.Password = "-"
item.CredentialId = "-"
if len(item.Username) == 0 {
item.Username = "-"
}
if len(item.Passphrase) == 0 {
item.Passphrase = "-"
}
case "custom":
item.PrivateKey = "-"
item.Passphrase = "-"
item.CredentialId = "-"
}
if len(item.Tags) == 0 {
item.Tags = "-"
}
if item.Description == "" {
item.Description = "-"
}
model.UpdateAssetById(&item, id)
if err := model.UpdateAssetAttributes(id, item.Protocol, m); err != nil {
return err
}
return Success(c, nil)
}
func AssetGetAttributeEndpoint(c echo.Context) error {
assetId := c.Param("id")
if err := PreCheckAssetPermission(c, assetId); err != nil {
return err
}
attributeMap, err := model.FindAssetAttrMapByAssetId(assetId)
if err != nil {
return err
}
return Success(c, attributeMap)
}
func AssetUpdateAttributeEndpoint(c echo.Context) error {
m := echo.Map{}
if err := c.Bind(&m); err != nil {
return err
}
assetId := c.Param("id")
protocol := c.QueryParam("protocol")
err := model.UpdateAssetAttributes(assetId, protocol, m)
if err != nil {
return err
}
return Success(c, "")
}
func AssetDeleteEndpoint(c echo.Context) error {
id := c.Param("id")
split := strings.Split(id, ",")
for i := range split {
if err := PreCheckAssetPermission(c, split[i]); err != nil {
return err
}
if err := model.DeleteAssetById(split[i]); err != nil {
return err
}
// 删除资产与用户的关系
if err := model.DeleteResourceSharerByResourceId(split[i]); err != nil {
return err
}
}
return Success(c, nil)
}
func AssetGetEndpoint(c echo.Context) (err error) {
id := c.Param("id")
if err := PreCheckAssetPermission(c, id); err != nil {
return err
}
var item model.Asset
if item, err = model.FindAssetById(id); err != nil {
return err
}
attributeMap, err := model.FindAssetAttrMapByAssetId(id)
if err != nil {
return err
}
itemMap := utils.StructToMap(item)
for key := range attributeMap {
itemMap[key] = attributeMap[key]
}
return Success(c, itemMap)
}
func AssetTcpingEndpoint(c echo.Context) (err error) {
id := c.Param("id")
var item model.Asset
if item, err = model.FindAssetById(id); err != nil {
return err
}
active := utils.Tcping(item.IP, item.Port)
model.UpdateAssetActiveById(active, item.ID)
return Success(c, active)
}
func AssetTagsEndpoint(c echo.Context) (err error) {
var items []string
if items, err = model.FindAssetTags(); err != nil {
return err
}
return Success(c, items)
}
func AssetChangeOwnerEndpoint(c echo.Context) (err error) {
id := c.Param("id")
if err := PreCheckAssetPermission(c, id); err != nil {
return err
}
owner := c.QueryParam("owner")
model.UpdateAssetById(&model.Asset{Owner: owner}, id)
return Success(c, "")
}
func PreCheckAssetPermission(c echo.Context, id string) error {
item, err := model.FindAssetById(id)
if err != nil {
return err
}
if !HasPermission(c, item.Owner) {
return errors.New("permission denied")
}
return nil
}

123
server/api/command.go Normal file
View File

@ -0,0 +1,123 @@
package api
import (
"errors"
"strconv"
"strings"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func CommandCreateEndpoint(c echo.Context) error {
var item model.Command
if err := c.Bind(&item); err != nil {
return err
}
account, _ := GetCurrentAccount(c)
item.Owner = account.ID
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := model.CreateNewCommand(&item); err != nil {
return err
}
return Success(c, item)
}
func CommandPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
content := c.QueryParam("content")
account, _ := GetCurrentAccount(c)
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := model.FindPageCommand(pageIndex, pageSize, name, content, order, field, account)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func CommandUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
if err := PreCheckCommandPermission(c, id); err != nil {
return err
}
var item model.Command
if err := c.Bind(&item); err != nil {
return err
}
model.UpdateCommandById(&item, id)
return Success(c, nil)
}
func CommandDeleteEndpoint(c echo.Context) error {
id := c.Param("id")
split := strings.Split(id, ",")
for i := range split {
if err := PreCheckCommandPermission(c, split[i]); err != nil {
return err
}
if err := model.DeleteCommandById(split[i]); err != nil {
return err
}
// 删除资产与用户的关系
if err := model.DeleteResourceSharerByResourceId(split[i]); err != nil {
return err
}
}
return Success(c, nil)
}
func CommandGetEndpoint(c echo.Context) (err error) {
id := c.Param("id")
if err := PreCheckCommandPermission(c, id); err != nil {
return err
}
var item model.Command
if item, err = model.FindCommandById(id); err != nil {
return err
}
return Success(c, item)
}
func CommandChangeOwnerEndpoint(c echo.Context) (err error) {
id := c.Param("id")
if err := PreCheckCommandPermission(c, id); err != nil {
return err
}
owner := c.QueryParam("owner")
model.UpdateCommandById(&model.Command{Owner: owner}, id)
return Success(c, "")
}
func PreCheckCommandPermission(c echo.Context, id string) error {
item, err := model.FindCommandById(id)
if err != nil {
return err
}
if !HasPermission(c, item.Owner) {
return errors.New("permission denied")
}
return nil
}

184
server/api/credential.go Normal file
View File

@ -0,0 +1,184 @@
package api
import (
"errors"
"strconv"
"strings"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func CredentialAllEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
items, _ := model.FindAllCredential(account)
return Success(c, items)
}
func CredentialCreateEndpoint(c echo.Context) error {
var item model.Credential
if err := c.Bind(&item); err != nil {
return err
}
account, _ := GetCurrentAccount(c)
item.Owner = account.ID
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
switch item.Type {
case constant.Custom:
item.PrivateKey = "-"
item.Passphrase = "-"
if len(item.Username) == 0 {
item.Username = "-"
}
if len(item.Password) == 0 {
item.Password = "-"
}
case constant.PrivateKey:
item.Password = "-"
if len(item.Username) == 0 {
item.Username = "-"
}
if len(item.PrivateKey) == 0 {
item.PrivateKey = "-"
}
if len(item.Passphrase) == 0 {
item.Passphrase = "-"
}
default:
return Fail(c, -1, "类型错误")
}
if err := model.CreateNewCredential(&item); err != nil {
return err
}
return Success(c, item)
}
func CredentialPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
order := c.QueryParam("order")
field := c.QueryParam("field")
account, _ := GetCurrentAccount(c)
items, total, err := model.FindPageCredential(pageIndex, pageSize, name, order, field, account)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func CredentialUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
if err := PreCheckCredentialPermission(c, id); err != nil {
return err
}
var item model.Credential
if err := c.Bind(&item); err != nil {
return err
}
switch item.Type {
case constant.Custom:
item.PrivateKey = "-"
item.Passphrase = "-"
if len(item.Username) == 0 {
item.Username = "-"
}
if len(item.Password) == 0 {
item.Password = "-"
}
case constant.PrivateKey:
item.Password = "-"
if len(item.Username) == 0 {
item.Username = "-"
}
if len(item.PrivateKey) == 0 {
item.PrivateKey = "-"
}
if len(item.Passphrase) == 0 {
item.Passphrase = "-"
}
default:
return Fail(c, -1, "类型错误")
}
model.UpdateCredentialById(&item, id)
return Success(c, nil)
}
func CredentialDeleteEndpoint(c echo.Context) error {
id := c.Param("id")
split := strings.Split(id, ",")
for i := range split {
if err := PreCheckCredentialPermission(c, split[i]); err != nil {
return err
}
if err := model.DeleteCredentialById(split[i]); err != nil {
return err
}
// 删除资产与用户的关系
if err := model.DeleteResourceSharerByResourceId(split[i]); err != nil {
return err
}
}
return Success(c, nil)
}
func CredentialGetEndpoint(c echo.Context) error {
id := c.Param("id")
if err := PreCheckCredentialPermission(c, id); err != nil {
return err
}
item, err := model.FindCredentialById(id)
if err != nil {
return err
}
if !HasPermission(c, item.Owner) {
return errors.New("permission denied")
}
return Success(c, item)
}
func CredentialChangeOwnerEndpoint(c echo.Context) error {
id := c.Param("id")
if err := PreCheckCredentialPermission(c, id); err != nil {
return err
}
owner := c.QueryParam("owner")
model.UpdateCredentialById(&model.Credential{Owner: owner}, id)
return Success(c, "")
}
func PreCheckCredentialPermission(c echo.Context, id string) error {
item, err := model.FindCredentialById(id)
if err != nil {
return err
}
if !HasPermission(c, item.Owner) {
return errors.New("permission denied")
}
return nil
}

122
server/api/job.go Normal file
View File

@ -0,0 +1,122 @@
package api
import (
"strconv"
"strings"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func JobCreateEndpoint(c echo.Context) error {
var item model.Job
if err := c.Bind(&item); err != nil {
return err
}
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := model.CreateNewJob(&item); err != nil {
return err
}
return Success(c, "")
}
func JobPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
status := c.QueryParam("status")
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := model.FindPageJob(pageIndex, pageSize, name, status, order, field)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func JobUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
var item model.Job
if err := c.Bind(&item); err != nil {
return err
}
if err := model.UpdateJobById(&item, id); err != nil {
return err
}
return Success(c, nil)
}
func JobChangeStatusEndpoint(c echo.Context) error {
id := c.Param("id")
status := c.QueryParam("status")
if err := model.ChangeJobStatusById(id, status); err != nil {
return err
}
return Success(c, "")
}
func JobExecEndpoint(c echo.Context) error {
id := c.Param("id")
if err := model.ExecJobById(id); err != nil {
return err
}
return Success(c, "")
}
func JobDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
split := strings.Split(ids, ",")
for i := range split {
jobId := split[i]
if err := model.DeleteJobById(jobId); err != nil {
return err
}
}
return Success(c, nil)
}
func JobGetEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := model.FindJobById(id)
if err != nil {
return err
}
return Success(c, item)
}
func JobGetLogsEndpoint(c echo.Context) error {
id := c.Param("id")
items, err := model.FindJobLogs(id)
if err != nil {
return err
}
return Success(c, items)
}
func JobDeleteLogsEndpoint(c echo.Context) error {
id := c.Param("id")
if err := model.DeleteJobLogByJobId(id); err != nil {
return err
}
return Success(c, "")
}

47
server/api/login-log.go Normal file
View File

@ -0,0 +1,47 @@
package api
import (
"strconv"
"strings"
"next-terminal/server/global"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
)
func LoginLogPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
userId := c.QueryParam("userId")
clientIp := c.QueryParam("clientIp")
items, total, err := model.FindPageLoginLog(pageIndex, pageSize, userId, clientIp)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func LoginLogDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
split := strings.Split(ids, ",")
for i := range split {
token := split[i]
global.Cache.Delete(token)
if err := model.Logout(token); err != nil {
logrus.WithError(err).Error("Cache Delete Failed")
}
}
if err := model.DeleteLoginLogByIdIn(split); err != nil {
return err
}
return Success(c, nil)
}

149
server/api/middleware.go Normal file
View File

@ -0,0 +1,149 @@
package api
import (
"fmt"
"net"
"regexp"
"strings"
"time"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func ErrorHandler(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := next(c); err != nil {
if he, ok := err.(*echo.HTTPError); ok {
message := fmt.Sprintf("%v", he.Message)
return Fail(c, he.Code, message)
}
return Fail(c, 0, err.Error())
}
return nil
}
}
func TcpWall(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if global.Securities == nil {
return next(c)
}
ip := c.RealIP()
for i := 0; i < len(global.Securities); i++ {
security := global.Securities[i]
if strings.Contains(security.IP, "/") {
// CIDR
_, ipNet, err := net.ParseCIDR(security.IP)
if err != nil {
continue
}
if !ipNet.Contains(net.ParseIP(ip)) {
continue
}
} else if strings.Contains(security.IP, "-") {
// 范围段
split := strings.Split(security.IP, "-")
if len(split) < 2 {
continue
}
start := split[0]
end := split[1]
intReqIP := utils.IpToInt(ip)
if intReqIP < utils.IpToInt(start) || intReqIP > utils.IpToInt(end) {
continue
}
} else {
// IP
if security.IP != ip {
continue
}
}
if security.Rule == constant.AccessRuleAllow {
return next(c)
}
if security.Rule == constant.AccessRuleReject {
if c.Request().Header.Get("X-Requested-With") != "" || c.Request().Header.Get(Token) != "" {
return Fail(c, 0, "您的访问请求被拒绝 :(")
} else {
return c.HTML(666, "您的访问请求被拒绝 :(")
}
}
}
return next(c)
}
}
func Auth(next echo.HandlerFunc) echo.HandlerFunc {
startWithUrls := []string{"/login", "/static", "/favicon.ico", "/logo.svg", "/asciinema"}
download := regexp.MustCompile(`^/sessions/\w{8}(-\w{4}){3}-\w{12}/download`)
recording := regexp.MustCompile(`^/sessions/\w{8}(-\w{4}){3}-\w{12}/recording`)
return func(c echo.Context) error {
uri := c.Request().RequestURI
if uri == "/" || strings.HasPrefix(uri, "/#") {
return next(c)
}
// 路由拦截 - 登录身份、资源权限判断等
for i := range startWithUrls {
if strings.HasPrefix(uri, startWithUrls[i]) {
return next(c)
}
}
if download.FindString(uri) != "" {
return next(c)
}
if recording.FindString(uri) != "" {
return next(c)
}
token := GetToken(c)
cacheKey := BuildCacheKeyByToken(token)
authorization, found := global.Cache.Get(cacheKey)
if !found {
return Fail(c, 401, "您的登录信息已失效,请重新登录后再试。")
}
if authorization.(Authorization).Remember {
// 记住登录有效期两周
global.Cache.Set(cacheKey, authorization, time.Hour*time.Duration(24*14))
} else {
global.Cache.Set(cacheKey, authorization, time.Hour*time.Duration(2))
}
return next(c)
}
}
func Admin(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
account, found := GetCurrentAccount(c)
if !found {
return Fail(c, 401, "您的登录信息已失效,请重新登录后再试。")
}
if account.Type != constant.TypeAdmin {
return Fail(c, 403, "permission denied")
}
return next(c)
}
}

59
server/api/overview.go Normal file
View File

@ -0,0 +1,59 @@
package api
import (
"next-terminal/server/constant"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
)
type Counter struct {
User int64 `json:"user"`
Asset int64 `json:"asset"`
Credential int64 `json:"credential"`
OnlineSession int64 `json:"onlineSession"`
}
func OverviewCounterEndPoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
var (
countUser int64
countOnlineSession int64
credential int64
asset int64
)
if constant.TypeUser == account.Type {
countUser, _ = userRepository.CountOnlineUser()
countOnlineSession, _ = model.CountOnlineSession()
credential, _ = model.CountCredentialByUserId(account.ID)
asset, _ = model.CountAssetByUserId(account.ID)
} else {
countUser, _ = userRepository.CountOnlineUser()
countOnlineSession, _ = model.CountOnlineSession()
credential, _ = model.CountCredential()
asset, _ = model.CountAsset()
}
counter := Counter{
User: countUser,
OnlineSession: countOnlineSession,
Credential: credential,
Asset: asset,
}
return Success(c, counter)
}
func OverviewSessionPoint(c echo.Context) (err error) {
d := c.QueryParam("d")
var results []model.D
if d == "m" {
results, err = model.CountSessionByDay(30)
} else {
results, err = model.CountSessionByDay(7)
}
if err != nil {
return err
}
return Success(c, results)
}

45
server/api/property.go Normal file
View File

@ -0,0 +1,45 @@
package api
import (
"errors"
"fmt"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
func PropertyGetEndpoint(c echo.Context) error {
properties := model.FindAllPropertiesMap()
return Success(c, properties)
}
func PropertyUpdateEndpoint(c echo.Context) error {
var item map[string]interface{}
if err := c.Bind(&item); err != nil {
return err
}
for key := range item {
value := fmt.Sprintf("%v", item[key])
if value == "" {
value = "-"
}
property := model.Property{
Name: key,
Value: value,
}
_, err := model.FindPropertyByName(key)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
if err := model.CreateNewProperty(&property); err != nil {
return err
}
} else {
model.UpdatePropertyByName(&property, key)
}
}
return Success(c, nil)
}

View File

@ -0,0 +1,68 @@
package api
import (
"next-terminal/server/model"
"github.com/labstack/echo/v4"
)
type RU struct {
UserGroupId string `json:"userGroupId"`
UserId string `json:"userId"`
ResourceType string `json:"resourceType"`
ResourceIds []string `json:"resourceIds"`
}
type UR struct {
ResourceId string `json:"resourceId"`
ResourceType string `json:"resourceType"`
UserIds []string `json:"userIds"`
}
func RSGetSharersEndPoint(c echo.Context) error {
resourceId := c.QueryParam("resourceId")
userIds, err := model.FindUserIdsByResourceId(resourceId)
if err != nil {
return err
}
return Success(c, userIds)
}
func RSOverwriteSharersEndPoint(c echo.Context) error {
var ur UR
if err := c.Bind(&ur); err != nil {
return err
}
if err := model.OverwriteUserIdsByResourceId(ur.ResourceId, ur.ResourceType, ur.UserIds); err != nil {
return err
}
return Success(c, "")
}
func ResourceRemoveByUserIdAssignEndPoint(c echo.Context) error {
var ru RU
if err := c.Bind(&ru); err != nil {
return err
}
if err := model.DeleteByUserIdAndResourceTypeAndResourceIdIn(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil {
return err
}
return Success(c, "")
}
func ResourceAddByUserIdAssignEndPoint(c echo.Context) error {
var ru RU
if err := c.Bind(&ru); err != nil {
return err
}
if err := model.AddSharerResources(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil {
return err
}
return Success(c, "")
}

246
server/api/routes.go Normal file
View File

@ -0,0 +1,246 @@
package api
import (
"net/http"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
const Token = "X-Auth-Token"
var (
userRepository repository.UserRepository
)
func SetupRoutes(ur repository.UserRepository) *echo.Echo {
userRepository = ur
e := echo.New()
e.HideBanner = true
e.Logger = log.GetEchoLogger()
e.File("/", "web/build/index.html")
e.File("/asciinema.html", "web/build/asciinema.html")
e.File("/asciinema-player.js", "web/build/asciinema-player.js")
e.File("/asciinema-player.css", "web/build/asciinema-player.css")
e.File("/", "web/build/index.html")
e.File("/logo.svg", "web/build/logo.svg")
e.File("/favicon.ico", "web/build/favicon.ico")
e.Static("/static", "web/build/static")
e.Use(middleware.Recover())
e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
Skipper: middleware.DefaultSkipper,
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}))
e.Use(ErrorHandler)
e.Use(TcpWall)
e.Use(Auth)
e.POST("/login", LoginEndpoint)
e.POST("/loginWithTotp", loginWithTotpEndpoint)
e.GET("/tunnel", TunEndpoint)
e.GET("/ssh", SSHEndpoint)
e.POST("/logout", LogoutEndpoint)
e.POST("/change-password", ChangePasswordEndpoint)
e.GET("/reload-totp", ReloadTOTPEndpoint)
e.POST("/reset-totp", ResetTOTPEndpoint)
e.POST("/confirm-totp", ConfirmTOTPEndpoint)
e.GET("/info", InfoEndpoint)
users := e.Group("/users")
{
users.POST("", Admin(UserCreateEndpoint))
users.GET("/paging", UserPagingEndpoint)
users.PUT("/:id", Admin(UserUpdateEndpoint))
users.DELETE("/:id", Admin(UserDeleteEndpoint))
users.GET("/:id", Admin(UserGetEndpoint))
users.POST("/:id/change-password", Admin(UserChangePasswordEndpoint))
users.POST("/:id/reset-totp", Admin(UserResetTotpEndpoint))
}
userGroups := e.Group("/user-groups", Admin)
{
userGroups.POST("", UserGroupCreateEndpoint)
userGroups.GET("/paging", UserGroupPagingEndpoint)
userGroups.PUT("/:id", UserGroupUpdateEndpoint)
userGroups.DELETE("/:id", UserGroupDeleteEndpoint)
userGroups.GET("/:id", UserGroupGetEndpoint)
//userGroups.POST("/:id/members", UserGroupAddMembersEndpoint)
//userGroups.DELETE("/:id/members/:memberId", UserGroupDelMembersEndpoint)
}
assets := e.Group("/assets")
{
assets.GET("", AssetAllEndpoint)
assets.POST("", AssetCreateEndpoint)
assets.POST("/import", Admin(AssetImportEndpoint))
assets.GET("/paging", AssetPagingEndpoint)
assets.POST("/:id/tcping", AssetTcpingEndpoint)
assets.PUT("/:id", AssetUpdateEndpoint)
assets.DELETE("/:id", AssetDeleteEndpoint)
assets.GET("/:id", AssetGetEndpoint)
assets.GET("/:id/attributes", AssetGetAttributeEndpoint)
assets.POST("/:id/change-owner", Admin(AssetChangeOwnerEndpoint))
}
e.GET("/tags", AssetTagsEndpoint)
commands := e.Group("/commands")
{
commands.GET("/paging", CommandPagingEndpoint)
commands.POST("", CommandCreateEndpoint)
commands.PUT("/:id", CommandUpdateEndpoint)
commands.DELETE("/:id", CommandDeleteEndpoint)
commands.GET("/:id", CommandGetEndpoint)
commands.POST("/:id/change-owner", Admin(CommandChangeOwnerEndpoint))
}
credentials := e.Group("/credentials")
{
credentials.GET("", CredentialAllEndpoint)
credentials.GET("/paging", CredentialPagingEndpoint)
credentials.POST("", CredentialCreateEndpoint)
credentials.PUT("/:id", CredentialUpdateEndpoint)
credentials.DELETE("/:id", CredentialDeleteEndpoint)
credentials.GET("/:id", CredentialGetEndpoint)
credentials.POST("/:id/change-owner", Admin(CredentialChangeOwnerEndpoint))
}
sessions := e.Group("/sessions")
{
sessions.POST("", SessionCreateEndpoint)
sessions.GET("/paging", Admin(SessionPagingEndpoint))
sessions.POST("/:id/connect", SessionConnectEndpoint)
sessions.POST("/:id/disconnect", Admin(SessionDisconnectEndpoint))
sessions.POST("/:id/resize", SessionResizeEndpoint)
sessions.GET("/:id/ls", SessionLsEndpoint)
sessions.GET("/:id/download", SessionDownloadEndpoint)
sessions.POST("/:id/upload", SessionUploadEndpoint)
sessions.POST("/:id/mkdir", SessionMkDirEndpoint)
sessions.POST("/:id/rm", SessionRmEndpoint)
sessions.POST("/:id/rename", SessionRenameEndpoint)
sessions.DELETE("/:id", Admin(SessionDeleteEndpoint))
sessions.GET("/:id/recording", SessionRecordingEndpoint)
}
resourceSharers := e.Group("/resource-sharers")
{
resourceSharers.GET("/sharers", RSGetSharersEndPoint)
resourceSharers.POST("/overwrite-sharers", RSOverwriteSharersEndPoint)
resourceSharers.POST("/remove-resources", Admin(ResourceRemoveByUserIdAssignEndPoint))
resourceSharers.POST("/add-resources", Admin(ResourceAddByUserIdAssignEndPoint))
}
loginLogs := e.Group("login-logs", Admin)
{
loginLogs.GET("/paging", LoginLogPagingEndpoint)
loginLogs.DELETE("/:id", LoginLogDeleteEndpoint)
}
e.GET("/properties", Admin(PropertyGetEndpoint))
e.PUT("/properties", Admin(PropertyUpdateEndpoint))
e.GET("/overview/counter", OverviewCounterEndPoint)
e.GET("/overview/sessions", OverviewSessionPoint)
jobs := e.Group("/jobs", Admin)
{
jobs.POST("", JobCreateEndpoint)
jobs.GET("/paging", JobPagingEndpoint)
jobs.PUT("/:id", JobUpdateEndpoint)
jobs.POST("/:id/change-status", JobChangeStatusEndpoint)
jobs.POST("/:id/exec", JobExecEndpoint)
jobs.DELETE("/:id", JobDeleteEndpoint)
jobs.GET("/:id", JobGetEndpoint)
jobs.GET("/:id/logs", JobGetLogsEndpoint)
jobs.DELETE("/:id/logs", JobDeleteLogsEndpoint)
}
securities := e.Group("/securities", Admin)
{
securities.POST("", SecurityCreateEndpoint)
securities.GET("/paging", SecurityPagingEndpoint)
securities.PUT("/:id", SecurityUpdateEndpoint)
securities.DELETE("/:id", SecurityDeleteEndpoint)
securities.GET("/:id", SecurityGetEndpoint)
}
return e
}
type H map[string]interface{}
func Fail(c echo.Context, code int, message string) error {
return c.JSON(200, H{
"code": code,
"message": message,
})
}
func FailWithData(c echo.Context, code int, message string, data interface{}) error {
return c.JSON(200, H{
"code": code,
"message": message,
"data": data,
})
}
func Success(c echo.Context, data interface{}) error {
return c.JSON(200, H{
"code": 1,
"message": "success",
"data": data,
})
}
func NotFound(c echo.Context, message string) error {
return c.JSON(200, H{
"code": -1,
"message": message,
})
}
func GetToken(c echo.Context) string {
token := c.Request().Header.Get(Token)
if len(token) > 0 {
return token
}
return c.QueryParam(Token)
}
func GetCurrentAccount(c echo.Context) (model.User, bool) {
token := GetToken(c)
cacheKey := BuildCacheKeyByToken(token)
get, b := global.Cache.Get(cacheKey)
if b {
return get.(Authorization).User, true
}
return model.User{}, false
}
func HasPermission(c echo.Context, owner string) bool {
// 检测是否登录
account, found := GetCurrentAccount(c)
if !found {
return false
}
// 检测是否为管理人员
if constant.TypeAdmin == account.Type {
return true
}
// 检测是否为所有者
if owner == account.ID {
return true
}
return false
}

116
server/api/security.go Normal file
View File

@ -0,0 +1,116 @@
package api
import (
"strconv"
"strings"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func SecurityCreateEndpoint(c echo.Context) error {
var item model.AccessSecurity
if err := c.Bind(&item); err != nil {
return err
}
item.ID = utils.UUID()
item.Source = "管理员添加"
if err := model.CreateNewSecurity(&item); err != nil {
return err
}
// 更新内存中的安全规则
if err := ReloadAccessSecurity(); err != nil {
return err
}
return Success(c, "")
}
func ReloadAccessSecurity() error {
rules, err := model.FindAllAccessSecurities()
if err != nil {
return err
}
if len(rules) > 0 {
var securities []*global.Security
for i := 0; i < len(rules); i++ {
rule := global.Security{
IP: rules[i].IP,
Rule: rules[i].Rule,
}
securities = append(securities, &rule)
}
global.Securities = securities
}
return nil
}
func SecurityPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
ip := c.QueryParam("ip")
rule := c.QueryParam("rule")
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := model.FindPageSecurity(pageIndex, pageSize, ip, rule, order, field)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func SecurityUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
var item model.AccessSecurity
if err := c.Bind(&item); err != nil {
return err
}
if err := model.UpdateSecurityById(&item, id); err != nil {
return err
}
// 更新内存中的安全规则
if err := ReloadAccessSecurity(); err != nil {
return err
}
return Success(c, nil)
}
func SecurityDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
split := strings.Split(ids, ",")
for i := range split {
jobId := split[i]
if err := model.DeleteSecurityById(jobId); err != nil {
return err
}
}
// 更新内存中的安全规则
if err := ReloadAccessSecurity(); err != nil {
return err
}
return Success(c, nil)
}
func SecurityGetEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := model.FindSecurityById(id)
if err != nil {
return err
}
return Success(c, item)
}

617
server/api/session.go Normal file
View File

@ -0,0 +1,617 @@
package api
import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"strconv"
"strings"
"sync"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"github.com/pkg/sftp"
"github.com/sirupsen/logrus"
)
func SessionPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
status := c.QueryParam("status")
userId := c.QueryParam("userId")
clientIp := c.QueryParam("clientIp")
assetId := c.QueryParam("assetId")
protocol := c.QueryParam("protocol")
items, total, err := model.FindPageSession(pageIndex, pageSize, status, userId, clientIp, assetId, protocol)
if err != nil {
return err
}
for i := 0; i < len(items); i++ {
if status == constant.Disconnected && len(items[i].Recording) > 0 {
var recording string
if items[i].Mode == constant.Naive {
recording = items[i].Recording
} else {
recording = items[i].Recording + "/recording"
}
if utils.FileExists(recording) {
items[i].Recording = "1"
} else {
items[i].Recording = "0"
}
} else {
items[i].Recording = "0"
}
}
return Success(c, H{
"total": total,
"items": items,
})
}
func SessionDeleteEndpoint(c echo.Context) error {
sessionIds := c.Param("id")
split := strings.Split(sessionIds, ",")
err := model.DeleteSessionByIds(split)
if err != nil {
return err
}
return Success(c, nil)
}
func SessionConnectEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session := model.Session{}
session.ID = sessionId
session.Status = constant.Connected
session.ConnectedTime = utils.NowJsonTime()
if err := model.UpdateSessionById(&session, sessionId); err != nil {
return err
}
return Success(c, nil)
}
func SessionDisconnectEndpoint(c echo.Context) error {
sessionIds := c.Param("id")
split := strings.Split(sessionIds, ",")
for i := range split {
CloseSessionById(split[i], ForcedDisconnect, "管理员强制关闭了此会话")
}
return Success(c, nil)
}
var mutex sync.Mutex
func CloseSessionById(sessionId string, code int, reason string) {
mutex.Lock()
defer mutex.Unlock()
observable, _ := global.Store.Get(sessionId)
if observable != nil {
logrus.Debugf("会话%v创建者退出原因%v", sessionId, reason)
observable.Subject.Close(code, reason)
for i := 0; i < len(observable.Observers); i++ {
observable.Observers[i].Close(code, reason)
logrus.Debugf("强制踢出会话%v的观察者", sessionId)
}
}
global.Store.Del(sessionId)
s, err := model.FindSessionById(sessionId)
if err != nil {
return
}
if s.Status == constant.Disconnected {
return
}
if s.Status == constant.Connecting {
// 会话还未建立成功,无需保留数据
_ = model.DeleteSessionById(sessionId)
return
}
session := model.Session{}
session.ID = sessionId
session.Status = constant.Disconnected
session.DisconnectedTime = utils.NowJsonTime()
session.Code = code
session.Message = reason
_ = model.UpdateSessionById(&session, sessionId)
}
func SessionResizeEndpoint(c echo.Context) error {
width := c.QueryParam("width")
height := c.QueryParam("height")
sessionId := c.Param("id")
if len(width) == 0 || len(height) == 0 {
panic("参数异常")
}
intWidth, _ := strconv.Atoi(width)
intHeight, _ := strconv.Atoi(height)
if err := model.UpdateSessionWindowSizeById(intWidth, intHeight, sessionId); err != nil {
return err
}
return Success(c, "")
}
func SessionCreateEndpoint(c echo.Context) error {
assetId := c.QueryParam("assetId")
mode := c.QueryParam("mode")
if mode == constant.Naive {
mode = constant.Naive
} else {
mode = constant.Guacd
}
user, _ := GetCurrentAccount(c)
if constant.TypeUser == user.Type {
// 检测是否有访问权限
assetIds, err := model.FindAssetIdsByUserId(user.ID)
if err != nil {
return err
}
if !utils.Contains(assetIds, assetId) {
return errors.New("您没有权限访问此资产")
}
}
asset, err := model.FindAssetById(assetId)
if err != nil {
return err
}
session := &model.Session{
ID: utils.UUID(),
AssetId: asset.ID,
Username: asset.Username,
Password: asset.Password,
PrivateKey: asset.PrivateKey,
Passphrase: asset.Passphrase,
Protocol: asset.Protocol,
IP: asset.IP,
Port: asset.Port,
Status: constant.NoConnect,
Creator: user.ID,
ClientIP: c.RealIP(),
Mode: mode,
}
if asset.AccountType == "credential" {
credential, err := model.FindCredentialById(asset.CredentialId)
if err != nil {
return err
}
if credential.Type == constant.Custom {
session.Username = credential.Username
session.Password = credential.Password
} else {
session.Username = credential.Username
session.PrivateKey = credential.PrivateKey
session.Passphrase = credential.Passphrase
}
}
if err := model.CreateNewSession(session); err != nil {
return err
}
return Success(c, echo.Map{"id": session.ID})
}
func SessionUploadEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
file, err := c.FormFile("file")
if err != nil {
return err
}
filename := file.Filename
src, err := file.Open()
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
remoteFile := path.Join(remoteDir, filename)
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
dstFile, err := tun.Subject.NextTerminal.SftpClient.Create(remoteFile)
if err != nil {
return err
}
defer dstFile.Close()
buf := make([]byte, 1024)
for {
n, err := src.Read(buf)
if err != nil {
if err != io.EOF {
logrus.Warnf("文件上传错误 %v", err)
} else {
break
}
}
_, _ = dstFile.Write(buf[:n])
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(remoteFile, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
// Destination
dst, err := os.Create(path.Join(drivePath, remoteFile))
if err != nil {
return err
}
defer dst.Close()
// Copy
if _, err = io.Copy(dst, src); err != nil {
return err
}
return Success(c, nil)
}
return err
}
func SessionDownloadEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
//remoteDir := c.Query("dir")
remoteFile := c.QueryParam("file")
// 获取带后缀的文件名称
filenameWithSuffix := path.Base(remoteFile)
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
dstFile, err := tun.Subject.NextTerminal.SftpClient.Open(remoteFile)
if err != nil {
return err
}
defer dstFile.Close()
c.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filenameWithSuffix))
var buff bytes.Buffer
if _, err := dstFile.WriteTo(&buff); err != nil {
return err
}
return c.Stream(http.StatusOK, echo.MIMEOctetStream, bytes.NewReader(buff.Bytes()))
} else if "rdp" == session.Protocol {
if strings.Contains(remoteFile, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
return c.Attachment(path.Join(drivePath, remoteFile), filenameWithSuffix)
}
return err
}
type File struct {
Name string `json:"name"`
Path string `json:"path"`
IsDir bool `json:"isDir"`
Mode string `json:"mode"`
IsLink bool `json:"isLink"`
ModTime utils.JsonTime `json:"modTime"`
Size int64 `json:"size"`
}
func SessionLsEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
if tun.Subject.NextTerminal == nil {
nextTerminal, err := CreateNextTerminalBySession(session)
if err != nil {
return err
}
tun.Subject.NextTerminal = nextTerminal
}
if tun.Subject.NextTerminal.SftpClient == nil {
sftpClient, err := sftp.NewClient(tun.Subject.NextTerminal.SshClient)
if err != nil {
logrus.Errorf("创建sftp客户端失败%v", err.Error())
return err
}
tun.Subject.NextTerminal.SftpClient = sftpClient
}
fileInfos, err := tun.Subject.NextTerminal.SftpClient.ReadDir(remoteDir)
if err != nil {
return err
}
var files = make([]File, 0)
for i := range fileInfos {
// 忽略因此文件
if strings.HasPrefix(fileInfos[i].Name(), ".") {
continue
}
file := File{
Name: fileInfos[i].Name(),
Path: path.Join(remoteDir, fileInfos[i].Name()),
IsDir: fileInfos[i].IsDir(),
Mode: fileInfos[i].Mode().String(),
IsLink: fileInfos[i].Mode()&os.ModeSymlink == os.ModeSymlink,
ModTime: utils.NewJsonTime(fileInfos[i].ModTime()),
Size: fileInfos[i].Size(),
}
files = append(files, file)
}
return Success(c, files)
} else if "rdp" == session.Protocol {
if strings.Contains(remoteDir, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
fileInfos, err := ioutil.ReadDir(path.Join(drivePath, remoteDir))
if err != nil {
return err
}
var files = make([]File, 0)
for i := range fileInfos {
file := File{
Name: fileInfos[i].Name(),
Path: path.Join(remoteDir, fileInfos[i].Name()),
IsDir: fileInfos[i].IsDir(),
Mode: fileInfos[i].Mode().String(),
IsLink: fileInfos[i].Mode()&os.ModeSymlink == os.ModeSymlink,
ModTime: utils.NewJsonTime(fileInfos[i].ModTime()),
Size: fileInfos[i].Size(),
}
files = append(files, file)
}
return Success(c, files)
}
return errors.New("当前协议不支持此操作")
}
func SafetyRuleTrigger(c echo.Context) {
logrus.Warnf("IP %v 尝试进行攻击请ban掉此IP", c.RealIP())
security := model.AccessSecurity{
ID: utils.UUID(),
Source: "安全规则触发",
IP: c.RealIP(),
Rule: constant.AccessRuleReject,
}
_ = model.CreateNewSecurity(&security)
}
func SessionMkDirEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
if err := tun.Subject.NextTerminal.SftpClient.Mkdir(remoteDir); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(remoteDir, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
if err := os.MkdirAll(path.Join(drivePath, remoteDir), os.ModePerm); err != nil {
return err
}
return Success(c, nil)
}
return errors.New("当前协议不支持此操作")
}
func SessionRmEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
key := c.QueryParam("key")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
sftpClient := tun.Subject.NextTerminal.SftpClient
stat, err := sftpClient.Stat(key)
if err != nil {
return err
}
if stat.IsDir() {
fileInfos, err := sftpClient.ReadDir(key)
if err != nil {
return err
}
for i := range fileInfos {
if err := sftpClient.Remove(path.Join(key, fileInfos[i].Name())); err != nil {
return err
}
}
if err := sftpClient.RemoveDirectory(key); err != nil {
return err
}
} else {
if err := sftpClient.Remove(key); err != nil {
return err
}
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(key, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
if err := os.RemoveAll(path.Join(drivePath, key)); err != nil {
return err
}
return Success(c, nil)
}
return errors.New("当前协议不支持此操作")
}
func SessionRenameEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
oldName := c.QueryParam("oldName")
newName := c.QueryParam("newName")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
sftpClient := tun.Subject.NextTerminal.SftpClient
if err := sftpClient.Rename(oldName, newName); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(oldName, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := model.GetDrivePath()
if err != nil {
return err
}
if err := os.Rename(path.Join(drivePath, oldName), path.Join(drivePath, newName)); err != nil {
return err
}
return Success(c, nil)
}
return errors.New("当前协议不支持此操作")
}
func SessionRecordingEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId)
if err != nil {
return err
}
var recording string
if session.Mode == constant.Naive {
recording = session.Recording
} else {
recording = session.Recording + "/recording"
}
logrus.Debugf("读取录屏文件:%v,是否存在: %v, 是否为文件: %v", recording, utils.FileExists(recording), utils.IsFile(recording))
return c.File(recording)
}

264
server/api/ssh.go Normal file
View File

@ -0,0 +1,264 @@
package api
import (
"encoding/json"
"net/http"
"path"
"strconv"
"time"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"next-terminal/server/model"
"next-terminal/server/term"
"next-terminal/server/utils"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
)
var UpGrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
Subprotocols: []string{"guacamole"},
}
const (
Connected = "connected"
Data = "data"
Resize = "resize"
Closed = "closed"
)
type Message struct {
Type string `json:"type"`
Content string `json:"content"`
}
type WindowSize struct {
Cols int `json:"cols"`
Rows int `json:"rows"`
}
func SSHEndpoint(c echo.Context) (err error) {
ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil)
if err != nil {
logrus.Errorf("升级为WebSocket协议失败%v", err.Error())
return err
}
sessionId := c.QueryParam("sessionId")
cols, _ := strconv.Atoi(c.QueryParam("cols"))
rows, _ := strconv.Atoi(c.QueryParam("rows"))
session, err := model.FindSessionById(sessionId)
if err != nil {
msg := Message{
Type: Closed,
Content: "get sshSession error." + err.Error(),
}
_ = WriteMessage(ws, msg)
return err
}
user, _ := GetCurrentAccount(c)
if constant.TypeUser == user.Type {
// 检测是否有访问权限
assetIds, err := model.FindAssetIdsByUserId(user.ID)
if err != nil {
return err
}
if !utils.Contains(assetIds, session.AssetId) {
msg := Message{
Type: Closed,
Content: "您没有权限访问此资产",
}
return WriteMessage(ws, msg)
}
}
var (
username = session.Username
password = session.Password
privateKey = session.PrivateKey
passphrase = session.Passphrase
ip = session.IP
port = session.Port
)
recording := ""
propertyMap := model.FindAllPropertiesMap()
if propertyMap[guacd.EnableRecording] == "true" {
recording = path.Join(propertyMap[guacd.RecordingPath], sessionId, "recording.cast")
}
tun := global.Tun{
Protocol: session.Protocol,
Mode: session.Mode,
WebSocket: ws,
}
if session.ConnectionId != "" {
// 监控会话
observable, ok := global.Store.Get(sessionId)
if ok {
observers := append(observable.Observers, tun)
observable.Observers = observers
global.Store.Set(sessionId, observable)
logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers))
}
return err
}
nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording)
if err != nil {
logrus.Errorf("创建SSH客户端失败%v", err.Error())
msg := Message{
Type: Closed,
Content: err.Error(),
}
err := WriteMessage(ws, msg)
return err
}
tun.NextTerminal = nextTerminal
var observers []global.Tun
observable := global.Observable{
Subject: &tun,
Observers: observers,
}
global.Store.Set(sessionId, &observable)
sess := model.Session{
ConnectionId: sessionId,
Width: cols,
Height: rows,
Status: constant.Connecting,
Recording: recording,
}
// 创建新会话
logrus.Debugf("创建新会话 %v", sess.ConnectionId)
if err := model.UpdateSessionById(&sess, sessionId); err != nil {
return err
}
msg := Message{
Type: Connected,
Content: "",
}
_ = WriteMessage(ws, msg)
quitChan := make(chan bool)
go ReadMessage(nextTerminal, quitChan, ws)
for {
_, message, err := ws.ReadMessage()
if err != nil {
// web socket会话关闭后主动关闭ssh会话
CloseSessionById(sessionId, Normal, "正常退出")
quitChan <- true
quitChan <- true
break
}
var msg Message
err = json.Unmarshal(message, &msg)
if err != nil {
logrus.Warnf("解析Json失败: %v, 原始字符串:%v", err, string(message))
continue
}
switch msg.Type {
case Resize:
var winSize WindowSize
err = json.Unmarshal([]byte(msg.Content), &winSize)
if err != nil {
logrus.Warnf("解析SSH会话窗口大小失败: %v", err)
continue
}
if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil {
logrus.Warnf("更改SSH会话窗口大小失败: %v", err)
continue
}
case Data:
_, err = nextTerminal.Write([]byte(msg.Content))
if err != nil {
logrus.Debugf("SSH会话写入失败: %v", err)
msg := Message{
Type: Closed,
Content: "the remote connection is closed.",
}
_ = WriteMessage(ws, msg)
}
}
}
return err
}
func ReadMessage(nextTerminal *term.NextTerminal, quitChan chan bool, ws *websocket.Conn) {
var quit bool
for {
select {
case quit = <-quitChan:
if quit {
return
}
default:
p, n, err := nextTerminal.Read()
if err != nil {
msg := Message{
Type: Closed,
Content: err.Error(),
}
_ = WriteMessage(ws, msg)
}
if n > 0 {
s := string(p)
msg := Message{
Type: Data,
Content: s,
}
_ = WriteMessage(ws, msg)
}
time.Sleep(time.Duration(10) * time.Millisecond)
}
}
}
func WriteMessage(ws *websocket.Conn, msg Message) error {
message, err := json.Marshal(msg)
if err != nil {
return err
}
WriteByteMessage(ws, message)
return err
}
func WriteByteMessage(ws *websocket.Conn, p []byte) {
err := ws.WriteMessage(websocket.TextMessage, p)
if err != nil {
logrus.Debugf("write: %v", err)
}
}
func CreateNextTerminalBySession(session model.Session) (*term.NextTerminal, error) {
var (
username = session.Username
password = session.Password
privateKey = session.PrivateKey
passphrase = session.Passphrase
ip = session.IP
port = session.Port
)
return term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, 10, 10, "")
}

251
server/api/tunnel.go Normal file
View File

@ -0,0 +1,251 @@
package api
import (
"errors"
"path"
"strconv"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"next-terminal/server/model"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
)
const (
TunnelClosed int = -1
Normal int = 0
NotFoundSession int = 800
NewTunnelError int = 801
ForcedDisconnect int = 802
)
func TunEndpoint(c echo.Context) error {
ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil)
if err != nil {
logrus.Errorf("升级为WebSocket协议失败%v", err.Error())
return err
}
width := c.QueryParam("width")
height := c.QueryParam("height")
dpi := c.QueryParam("dpi")
sessionId := c.QueryParam("sessionId")
connectionId := c.QueryParam("connectionId")
intWidth, _ := strconv.Atoi(width)
intHeight, _ := strconv.Atoi(height)
configuration := guacd.NewConfiguration()
propertyMap := model.FindAllPropertiesMap()
var session model.Session
if len(connectionId) > 0 {
session, err = model.FindSessionByConnectionId(connectionId)
if err != nil {
logrus.Warnf("会话不存在")
return err
}
if session.Status != constant.Connected {
logrus.Warnf("会话未在线")
return errors.New("会话未在线")
}
configuration.ConnectionID = connectionId
sessionId = session.ID
configuration.SetParameter("width", strconv.Itoa(session.Width))
configuration.SetParameter("height", strconv.Itoa(session.Height))
configuration.SetParameter("dpi", "96")
} else {
configuration.SetParameter("width", width)
configuration.SetParameter("height", height)
configuration.SetParameter("dpi", dpi)
session, err = model.FindSessionById(sessionId)
if err != nil {
CloseSessionById(sessionId, NotFoundSession, "会话不存在")
return err
}
if propertyMap[guacd.EnableRecording] == "true" {
configuration.SetParameter(guacd.RecordingPath, path.Join(propertyMap[guacd.RecordingPath], sessionId))
configuration.SetParameter(guacd.CreateRecordingPath, propertyMap[guacd.CreateRecordingPath])
} else {
configuration.SetParameter(guacd.RecordingPath, "")
}
configuration.Protocol = session.Protocol
switch configuration.Protocol {
case "rdp":
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
configuration.SetParameter("security", "any")
configuration.SetParameter("ignore-cert", "true")
configuration.SetParameter("create-drive-path", "true")
configuration.SetParameter("resize-method", "reconnect")
configuration.SetParameter(guacd.EnableDrive, propertyMap[guacd.EnableDrive])
configuration.SetParameter(guacd.DriveName, propertyMap[guacd.DriveName])
configuration.SetParameter(guacd.DrivePath, propertyMap[guacd.DrivePath])
configuration.SetParameter(guacd.EnableWallpaper, propertyMap[guacd.EnableWallpaper])
configuration.SetParameter(guacd.EnableTheming, propertyMap[guacd.EnableTheming])
configuration.SetParameter(guacd.EnableFontSmoothing, propertyMap[guacd.EnableFontSmoothing])
configuration.SetParameter(guacd.EnableFullWindowDrag, propertyMap[guacd.EnableFullWindowDrag])
configuration.SetParameter(guacd.EnableDesktopComposition, propertyMap[guacd.EnableDesktopComposition])
configuration.SetParameter(guacd.EnableMenuAnimations, propertyMap[guacd.EnableMenuAnimations])
configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching])
configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching])
configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching])
case "ssh":
if len(session.PrivateKey) > 0 && session.PrivateKey != "-" {
configuration.SetParameter("username", session.Username)
configuration.SetParameter("private-key", session.PrivateKey)
configuration.SetParameter("passphrase", session.Passphrase)
} else {
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
}
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
case "vnc":
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
case "telnet":
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
case "kubernetes":
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
default:
logrus.WithField("configuration.Protocol", configuration.Protocol).Error("UnSupport Protocol")
return Fail(c, 400, "不支持的协议")
}
configuration.SetParameter("hostname", session.IP)
configuration.SetParameter("port", strconv.Itoa(session.Port))
// 加载资产配置的属性,优先级比全局配置的高,因此最后加载,覆盖掉全局配置
attributes, _ := model.FindAssetAttributeByAssetId(session.AssetId)
if len(attributes) > 0 {
for i := range attributes {
attribute := attributes[i]
configuration.SetParameter(attribute.Name, attribute.Value)
}
}
}
for name := range configuration.Parameters {
// 替换数据库空格字符串占位符为真正的空格
if configuration.Parameters[name] == "-" {
configuration.Parameters[name] = ""
}
}
addr := propertyMap[guacd.Host] + ":" + propertyMap[guacd.Port]
tunnel, err := guacd.NewTunnel(addr, configuration)
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, NewTunnelError, err.Error())
}
logrus.Printf("建立连接失败: %v", err.Error())
return err
}
tun := global.Tun{
Protocol: session.Protocol,
Mode: session.Mode,
WebSocket: ws,
Tunnel: tunnel,
}
if len(session.ConnectionId) == 0 {
var observers []global.Tun
observable := global.Observable{
Subject: &tun,
Observers: observers,
}
global.Store.Set(sessionId, &observable)
sess := model.Session{
ConnectionId: tunnel.UUID,
Width: intWidth,
Height: intHeight,
Status: constant.Connecting,
Recording: configuration.GetParameter(guacd.RecordingPath),
}
// 创建新会话
logrus.Debugf("创建新会话 %v", sess.ConnectionId)
if err := model.UpdateSessionById(&sess, sessionId); err != nil {
return err
}
} else {
// 监控会话
observable, ok := global.Store.Get(sessionId)
if ok {
observers := append(observable.Observers, tun)
observable.Observers = observers
global.Store.Set(sessionId, observable)
logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers))
}
}
go func() {
for {
instruction, err := tunnel.Read()
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, TunnelClosed, "远程连接关闭")
}
break
}
if len(instruction) == 0 {
continue
}
err = ws.WriteMessage(websocket.TextMessage, instruction)
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, Normal, "正常退出")
}
break
}
}
}()
for {
_, message, err := ws.ReadMessage()
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, Normal, "正常退出")
}
break
}
_, err = tunnel.WriteAndFlush(message)
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, TunnelClosed, "远程连接关闭")
}
break
}
}
return err
}

136
server/api/user-group.go Normal file
View File

@ -0,0 +1,136 @@
package api
import (
"strconv"
"strings"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
type UserGroup struct {
Id string `json:"id"`
Name string `json:"name"`
Members []string `json:"members"`
}
func UserGroupCreateEndpoint(c echo.Context) error {
var item UserGroup
if err := c.Bind(&item); err != nil {
return err
}
userGroup := model.UserGroup{
ID: utils.UUID(),
Created: utils.NowJsonTime(),
Name: item.Name,
}
if err := model.CreateNewUserGroup(&userGroup, item.Members); err != nil {
return err
}
return Success(c, item)
}
func UserGroupPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := model.FindPageUserGroup(pageIndex, pageSize, name, order, field)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func UserGroupUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
var item UserGroup
if err := c.Bind(&item); err != nil {
return err
}
userGroup := model.UserGroup{
Name: item.Name,
}
if err := model.UpdateUserGroupById(&userGroup, item.Members, id); err != nil {
return err
}
return Success(c, nil)
}
func UserGroupDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
split := strings.Split(ids, ",")
for i := range split {
userId := split[i]
model.DeleteUserGroupById(userId)
}
return Success(c, nil)
}
func UserGroupGetEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := model.FindUserGroupById(id)
if err != nil {
return err
}
members, err := model.FindUserGroupMembersByUserGroupId(id)
if err != nil {
return err
}
userGroup := UserGroup{
Id: item.ID,
Name: item.Name,
Members: members,
}
return Success(c, userGroup)
}
func UserGroupAddMembersEndpoint(c echo.Context) error {
id := c.Param("id")
var items []string
if err := c.Bind(&items); err != nil {
return err
}
if err := model.AddUserGroupMembers(global.DB, items, id); err != nil {
return err
}
return Success(c, "")
}
func UserGroupDelMembersEndpoint(c echo.Context) (err error) {
id := c.Param("id")
memberIdsStr := c.Param("memberId")
memberIds := strings.Split(memberIdsStr, ",")
for i := range memberIds {
memberId := memberIds[i]
err = global.DB.Where("user_group_id = ? and user_id = ?", id, memberId).Delete(&model.UserGroupMember{}).Error
if err != nil {
return err
}
}
return Success(c, "")
}

163
server/api/user.go Normal file
View File

@ -0,0 +1,163 @@
package api
import (
"strconv"
"strings"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
)
func UserCreateEndpoint(c echo.Context) error {
var item model.User
if err := c.Bind(&item); err != nil {
return err
}
password := item.Password
var pass []byte
var err error
if pass, err = utils.Encoder.Encode([]byte(password)); err != nil {
return err
}
item.Password = string(pass)
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := userRepository.Create(&item); err != nil {
return err
}
if item.Mail != "" {
go model.SendMail(item.Mail, "[Next Terminal] 注册通知", "你好,"+item.Nickname+"。管理员为你注册了账号:"+item.Username+" 密码:"+password)
}
return Success(c, item)
}
func UserPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
username := c.QueryParam("username")
nickname := c.QueryParam("nickname")
mail := c.QueryParam("mail")
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := userRepository.Find(pageIndex, pageSize, username, nickname, mail, order, field)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func UserUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
var item model.User
if err := c.Bind(&item); err != nil {
return err
}
item.ID = id
if err := userRepository.Update(&item); err != nil {
return err
}
return Success(c, nil)
}
func UserDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
account, found := GetCurrentAccount(c)
if !found {
return Fail(c, -1, "获取当前登录账户失败")
}
split := strings.Split(ids, ",")
for i := range split {
userId := split[i]
if account.ID == userId {
return Fail(c, -1, "不允许删除自身账户")
}
// 将用户强制下线
loginLogs, err := model.FindAliveLoginLogsByUserId(userId)
if err != nil {
return err
}
for j := range loginLogs {
global.Cache.Delete(loginLogs[j].ID)
if err := model.Logout(loginLogs[j].ID); err != nil {
logrus.WithError(err).WithField("id:", loginLogs[j].ID).Error("Cache Deleted Error")
return Fail(c, 500, "强制下线错误")
}
}
// 删除用户
if err := userRepository.DeleteById(userId); err != nil {
return err
}
}
return Success(c, nil)
}
func UserGetEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := userRepository.FindById(id)
if err != nil {
return err
}
return Success(c, item)
}
func UserChangePasswordEndpoint(c echo.Context) error {
id := c.Param("id")
password := c.QueryParam("password")
user, err := userRepository.FindById(id)
if err != nil {
return err
}
passwd, err := utils.Encoder.Encode([]byte(password))
if err != nil {
return err
}
u := &model.User{
Password: string(passwd),
ID: id,
}
if err := userRepository.Update(u); err != nil {
return err
}
if user.Mail != "" {
go model.SendMail(user.Mail, "[Next Terminal] 密码修改通知", "你好,"+user.Nickname+"。管理员已将你的密码修改为:"+password)
}
return Success(c, "")
}
func UserResetTotpEndpoint(c echo.Context) error {
id := c.Param("id")
u := &model.User{
TOTPSecret: "-",
ID: id,
}
if err := userRepository.Update(u); err != nil {
return err
}
return Success(c, "")
}

89
server/config/config.go Normal file
View File

@ -0,0 +1,89 @@
package config
import (
"strings"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
type Config struct {
Debug bool
Demo bool
DB string
Server *Server
Mysql *Mysql
Sqlite *Sqlite
ResetPassword string
}
type Mysql struct {
Hostname string
Port int
Username string
Password string
Database string
}
type Sqlite struct {
File string
}
type Server struct {
Addr string
Cert string
Key string
}
func SetupConfig() *Config {
viper.SetConfigName("config")
viper.SetConfigType("yml")
viper.AddConfigPath("/etc/next-terminal/")
viper.AddConfigPath("$HOME/.next-terminal")
viper.AddConfigPath(".")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
pflag.String("db", "sqlite", "db mode")
pflag.String("sqlite.file", "next-terminal.db", "sqlite db file")
pflag.String("mysql.hostname", "127.0.0.1", "mysql hostname")
pflag.Int("mysql.port", 3306, "mysql port")
pflag.String("mysql.username", "mysql", "mysql username")
pflag.String("mysql.password", "mysql", "mysql password")
pflag.String("mysql.database", "next_terminal", "mysql database")
pflag.String("server.addr", "", "server listen addr")
pflag.String("server.cert", "", "tls cert file")
pflag.String("server.key", "", "tls key file")
pflag.String("reset-password", "", "")
pflag.Parse()
_ = viper.BindPFlags(pflag.CommandLine)
_ = viper.ReadInConfig()
var config = &Config{
DB: viper.GetString("db"),
Mysql: &Mysql{
Hostname: viper.GetString("mysql.hostname"),
Port: viper.GetInt("mysql.port"),
Username: viper.GetString("mysql.username"),
Password: viper.GetString("mysql.password"),
Database: viper.GetString("mysql.database"),
},
Sqlite: &Sqlite{
File: viper.GetString("sqlite.file"),
},
Server: &Server{
Addr: viper.GetString("server.addr"),
Cert: viper.GetString("server.cert"),
Key: viper.GetString("server.key"),
},
ResetPassword: viper.GetString("reset-password"),
Debug: viper.GetBool("debug"),
Demo: viper.GetBool("demo"),
}
return config
}

33
server/constant/const.go Normal file
View File

@ -0,0 +1,33 @@
package constant
const (
AccessRuleAllow = "allow" // 允许访问
AccessRuleReject = "reject" // 拒绝访问
Custom = "custom" // 密码
PrivateKey = "private-key" // 密钥
JobStatusRunning = "running" // 计划任务运行状态
JobStatusNotRunning = "not-running" // 计划任务未运行状态
FuncCheckAssetStatusJob = "check-asset-status-job" // 检测资产是否在线
FuncShellJob = "shell-job" // 执行Shell脚本
JobModeAll = "all" // 全部资产
JobModeCustom = "custom" // 自定义选择资产
SshMode = "ssh-mode" // ssh模式
MailHost = "mail-host" // 邮件服务器地址
MailPort = "mail-port" // 邮件服务器端口
MailUsername = "mail-username" // 邮件服务账号
MailPassword = "mail-password" // 邮件服务密码
NoConnect = "no_connect" // 会话状态:未连接
Connecting = "connecting" // 会话状态:连接中
Connected = "connected" // 会话状态:已连接
Disconnected = "disconnected" // 会话状态:已断开连接
Guacd = "guacd" // 接入模式guacd
Naive = "naive" // 接入模式:原生
TypeUser = "user" // 普通用户
TypeAdmin = "admin" // 管理员
)

26
server/global/global.go Normal file
View File

@ -0,0 +1,26 @@
package global
import (
"next-terminal/server/config"
"github.com/patrickmn/go-cache"
"github.com/robfig/cron/v3"
"gorm.io/gorm"
)
var DB *gorm.DB
var Cache *cache.Cache
var Config *config.Config
var Store *TunStore
var Cron *cron.Cron
type Security struct {
Rule string
IP string
}
var Securities []*Security

71
server/global/store.go Normal file
View File

@ -0,0 +1,71 @@
package global
import (
"strconv"
"sync"
"next-terminal/server/guacd"
"next-terminal/server/term"
"github.com/gorilla/websocket"
)
type Tun struct {
Protocol string
Mode string
WebSocket *websocket.Conn
Tunnel *guacd.Tunnel
NextTerminal *term.NextTerminal
}
func (r *Tun) Close(code int, reason string) {
if r.Tunnel != nil {
_ = r.Tunnel.Close()
}
if r.NextTerminal != nil {
_ = r.NextTerminal.Close()
}
ws := r.WebSocket
if ws != nil {
if r.Mode == "guacd" {
err := guacd.NewInstruction("error", "", strconv.Itoa(code))
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.String()))
disconnect := guacd.NewInstruction("disconnect")
_ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String()))
} else {
msg := `{"type":"closed","content":"` + reason + `"}`
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg))
}
}
}
type Observable struct {
Subject *Tun
Observers []Tun
}
type TunStore struct {
m sync.Map
}
func (s *TunStore) Set(k string, v *Observable) {
s.m.Store(k, v)
}
func (s *TunStore) Del(k string) {
s.m.Delete(k)
}
func (s *TunStore) Get(k string) (item *Observable, ok bool) {
value, ok := s.m.Load(k)
if ok {
return value.(*Observable), true
}
return item, false
}
func NewStore() *TunStore {
store := TunStore{sync.Map{}}
return &store
}

292
server/guacd/guacd.go Normal file
View File

@ -0,0 +1,292 @@
package guacd
import (
"bufio"
"errors"
"fmt"
"net"
"strings"
)
const (
Host = "host"
Port = "port"
EnableRecording = "enable-recording"
RecordingPath = "recording-path"
CreateRecordingPath = "create-recording-path"
FontName = "font-name"
FontSize = "font-size"
ColorScheme = "color-scheme"
Backspace = "backspace"
TerminalType = "terminal-type"
EnableDrive = "enable-drive"
DriveName = "drive-name"
DrivePath = "drive-path"
EnableWallpaper = "enable-wallpaper"
EnableTheming = "enable-theming"
EnableFontSmoothing = "enable-font-smoothing"
EnableFullWindowDrag = "enable-full-window-drag"
EnableDesktopComposition = "enable-desktop-composition"
EnableMenuAnimations = "enable-menu-animations"
DisableBitmapCaching = "disable-bitmap-caching"
DisableOffscreenCaching = "disable-offscreen-caching"
DisableGlyphCaching = "disable-glyph-caching"
Domain = "domain"
RemoteApp = "remote-app"
RemoteAppDir = "remote-app-dir"
RemoteAppArgs = "remote-app-args"
ColorDepth = "color-depth"
Cursor = "cursor"
SwapRedBlue = "swap-red-blue"
DestHost = "dest-host"
DestPort = "dest-port"
UsernameRegex = "username-regex"
PasswordRegex = "password-regex"
LoginSuccessRegex = "login-success-regex"
LoginFailureRegex = "login-failure-regex"
Namespace = "namespace"
Pod = "pod"
Container = "container"
UesSSL = "use-ssl"
ClientCert = "client-cert"
ClientKey = "client-key"
CaCert = "ca-cert"
IgnoreCert = "ignore-cert"
)
const Delimiter = ';'
const Version = "VERSION_1_3_0"
type Configuration struct {
ConnectionID string
Protocol string
Parameters map[string]string
}
func NewConfiguration() (ret Configuration) {
ret.Parameters = make(map[string]string)
return ret
}
func (opt *Configuration) SetParameter(name, value string) {
opt.Parameters[name] = value
}
func (opt *Configuration) UnSetParameter(name string) {
delete(opt.Parameters, name)
}
func (opt *Configuration) GetParameter(name string) string {
return opt.Parameters[name]
}
type Instruction struct {
Opcode string
Args []string
ProtocolForm string
}
func NewInstruction(opcode string, args ...string) (ret Instruction) {
ret.Opcode = opcode
ret.Args = args
return ret
}
func (opt *Instruction) String() string {
if len(opt.ProtocolForm) > 0 {
return opt.ProtocolForm
}
opt.ProtocolForm = fmt.Sprintf("%d.%s", len(opt.Opcode), opt.Opcode)
for _, value := range opt.Args {
opt.ProtocolForm += fmt.Sprintf(",%d.%s", len(value), value)
}
opt.ProtocolForm += string(Delimiter)
return opt.ProtocolForm
}
func (opt *Instruction) Parse(content string) Instruction {
if strings.LastIndex(content, ";") > 0 {
content = strings.TrimRight(content, ";")
}
messages := strings.Split(content, ",")
var args = make([]string, len(messages))
for i := range messages {
lm := strings.Split(messages[i], ".")
args[i] = lm[1]
}
return NewInstruction(args[0], args[1:]...)
}
type Tunnel struct {
rw *bufio.ReadWriter
conn net.Conn
UUID string
Config Configuration
IsOpen bool
}
func NewTunnel(address string, config Configuration) (ret *Tunnel, err error) {
conn, err := net.Dial("tcp", address)
if err != nil {
return
}
ret = &Tunnel{}
ret.conn = conn
ret.rw = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
ret.Config = config
selectArg := config.ConnectionID
if selectArg == "" {
selectArg = config.Protocol
}
if err := ret.WriteInstructionAndFlush(NewInstruction("select", selectArg)); err != nil {
return nil, err
}
args, err := ret.expect("args")
if err != nil {
return
}
width := config.GetParameter("width")
height := config.GetParameter("height")
dpi := config.GetParameter("dpi")
// send size
if err := ret.WriteInstructionAndFlush(NewInstruction("size", width, height, dpi)); err != nil {
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("audio", "audio/L8", "audio/L16")); err != nil {
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("video")); err != nil {
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("image", "image/jpeg", "image/png", "image/webp")); err != nil {
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil {
return nil, err
}
parameters := make([]string, len(args.Args))
for i := range args.Args {
argName := args.Args[i]
if strings.Contains(argName, "VERSION") {
parameters[i] = Version
continue
}
parameters[i] = config.GetParameter(argName)
}
// send connect
if err := ret.WriteInstructionAndFlush(NewInstruction("connect", parameters...)); err != nil {
return nil, err
}
ready, err := ret.expect("ready")
if err != nil {
return
}
if len(ready.Args) == 0 {
return nil, errors.New("no connection id received")
}
ret.UUID = ready.Args[0]
ret.IsOpen = true
return ret, nil
}
func (opt *Tunnel) WriteInstructionAndFlush(instruction Instruction) error {
if _, err := opt.WriteAndFlush([]byte(instruction.String())); err != nil {
return err
}
return nil
}
func (opt *Tunnel) WriteInstruction(instruction Instruction) error {
if _, err := opt.Write([]byte(instruction.String())); err != nil {
return err
}
return nil
}
func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) {
//fmt.Printf("-> %v\n", string(p))
nn, err := opt.rw.Write(p)
if err != nil {
return nn, err
}
err = opt.rw.Flush()
if err != nil {
return nn, err
}
return nn, nil
}
func (opt *Tunnel) Write(p []byte) (int, error) {
//fmt.Printf("-> %v \n", string(p))
nn, err := opt.rw.Write(p)
if err != nil {
return nn, err
}
return nn, nil
}
func (opt *Tunnel) Flush() error {
return opt.rw.Flush()
}
func (opt *Tunnel) ReadInstruction() (instruction Instruction, err error) {
msg, err := opt.rw.ReadString(Delimiter)
//fmt.Printf("<- %v \n", msg)
if err != nil {
return instruction, err
}
return instruction.Parse(msg), err
}
func (opt *Tunnel) Read() (p []byte, err error) {
p, err = opt.rw.ReadBytes(Delimiter)
//fmt.Printf("<- %v \n", string(p))
s := string(p)
if s == "rate=44100,channels=2;" {
return make([]byte, 0), nil
}
if s == "rate=22050,channels=2;" {
return make([]byte, 0), nil
}
if s == "5.audio,1.1,31.audio/L16;" {
s += "rate=44100,channels=2;"
}
return []byte(s), err
}
func (opt *Tunnel) expect(opcode string) (instruction Instruction, err error) {
instruction, err = opt.ReadInstruction()
if err != nil {
return instruction, err
}
if opcode != instruction.Opcode {
msg := fmt.Sprintf(`expected "%s" instruction but instead received "%s"`, opcode, instruction.Opcode)
return instruction, errors.New(msg)
}
return instruction, nil
}
func (opt *Tunnel) Close() error {
opt.IsOpen = false
return opt.conn.Close()
}

303
server/handle/runner.go Normal file
View File

@ -0,0 +1,303 @@
package handle
import (
"os"
"strconv"
"time"
"next-terminal/server/constant"
"next-terminal/server/guacd"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/sirupsen/logrus"
)
func RunTicker() {
// 每隔一小时删除一次未使用的会话信息
unUsedSessionTicker := time.NewTicker(time.Minute * 60)
go func() {
for range unUsedSessionTicker.C {
sessions, _ := model.FindSessionByStatusIn([]string{constant.NoConnect, constant.Connecting})
if len(sessions) > 0 {
now := time.Now()
for i := range sessions {
if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 {
_ = model.DeleteSessionById(sessions[i].ID)
s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port)
logrus.Infof("会话「%v」ID「%v」超过1小时未打开已删除。", s, sessions[i].ID)
}
}
}
}
}()
// 每日凌晨删除超过时长限制的会话
timeoutSessionTicker := time.NewTicker(time.Hour * 24)
go func() {
for range timeoutSessionTicker.C {
property, err := model.FindPropertyByName("session-saved-limit")
if err != nil {
return
}
if property.Value == "" || property.Value == "-" {
return
}
limit, err := strconv.Atoi(property.Value)
if err != nil {
return
}
sessions, err := model.FindOutTimeSessions(limit)
if err != nil {
return
}
if len(sessions) > 0 {
var sessionIds []string
for i := range sessions {
sessionIds = append(sessionIds, sessions[i].ID)
}
err := model.DeleteSessionByIds(sessionIds)
if err != nil {
logrus.Errorf("删除离线会话失败 %v", err)
}
}
}
}()
}
func RunDataFix() {
sessions, _ := model.FindSessionByStatus(constant.Connected)
if sessions == nil {
return
}
for i := range sessions {
session := model.Session{
Status: constant.Disconnected,
DisconnectedTime: utils.NowJsonTime(),
}
_ = model.UpdateSessionById(&session, sessions[i].ID)
}
}
func InitProperties() error {
propertyMap := model.FindAllPropertiesMap()
if len(propertyMap[guacd.Host]) == 0 {
property := model.Property{
Name: guacd.Host,
Value: "127.0.0.1",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.Port]) == 0 {
property := model.Property{
Name: guacd.Port,
Value: "4822",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableRecording]) == 0 {
property := model.Property{
Name: guacd.EnableRecording,
Value: "true",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.RecordingPath]) == 0 {
path, _ := os.Getwd()
property := model.Property{
Name: guacd.RecordingPath,
Value: path + "/recording/",
}
if !utils.FileExists(property.Value) {
if err := os.Mkdir(property.Value, os.ModePerm); err != nil {
return err
}
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.CreateRecordingPath]) == 0 {
property := model.Property{
Name: guacd.CreateRecordingPath,
Value: "true",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DriveName]) == 0 {
property := model.Property{
Name: guacd.DriveName,
Value: "File-System",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DrivePath]) == 0 {
path, _ := os.Getwd()
property := model.Property{
Name: guacd.DrivePath,
Value: path + "/drive/",
}
if !utils.FileExists(property.Value) {
if err := os.Mkdir(property.Value, os.ModePerm); err != nil {
return err
}
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.FontName]) == 0 {
property := model.Property{
Name: guacd.FontName,
Value: "menlo",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.FontSize]) == 0 {
property := model.Property{
Name: guacd.FontSize,
Value: "12",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.ColorScheme]) == 0 {
property := model.Property{
Name: guacd.ColorScheme,
Value: "gray-black",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableDrive]) == 0 {
property := model.Property{
Name: guacd.EnableDrive,
Value: "true",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableWallpaper]) == 0 {
property := model.Property{
Name: guacd.EnableWallpaper,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableTheming]) == 0 {
property := model.Property{
Name: guacd.EnableTheming,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableFontSmoothing]) == 0 {
property := model.Property{
Name: guacd.EnableFontSmoothing,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableFullWindowDrag]) == 0 {
property := model.Property{
Name: guacd.EnableFullWindowDrag,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableDesktopComposition]) == 0 {
property := model.Property{
Name: guacd.EnableDesktopComposition,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableMenuAnimations]) == 0 {
property := model.Property{
Name: guacd.EnableMenuAnimations,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DisableBitmapCaching]) == 0 {
property := model.Property{
Name: guacd.DisableBitmapCaching,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DisableOffscreenCaching]) == 0 {
property := model.Property{
Name: guacd.DisableOffscreenCaching,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DisableGlyphCaching]) == 0 {
property := model.Property{
Name: guacd.DisableGlyphCaching,
Value: "false",
}
if err := model.CreateNewProperty(&property); err != nil {
return err
}
}
return nil
}

193
server/log/logger.go Normal file
View File

@ -0,0 +1,193 @@
package log
import (
"io"
"strconv"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/sirupsen/logrus"
)
// Logrus : implement Logger
type Logrus struct {
*logrus.Logger
}
// Logger ...
var Logger = logrus.New()
// GetEchoLogger for e.Logger
func GetEchoLogger() Logrus {
return Logrus{Logger}
}
// Level returns logger level
func (l Logrus) Level() log.Lvl {
switch l.Logger.Level {
case logrus.DebugLevel:
return log.DEBUG
case logrus.WarnLevel:
return log.WARN
case logrus.ErrorLevel:
return log.ERROR
case logrus.InfoLevel:
return log.INFO
default:
l.Panic("Invalid level")
}
return log.OFF
}
// SetHeader is a stub to satisfy interface
// It's controlled by Logger
func (l Logrus) SetHeader(_ string) {}
// SetPrefix It's controlled by Logger
func (l Logrus) SetPrefix(s string) {}
// Prefix It's controlled by Logger
func (l Logrus) Prefix() string {
return ""
}
// SetLevel set level to logger from given log.Lvl
func (l Logrus) SetLevel(lvl log.Lvl) {
switch lvl {
case log.DEBUG:
Logger.SetLevel(logrus.DebugLevel)
case log.WARN:
Logger.SetLevel(logrus.WarnLevel)
case log.ERROR:
Logger.SetLevel(logrus.ErrorLevel)
case log.INFO:
Logger.SetLevel(logrus.InfoLevel)
default:
l.Panic("Invalid level")
}
}
// Output logger output func
func (l Logrus) Output() io.Writer {
return l.Out
}
// SetOutput change output, default os.Stdout
func (l Logrus) SetOutput(w io.Writer) {
Logger.SetOutput(w)
}
// Printj print json log
func (l Logrus) Printj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Print()
}
// Debugj debug json log
func (l Logrus) Debugj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Debug()
}
// Infoj info json log
func (l Logrus) Infoj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Info()
}
// Warnj warning json log
func (l Logrus) Warnj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Warn()
}
// Errorj error json log
func (l Logrus) Errorj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Error()
}
// Fatalj fatal json log
func (l Logrus) Fatalj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Fatal()
}
// Panicj panic json log
func (l Logrus) Panicj(j log.JSON) {
Logger.WithFields(logrus.Fields(j)).Panic()
}
// Print string log
func (l Logrus) Print(i ...interface{}) {
Logger.Print(i[0].(string))
}
// Debug string log
func (l Logrus) Debug(i ...interface{}) {
Logger.Debug(i[0].(string))
}
// Info string log
func (l Logrus) Info(i ...interface{}) {
Logger.Info(i[0].(string))
}
// Warn string log
func (l Logrus) Warn(i ...interface{}) {
Logger.Warn(i[0].(string))
}
// Error string log
func (l Logrus) Error(i ...interface{}) {
Logger.Error(i[0].(string))
}
// Fatal string log
func (l Logrus) Fatal(i ...interface{}) {
Logger.Fatal(i[0].(string))
}
// Panic string log
func (l Logrus) Panic(i ...interface{}) {
Logger.Panic(i[0].(string))
}
func logrusMiddlewareHandler(c echo.Context, next echo.HandlerFunc) error {
req := c.Request()
res := c.Response()
start := time.Now()
if err := next(c); err != nil {
c.Error(err)
}
stop := time.Now()
p := req.URL.Path
bytesIn := req.Header.Get(echo.HeaderContentLength)
Logger.WithFields(map[string]interface{}{
"time_rfc3339": time.Now().Format(time.RFC3339),
"remote_ip": c.RealIP(),
"host": req.Host,
"uri": req.RequestURI,
"method": req.Method,
"path": p,
"referer": req.Referer(),
"user_agent": req.UserAgent(),
"status": res.Status,
"latency": strconv.FormatInt(stop.Sub(start).Nanoseconds()/1000, 10),
"latency_human": stop.Sub(start).String(),
"bytes_in": bytesIn,
"bytes_out": strconv.FormatInt(res.Size, 10),
}).Info("Handled request")
return nil
}
func logger(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
return logrusMiddlewareHandler(c, next)
}
}
// Hook is a function to process log.
func Hook() echo.MiddlewareFunc {
return logger
}

View File

@ -0,0 +1,83 @@
package model
import (
"next-terminal/server/global"
)
type AccessSecurity struct {
ID string `json:"id"`
Rule string `json:"rule"`
IP string `json:"ip"`
Source string `json:"source"`
Priority int64 `json:"priority"` // 越小优先级越高
}
func (r *AccessSecurity) TableName() string {
return "access_securities"
}
func FindAllAccessSecurities() (o []AccessSecurity, err error) {
db := global.DB
err = db.Order("priority asc").Find(&o).Error
return
}
func FindPageSecurity(pageIndex, pageSize int, ip, rule, order, field string) (o []AccessSecurity, total int64, err error) {
t := AccessSecurity{}
db := global.DB.Table(t.TableName())
dbCounter := global.DB.Table(t.TableName())
if len(ip) > 0 {
db = db.Where("ip like ?", "%"+ip+"%")
dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%")
}
if len(rule) > 0 {
db = db.Where("rule = ?", rule)
dbCounter = dbCounter.Where("rule = ?", rule)
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "descend" {
order = "desc"
} else {
order = "asc"
}
if field == "ip" {
field = "ip"
} else if field == "rule" {
field = "rule"
} else {
field = "priority"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]AccessSecurity, 0)
}
return
}
func CreateNewSecurity(o *AccessSecurity) error {
return global.DB.Create(o).Error
}
func UpdateSecurityById(o *AccessSecurity, id string) error {
o.ID = id
return global.DB.Updates(o).Error
}
func DeleteSecurityById(id string) error {
return global.DB.Where("id = ?", id).Delete(AccessSecurity{}).Error
}
func FindSecurityById(id string) (o *AccessSecurity, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}

239
server/model/asset.go Normal file
View File

@ -0,0 +1,239 @@
package model
import (
"strings"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Asset struct {
ID string `gorm:"primary_key " json:"id"`
Name string `json:"name"`
Protocol string `json:"protocol"`
IP string `json:"ip"`
Port int `json:"port"`
AccountType string `json:"accountType"`
Username string `json:"username"`
Password string `json:"password"`
CredentialId string `gorm:"index" json:"credentialId"`
PrivateKey string `json:"privateKey"`
Passphrase string `json:"passphrase"`
Description string `json:"description"`
Active bool `json:"active"`
Created utils.JsonTime `json:"created"`
Tags string `json:"tags"`
Owner string `gorm:"index" json:"owner"`
}
type AssetVo struct {
ID string `json:"id"`
Name string `json:"name"`
IP string `json:"ip"`
Protocol string `json:"protocol"`
Port int `json:"port"`
Active bool `json:"active"`
Created utils.JsonTime `json:"created"`
Tags string `json:"tags"`
Owner string `json:"owner"`
OwnerName string `json:"ownerName"`
SharerCount int64 `json:"sharerCount"`
}
func (r *Asset) TableName() string {
return "assets"
}
func FindAllAsset() (o []Asset, err error) {
err = global.DB.Find(&o).Error
return
}
func FindAssetByIds(assetIds []string) (o []Asset, err error) {
err = global.DB.Where("id in ?", assetIds).Find(&o).Error
return
}
func FindAssetByProtocol(protocol string) (o []Asset, err error) {
err = global.DB.Where("protocol = ?", protocol).Find(&o).Error
return
}
func FindAssetByProtocolAndIds(protocol string, assetIds []string) (o []Asset, err error) {
err = global.DB.Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error
return
}
func FindAssetByConditions(protocol string, account User) (o []Asset, err error) {
db := global.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
}
if len(protocol) > 0 {
db = db.Where("assets.protocol = ?", protocol)
}
err = db.Find(&o).Error
return
}
func FindPageAsset(pageIndex, pageSize int, name, protocol, tags string, account User, owner, sharer, userGroupId, ip, order, field string) (o []AssetVo, total int64, err error) {
db := global.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
dbCounter := global.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
// 查询用户所在用户组列表
userGroupIds, err := FindUserGroupIdsByUserId(account.ID)
if err != nil {
return nil, 0, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
dbCounter = dbCounter.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
} else {
if len(owner) > 0 {
db = db.Where("assets.owner = ?", owner)
dbCounter = dbCounter.Where("assets.owner = ?", owner)
}
if len(sharer) > 0 {
db = db.Where("resource_sharers.user_id = ?", sharer)
dbCounter = dbCounter.Where("resource_sharers.user_id = ?", sharer)
}
if len(userGroupId) > 0 {
db = db.Where("resource_sharers.user_group_id = ?", userGroupId)
dbCounter = dbCounter.Where("resource_sharers.user_group_id = ?", userGroupId)
}
}
if len(name) > 0 {
db = db.Where("assets.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("assets.name like ?", "%"+name+"%")
}
if len(ip) > 0 {
db = db.Where("assets.ip like ?", "%"+ip+"%")
dbCounter = dbCounter.Where("assets.ip like ?", "%"+ip+"%")
}
if len(protocol) > 0 {
db = db.Where("assets.protocol = ?", protocol)
dbCounter = dbCounter.Where("assets.protocol = ?", protocol)
}
if len(tags) > 0 {
tagArr := strings.Split(tags, ",")
for i := range tagArr {
if global.Config.DB == "sqlite" {
db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
} else {
db = db.Where("find_in_set(?, assets.tags)", tagArr[i])
dbCounter = dbCounter.Where("find_in_set(?, assets.tags)", tagArr[i])
}
}
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("assets." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]AssetVo, 0)
}
return
}
func CreateNewAsset(o *Asset) (err error) {
if err = global.DB.Create(o).Error; err != nil {
return err
}
return nil
}
func FindAssetById(id string) (o Asset, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func UpdateAssetById(o *Asset, id string) {
o.ID = id
global.DB.Updates(o)
}
func UpdateAssetActiveById(active bool, id string) {
sql := "update assets set active = ? where id = ?"
global.DB.Exec(sql, active, id)
}
func DeleteAssetById(id string) error {
return global.DB.Where("id = ?", id).Delete(&Asset{}).Error
}
func CountAsset() (total int64, err error) {
err = global.DB.Find(&Asset{}).Count(&total).Error
return
}
func CountAssetByUserId(userId string) (total int64, err error) {
db := global.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id")
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId)
// 查询用户所在用户组列表
userGroupIds, err := FindUserGroupIdsByUserId(userId)
if err != nil {
return 0, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
err = db.Find(&Asset{}).Count(&total).Error
return
}
func FindAssetTags() (o []string, err error) {
var assets []Asset
err = global.DB.Not("tags = ?", "").Find(&assets).Error
if err != nil {
return nil, err
}
o = make([]string, 0)
for i := range assets {
if len(assets[i].Tags) == 0 {
continue
}
split := strings.Split(assets[i].Tags, ",")
o = append(o, split...)
}
return utils.Distinct(o), nil
}

View File

@ -0,0 +1,119 @@
package model
import (
"fmt"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
type AssetAttribute struct {
Id string `gorm:"index" json:"id"`
AssetId string `gorm:"index" json:"assetId"`
Name string `gorm:"index" json:"name"`
Value string `json:"value"`
}
func (r *AssetAttribute) TableName() string {
return "asset_attributes"
}
var SSHParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, constant.SshMode}
var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs}
var VNCParameterNames = []string{guacd.ColorDepth, guacd.Cursor, guacd.SwapRedBlue, guacd.DestHost, guacd.DestPort}
var TelnetParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.UsernameRegex, guacd.PasswordRegex, guacd.LoginSuccessRegex, guacd.LoginFailureRegex}
var KubernetesParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.Namespace, guacd.Pod, guacd.Container, guacd.UesSSL, guacd.ClientCert, guacd.ClientKey, guacd.CaCert, guacd.IgnoreCert}
func UpdateAssetAttributes(assetId, protocol string, m echo.Map) error {
var data []AssetAttribute
var parameterNames []string
switch protocol {
case "ssh":
parameterNames = SSHParameterNames
case "rdp":
parameterNames = RDPParameterNames
case "vnc":
parameterNames = VNCParameterNames
case "telnet":
parameterNames = TelnetParameterNames
case "kubernetes":
parameterNames = KubernetesParameterNames
}
for i := range parameterNames {
name := parameterNames[i]
if m[name] != nil && m[name] != "" {
data = append(data, genAttribute(assetId, name, m))
}
}
return global.DB.Transaction(func(tx *gorm.DB) error {
err := tx.Where("asset_id = ?", assetId).Delete(&AssetAttribute{}).Error
if err != nil {
return err
}
return tx.CreateInBatches(&data, len(data)).Error
})
}
func genAttribute(assetId, name string, m echo.Map) AssetAttribute {
value := fmt.Sprintf("%v", m[name])
attribute := AssetAttribute{
Id: utils.Sign([]string{assetId, name}),
AssetId: assetId,
Name: name,
Value: value,
}
return attribute
}
func FindAssetAttributeByAssetId(assetId string) (o []AssetAttribute, err error) {
err = global.DB.Where("asset_id = ?", assetId).Find(&o).Error
if o == nil {
o = make([]AssetAttribute, 0)
}
return o, err
}
func FindAssetAttrMapByAssetId(assetId string) (map[string]interface{}, error) {
asset, err := FindAssetById(assetId)
if err != nil {
return nil, err
}
attributes, err := FindAssetAttributeByAssetId(assetId)
if err != nil {
return nil, err
}
var parameterNames []string
switch asset.Protocol {
case "ssh":
parameterNames = SSHParameterNames
case "rdp":
parameterNames = RDPParameterNames
case "vnc":
parameterNames = VNCParameterNames
case "telnet":
parameterNames = TelnetParameterNames
case "kubernetes":
parameterNames = KubernetesParameterNames
}
propertiesMap := FindAllPropertiesMap()
var attributeMap = make(map[string]interface{})
for name := range propertiesMap {
if utils.Contains(parameterNames, name) {
attributeMap[name] = propertiesMap[name]
}
}
for i := range attributes {
attributeMap[attributes[i].Name] = attributes[i].Value
}
return attributeMap, nil
}

95
server/model/command.go Normal file
View File

@ -0,0 +1,95 @@
package model
import (
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Command struct {
ID string `gorm:"primary_key" json:"id"`
Name string `json:"name"`
Content string `json:"content"`
Created utils.JsonTime `json:"created"`
Owner string `gorm:"index" json:"owner"`
}
type CommandVo struct {
ID string `gorm:"primary_key" json:"id"`
Name string `json:"name"`
Content string `json:"content"`
Created utils.JsonTime `json:"created"`
Owner string `json:"owner"`
OwnerName string `json:"ownerName"`
SharerCount int64 `json:"sharerCount"`
}
func (r *Command) TableName() string {
return "commands"
}
func FindPageCommand(pageIndex, pageSize int, name, content, order, field string, account User) (o []CommandVo, total int64, err error) {
db := global.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id")
dbCounter := global.DB.Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner)
dbCounter = dbCounter.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner)
}
if len(name) > 0 {
db = db.Where("commands.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("commands.name like ?", "%"+name+"%")
}
if len(content) > 0 {
db = db.Where("commands.content like ?", "%"+content+"%")
dbCounter = dbCounter.Where("commands.content like ?", "%"+content+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("commands." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]CommandVo, 0)
}
return
}
func CreateNewCommand(o *Command) (err error) {
if err = global.DB.Create(o).Error; err != nil {
return err
}
return nil
}
func FindCommandById(id string) (o Command, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func UpdateCommandById(o *Command, id string) {
o.ID = id
global.DB.Updates(o)
}
func DeleteCommandById(id string) error {
return global.DB.Where("id = ?", id).Delete(&Command{}).Error
}

131
server/model/credential.go Normal file
View File

@ -0,0 +1,131 @@
package model
import (
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Credential struct {
ID string `gorm:"primary_key" json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Username string `json:"username"`
Password string `json:"password"`
PrivateKey string `json:"privateKey"`
Passphrase string `json:"passphrase"`
Created utils.JsonTime `json:"created"`
Owner string `gorm:"index" json:"owner"`
}
func (r *Credential) TableName() string {
return "credentials"
}
type CredentialVo struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Username string `json:"username"`
Created utils.JsonTime `json:"created"`
Owner string `json:"owner"`
OwnerName string `json:"ownerName"`
SharerCount int64 `json:"sharerCount"`
}
type CredentialSimpleVo struct {
ID string `json:"id"`
Name string `json:"name"`
}
func FindAllCredential(account User) (o []CredentialSimpleVo, err error) {
db := global.DB.Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id")
if account.Type == constant.TypeUser {
db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", account.ID, account.ID)
}
err = db.Find(&o).Error
return
}
func FindPageCredential(pageIndex, pageSize int, name, order, field string, account User) (o []CredentialVo, total int64, err error) {
db := global.DB.Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id")
dbCounter := global.DB.Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner)
dbCounter = dbCounter.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner)
}
if len(name) > 0 {
db = db.Where("credentials.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("credentials.name like ?", "%"+name+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("credentials." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]CredentialVo, 0)
}
return
}
func CreateNewCredential(o *Credential) (err error) {
if err = global.DB.Create(o).Error; err != nil {
return err
}
return nil
}
func FindCredentialById(id string) (o Credential, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func UpdateCredentialById(o *Credential, id string) {
o.ID = id
global.DB.Updates(o)
}
func DeleteCredentialById(id string) error {
return global.DB.Where("id = ?", id).Delete(&Credential{}).Error
}
func CountCredential() (total int64, err error) {
err = global.DB.Find(&Credential{}).Count(&total).Error
return
}
func CountCredentialByUserId(userId string) (total int64, err error) {
db := global.DB.Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id")
db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId)
// 查询用户所在用户组列表
userGroupIds, err := FindUserGroupIdsByUserId(userId)
if err != nil {
return 0, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
err = db.Find(&Credential{}).Count(&total).Error
return
}

356
server/model/job.go Normal file
View File

@ -0,0 +1,356 @@
package model
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/term"
"next-terminal/server/utils"
"github.com/robfig/cron/v3"
"github.com/sirupsen/logrus"
)
type Job struct {
ID string `gorm:"primary_key" json:"id"`
CronJobId int `json:"cronJobId"`
Name string `json:"name"`
Func string `json:"func"`
Cron string `json:"cron"`
Mode string `json:"mode"`
ResourceIds string `json:"resourceIds"`
Status string `json:"status"`
Metadata string `json:"metadata"`
Created utils.JsonTime `json:"created"`
Updated utils.JsonTime `json:"updated"`
}
func (r *Job) TableName() string {
return "jobs"
}
func FindPageJob(pageIndex, pageSize int, name, status, order, field string) (o []Job, total int64, err error) {
job := Job{}
db := global.DB.Table(job.TableName())
dbCounter := global.DB.Table(job.TableName())
if len(name) > 0 {
db = db.Where("name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("name like ?", "%"+name+"%")
}
if len(status) > 0 {
db = db.Where("status = ?", status)
dbCounter = dbCounter.Where("status = ?", status)
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else if field == "created" {
field = "created"
} else {
field = "updated"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]Job, 0)
}
return
}
func FindJobByFunc(function string) (o []Job, err error) {
db := global.DB
err = db.Where("func = ?", function).Find(&o).Error
return
}
func CreateNewJob(o *Job) (err error) {
if o.Status == constant.JobStatusRunning {
j, err := getJob(o)
if err != nil {
return err
}
jobId, err := global.Cron.AddJob(o.Cron, j)
if err != nil {
return err
}
o.CronJobId = int(jobId)
}
return global.DB.Create(o).Error
}
func UpdateJobById(o *Job, id string) (err error) {
if o.Status == constant.JobStatusRunning {
return errors.New("请先停止定时任务后再修改")
}
o.ID = id
return global.DB.Updates(o).Error
}
func UpdateJonUpdatedById(id string) (err error) {
err = global.DB.Updates(Job{ID: id, Updated: utils.NowJsonTime()}).Error
return
}
func ChangeJobStatusById(id, status string) (err error) {
var job Job
err = global.DB.Where("id = ?", id).First(&job).Error
if err != nil {
return err
}
if status == constant.JobStatusRunning {
j, err := getJob(&job)
if err != nil {
return err
}
entryID, err := global.Cron.AddJob(job.Cron, j)
if err != nil {
return err
}
logrus.Debugf("开启计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries()))
return global.DB.Updates(Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)}).Error
} else {
global.Cron.Remove(cron.EntryID(job.CronJobId))
logrus.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries()))
return global.DB.Updates(Job{ID: id, Status: constant.JobStatusNotRunning}).Error
}
}
func ExecJobById(id string) (err error) {
job, err := FindJobById(id)
if err != nil {
return err
}
j, err := getJob(&job)
if err != nil {
return err
}
j.Run()
return nil
}
func FindJobById(id string) (o Job, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func DeleteJobById(id string) error {
job, err := FindJobById(id)
if err != nil {
return err
}
if job.Status == constant.JobStatusRunning {
if err := ChangeJobStatusById(id, constant.JobStatusNotRunning); err != nil {
return err
}
}
return global.DB.Where("id = ?", id).Delete(Job{}).Error
}
func getJob(j *Job) (job cron.Job, err error) {
switch j.Func {
case constant.FuncCheckAssetStatusJob:
job = CheckAssetStatusJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata}
case constant.FuncShellJob:
job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata}
default:
return nil, errors.New("未识别的任务")
}
return job, err
}
type CheckAssetStatusJob struct {
ID string
Mode string
ResourceIds string
Metadata string
}
func (r CheckAssetStatusJob) Run() {
if r.ID == "" {
return
}
var assets []Asset
if r.Mode == constant.JobModeAll {
assets, _ = FindAllAsset()
} else {
assets, _ = FindAssetByIds(strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
msgChan := make(chan string)
for i := range assets {
asset := assets[i]
go func() {
t1 := time.Now()
active := utils.Tcping(asset.IP, asset.Port)
elapsed := time.Since(t1)
msg := fmt.Sprintf("资产「%v」存活状态检测完成存活「%v」耗时「%v」", asset.Name, active, elapsed)
UpdateAssetActiveById(active, asset.ID)
logrus.Infof(msg)
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = UpdateJonUpdatedById(r.ID)
jobLog := JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = CreateNewJobLog(&jobLog)
}
type ShellJob struct {
ID string
Mode string
ResourceIds string
Metadata string
}
type MetadataShell struct {
Shell string
}
func (r ShellJob) Run() {
if r.ID == "" {
return
}
var assets []Asset
if r.Mode == constant.JobModeAll {
assets, _ = FindAssetByProtocol("ssh")
} else {
assets, _ = FindAssetByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
var metadataShell MetadataShell
err := json.Unmarshal([]byte(r.Metadata), &metadataShell)
if err != nil {
logrus.Errorf("JSON数据解析失败 %v", err)
return
}
msgChan := make(chan string)
for i := range assets {
asset, err := FindAssetById(assets[i].ID)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询数据异常「%v」", assets[i].Name, err.Error())
return
}
var (
username = asset.Username
password = asset.Password
privateKey = asset.PrivateKey
passphrase = asset.Passphrase
ip = asset.IP
port = asset.Port
)
if asset.AccountType == "credential" {
credential, err := FindCredentialById(asset.CredentialId)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询授权凭证数据异常「%v」", assets[i].Name, err.Error())
return
}
if credential.Type == constant.Custom {
username = credential.Username
password = credential.Password
} else {
username = credential.Username
privateKey = credential.PrivateKey
passphrase = credential.Passphrase
}
}
go func() {
t1 := time.Now()
result, err := ExecCommandBySSH(metadataShell.Shell, ip, port, username, password, privateKey, passphrase)
elapsed := time.Since(t1)
var msg string
if err != nil {
msg = fmt.Sprintf("资产「%v」Shell执行失败返回值「%v」耗时「%v」", asset.Name, err.Error(), elapsed)
logrus.Infof(msg)
} else {
msg = fmt.Sprintf("资产「%v」Shell执行成功返回值「%v」耗时「%v」", asset.Name, result, elapsed)
logrus.Infof(msg)
}
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = UpdateJonUpdatedById(r.ID)
jobLog := JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = CreateNewJobLog(&jobLog)
}
func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) {
sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase)
if err != nil {
return "", err
}
session, err := sshClient.NewSession()
if err != nil {
return "", err
}
defer session.Close()
//执行远程命令
combo, err := session.CombinedOutput(cmd)
if err != nil {
return "", err
}
return string(combo), nil
}

30
server/model/job_log.go Normal file
View File

@ -0,0 +1,30 @@
package model
import (
"next-terminal/server/global"
"next-terminal/server/utils"
)
type JobLog struct {
ID string `json:"id"`
Timestamp utils.JsonTime `json:"timestamp"`
JobId string `json:"jobId"`
Message string `json:"message"`
}
func (r *JobLog) TableName() string {
return "job_logs"
}
func CreateNewJobLog(o *JobLog) error {
return global.DB.Create(o).Error
}
func FindJobLogs(jobId string) (o []JobLog, err error) {
err = global.DB.Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error
return
}
func DeleteJobLogByJobId(jobId string) error {
return global.DB.Where("job_id = ?", jobId).Delete(JobLog{}).Error
}

106
server/model/login_log.go Normal file
View File

@ -0,0 +1,106 @@
package model
import (
"next-terminal/server/global"
"next-terminal/server/utils"
)
type LoginLog struct {
ID string `gorm:"primary_key" json:"id"`
UserId string `gorm:"index" json:"userId"`
ClientIP string `json:"clientIp"`
ClientUserAgent string `json:"clientUserAgent"`
LoginTime utils.JsonTime `json:"loginTime"`
LogoutTime utils.JsonTime `json:"logoutTime"`
Remember bool `json:"remember"`
}
type LoginLogVo struct {
ID string `json:"id"`
UserId string `json:"userId"`
UserName string `json:"userName"`
ClientIP string `json:"clientIp"`
ClientUserAgent string `json:"clientUserAgent"`
LoginTime utils.JsonTime `json:"loginTime"`
LogoutTime utils.JsonTime `json:"logoutTime"`
Remember bool `json:"remember"`
}
func (r *LoginLog) TableName() string {
return "login_logs"
}
func FindPageLoginLog(pageIndex, pageSize int, userId, clientIp string) (o []LoginLogVo, total int64, err error) {
db := global.DB.Table("login_logs").Select("login_logs.id,login_logs.user_id,login_logs.client_ip,login_logs.client_user_agent,login_logs.login_time, login_logs.logout_time, users.nickname as user_name").Joins("left join users on login_logs.user_id = users.id")
dbCounter := global.DB.Table("login_logs").Select("DISTINCT login_logs.id")
if userId != "" {
db = db.Where("login_logs.user_id = ?", userId)
dbCounter = dbCounter.Where("login_logs.user_id = ?", userId)
}
if clientIp != "" {
db = db.Where("login_logs.client_ip like ?", "%"+clientIp+"%")
dbCounter = dbCounter.Where("login_logs.client_ip like ?", "%"+clientIp+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
err = db.Order("login_logs.login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]LoginLogVo, 0)
}
return
}
func FindAliveLoginLogs() (o []LoginLog, err error) {
err = global.DB.Where("logout_time is null").Find(&o).Error
return
}
func FindAliveLoginLogsByUserId(userId string) (o []LoginLog, err error) {
err = global.DB.Where("logout_time is null and user_id = ?", userId).Find(&o).Error
return
}
func CreateNewLoginLog(o *LoginLog) (err error) {
return global.DB.Create(o).Error
}
func DeleteLoginLogByIdIn(ids []string) (err error) {
return global.DB.Where("id in ?", ids).Delete(&LoginLog{}).Error
}
func FindLoginLogById(id string) (o LoginLog, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func Logout(token string) (err error) {
//
//loginLog, err := FindLoginLogById(token)
//if err != nil {
// logrus.Warnf("登录日志「%v」获取失败", token)
// return
//}
//
//err = global.DB.Updates(&LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}).Error
//if err != nil {
// return err
//}
//
//loginLogs, err := FindAliveLoginLogsByUserId(loginLog.UserId)
//if err != nil {
// return
//}
//
//if len(loginLogs) == 0 {
// // TODO
// err = UpdateUserOnline(false, loginLog.UserId)
//}
return
}

25
server/model/num.go Normal file
View File

@ -0,0 +1,25 @@
package model
import (
"next-terminal/server/global"
)
type Num struct {
I string `gorm:"primary_key" json:"i"`
}
func (r *Num) TableName() string {
return "nums"
}
func FindAllTemp() (o []Num) {
if global.DB.Find(&o).Error != nil {
return nil
}
return
}
func CreateNewTemp(o *Num) (err error) {
err = global.DB.Create(o).Error
return
}

91
server/model/property.go Normal file
View File

@ -0,0 +1,91 @@
package model
import (
"net/smtp"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"github.com/jordan-wright/email"
"github.com/sirupsen/logrus"
)
type Property struct {
Name string `gorm:"primary_key" json:"name"`
Value string `json:"value"`
}
func (r *Property) TableName() string {
return "properties"
}
func FindAllProperties() (o []Property) {
if global.DB.Find(&o).Error != nil {
return nil
}
return
}
func CreateNewProperty(o *Property) (err error) {
err = global.DB.Create(o).Error
return
}
func UpdatePropertyByName(o *Property, name string) {
o.Name = name
global.DB.Updates(o)
}
func FindPropertyByName(name string) (o Property, err error) {
err = global.DB.Where("name = ?", name).First(&o).Error
return
}
func FindAllPropertiesMap() map[string]string {
properties := FindAllProperties()
propertyMap := make(map[string]string)
for i := range properties {
propertyMap[properties[i].Name] = properties[i].Value
}
return propertyMap
}
func GetDrivePath() (string, error) {
property, err := FindPropertyByName(guacd.DrivePath)
if err != nil {
return "", err
}
return property.Value, nil
}
func GetRecordingPath() (string, error) {
property, err := FindPropertyByName(guacd.RecordingPath)
if err != nil {
return "", err
}
return property.Value, nil
}
func SendMail(to, subject, text string) {
propertiesMap := FindAllPropertiesMap()
host := propertiesMap[constant.MailHost]
port := propertiesMap[constant.MailPort]
username := propertiesMap[constant.MailUsername]
password := propertiesMap[constant.MailPassword]
if host == "" || port == "" || username == "" || password == "" {
logrus.Debugf("邮箱信息不完整,跳过发送邮件。")
return
}
e := email.NewEmail()
e.From = "Next Terminal <" + username + ">"
e.To = []string{to}
e.Subject = subject
e.Text = []byte(text)
err := e.Send(host+":"+port, smtp.PlainAuth("", username, password, host))
if err != nil {
logrus.Errorf("邮件发送失败: %v", err.Error())
}
}

View File

@ -0,0 +1,198 @@
package model
import (
"next-terminal/server/global"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"gorm.io/gorm"
)
type ResourceSharer struct {
ID string `gorm:"primary_key" json:"id"`
ResourceId string `gorm:"index" json:"resourceId"`
ResourceType string `gorm:"index" json:"resourceType"`
UserId string `gorm:"index" json:"userId"`
UserGroupId string `gorm:"index" json:"userGroupId"`
}
func (r *ResourceSharer) TableName() string {
return "resource_sharers"
}
func FindUserIdsByResourceId(resourceId string) (r []string, err error) {
db := global.DB
err = db.Table("resource_sharers").Select("user_id").Where("resource_id = ?", resourceId).Find(&r).Error
if r == nil {
r = make([]string, 0)
}
return
}
func OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) {
db := global.DB.Begin()
var owner string
// 检查资产是否存在
switch resourceType {
case "asset":
resource := Asset{}
err = db.Where("id = ?", resourceId).First(&resource).Error
owner = resource.Owner
case "command":
resource := Command{}
err = db.Where("id = ?", resourceId).First(&resource).Error
owner = resource.Owner
case "credential":
resource := Credential{}
err = db.Where("id = ?", resourceId).First(&resource).Error
owner = resource.Owner
}
if err == gorm.ErrRecordNotFound {
return echo.NewHTTPError(404, "资源「"+resourceId+"」不存在")
}
for i := range userIds {
if owner == userIds[i] {
return echo.NewHTTPError(400, "参数错误")
}
}
db.Where("resource_id = ?", resourceId).Delete(&ResourceSharer{})
for i := range userIds {
userId := userIds[i]
if len(userId) == 0 {
continue
}
id := utils.Sign([]string{resourceId, resourceType, userId})
resource := &ResourceSharer{
ID: id,
ResourceId: resourceId,
ResourceType: resourceType,
UserId: userId,
}
err = db.Create(resource).Error
if err != nil {
return err
}
}
db.Commit()
return nil
}
func DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceType string, resourceIds []string) error {
db := global.DB
if userGroupId != "" {
db = db.Where("user_group_id = ?", userGroupId)
}
if userId != "" {
db = db.Where("user_id = ?", userId)
}
if resourceType != "" {
db = db.Where("resource_type = ?", resourceType)
}
if resourceIds != nil {
db = db.Where("resource_id in ?", resourceIds)
}
return db.Delete(&ResourceSharer{}).Error
}
func DeleteResourceSharerByResourceId(resourceId string) error {
return global.DB.Where("resource_id = ?", resourceId).Delete(&ResourceSharer{}).Error
}
func AddSharerResources(userGroupId, userId, resourceType string, resourceIds []string) error {
return global.DB.Transaction(func(tx *gorm.DB) (err error) {
for i := range resourceIds {
resourceId := resourceIds[i]
var owner string
// 检查资产是否存在
switch resourceType {
case "asset":
resource := Asset{}
if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil {
return errors.Wrap(err, "find asset fail")
}
owner = resource.Owner
case "command":
resource := Command{}
if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil {
return errors.Wrap(err, "find command fail")
}
owner = resource.Owner
case "credential":
resource := Credential{}
if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil {
return errors.Wrap(err, "find credential fail")
}
owner = resource.Owner
}
if owner == userId {
return echo.NewHTTPError(400, "参数错误")
}
id := utils.Sign([]string{resourceId, resourceType, userId, userGroupId})
resource := &ResourceSharer{
ID: id,
ResourceId: resourceId,
ResourceType: resourceType,
UserId: userId,
UserGroupId: userGroupId,
}
err = tx.Create(resource).Error
if err != nil {
return err
}
}
return nil
})
}
func FindAssetIdsByUserId(userId string) (assetIds []string, err error) {
// 查询当前用户创建的资产
var ownerAssetIds, sharerAssetIds []string
asset := Asset{}
err = global.DB.Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error
if err != nil {
return nil, err
}
// 查询其他用户授权给该用户的资产
groupIds, err := FindUserGroupIdsByUserId(userId)
if err != nil {
return nil, err
}
db := global.DB.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId)
if len(groupIds) > 0 {
db = db.Or("user_group_id in ?", groupIds)
}
err = db.Find(&sharerAssetIds).Error
if err != nil {
return nil, err
}
// 合并查询到的资产ID
assetIds = make([]string, 0)
if ownerAssetIds != nil {
assetIds = append(assetIds, ownerAssetIds...)
}
if sharerAssetIds != nil {
assetIds = append(assetIds, sharerAssetIds...)
}
return
}

210
server/model/session.go Normal file
View File

@ -0,0 +1,210 @@
package model
import (
"os"
"path"
"time"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Session struct {
ID string `gorm:"primary_key" json:"id"`
Protocol string `json:"protocol"`
IP string `json:"ip"`
Port int `json:"port"`
ConnectionId string `json:"connectionId"`
AssetId string `gorm:"index" json:"assetId"`
Username string `json:"username"`
Password string `json:"password"`
Creator string `gorm:"index" json:"creator"`
ClientIP string `json:"clientIp"`
Width int `json:"width"`
Height int `json:"height"`
Status string `gorm:"index" json:"status"`
Recording string `json:"recording"`
PrivateKey string `json:"privateKey"`
Passphrase string `json:"passphrase"`
Code int `json:"code"`
Message string `json:"message"`
ConnectedTime utils.JsonTime `json:"connectedTime"`
DisconnectedTime utils.JsonTime `json:"disconnectedTime"`
Mode string `json:"mode"`
}
func (r *Session) TableName() string {
return "sessions"
}
type SessionVo struct {
ID string `json:"id"`
Protocol string `json:"protocol"`
IP string `json:"ip"`
Port int `json:"port"`
Username string `json:"username"`
ConnectionId string `json:"connectionId"`
AssetId string `json:"assetId"`
Creator string `json:"creator"`
ClientIP string `json:"clientIp"`
Width int `json:"width"`
Height int `json:"height"`
Status string `json:"status"`
Recording string `json:"recording"`
ConnectedTime utils.JsonTime `json:"connectedTime"`
DisconnectedTime utils.JsonTime `json:"disconnectedTime"`
AssetName string `json:"assetName"`
CreatorName string `json:"creatorName"`
Code int `json:"code"`
Message string `json:"message"`
Mode string `json:"mode"`
}
func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []SessionVo, total int64, err error) {
db := global.DB
var params []interface{}
params = append(params, status)
itemSql := "SELECT s.id,s.mode, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? "
countSql := "select count(*) from sessions as s where s.status = ? "
if len(userId) > 0 {
itemSql += " and s.creator = ?"
countSql += " and s.creator = ?"
params = append(params, userId)
}
if len(clientIp) > 0 {
itemSql += " and s.client_ip like ?"
countSql += " and s.client_ip like ?"
params = append(params, "%"+clientIp+"%")
}
if len(assetId) > 0 {
itemSql += " and s.asset_id = ?"
countSql += " and s.asset_id = ?"
params = append(params, assetId)
}
if len(protocol) > 0 {
itemSql += " and s.protocol = ?"
countSql += " and s.protocol = ?"
params = append(params, protocol)
}
params = append(params, (pageIndex-1)*pageSize, pageSize)
itemSql += " order by s.connected_time desc LIMIT ?, ?"
db.Raw(countSql, params...).Scan(&total)
err = db.Raw(itemSql, params...).Scan(&results).Error
if results == nil {
results = make([]SessionVo, 0)
}
return
}
func FindSessionByStatus(status string) (o []Session, err error) {
err = global.DB.Where("status = ?", status).Find(&o).Error
return
}
func FindSessionByStatusIn(statuses []string) (o []Session, err error) {
err = global.DB.Where("status in ?", statuses).Find(&o).Error
return
}
func FindOutTimeSessions(dayLimit int) (o []Session, err error) {
limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour)
err = global.DB.Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error
return
}
func CreateNewSession(o *Session) (err error) {
err = global.DB.Create(o).Error
return
}
func FindSessionById(id string) (o Session, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func FindSessionByConnectionId(connectionId string) (o Session, err error) {
err = global.DB.Where("connection_id = ?", connectionId).First(&o).Error
return
}
func UpdateSessionById(o *Session, id string) error {
o.ID = id
return global.DB.Updates(o).Error
}
func UpdateSessionWindowSizeById(width, height int, id string) error {
session := Session{}
session.Width = width
session.Height = height
return UpdateSessionById(&session, id)
}
func DeleteSessionById(id string) error {
return global.DB.Where("id = ?", id).Delete(&Session{}).Error
}
func DeleteSessionByIds(sessionIds []string) error {
drivePath, err := GetRecordingPath()
if err != nil {
return err
}
for i := range sessionIds {
if err := os.RemoveAll(path.Join(drivePath, sessionIds[i])); err != nil {
return err
}
if err := DeleteSessionById(sessionIds[i]); err != nil {
return err
}
}
return nil
}
func DeleteSessionByStatus(status string) {
global.DB.Where("status = ?", status).Delete(&Session{})
}
func CountOnlineSession() (total int64, err error) {
err = global.DB.Where("status = ?", constant.Connected).Find(&Session{}).Count(&total).Error
return
}
type D struct {
Day string `json:"day"`
Count int `json:"count"`
Protocol string `json:"protocol"`
}
func CountSessionByDay(day int) (results []D, err error) {
today := time.Now().Format("20060102")
sql := "select t1.`day`, count(t2.id) as count\nfrom (\n SELECT @date := DATE_ADD(@date, INTERVAL - 1 DAY) day\n FROM (SELECT @date := DATE_ADD('" + today + "', INTERVAL + 1 DAY) FROM nums) as t0\n LIMIT ?\n )\n as t1\n left join\n (\n select DATE(s.connected_time) as day, s.id\n from sessions as s\n WHERE protocol = ? and DATE(connected_time) <= '" + today + "'\n AND DATE(connected_time) > DATE_SUB('" + today + "', INTERVAL ? DAY)\n ) as t2 on t1.day = t2.day\ngroup by t1.day"
protocols := []string{"rdp", "ssh", "vnc", "telnet"}
for i := range protocols {
var result []D
err = global.DB.Raw(sql, day, protocols[i], day).Scan(&result).Error
if err != nil {
return nil, err
}
for j := range result {
result[j].Protocol = protocols[i]
}
results = append(results, result...)
}
return
}

View File

@ -0,0 +1,23 @@
package model
import "next-terminal/server/global"
type UserAttribute struct {
Id string `gorm:"index" json:"id"`
UserId string `gorm:"index" json:"userId"`
Name string `gorm:"index" json:"name"`
Value string `json:"value"`
}
func (r *UserAttribute) TableName() string {
return "user_attributes"
}
func CreateUserAttribute(o *UserAttribute) error {
return global.DB.Create(o).Error
}
func FindUserAttributeByUserId(userId string) (o []UserAttribute, err error) {
err = global.DB.Where("user_id = ?", userId).Find(&o).Error
return o, err
}

View File

@ -0,0 +1,18 @@
package model
import "next-terminal/server/global"
type UserGroupMember struct {
ID string `gorm:"primary_key" json:"name"`
UserId string `gorm:"index" json:"userId"`
UserGroupId string `gorm:"index" json:"userGroupId"`
}
func (r *UserGroupMember) TableName() string {
return "user_group_members"
}
func FindUserGroupMembersByUserGroupId(id string) (o []string, err error) {
err = global.DB.Table("user_group_members").Select("user_id").Where("user_group_id = ?", id).Find(&o).Error
return
}

137
server/model/user-group.go Normal file
View File

@ -0,0 +1,137 @@
package model
import (
"next-terminal/server/global"
"next-terminal/server/utils"
"gorm.io/gorm"
)
type UserGroup struct {
ID string `gorm:"primary_key" json:"id"`
Name string `json:"name"`
Created utils.JsonTime `json:"created"`
}
type UserGroupVo struct {
ID string `json:"id"`
Name string `json:"name"`
Created utils.JsonTime `json:"created"`
AssetCount int64 `json:"assetCount"`
}
func (r *UserGroup) TableName() string {
return "user_groups"
}
func FindPageUserGroup(pageIndex, pageSize int, name, order, field string) (o []UserGroupVo, total int64, err error) {
db := global.DB.Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id")
dbCounter := global.DB.Table("user_groups")
if len(name) > 0 {
db = db.Where("user_groups.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("name like ?", "%"+name+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("user_groups." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]UserGroupVo, 0)
}
return
}
func CreateNewUserGroup(o *UserGroup, members []string) (err error) {
return global.DB.Transaction(func(tx *gorm.DB) error {
err = tx.Create(o).Error
if err != nil {
return err
}
if members != nil {
userGroupId := o.ID
err = AddUserGroupMembers(tx, members, userGroupId)
if err != nil {
return err
}
}
return err
})
}
func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error {
//for i := range userIds {
// userId := userIds[i]
// // TODO
// _, err := FindUserById(userId)
// if err != nil {
// return err
// }
//
// userGroupMember := UserGroupMember{
// ID: utils.Sign([]string{userGroupId, userId}),
// UserId: userId,
// UserGroupId: userGroupId,
// }
// err = tx.Create(&userGroupMember).Error
// if err != nil {
// return err
// }
//}
return nil
}
func FindUserGroupById(id string) (o UserGroup, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func FindUserGroupIdsByUserId(userId string) (o []string, err error) {
// 先查询用户所在的用户
err = global.DB.Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error
return
}
func UpdateUserGroupById(o *UserGroup, members []string, id string) error {
return global.DB.Transaction(func(tx *gorm.DB) error {
o.ID = id
err := tx.Updates(o).Error
if err != nil {
return err
}
err = tx.Where("user_group_id = ?", id).Delete(&UserGroupMember{}).Error
if err != nil {
return err
}
if members != nil {
userGroupId := o.ID
err = AddUserGroupMembers(tx, members, userGroupId)
if err != nil {
return err
}
}
return err
})
}
func DeleteUserGroupById(id string) {
global.DB.Where("id = ?", id).Delete(&UserGroup{})
global.DB.Where("user_group_id = ?", id).Delete(&UserGroupMember{})
}

35
server/model/user.go Normal file
View File

@ -0,0 +1,35 @@
package model
import (
"next-terminal/server/utils"
)
type User struct {
ID string `gorm:"primary_key" json:"id"`
Username string `gorm:"index" json:"username"`
Password string `json:"password"`
Nickname string `json:"nickname"`
TOTPSecret string `json:"-"`
Online bool `json:"online"`
Enabled bool `json:"enabled"`
Created utils.JsonTime `json:"created"`
Type string `json:"type"`
Mail string `json:"mail"`
}
type UserVo struct {
ID string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
TOTPSecret string `json:"totpSecret"`
Mail string `json:"mail"`
Online bool `json:"online"`
Enabled bool `json:"enabled"`
Created utils.JsonTime `json:"created"`
Type string `json:"type"`
SharerAssetCount int64 `json:"sharerAssetCount"`
}
func (r *User) TableName() string {
return "users"
}

123
server/repository/user.go Normal file
View File

@ -0,0 +1,123 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/model"
)
type UserRepository struct {
DB *gorm.DB
}
func (r UserRepository) FindAll() (o []model.User) {
if r.DB.Find(&o).Error != nil {
return nil
}
return
}
func (r UserRepository) Find(pageIndex, pageSize int, username, nickname, mail, order, field string) (o []model.UserVo, total int64, err error) {
db := r.DB.Table("users").Select("users.id,users.username,users.nickname,users.mail,users.online,users.enabled,users.created,users.type, count(resource_sharers.user_id) as sharer_asset_count, users.totp_secret").Joins("left join resource_sharers on users.id = resource_sharers.user_id and resource_sharers.resource_type = 'asset'").Group("users.id")
dbCounter := r.DB.Table("users")
if len(username) > 0 {
db = db.Where("users.username like ?", "%"+username+"%")
dbCounter = dbCounter.Where("username like ?", "%"+username+"%")
}
if len(nickname) > 0 {
db = db.Where("users.nickname like ?", "%"+nickname+"%")
dbCounter = dbCounter.Where("nickname like ?", "%"+nickname+"%")
}
if len(mail) > 0 {
db = db.Where("users.mail like ?", "%"+mail+"%")
dbCounter = dbCounter.Where("mail like ?", "%"+mail+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "username" {
field = "username"
} else if field == "nickname" {
field = "nickname"
} else {
field = "created"
}
err = db.Order("users." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.UserVo, 0)
}
for i := 0; i < len(o); i++ {
if o[i].TOTPSecret == "" || o[i].TOTPSecret == "-" {
o[i].TOTPSecret = "0"
} else {
o[i].TOTPSecret = "1"
}
}
return
}
func (r UserRepository) FindById(id string) (o model.User, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r UserRepository) FindByUsername(username string) (o model.User, err error) {
err = r.DB.Where("username = ?", username).First(&o).Error
return
}
func (r UserRepository) FindOnlineUsers() (o []model.User, err error) {
err = r.DB.Where("online = ?", true).Find(&o).Error
return
}
func (r UserRepository) Create(o *model.User) error {
return r.DB.Create(o).Error
}
func (r UserRepository) Update(o *model.User) error {
return r.DB.Updates(o).Error
}
func (r UserRepository) UpdateOnline(id string, online bool) error {
sql := "update users set online = ? where id = ?"
return r.DB.Exec(sql, online, id).Error
}
func (r UserRepository) DeleteById(id string) error {
return r.DB.Transaction(func(tx *gorm.DB) (err error) {
// 删除用户
err = tx.Where("id = ?", id).Delete(&model.User{}).Error
if err != nil {
return err
}
// 删除用户组中的用户关系
err = tx.Where("user_id = ?", id).Delete(&model.UserGroupMember{}).Error
if err != nil {
return err
}
// 删除用户分享到的资产
err = tx.Where("user_id = ?", id).Delete(&model.ResourceSharer{}).Error
if err != nil {
return err
}
return nil
})
}
func (r UserRepository) CountOnlineUser() (total int64, err error) {
err = r.DB.Where("online = ?", true).Find(&model.User{}).Count(&total).Error
return
}

View File

@ -0,0 +1,107 @@
package term
import (
"io"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
type NextTerminal struct {
SshClient *ssh.Client
SshSession *ssh.Session
StdinPipe io.WriteCloser
SftpClient *sftp.Client
Recorder *Recorder
NextWriter *NextWriter
}
func NewNextTerminal(ip string, port int, username, password, privateKey, passphrase string, rows, cols int, recording string) (*NextTerminal, error) {
sshClient, err := NewSshClient(ip, port, username, password, privateKey, passphrase)
if err != nil {
return nil, err
}
sshSession, err := sshClient.NewSession()
if err != nil {
return nil, err
}
//defer sshSession.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := sshSession.RequestPty("xterm-256color", rows, cols, modes); err != nil {
return nil, err
}
var nextWriter NextWriter
sshSession.Stdout = &nextWriter
sshSession.Stderr = &nextWriter
stdinPipe, err := sshSession.StdinPipe()
if err != nil {
return nil, err
}
if err := sshSession.Shell(); err != nil {
return nil, err
}
var recorder *Recorder
if recording != "" {
recorder, err = CreateRecording(recording, rows, cols)
if err != nil {
return nil, err
}
}
terminal := NextTerminal{
SshClient: sshClient,
SshSession: sshSession,
Recorder: recorder,
StdinPipe: stdinPipe,
NextWriter: &nextWriter,
}
return &terminal, nil
}
func (ret *NextTerminal) Write(p []byte) (int, error) {
return ret.StdinPipe.Write(p)
}
func (ret *NextTerminal) Read() ([]byte, int, error) {
bytes, n, err := ret.NextWriter.Read()
if err != nil {
return nil, 0, err
}
if ret.Recorder != nil && n > 0 {
_ = ret.Recorder.WriteData(string(bytes))
}
return bytes, n, nil
}
func (ret *NextTerminal) Close() error {
if ret.SshSession != nil {
return ret.SshSession.Close()
}
if ret.SshClient != nil {
return ret.SshClient.Close()
}
if ret.Recorder != nil {
return ret.Close()
}
return nil
}
func (ret *NextTerminal) WindowChange(h int, w int) error {
return ret.SshSession.WindowChange(h, w)
}

View File

@ -0,0 +1,30 @@
package term
import (
"bytes"
"sync"
)
type NextWriter struct {
b bytes.Buffer
mu sync.Mutex
}
func (w *NextWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.b.Write(p)
}
func (w *NextWriter) Read() ([]byte, int, error) {
w.mu.Lock()
defer w.mu.Unlock()
p := w.b.Bytes()
buf := make([]byte, len(p))
read, err := w.b.Read(buf)
w.b.Reset()
if err != nil {
return nil, 0, err
}
return buf, read, err
}

123
server/term/recording.go Normal file
View File

@ -0,0 +1,123 @@
package term
import (
"encoding/json"
"os"
"time"
"next-terminal/server/utils"
)
type Env struct {
Shell string `json:"SHELL"`
Term string `json:"TERM"`
}
type Header struct {
Title string `json:"title"`
Version int `json:"version"`
Height int `json:"height"`
Width int `json:"width"`
Env Env `json:"env"`
Timestamp int `json:"Timestamp"`
}
type Recorder struct {
File *os.File
Timestamp int
}
func NewRecorder(recoding string) (recorder *Recorder, err error) {
recorder = &Recorder{}
parentDirectory := utils.GetParentDirectory(recoding)
if utils.FileExists(parentDirectory) {
if err := os.RemoveAll(parentDirectory); err != nil {
return nil, err
}
}
if err = os.MkdirAll(parentDirectory, 0777); err != nil {
return
}
var file *os.File
file, err = os.Create(recoding)
if err != nil {
return nil, err
}
recorder.File = file
return recorder, nil
}
func (recorder *Recorder) Close() {
if recorder.File != nil {
recorder.File.Close()
}
}
func (recorder *Recorder) WriteHeader(header *Header) (err error) {
var p []byte
if p, err = json.Marshal(header); err != nil {
return
}
if _, err := recorder.File.Write(p); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
recorder.Timestamp = header.Timestamp
return
}
func (recorder *Recorder) WriteData(data string) (err error) {
now := int(time.Now().UnixNano())
delta := float64(now-recorder.Timestamp*1000*1000*1000) / 1000 / 1000 / 1000
row := make([]interface{}, 0)
row = append(row, delta)
row = append(row, "o")
row = append(row, data)
var s []byte
if s, err = json.Marshal(row); err != nil {
return
}
if _, err := recorder.File.Write(s); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
return
}
func CreateRecording(recordingPath string, h int, w int) (*Recorder, error) {
recorder, err := NewRecorder(recordingPath)
if err != nil {
return nil, err
}
header := &Header{
Title: "",
Version: 2,
Height: 42,
Width: 150,
Env: Env{Shell: "/bin/bash", Term: "xterm-256color"},
Timestamp: int(time.Now().Unix()),
}
if err := recorder.WriteHeader(header); err != nil {
return nil, err
}
return recorder, nil
}

54
server/term/ssh.go Normal file
View File

@ -0,0 +1,54 @@
package term
import (
"fmt"
"time"
"golang.org/x/crypto/ssh"
)
func NewSshClient(ip string, port int, username, password, privateKey, passphrase string) (*ssh.Client, error) {
var authMethod ssh.AuthMethod
if username == "-" || username == "" {
username = "root"
}
if password == "-" {
password = ""
}
if privateKey == "-" {
privateKey = ""
}
if passphrase == "-" {
passphrase = ""
}
var err error
if privateKey != "" {
var key ssh.Signer
if len(passphrase) > 0 {
key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase))
if err != nil {
return nil, err
}
} else {
key, err = ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, err
}
}
authMethod = ssh.PublicKeys(key)
} else {
authMethod = ssh.Password(password)
}
config := &ssh.ClientConfig{
Timeout: 1 * time.Second,
User: username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", ip, port)
return ssh.Dial("tcp", addr, config)
}

View File

@ -0,0 +1,175 @@
package main
import (
"fmt"
"io"
"os"
"time"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
type SSHTerminal struct {
Session *ssh.Session
exitMsg string
stdout io.Reader
stdin io.Writer
stderr io.Reader
}
func main() {
sshConfig := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{
ssh.Password("root"),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
client, err := ssh.Dial("tcp", "172.16.101.32:22", sshConfig)
if err != nil {
logrus.Error(err)
}
defer client.Close()
err = New(client)
if err != nil {
fmt.Println(err)
}
}
func (t *SSHTerminal) updateTerminalSize() {
go func() {
// SIGWINCH is sent to the process when the window size of the terminal has
// changed.
sigwinchCh := make(chan os.Signal, 1)
//signal.Notify(sigwinchCh, syscall.SIN)
fd := int(os.Stdin.Fd())
termWidth, termHeight, err := terminal.GetSize(fd)
if err != nil {
fmt.Println(err)
}
for {
select {
// The client updated the size of the local PTY. This change needs to occur
// on the server side PTY as well.
case sigwinch := <-sigwinchCh:
if sigwinch == nil {
return
}
currTermWidth, currTermHeight, err := terminal.GetSize(fd)
// Terminal size has not changed, don't do anything.
if currTermHeight == termHeight && currTermWidth == termWidth {
continue
}
err = t.Session.WindowChange(currTermHeight, currTermWidth)
if err != nil {
fmt.Printf("Unable to send window-change reqest: %s.", err)
continue
}
termWidth, termHeight = currTermWidth, currTermHeight
}
}
}()
}
func (t *SSHTerminal) interactiveSession() error {
defer func() {
if t.exitMsg == "" {
logrus.Info(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822))
} else {
logrus.Info(os.Stdout, t.exitMsg)
}
}()
fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd)
if err != nil {
return err
}
defer terminal.Restore(fd, state)
termWidth, termHeight, err := terminal.GetSize(fd)
if err != nil {
return err
}
termType := os.Getenv("TERM")
if termType == "" {
termType = "xterm-256color"
}
err = t.Session.RequestPty(termType, termHeight, termWidth, ssh.TerminalModes{})
if err != nil {
return err
}
t.updateTerminalSize()
t.stdin, err = t.Session.StdinPipe()
if err != nil {
return err
}
t.stdout, err = t.Session.StdoutPipe()
if err != nil {
return err
}
t.stderr, err = t.Session.StderrPipe()
go io.Copy(os.Stderr, t.stderr)
go io.Copy(os.Stdout, t.stdout)
go func() {
buf := make([]byte, 128)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
fmt.Println(err)
return
}
if n > 0 {
_, err = t.stdin.Write(buf[:n])
if err != nil {
fmt.Println(err)
t.exitMsg = err.Error()
return
}
}
}
}()
err = t.Session.Shell()
if err != nil {
return err
}
err = t.Session.Wait()
if err != nil {
return err
}
return nil
}
func New(client *ssh.Client) error {
session, err := client.NewSession()
if err != nil {
return err
}
defer session.Close()
s := SSHTerminal{
Session: session,
}
return s.interactiveSession()
}

19
server/totp/totp.go Normal file
View File

@ -0,0 +1,19 @@
package totp
import (
otp_t "github.com/pquerna/otp"
totp_t "github.com/pquerna/otp/totp"
)
type GenerateOpts totp_t.GenerateOpts
func NewTOTP(opt GenerateOpts) (*otp_t.Key, error) {
return totp_t.Generate(totp_t.GenerateOpts(opt))
}
func Validate(code string, secret string) bool {
if secret == "" {
return true
}
return totp_t.Validate(code, secret)
}

215
server/utils/utils.go Normal file
View File

@ -0,0 +1,215 @@
package utils
import (
"bytes"
"crypto/md5"
"database/sql/driver"
"encoding/base64"
"fmt"
"image"
"image/png"
"net"
"os"
"path/filepath"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
)
type JsonTime struct {
time.Time
}
func NewJsonTime(t time.Time) JsonTime {
return JsonTime{
Time: t,
}
}
func NowJsonTime() JsonTime {
return JsonTime{
Time: time.Now(),
}
}
func (t JsonTime) MarshalJSON() ([]byte, error) {
var stamp = fmt.Sprintf("\"%s\"", t.Format("2006-01-02 15:04:05"))
return []byte(stamp), nil
}
func (t JsonTime) Value() (driver.Value, error) {
var zeroTime time.Time
if t.Time.UnixNano() == zeroTime.UnixNano() {
return nil, nil
}
return t.Time, nil
}
func (t *JsonTime) Scan(v interface{}) error {
value, ok := v.(time.Time)
if ok {
*t = JsonTime{Time: value}
return nil
}
return fmt.Errorf("can not convert %v to timestamp", v)
}
type Bcrypt struct {
cost int
}
func (b *Bcrypt) Encode(password []byte) ([]byte, error) {
return bcrypt.GenerateFromPassword(password, b.cost)
}
func (b *Bcrypt) Match(hashedPassword, password []byte) error {
return bcrypt.CompareHashAndPassword(hashedPassword, password)
}
var Encoder = Bcrypt{
cost: bcrypt.DefaultCost,
}
func UUID() string {
v4, _ := uuid.NewV4()
return v4.String()
}
func Tcping(ip string, port int) bool {
var conn net.Conn
var err error
if conn, err = net.DialTimeout("tcp", ip+":"+strconv.Itoa(port), 2*time.Second); err != nil {
return false
}
defer conn.Close()
return true
}
func ImageToBase64Encode(img image.Image) (string, error) {
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// 判断所给路径文件/文件夹是否存在
func FileExists(path string) bool {
_, err := os.Stat(path) //os.Stat获取文件信息
if err != nil {
return os.IsExist(err)
}
return true
}
// 判断所给路径是否为文件夹
func IsDir(path string) bool {
s, err := os.Stat(path)
if err != nil {
return false
}
return s.IsDir()
}
// 判断所给路径是否为文件
func IsFile(path string) bool {
return !IsDir(path)
}
func GetParentDirectory(directory string) string {
return filepath.Dir(directory)
}
// 去除重复元素
func Distinct(a []string) []string {
result := make([]string, 0, len(a))
temp := map[string]struct{}{}
for _, item := range a {
if _, ok := temp[item]; !ok {
temp[item] = struct{}{}
result = append(result, item)
}
}
return result
}
// 排序+拼接+摘要
func Sign(a []string) string {
sort.Strings(a)
data := []byte(strings.Join(a, ""))
has := md5.Sum(data)
return fmt.Sprintf("%x", has)
}
func Contains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
func StructToMap(obj interface{}) map[string]interface{} {
t := reflect.TypeOf(obj)
v := reflect.ValueOf(obj)
if t.Kind() == reflect.Ptr {
// 如果是指针,则获取其所指向的元素
t = t.Elem()
v = v.Elem()
}
var data = make(map[string]interface{})
if t.Kind() == reflect.Struct {
// 只有结构体可以获取其字段信息
for i := 0; i < t.NumField(); i++ {
jsonName := t.Field(i).Tag.Get("json")
if jsonName != "" {
data[jsonName] = v.Field(i).Interface()
} else {
data[t.Field(i).Name] = v.Field(i).Interface()
}
}
}
return data
}
func IpToInt(ip string) int64 {
if len(ip) == 0 {
return 0
}
bits := strings.Split(ip, ".")
if len(bits) < 4 {
return 0
}
b0 := StringToInt(bits[0])
b1 := StringToInt(bits[1])
b2 := StringToInt(bits[2])
b3 := StringToInt(bits[3])
var sum int64
sum += int64(b0) << 24
sum += int64(b1) << 16
sum += int64(b2) << 8
sum += int64(b3)
return sum
}
func StringToInt(in string) (out int) {
out, _ = strconv.Atoi(in)
return
}
func Check(f func() error) {
if err := f(); err != nil {
logrus.Error("Received error:", err)
}
}