release v1.2.0

This commit is contained in:
dushixiang
2021-10-31 17:15:35 +08:00
parent 4665ab6f78
commit 6132a05786
173 changed files with 37928 additions and 9349 deletions

View File

@ -0,0 +1,119 @@
package api
import (
"strconv"
"strings"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func AccessGatewayCreateEndpoint(c echo.Context) error {
var item model.AccessGateway
if err := c.Bind(&item); err != nil {
return err
}
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := accessGatewayRepository.Create(&item); err != nil {
return err
}
// 连接网关
accessGatewayService.ReConnect(&item)
return Success(c, "")
}
func AccessGatewayAllEndpoint(c echo.Context) error {
gateways, err := accessGatewayRepository.FindAll()
if err != nil {
return err
}
var simpleGateways = make([]model.AccessGatewayForPage, 0)
for i := 0; i < len(gateways); i++ {
simpleGateways = append(simpleGateways, model.AccessGatewayForPage{ID: gateways[i].ID, Name: gateways[i].Name})
}
return Success(c, simpleGateways)
}
func AccessGatewayPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
ip := c.QueryParam("ip")
name := c.QueryParam("name")
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := accessGatewayRepository.Find(pageIndex, pageSize, ip, name, order, field)
if err != nil {
return err
}
for i := 0; i < len(items); i++ {
g, err := accessGatewayService.GetGatewayById(items[i].ID)
if err != nil {
return err
}
items[i].Connected = g.Connected
items[i].Message = g.Message
}
return Success(c, H{
"total": total,
"items": items,
})
}
func AccessGatewayUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
var item model.AccessGateway
if err := c.Bind(&item); err != nil {
return err
}
if err := accessGatewayRepository.UpdateById(&item, id); err != nil {
return err
}
accessGatewayService.DisconnectById(id)
_, _ = accessGatewayService.GetGatewayAndReconnectById(id)
return Success(c, nil)
}
func AccessGatewayDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
split := strings.Split(ids, ",")
for i := range split {
id := split[i]
if err := accessGatewayRepository.DeleteById(id); err != nil {
return err
}
accessGatewayService.DisconnectById(id)
}
return Success(c, nil)
}
func AccessGatewayGetEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := accessGatewayRepository.FindById(id)
if err != nil {
return err
}
return Success(c, item)
}
func AccessGatewayReconnectEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := accessGatewayRepository.FindById(id)
if err != nil {
return err
}
accessGatewayService.ReConnect(&item)
return Success(c, "")
}

View File

@ -1,12 +1,15 @@
package api
import (
"path"
"strconv"
"strings"
"time"
"next-terminal/pkg/global"
"next-terminal/pkg/totp"
"next-terminal/server/config"
"next-terminal/server/global/cache"
"next-terminal/server/model"
"next-terminal/server/totp"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
@ -40,11 +43,6 @@ type Authorization struct {
User model.User
}
//
//type UserServer struct {
// repository.UserRepository
//}
func LoginEndpoint(c echo.Context) error {
var loginAccount LoginAccount
if err := c.Bind(&loginAccount); err != nil {
@ -54,25 +52,33 @@ func LoginEndpoint(c echo.Context) error {
user, err := userRepository.FindByUsername(loginAccount.Username)
// 存储登录失败次数信息
loginFailCountKey := loginAccount.Username
v, ok := global.Cache.Get(loginFailCountKey)
loginFailCountKey := c.RealIP() + loginAccount.Username
v, ok := cache.GlobalCache.Get(loginFailCountKey)
if !ok {
v = 1
}
count := v.(int)
if count >= 5 {
return Fail(c, -1, "登录失败次数过多,请后再试")
return Fail(c, -1, "登录失败次数过多,请等待5分钟后再试")
}
if err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
// 保存登录日志
if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil {
return err
}
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
if err := utils.Encoder.Match([]byte(user.Password), []byte(loginAccount.Password)); err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
// 保存登录日志
if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil {
return err
}
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
@ -80,15 +86,42 @@ func LoginEndpoint(c echo.Context) error {
return Fail(c, 0, "")
}
token, err := LoginSuccess(c, loginAccount, user)
token, err := LoginSuccess(loginAccount, user)
if err != nil {
return err
}
// 保存登录日志
if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, true, loginAccount.Remember, token, ""); err != nil {
return err
}
return Success(c, token)
}
func LoginSuccess(c echo.Context, loginAccount LoginAccount, user model.User) (token string, err error) {
func SaveLoginLog(clientIP, clientUserAgent string, username string, success, remember bool, id, reason string) error {
loginLog := model.LoginLog{
Username: username,
ClientIP: clientIP,
ClientUserAgent: clientUserAgent,
LoginTime: utils.NowJsonTime(),
Reason: reason,
Remember: remember,
}
if success {
loginLog.State = "1"
loginLog.ID = id
} else {
loginLog.State = "0"
loginLog.ID = utils.UUID()
}
if err := loginLogRepository.Create(&loginLog); err != nil {
return err
}
return nil
}
func LoginSuccess(loginAccount LoginAccount, user model.User) (token string, err error) {
token = strings.Join([]string{utils.UUID(), utils.UUID(), utils.UUID(), utils.UUID()}, "")
authorization := Authorization{
@ -97,45 +130,20 @@ func LoginSuccess(c echo.Context, loginAccount LoginAccount, user model.User) (t
User: user,
}
cacheKey := BuildCacheKeyByToken(token)
cacheKey := userService.BuildCacheKeyByToken(token)
if authorization.Remember {
// 记住登录有效期两周
global.Cache.Set(cacheKey, authorization, RememberEffectiveTime)
cache.GlobalCache.Set(cacheKey, authorization, RememberEffectiveTime)
} else {
global.Cache.Set(cacheKey, authorization, NotRememberEffectiveTime)
}
// 保存登录日志
loginLog := model.LoginLog{
ID: token,
UserId: user.ID,
ClientIP: c.RealIP(),
ClientUserAgent: c.Request().UserAgent(),
LoginTime: utils.NowJsonTime(),
Remember: authorization.Remember,
}
if loginLogRepository.Create(&loginLog) != nil {
return "", err
cache.GlobalCache.Set(cacheKey, authorization, NotRememberEffectiveTime)
}
// 修改登录状态
err = userRepository.Update(&model.User{Online: true, ID: user.ID})
return token, err
}
func BuildCacheKeyByToken(token string) string {
cacheKey := strings.Join([]string{Token, token}, ":")
return cacheKey
}
func GetTokenFormCacheKey(cacheKey string) string {
token := strings.Split(cacheKey, ":")[1]
return token
}
func loginWithTotpEndpoint(c echo.Context) error {
var loginAccount LoginAccount
if err := c.Bind(&loginAccount); err != nil {
@ -143,47 +151,63 @@ func loginWithTotpEndpoint(c echo.Context) error {
}
// 存储登录失败次数信息
loginFailCountKey := loginAccount.Username
v, ok := global.Cache.Get(loginFailCountKey)
loginFailCountKey := c.RealIP() + loginAccount.Username
v, ok := cache.GlobalCache.Get(loginFailCountKey)
if !ok {
v = 1
}
count := v.(int)
if count >= 5 {
return Fail(c, -1, "登录失败次数过多,请后再试")
return Fail(c, -1, "登录失败次数过多,请等待5分钟后再试")
}
user, err := userRepository.FindByUsername(loginAccount.Username)
if err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
// 保存登录日志
if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil {
return err
}
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
if err := utils.Encoder.Match([]byte(user.Password), []byte(loginAccount.Password)); err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
// 保存登录日志
if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil {
return err
}
return FailWithData(c, -1, "您输入的账号或密码不正确", count)
}
if !totp.Validate(loginAccount.TOTP, user.TOTPSecret) {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
// 保存登录日志
if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "双因素认证授权码不正确"); err != nil {
return err
}
return FailWithData(c, -1, "您输入双因素认证授权码不正确", count)
}
token, err := LoginSuccess(c, loginAccount, user)
token, err := LoginSuccess(loginAccount, user)
if err != nil {
return err
}
// 保存登录日志
if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, true, loginAccount.Remember, token, ""); err != nil {
return err
}
return Success(c, token)
}
func LogoutEndpoint(c echo.Context) error {
token := GetToken(c)
cacheKey := BuildCacheKeyByToken(token)
global.Cache.Delete(cacheKey)
cacheKey := userService.BuildCacheKeyByToken(token)
cache.GlobalCache.Delete(cacheKey)
err := userService.Logout(token)
if err != nil {
return err
@ -192,7 +216,7 @@ func LogoutEndpoint(c echo.Context) error {
}
func ConfirmTOTPEndpoint(c echo.Context) error {
if global.Config.Demo {
if config.GlobalCfg.Demo {
return Fail(c, 0, "演示模式禁止开启两步验证")
}
account, _ := GetCurrentAccount(c)
@ -258,7 +282,7 @@ func ResetTOTPEndpoint(c echo.Context) error {
}
func ChangePasswordEndpoint(c echo.Context) error {
if global.Config.Demo {
if config.GlobalCfg.Demo {
return Fail(c, 0, "演示模式禁止修改密码")
}
account, _ := GetCurrentAccount(c)
@ -313,3 +337,48 @@ func InfoEndpoint(c echo.Context) error {
}
return Success(c, info)
}
func AccountAssetEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
protocol := c.QueryParam("protocol")
tags := c.QueryParam("tags")
owner := c.QueryParam("owner")
sharer := c.QueryParam("sharer")
userGroupId := c.QueryParam("userGroupId")
ip := c.QueryParam("ip")
order := c.QueryParam("order")
field := c.QueryParam("field")
account, _ := GetCurrentAccount(c)
items, total, err := assetRepository.Find(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func AccountStorageEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
storageId := account.ID
storage, err := storageRepository.FindById(storageId)
if err != nil {
return err
}
structMap := utils.StructToMap(storage)
drivePath := storageService.GetBaseDrivePath()
dirSize, err := utils.DirSize(path.Join(drivePath, storageId))
if err != nil {
structMap["usedSize"] = -1
} else {
structMap["usedSize"] = dirSize
}
return Success(c, structMap)
}

View File

@ -1,8 +1,8 @@
package api
import (
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/server/constant"
"next-terminal/server/global/cache"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
@ -41,17 +41,17 @@ func NotFound(c echo.Context, message string) error {
}
func GetToken(c echo.Context) string {
token := c.Request().Header.Get(Token)
token := c.Request().Header.Get(constant.Token)
if len(token) > 0 {
return token
}
return c.QueryParam(Token)
return c.QueryParam(constant.Token)
}
func GetCurrentAccount(c echo.Context) (model.User, bool) {
token := GetToken(c)
cacheKey := BuildCacheKeyByToken(token)
get, b := global.Cache.Get(cacheKey)
cacheKey := userService.BuildCacheKeyByToken(token)
get, b := cache.GlobalCache.Get(cacheKey)
if b {
return get.(Authorization).User, true
}

View File

@ -8,8 +8,8 @@ import (
"strconv"
"strings"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
@ -32,6 +32,7 @@ func AssetCreateEndpoint(c echo.Context) error {
item.Owner = account.ID
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
item.Active = true
if err := assetRepository.Create(&item); err != nil {
return err
@ -41,10 +42,12 @@ func AssetCreateEndpoint(c echo.Context) error {
return err
}
// 创建后自动检测资产是否存活
go func() {
active := utils.Tcping(item.IP, item.Port)
_ = assetRepository.UpdateActiveById(active, item.ID)
active, _ := assetService.CheckStatus(item.AccessGatewayId, item.IP, item.Port)
if item.Active != active {
_ = assetRepository.UpdateActiveById(active, item.ID)
}
}()
return Success(c, item)
@ -74,7 +77,6 @@ func AssetImportEndpoint(c echo.Context) error {
if total == 0 {
return errors.New("csv数据为空")
}
var successCount = 0
var errorCount = 0
m := echo.Map{}
@ -97,6 +99,7 @@ func AssetImportEndpoint(c echo.Context) error {
Description: record[8],
Created: utils.NowJsonTime(),
Owner: account.ID,
Active: true,
}
if len(record) >= 10 {
@ -110,11 +113,6 @@ func AssetImportEndpoint(c echo.Context) error {
m[strconv.Itoa(i)] = err.Error()
} else {
successCount++
// 创建后自动检测资产是否存活
go func() {
active := utils.Tcping(asset.IP, asset.Port)
_ = assetRepository.UpdateActiveById(active, asset.ID)
}()
}
}
}
@ -141,6 +139,7 @@ func AssetPagingEndpoint(c echo.Context) error {
field := c.QueryParam("field")
account, _ := GetCurrentAccount(c)
items, total, err := assetRepository.Find(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field)
if err != nil {
return err
@ -154,8 +153,7 @@ func AssetPagingEndpoint(c echo.Context) error {
func AssetAllEndpoint(c echo.Context) error {
protocol := c.QueryParam("protocol")
account, _ := GetCurrentAccount(c)
items, _ := assetRepository.FindByProtocolAndUser(protocol, account)
items, _ := assetRepository.FindByProtocol(protocol)
return Success(c, items)
}
@ -205,7 +203,7 @@ func AssetUpdateEndpoint(c echo.Context) error {
item.Description = "-"
}
if err := assetRepository.Encrypt(&item, global.Config.EncryptionPassword); err != nil {
if err := assetRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err := assetRepository.UpdateById(&item, id); err != nil {
@ -267,7 +265,7 @@ func AssetTcpingEndpoint(c echo.Context) (err error) {
return err
}
active := utils.Tcping(item.IP, item.Port)
active, err := assetService.CheckStatus(item.AccessGatewayId, item.IP, item.Port)
if item.Active != active {
if err := assetRepository.UpdateActiveById(active, item.ID); err != nil {
@ -275,7 +273,15 @@ func AssetTcpingEndpoint(c echo.Context) (err error) {
}
}
return Success(c, active)
var message = ""
if err != nil {
message = err.Error()
}
return Success(c, H{
"active": active,
"message": message,
})
}
func AssetTagsEndpoint(c echo.Context) (err error) {

View File

@ -29,6 +29,12 @@ func CommandCreateEndpoint(c echo.Context) error {
return Success(c, item)
}
func CommandAllEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
items, _ := commandRepository.FindByUser(account)
return Success(c, items)
}
func CommandPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))

View File

@ -6,8 +6,8 @@ import (
"strconv"
"strings"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
@ -106,7 +106,7 @@ func CredentialUpdateEndpoint(c echo.Context) error {
item.Password = "-"
}
if item.Password != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), global.Config.EncryptionPassword)
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
@ -121,7 +121,7 @@ func CredentialUpdateEndpoint(c echo.Context) error {
item.PrivateKey = "-"
}
if item.PrivateKey != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), global.Config.EncryptionPassword)
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
@ -131,7 +131,7 @@ func CredentialUpdateEndpoint(c echo.Context) error {
item.Passphrase = "-"
}
if item.Passphrase != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), global.Config.EncryptionPassword)
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}

View File

@ -53,7 +53,7 @@ func JobUpdateEndpoint(c echo.Context) error {
return err
}
item.ID = id
if err := jobRepository.UpdateById(&item); err != nil {
if err := jobService.UpdateById(&item); err != nil {
return err
}
@ -83,7 +83,7 @@ func JobDeleteEndpoint(c echo.Context) error {
split := strings.Split(ids, ",")
for i := range split {
jobId := split[i]
if err := jobRepository.DeleteJobById(jobId); err != nil {
if err := jobService.DeleteJobById(jobId); err != nil {
return err
}
}

View File

@ -4,8 +4,8 @@ import (
"strconv"
"strings"
"next-terminal/pkg/global"
"next-terminal/pkg/log"
"next-terminal/server/global/cache"
"next-terminal/server/log"
"github.com/labstack/echo/v4"
)
@ -13,10 +13,11 @@ import (
func LoginLogPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
userId := c.QueryParam("userId")
username := c.QueryParam("username")
clientIp := c.QueryParam("clientIp")
state := c.QueryParam("state")
items, total, err := loginLogRepository.Find(pageIndex, pageSize, userId, clientIp)
items, total, err := loginLogRepository.Find(pageIndex, pageSize, username, clientIp, state)
if err != nil {
return err
@ -33,7 +34,7 @@ func LoginLogDeleteEndpoint(c echo.Context) error {
split := strings.Split(ids, ",")
for i := range split {
token := split[i]
global.Cache.Delete(token)
cache.GlobalCache.Delete(token)
if err := userService.Logout(token); err != nil {
log.WithError(err).Error("Cache Delete Failed")
}
@ -44,3 +45,10 @@ func LoginLogDeleteEndpoint(c echo.Context) error {
return Success(c, nil)
}
//func LoginLogClearEndpoint(c echo.Context) error {
// loginLogs, err := loginLogRepository.FindAliveLoginLogs()
// if err != nil {
// return err
// }
//}

View File

@ -3,12 +3,12 @@ package api
import (
"fmt"
"net"
"regexp"
"strings"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/server/constant"
"next-terminal/server/global/cache"
"next-terminal/server/global/security"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
@ -33,27 +33,26 @@ func ErrorHandler(next echo.HandlerFunc) echo.HandlerFunc {
func TcpWall(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if global.Securities == nil {
securities := security.GlobalSecurityManager.Values()
if len(securities) == 0 {
return next(c)
}
ip := c.RealIP()
for i := 0; i < len(global.Securities); i++ {
security := global.Securities[i]
if strings.Contains(security.IP, "/") {
for _, s := range securities {
if strings.Contains(s.IP, "/") {
// CIDR
_, ipNet, err := net.ParseCIDR(security.IP)
_, ipNet, err := net.ParseCIDR(s.IP)
if err != nil {
continue
}
if !ipNet.Contains(net.ParseIP(ip)) {
continue
}
} else if strings.Contains(security.IP, "-") {
} else if strings.Contains(s.IP, "-") {
// 范围段
split := strings.Split(security.IP, "-")
split := strings.Split(s.IP, "-")
if len(split) < 2 {
continue
}
@ -65,16 +64,16 @@ func TcpWall(next echo.HandlerFunc) echo.HandlerFunc {
}
} else {
// IP
if security.IP != ip {
if s.IP != ip {
continue
}
}
if security.Rule == constant.AccessRuleAllow {
if s.Rule == constant.AccessRuleAllow {
return next(c)
}
if security.Rule == constant.AccessRuleReject {
if c.Request().Header.Get("X-Requested-With") != "" || c.Request().Header.Get(Token) != "" {
if s.Rule == constant.AccessRuleReject {
if c.Request().Header.Get("X-Requested-With") != "" || c.Request().Header.Get(constant.Token) != "" {
return Fail(c, 0, "您的访问请求被拒绝 :(")
} else {
return c.HTML(666, "您的访问请求被拒绝 :(")
@ -88,10 +87,7 @@ func TcpWall(next echo.HandlerFunc) echo.HandlerFunc {
func Auth(next echo.HandlerFunc) echo.HandlerFunc {
startWithUrls := []string{"/login", "/static", "/favicon.ico", "/logo.svg", "/asciinema"}
download := regexp.MustCompile(`^/sessions/\w{8}(-\w{4}){3}-\w{12}/download`)
recording := regexp.MustCompile(`^/sessions/\w{8}(-\w{4}){3}-\w{12}/recording`)
anonymousUrls := []string{"/login", "/static", "/favicon.ico", "/logo.svg", "/asciinema"}
return func(c echo.Context) error {
@ -100,32 +96,27 @@ func Auth(next echo.HandlerFunc) echo.HandlerFunc {
return next(c)
}
// 路由拦截 - 登录身份、资源权限判断等
for i := range startWithUrls {
if strings.HasPrefix(uri, startWithUrls[i]) {
for i := range anonymousUrls {
if strings.HasPrefix(uri, anonymousUrls[i]) {
return next(c)
}
}
if download.FindString(uri) != "" {
return next(c)
}
if recording.FindString(uri) != "" {
return next(c)
}
token := GetToken(c)
cacheKey := BuildCacheKeyByToken(token)
authorization, found := global.Cache.Get(cacheKey)
if token == "" {
return Fail(c, 401, "您的登录信息已失效,请重新登录后再试。")
}
cacheKey := userService.BuildCacheKeyByToken(token)
authorization, found := cache.GlobalCache.Get(cacheKey)
if !found {
return Fail(c, 401, "您的登录信息已失效,请重新登录后再试。")
}
if authorization.(Authorization).Remember {
// 记住登录有效期两周
global.Cache.Set(cacheKey, authorization, time.Hour*time.Duration(24*14))
cache.GlobalCache.Set(cacheKey, authorization, time.Hour*time.Duration(24*14))
} else {
global.Cache.Set(cacheKey, authorization, time.Hour*time.Duration(2))
cache.GlobalCache.Set(cacheKey, authorization, time.Hour*time.Duration(2))
}
return next(c)

View File

@ -1,17 +0,0 @@
package api
import (
"github.com/labstack/echo/v4"
)
// todo 监控
func MonitorEndpoint(c echo.Context) (err error) {
//ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil)
//if err != nil {
// log.Errorf("升级为WebSocket协议失败%v", err.Error())
// return err
//}
return
}

View File

@ -1,8 +1,7 @@
package api
import (
"next-terminal/pkg/constant"
"next-terminal/server/repository"
"next-terminal/server/constant"
"github.com/labstack/echo/v4"
)
@ -44,16 +43,43 @@ func OverviewCounterEndPoint(c echo.Context) error {
return Success(c, counter)
}
func OverviewSessionPoint(c echo.Context) (err error) {
d := c.QueryParam("d")
var results []repository.D
if d == "m" {
results, err = sessionRepository.CountSessionByDay(30)
func OverviewAssetEndPoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
var (
ssh int64
rdp int64
vnc int64
telnet int64
kubernetes int64
)
if constant.TypeUser == account.Type {
ssh, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.SSH)
rdp, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.RDP)
vnc, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.VNC)
telnet, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.Telnet)
kubernetes, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.K8s)
} else {
results, err = sessionRepository.CountSessionByDay(7)
ssh, _ = assetRepository.CountByProtocol(constant.SSH)
rdp, _ = assetRepository.CountByProtocol(constant.RDP)
vnc, _ = assetRepository.CountByProtocol(constant.VNC)
telnet, _ = assetRepository.CountByProtocol(constant.Telnet)
kubernetes, _ = assetRepository.CountByProtocol(constant.K8s)
}
m := echo.Map{
"ssh": ssh,
"rdp": rdp,
"vnc": vnc,
"telnet": telnet,
"kubernetes": kubernetes,
}
return Success(c, m)
}
func OverviewAccessEndPoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
access, err := sessionRepository.OverviewAccess(account)
if err != nil {
return err
}
return Success(c, results)
return Success(c, access)
}

View File

@ -7,6 +7,7 @@ import (
type RU struct {
UserGroupId string `json:"userGroupId"`
UserId string `json:"userId"`
StrategyId string `json:"strategyId"`
ResourceType string `json:"resourceType"`
ResourceIds []string `json:"resourceIds"`
}
@ -19,26 +20,16 @@ type UR struct {
func RSGetSharersEndPoint(c echo.Context) error {
resourceId := c.QueryParam("resourceId")
userIds, err := resourceSharerRepository.FindUserIdsByResourceId(resourceId)
resourceType := c.QueryParam("resourceType")
userId := c.QueryParam("userId")
userGroupId := c.QueryParam("userGroupId")
userIds, err := resourceSharerRepository.Find(resourceId, resourceType, userId, userGroupId)
if err != nil {
return err
}
return Success(c, userIds)
}
func RSOverwriteSharersEndPoint(c echo.Context) error {
var ur UR
if err := c.Bind(&ur); err != nil {
return err
}
if err := resourceSharerRepository.OverwriteUserIdsByResourceId(ur.ResourceId, ur.ResourceType, ur.UserIds); err != nil {
return err
}
return Success(c, "")
}
func ResourceRemoveByUserIdAssignEndPoint(c echo.Context) error {
var ru RU
if err := c.Bind(&ru); err != nil {
@ -58,7 +49,7 @@ func ResourceAddByUserIdAssignEndPoint(c echo.Context) error {
return err
}
if err := resourceSharerRepository.AddSharerResources(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil {
if err := resourceSharerRepository.AddSharerResources(ru.UserGroupId, ru.UserId, ru.StrategyId, ru.ResourceType, ru.ResourceIds); err != nil {
return err
}

View File

@ -5,27 +5,23 @@ import (
"fmt"
"net/http"
"os"
"strings"
"time"
"next-terminal/pkg/global"
"next-terminal/pkg/log"
"next-terminal/pkg/service"
"next-terminal/server/config"
"next-terminal/server/global/cache"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/service"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/patrickmn/go-cache"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const Token = "X-Auth-Token"
var (
userRepository *repository.UserRepository
userGroupRepository *repository.UserGroupRepository
@ -35,20 +31,23 @@ var (
propertyRepository *repository.PropertyRepository
commandRepository *repository.CommandRepository
sessionRepository *repository.SessionRepository
numRepository *repository.NumRepository
accessSecurityRepository *repository.AccessSecurityRepository
accessGatewayRepository *repository.AccessGatewayRepository
jobRepository *repository.JobRepository
jobLogRepository *repository.JobLogRepository
loginLogRepository *repository.LoginLogRepository
storageRepository *repository.StorageRepository
strategyRepository *repository.StrategyRepository
jobService *service.JobService
propertyService *service.PropertyService
userService *service.UserService
sessionService *service.SessionService
mailService *service.MailService
numService *service.NumService
assetService *service.AssetService
credentialService *service.CredentialService
jobService *service.JobService
propertyService *service.PropertyService
userService *service.UserService
sessionService *service.SessionService
mailService *service.MailService
assetService *service.AssetService
credentialService *service.CredentialService
storageService *service.StorageService
accessGatewayService *service.AccessGatewayService
)
func SetupRoutes(db *gorm.DB) *echo.Echo {
@ -56,8 +55,10 @@ func SetupRoutes(db *gorm.DB) *echo.Echo {
InitRepository(db)
InitService()
cache.GlobalCache.OnEvicted(userService.OnEvicted)
if err := InitDBData(); err != nil {
log.WithError(err).Error("初始化数据异常")
log.Errorf("初始化数据异常: %v", err.Error())
os.Exit(0)
}
@ -68,13 +69,10 @@ func SetupRoutes(db *gorm.DB) *echo.Echo {
e := echo.New()
e.HideBanner = true
//e.Logger = log.GetEchoLogger()
e.Use(log.Hook())
//e.Use(log.Hook())
e.File("/", "web/build/index.html")
e.File("/asciinema.html", "web/build/asciinema.html")
e.File("/asciinema-player.js", "web/build/asciinema-player.js")
e.File("/asciinema-player.css", "web/build/asciinema-player.css")
e.File("/", "web/build/index.html")
e.File("/logo.svg", "web/build/logo.svg")
e.File("/favicon.ico", "web/build/favicon.ico")
e.Static("/static", "web/build/static")
@ -93,7 +91,7 @@ func SetupRoutes(db *gorm.DB) *echo.Echo {
e.GET("/tunnel", TunEndpoint)
e.GET("/ssh", SSHEndpoint)
e.GET("/ssh-monitor", SshMonitor)
e.POST("/logout", LogoutEndpoint)
e.POST("/change-password", ChangePasswordEndpoint)
e.GET("/reload-totp", ReloadTOTPEndpoint)
@ -101,15 +99,21 @@ func SetupRoutes(db *gorm.DB) *echo.Echo {
e.POST("/confirm-totp", ConfirmTOTPEndpoint)
e.GET("/info", InfoEndpoint)
users := e.Group("/users")
account := e.Group("/account")
{
users.POST("", Admin(UserCreateEndpoint))
account.GET("/assets", AccountAssetEndpoint)
account.GET("/storage", AccountStorageEndpoint)
}
users := e.Group("/users", Admin)
{
users.POST("", UserCreateEndpoint)
users.GET("/paging", UserPagingEndpoint)
users.PUT("/:id", Admin(UserUpdateEndpoint))
users.DELETE("/:id", Admin(UserDeleteEndpoint))
users.GET("/:id", Admin(UserGetEndpoint))
users.POST("/:id/change-password", Admin(UserChangePasswordEndpoint))
users.POST("/:id/reset-totp", Admin(UserResetTotpEndpoint))
users.PUT("/:id", UserUpdateEndpoint)
users.DELETE("/:id", UserDeleteEndpoint)
users.GET("/:id", UserGetEndpoint)
users.POST("/:id/change-password", UserChangePasswordEndpoint)
users.POST("/:id/reset-totp", UserResetTotpEndpoint)
}
userGroups := e.Group("/user-groups", Admin)
@ -119,36 +123,35 @@ func SetupRoutes(db *gorm.DB) *echo.Echo {
userGroups.PUT("/:id", UserGroupUpdateEndpoint)
userGroups.DELETE("/:id", UserGroupDeleteEndpoint)
userGroups.GET("/:id", UserGroupGetEndpoint)
//userGroups.POST("/:id/members", UserGroupAddMembersEndpoint)
//userGroups.DELETE("/:id/members/:memberId", UserGroupDelMembersEndpoint)
}
assets := e.Group("/assets")
assets := e.Group("/assets", Admin)
{
assets.GET("", AssetAllEndpoint)
assets.POST("", AssetCreateEndpoint)
assets.POST("/import", Admin(AssetImportEndpoint))
assets.POST("/import", AssetImportEndpoint)
assets.GET("/paging", AssetPagingEndpoint)
assets.POST("/:id/tcping", AssetTcpingEndpoint)
assets.PUT("/:id", AssetUpdateEndpoint)
assets.DELETE("/:id", AssetDeleteEndpoint)
assets.GET("/:id", AssetGetEndpoint)
assets.POST("/:id/change-owner", Admin(AssetChangeOwnerEndpoint))
assets.DELETE("/:id", AssetDeleteEndpoint)
assets.POST("/:id/change-owner", AssetChangeOwnerEndpoint)
}
e.GET("/tags", AssetTagsEndpoint)
commands := e.Group("/commands")
{
commands.GET("", CommandAllEndpoint)
commands.GET("/paging", CommandPagingEndpoint)
commands.POST("", CommandCreateEndpoint)
commands.PUT("/:id", CommandUpdateEndpoint)
commands.DELETE("/:id", CommandDeleteEndpoint)
commands.GET("/:id", CommandGetEndpoint)
commands.POST("/:id/change-owner", Admin(CommandChangeOwnerEndpoint))
commands.POST("/:id/change-owner", CommandChangeOwnerEndpoint, Admin)
}
credentials := e.Group("/credentials")
credentials := e.Group("/credentials", Admin)
{
credentials.GET("", CredentialAllEndpoint)
credentials.GET("/paging", CredentialPagingEndpoint)
@ -156,45 +159,54 @@ func SetupRoutes(db *gorm.DB) *echo.Echo {
credentials.PUT("/:id", CredentialUpdateEndpoint)
credentials.DELETE("/:id", CredentialDeleteEndpoint)
credentials.GET("/:id", CredentialGetEndpoint)
credentials.POST("/:id/change-owner", Admin(CredentialChangeOwnerEndpoint))
credentials.POST("/:id/change-owner", CredentialChangeOwnerEndpoint)
}
sessions := e.Group("/sessions")
{
sessions.POST("", SessionCreateEndpoint)
sessions.GET("/paging", Admin(SessionPagingEndpoint))
sessions.POST("/:id/connect", SessionConnectEndpoint)
sessions.POST("/:id/disconnect", Admin(SessionDisconnectEndpoint))
sessions.DELETE("/:id", Admin(SessionDeleteEndpoint))
sessions.GET("/:id/recording", Admin(SessionRecordingEndpoint))
sessions.GET("/:id", Admin(SessionGetEndpoint))
sessions.POST("", SessionCreateEndpoint)
sessions.POST("/:id/connect", SessionConnectEndpoint)
sessions.POST("/:id/resize", SessionResizeEndpoint)
sessions.GET("/:id/ls", SessionLsEndpoint)
sessions.GET("/:id/stats", SessionStatsEndpoint)
sessions.POST("/:id/ls", SessionLsEndpoint)
sessions.GET("/:id/download", SessionDownloadEndpoint)
sessions.POST("/:id/upload", SessionUploadEndpoint)
sessions.POST("/:id/edit", SessionEditEndpoint)
sessions.POST("/:id/mkdir", SessionMkDirEndpoint)
sessions.POST("/:id/rm", SessionRmEndpoint)
sessions.POST("/:id/rename", SessionRenameEndpoint)
sessions.DELETE("/:id", Admin(SessionDeleteEndpoint))
sessions.GET("/:id/recording", SessionRecordingEndpoint)
}
resourceSharers := e.Group("/resource-sharers")
resourceSharers := e.Group("/resource-sharers", Admin)
{
resourceSharers.GET("/sharers", RSGetSharersEndPoint)
resourceSharers.POST("/overwrite-sharers", RSOverwriteSharersEndPoint)
resourceSharers.POST("/remove-resources", Admin(ResourceRemoveByUserIdAssignEndPoint))
resourceSharers.POST("/add-resources", Admin(ResourceAddByUserIdAssignEndPoint))
resourceSharers.GET("", RSGetSharersEndPoint)
resourceSharers.POST("/remove-resources", ResourceRemoveByUserIdAssignEndPoint)
resourceSharers.POST("/add-resources", ResourceAddByUserIdAssignEndPoint)
}
loginLogs := e.Group("login-logs", Admin)
{
loginLogs.GET("/paging", LoginLogPagingEndpoint)
loginLogs.DELETE("/:id", LoginLogDeleteEndpoint)
//loginLogs.DELETE("/clear", LoginLogClearEndpoint)
}
e.GET("/properties", Admin(PropertyGetEndpoint))
e.PUT("/properties", Admin(PropertyUpdateEndpoint))
e.GET("/overview/counter", OverviewCounterEndPoint)
e.GET("/overview/sessions", OverviewSessionPoint)
overview := e.Group("overview", Admin)
{
overview.GET("/counter", OverviewCounterEndPoint)
overview.GET("/asset", OverviewAssetEndPoint)
overview.GET("/access", OverviewAccessEndPoint)
}
jobs := e.Group("/jobs", Admin)
{
@ -218,6 +230,44 @@ func SetupRoutes(db *gorm.DB) *echo.Echo {
securities.GET("/:id", SecurityGetEndpoint)
}
storages := e.Group("/storages")
{
storages.GET("/paging", StoragePagingEndpoint, Admin)
storages.POST("", StorageCreateEndpoint, Admin)
storages.DELETE("/:id", StorageDeleteEndpoint, Admin)
storages.PUT("/:id", StorageUpdateEndpoint, Admin)
storages.GET("/shares", StorageSharesEndpoint, Admin)
storages.GET("/:id", StorageGetEndpoint, Admin)
storages.POST("/:storageId/ls", StorageLsEndpoint)
storages.GET("/:storageId/download", StorageDownloadEndpoint)
storages.POST("/:storageId/upload", StorageUploadEndpoint)
storages.POST("/:storageId/mkdir", StorageMkDirEndpoint)
storages.POST("/:storageId/rm", StorageRmEndpoint)
storages.POST("/:storageId/rename", StorageRenameEndpoint)
storages.POST("/:storageId/edit", StorageEditEndpoint)
}
strategies := e.Group("/strategies", Admin)
{
strategies.GET("", StrategyAllEndpoint)
strategies.GET("/paging", StrategyPagingEndpoint)
strategies.POST("", StrategyCreateEndpoint)
strategies.DELETE("/:id", StrategyDeleteEndpoint)
strategies.PUT("/:id", StrategyUpdateEndpoint)
}
accessGateways := e.Group("/access-gateways", Admin)
{
accessGateways.GET("", AccessGatewayAllEndpoint)
accessGateways.POST("", AccessGatewayCreateEndpoint)
accessGateways.GET("/paging", AccessGatewayPagingEndpoint)
accessGateways.PUT("/:id", AccessGatewayUpdateEndpoint)
accessGateways.DELETE("/:id", AccessGatewayDeleteEndpoint)
accessGateways.GET("/:id", AccessGatewayGetEndpoint)
accessGateways.POST("/:id/reconnect", AccessGatewayReconnectEndpoint)
}
return e
}
@ -241,29 +291,32 @@ func InitRepository(db *gorm.DB) {
propertyRepository = repository.NewPropertyRepository(db)
commandRepository = repository.NewCommandRepository(db)
sessionRepository = repository.NewSessionRepository(db)
numRepository = repository.NewNumRepository(db)
accessSecurityRepository = repository.NewAccessSecurityRepository(db)
accessGatewayRepository = repository.NewAccessGatewayRepository(db)
jobRepository = repository.NewJobRepository(db)
jobLogRepository = repository.NewJobLogRepository(db)
loginLogRepository = repository.NewLoginLogRepository(db)
storageRepository = repository.NewStorageRepository(db)
strategyRepository = repository.NewStrategyRepository(db)
}
func InitService() {
jobService = service.NewJobService(jobRepository, jobLogRepository, assetRepository, credentialRepository)
propertyService = service.NewPropertyService(propertyRepository)
userService = service.NewUserService(userRepository, loginLogRepository)
sessionService = service.NewSessionService(sessionRepository)
mailService = service.NewMailService(propertyRepository)
numService = service.NewNumService(numRepository)
assetService = service.NewAssetService(assetRepository)
jobService = service.NewJobService(jobRepository, jobLogRepository, assetRepository, credentialRepository, assetService)
credentialService = service.NewCredentialService(credentialRepository)
storageService = service.NewStorageService(storageRepository, userRepository, propertyRepository)
accessGatewayService = service.NewAccessGatewayService(accessGatewayRepository)
}
func InitDBData() (err error) {
if err := propertyService.InitProperties(); err != nil {
if err := propertyService.DeleteDeprecatedProperty(); err != nil {
return err
}
if err := numService.InitNums(); err != nil {
if err := propertyService.InitProperties(); err != nil {
return err
}
if err := userService.InitUser(); err != nil {
@ -287,6 +340,12 @@ func InitDBData() (err error) {
if err := assetService.Encrypt(); err != nil {
return err
}
if err := storageService.InitStorages(); err != nil {
return err
}
if err := accessGatewayService.ReConnectAll(); err != nil {
return err
}
return nil
}
@ -368,59 +427,46 @@ func ChangeEncryptionKey(oldEncryptionKey, newEncryptionKey string) error {
return nil
}
func SetupCache() *cache.Cache {
// 配置缓存器
mCache := cache.New(5*time.Minute, 10*time.Minute)
mCache.OnEvicted(func(key string, value interface{}) {
if strings.HasPrefix(key, Token) {
token := GetTokenFormCacheKey(key)
log.Debugf("用户Token「%v」过期", token)
err := userService.Logout(token)
if err != nil {
log.Errorf("退出登录失败 %v", err)
}
}
})
return mCache
}
func SetupDB() *gorm.DB {
var logMode logger.Interface
if global.Config.Debug {
if config.GlobalCfg.Debug {
logMode = logger.Default.LogMode(logger.Info)
} else {
logMode = logger.Default.LogMode(logger.Silent)
}
fmt.Printf("当前数据库模式为:%v\n", global.Config.DB)
fmt.Printf("当前数据库模式为:%v\n", config.GlobalCfg.DB)
var err error
var db *gorm.DB
if global.Config.DB == "mysql" {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Hostname,
global.Config.Mysql.Port,
global.Config.Mysql.Database,
if config.GlobalCfg.DB == "mysql" {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=60s",
config.GlobalCfg.Mysql.Username,
config.GlobalCfg.Mysql.Password,
config.GlobalCfg.Mysql.Hostname,
config.GlobalCfg.Mysql.Port,
config.GlobalCfg.Mysql.Database,
)
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logMode,
})
} else {
db, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{
db, err = gorm.Open(sqlite.Open(config.GlobalCfg.Sqlite.File), &gorm.Config{
Logger: logMode,
})
}
if err != nil {
log.WithError(err).Panic("连接数据库异常")
log.Errorf("连接数据库异常: %v", err.Error())
os.Exit(0)
}
if err := db.AutoMigrate(&model.User{}, &model.Asset{}, &model.AssetAttribute{}, &model.Session{}, &model.Command{},
&model.Credential{}, &model.Property{}, &model.ResourceSharer{}, &model.UserGroup{}, &model.UserGroupMember{},
&model.LoginLog{}, &model.Num{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}); err != nil {
log.WithError(err).Panic("初始化数据库表结构异常")
&model.LoginLog{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}, &model.AccessGateway{},
&model.Storage{}, &model.Strategy{}); err != nil {
log.Errorf("初始化数据库表结构异常: %v", err.Error())
os.Exit(0)
}
return db
}

View File

@ -4,7 +4,7 @@ import (
"strconv"
"strings"
"next-terminal/pkg/global"
"next-terminal/server/global/security"
"next-terminal/server/model"
"next-terminal/server/utils"
@ -24,9 +24,14 @@ func SecurityCreateEndpoint(c echo.Context) error {
return err
}
// 更新内存中的安全规则
if err := ReloadAccessSecurity(); err != nil {
return err
rule := &security.Security{
ID: item.ID,
IP: item.IP,
Rule: item.Rule,
Priority: item.Priority,
}
security.GlobalSecurityManager.Add <- rule
return Success(c, "")
}
@ -36,15 +41,18 @@ func ReloadAccessSecurity() error {
return err
}
if len(rules) > 0 {
var securities []*global.Security
// 先清空
security.GlobalSecurityManager.Clear()
// 再添加到全局的安全管理器中
for i := 0; i < len(rules); i++ {
rule := global.Security{
IP: rules[i].IP,
Rule: rules[i].Rule,
rule := &security.Security{
ID: rules[i].ID,
IP: rules[i].IP,
Rule: rules[i].Rule,
Priority: rules[i].Priority,
}
securities = append(securities, &rule)
security.GlobalSecurityManager.Add <- rule
}
global.Securities = securities
}
return nil
}
@ -81,9 +89,15 @@ func SecurityUpdateEndpoint(c echo.Context) error {
return err
}
// 更新内存中的安全规则
if err := ReloadAccessSecurity(); err != nil {
return err
security.GlobalSecurityManager.Del <- id
rule := &security.Security{
ID: item.ID,
IP: item.IP,
Rule: item.Rule,
Priority: item.Priority,
}
security.GlobalSecurityManager.Add <- rule
return Success(c, nil)
}
@ -92,15 +106,14 @@ func SecurityDeleteEndpoint(c echo.Context) error {
split := strings.Split(ids, ",")
for i := range split {
jobId := split[i]
if err := accessSecurityRepository.DeleteById(jobId); err != nil {
id := split[i]
if err := accessSecurityRepository.DeleteById(id); err != nil {
return err
}
// 更新内存中的安全规则
security.GlobalSecurityManager.Del <- id
}
// 更新内存中的安全规则
if err := ReloadAccessSecurity(); err != nil {
return err
}
return Success(c, nil)
}

View File

@ -1,11 +1,11 @@
package api
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path"
@ -13,14 +13,18 @@ import (
"strings"
"sync"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/log"
"next-terminal/server/constant"
"next-terminal/server/global/session"
"next-terminal/server/guacd"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/service"
"next-terminal/server/utils"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/pkg/sftp"
"gorm.io/gorm"
)
func SessionPagingEndpoint(c echo.Context) error {
@ -42,7 +46,7 @@ func SessionPagingEndpoint(c echo.Context) error {
if status == constant.Disconnected && len(items[i].Recording) > 0 {
var recording string
if items[i].Mode == constant.Naive {
if items[i].Mode == constant.Naive || items[i].Mode == constant.Terminal {
recording = items[i].Recording
} else {
recording = items[i].Recording + "/recording"
@ -78,14 +82,28 @@ func SessionDeleteEndpoint(c echo.Context) error {
func SessionConnectEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session := model.Session{}
session.ID = sessionId
session.Status = constant.Connected
session.ConnectedTime = utils.NowJsonTime()
s := model.Session{}
s.ID = sessionId
s.Status = constant.Connected
s.ConnectedTime = utils.NowJsonTime()
if err := sessionRepository.UpdateById(&session, sessionId); err != nil {
if err := sessionRepository.UpdateById(&s, sessionId); err != nil {
return err
}
o, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
asset, err := assetRepository.FindById(o.AssetId)
if err != nil {
return err
}
if !asset.Active {
asset.Active = true
_ = assetRepository.UpdateById(&asset, asset.ID)
}
return Success(c, nil)
}
@ -104,18 +122,48 @@ var mutex sync.Mutex
func CloseSessionById(sessionId string, code int, reason string) {
mutex.Lock()
defer mutex.Unlock()
observable, _ := global.Store.Get(sessionId)
if observable != nil {
log.Debugf("会话%v创建者退出,原因:%v", sessionId, reason)
observable.Subject.Close(code, reason)
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession != nil {
log.Debugf("[%v] 会话关闭,原因:%v", sessionId, reason)
WriteCloseMessage(nextSession.WebSocket, nextSession.Mode, code, reason)
for i := 0; i < len(observable.Observers); i++ {
observable.Observers[i].Close(code, reason)
log.Debugf("强制踢出会话%v的观察者", sessionId)
if nextSession.Observer != nil {
obs := nextSession.Observer.All()
for _, ob := range obs {
WriteCloseMessage(ob.WebSocket, ob.Mode, code, reason)
log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID)
}
}
}
global.Store.Del(sessionId)
session.GlobalSessionManager.Del <- sessionId
DisDBSess(sessionId, code, reason)
}
func WriteCloseMessage(ws *websocket.Conn, mode string, code int, reason string) {
switch mode {
case constant.Guacd:
if ws != nil {
err := guacd.NewInstruction("error", "", strconv.Itoa(code))
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.String()))
disconnect := guacd.NewInstruction("disconnect")
_ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String()))
}
case constant.Naive:
if ws != nil {
msg := `0` + reason
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg))
}
case constant.Terminal:
// 这里是关闭观察者的ssh会话
if ws != nil {
msg := `0` + reason
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg))
}
}
}
func DisDBSess(sessionId string, code int, reason string) {
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return
@ -131,17 +179,17 @@ func CloseSessionById(sessionId string, code int, reason string) {
return
}
session := model.Session{}
session.ID = sessionId
session.Status = constant.Disconnected
session.DisconnectedTime = utils.NowJsonTime()
session.Code = code
session.Message = reason
session.Password = "-"
session.PrivateKey = "-"
session.Passphrase = "-"
ss := model.Session{}
ss.ID = sessionId
ss.Status = constant.Disconnected
ss.DisconnectedTime = utils.NowJsonTime()
ss.Code = code
ss.Message = reason
ss.Password = "-"
ss.PrivateKey = "-"
ss.Passphrase = "-"
_ = sessionRepository.UpdateById(&session, sessionId)
_ = sessionRepository.UpdateById(&ss, sessionId)
}
func SessionResizeEndpoint(c echo.Context) error {
@ -150,11 +198,10 @@ func SessionResizeEndpoint(c echo.Context) error {
sessionId := c.Param("id")
if len(width) == 0 || len(height) == 0 {
panic("参数异常")
return errors.New("参数异常")
}
intWidth, _ := strconv.Atoi(width)
intHeight, _ := strconv.Atoi(height)
if err := sessionRepository.UpdateWindowSizeById(intWidth, intHeight, sessionId); err != nil {
@ -175,37 +222,83 @@ func SessionCreateEndpoint(c echo.Context) error {
user, _ := GetCurrentAccount(c)
if constant.TypeUser == user.Type {
// 检测是否有访问权限
assetIds, err := resourceSharerRepository.FindAssetIdsByUserId(user.ID)
if err != nil {
return err
}
if !utils.Contains(assetIds, assetId) {
return errors.New("您没有权限访问此资产")
}
}
asset, err := assetRepository.FindById(assetId)
if err != nil {
return err
}
session := &model.Session{
ID: utils.UUID(),
AssetId: asset.ID,
Username: asset.Username,
Password: asset.Password,
PrivateKey: asset.PrivateKey,
Passphrase: asset.Passphrase,
Protocol: asset.Protocol,
IP: asset.IP,
Port: asset.Port,
Status: constant.NoConnect,
Creator: user.ID,
ClientIP: c.RealIP(),
Mode: mode,
var (
upload = "1"
download = "1"
_delete = "1"
rename = "1"
edit = "1"
fileSystem = "1"
)
if asset.Owner != user.ID && constant.TypeUser == user.Type {
// 普通用户访问非自己创建的资产需要校验权限
resourceSharers, err := resourceSharerRepository.FindByResourceIdAndUserId(assetId, user.ID)
if err != nil {
return err
}
if len(resourceSharers) == 0 {
return errors.New("您没有权限访问此资产")
}
strategyId := resourceSharers[0].StrategyId
if strategyId != "" {
strategy, err := strategyRepository.FindById(strategyId)
if err != nil {
if !errors.Is(gorm.ErrRecordNotFound, err) {
return err
}
} else {
upload = strategy.Upload
download = strategy.Download
_delete = strategy.Delete
rename = strategy.Rename
edit = strategy.Edit
}
}
}
var storageId = ""
if constant.RDP == asset.Protocol {
attr, err := assetRepository.FindAssetAttrMapByAssetId(assetId)
if err != nil {
return err
}
if "true" == attr[guacd.EnableDrive] {
fileSystem = "1"
storageId = attr[guacd.DrivePath]
if storageId == "" {
storageId = user.ID
}
} else {
fileSystem = "0"
}
}
s := &model.Session{
ID: utils.UUID(),
AssetId: asset.ID,
Username: asset.Username,
Password: asset.Password,
PrivateKey: asset.PrivateKey,
Passphrase: asset.Passphrase,
Protocol: asset.Protocol,
IP: asset.IP,
Port: asset.Port,
Status: constant.NoConnect,
Creator: user.ID,
ClientIP: c.RealIP(),
Mode: mode,
Upload: upload,
Download: download,
Delete: _delete,
Rename: rename,
Edit: edit,
StorageId: storageId,
AccessGatewayId: asset.AccessGatewayId,
}
if asset.AccountType == "credential" {
@ -215,28 +308,41 @@ func SessionCreateEndpoint(c echo.Context) error {
}
if credential.Type == constant.Custom {
session.Username = credential.Username
session.Password = credential.Password
s.Username = credential.Username
s.Password = credential.Password
} else {
session.Username = credential.Username
session.PrivateKey = credential.PrivateKey
session.Passphrase = credential.Passphrase
s.Username = credential.Username
s.PrivateKey = credential.PrivateKey
s.Passphrase = credential.Passphrase
}
}
if err := sessionRepository.Create(session); err != nil {
if err := sessionRepository.Create(s); err != nil {
return err
}
return Success(c, echo.Map{"id": session.ID})
return Success(c, echo.Map{
"id": s.ID,
"upload": s.Upload,
"download": s.Download,
"delete": s.Delete,
"rename": s.Rename,
"edit": s.Edit,
"storageId": s.StorageId,
"fileSystem": fileSystem,
})
}
func SessionUploadEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := sessionRepository.FindById(sessionId)
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
if s.Upload != "1" {
return errors.New("禁止操作")
}
file, err := c.FormFile("file")
if err != nil {
return err
@ -251,77 +357,94 @@ func SessionUploadEndpoint(c echo.Context) error {
remoteDir := c.QueryParam("dir")
remoteFile := path.Join(remoteDir, filename)
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
if "ssh" == s.Protocol {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
dstFile, err := tun.Subject.NextTerminal.SftpClient.Create(remoteFile)
sftpClient := nextSession.NextTerminal.SftpClient
// 文件夹不存在时自动创建文件夹
if _, err := sftpClient.Stat(remoteDir); os.IsNotExist(err) {
if err := sftpClient.MkdirAll(remoteDir); err != nil {
return err
}
}
dstFile, err := sftpClient.Create(remoteFile)
if err != nil {
return err
}
defer dstFile.Close()
buf := make([]byte, 1024)
for {
n, err := src.Read(buf)
if err != nil {
if err != io.EOF {
log.Warnf("文件上传错误 %v", err)
} else {
break
}
}
_, _ = dstFile.Write(buf[:n])
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(remoteFile, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := propertyRepository.GetDrivePath()
if err != nil {
return err
}
// Destination
dst, err := os.Create(path.Join(drivePath, remoteFile))
if err != nil {
return err
}
defer dst.Close()
// Copy
if _, err = io.Copy(dst, src); err != nil {
if _, err = io.Copy(dstFile, src); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == s.Protocol {
return StorageUpload(c, file, s.StorageId)
}
return err
}
func SessionDownloadEndpoint(c echo.Context) error {
func SessionEditEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := sessionRepository.FindById(sessionId)
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
//remoteDir := c.Query("dir")
if s.Edit != "1" {
return errors.New("禁止操作")
}
file := c.FormValue("file")
fileContent := c.FormValue("fileContent")
if "ssh" == s.Protocol {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
sftpClient := nextSession.NextTerminal.SftpClient
dstFile, err := sftpClient.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC)
if err != nil {
return err
}
defer dstFile.Close()
write := bufio.NewWriter(dstFile)
if _, err := write.WriteString(fileContent); err != nil {
return err
}
if err := write.Flush(); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == s.Protocol {
return StorageEdit(c, file, fileContent, s.StorageId)
}
return err
}
func SessionDownloadEndpoint(c echo.Context) error {
sessionId := c.Param("id")
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
if s.Download != "1" {
return errors.New("禁止操作")
}
remoteFile := c.QueryParam("file")
// 获取带后缀的文件名称
filenameWithSuffix := path.Base(remoteFile)
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
if "ssh" == s.Protocol {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
dstFile, err := tun.Subject.NextTerminal.SftpClient.Open(remoteFile)
dstFile, err := nextSession.NextTerminal.SftpClient.Open(remoteFile)
if err != nil {
return err
}
@ -335,105 +458,51 @@ func SessionDownloadEndpoint(c echo.Context) error {
}
return c.Stream(http.StatusOK, echo.MIMEOctetStream, bytes.NewReader(buff.Bytes()))
} else if "rdp" == session.Protocol {
if strings.Contains(remoteFile, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := propertyRepository.GetDrivePath()
if err != nil {
return err
}
return c.Attachment(path.Join(drivePath, remoteFile), filenameWithSuffix)
} else if "rdp" == s.Protocol {
storageId := s.StorageId
return StorageDownload(c, remoteFile, storageId)
}
return err
}
type File struct {
Name string `json:"name"`
Path string `json:"path"`
IsDir bool `json:"isDir"`
Mode string `json:"mode"`
IsLink bool `json:"isLink"`
ModTime utils.JsonTime `json:"modTime"`
Size int64 `json:"size"`
}
func SessionLsEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := sessionRepository.FindByIdAndDecrypt(sessionId)
s, err := sessionRepository.FindByIdAndDecrypt(sessionId)
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
remoteDir := c.FormValue("dir")
if "ssh" == s.Protocol {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
if tun.Subject.NextTerminal == nil {
nextTerminal, err := CreateNextTerminalBySession(session)
if err != nil {
return err
}
tun.Subject.NextTerminal = nextTerminal
}
if tun.Subject.NextTerminal.SftpClient == nil {
sftpClient, err := sftp.NewClient(tun.Subject.NextTerminal.SshClient)
if nextSession.NextTerminal.SftpClient == nil {
sftpClient, err := sftp.NewClient(nextSession.NextTerminal.SshClient)
if err != nil {
log.Errorf("创建sftp客户端失败%v", err.Error())
return err
}
tun.Subject.NextTerminal.SftpClient = sftpClient
nextSession.NextTerminal.SftpClient = sftpClient
}
fileInfos, err := tun.Subject.NextTerminal.SftpClient.ReadDir(remoteDir)
fileInfos, err := nextSession.NextTerminal.SftpClient.ReadDir(remoteDir)
if err != nil {
return err
}
var files = make([]File, 0)
var files = make([]service.File, 0)
for i := range fileInfos {
// 忽略因此文件
// 忽略隐藏文件
if strings.HasPrefix(fileInfos[i].Name(), ".") {
continue
}
file := File{
Name: fileInfos[i].Name(),
Path: path.Join(remoteDir, fileInfos[i].Name()),
IsDir: fileInfos[i].IsDir(),
Mode: fileInfos[i].Mode().String(),
IsLink: fileInfos[i].Mode()&os.ModeSymlink == os.ModeSymlink,
ModTime: utils.NewJsonTime(fileInfos[i].ModTime()),
Size: fileInfos[i].Size(),
}
files = append(files, file)
}
return Success(c, files)
} else if "rdp" == session.Protocol {
if strings.Contains(remoteDir, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := propertyRepository.GetDrivePath()
if err != nil {
return err
}
fileInfos, err := ioutil.ReadDir(path.Join(drivePath, remoteDir))
if err != nil {
return err
}
var files = make([]File, 0)
for i := range fileInfos {
file := File{
file := service.File{
Name: fileInfos[i].Name(),
Path: path.Join(remoteDir, fileInfos[i].Name()),
IsDir: fileInfos[i].IsDir(),
@ -447,115 +516,87 @@ func SessionLsEndpoint(c echo.Context) error {
}
return Success(c, files)
} else if "rdp" == s.Protocol {
storageId := s.StorageId
return StorageLs(c, remoteDir, storageId)
}
return errors.New("当前协议不支持此操作")
}
func SafetyRuleTrigger(c echo.Context) {
log.Warnf("IP %v 尝试进行攻击请ban掉此IP", c.RealIP())
security := model.AccessSecurity{
ID: utils.UUID(),
Source: "安全规则触发",
IP: c.RealIP(),
Rule: constant.AccessRuleReject,
}
_ = accessSecurityRepository.Create(&security)
}
func SessionMkDirEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := sessionRepository.FindById(sessionId)
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
}
if err := tun.Subject.NextTerminal.SftpClient.Mkdir(remoteDir); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(remoteDir, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := propertyRepository.GetDrivePath()
if err != nil {
return err
}
if err := os.MkdirAll(path.Join(drivePath, remoteDir), os.ModePerm); err != nil {
return err
}
return Success(c, nil)
if s.Upload != "1" {
return errors.New("禁止操作")
}
remoteDir := c.QueryParam("dir")
if "ssh" == s.Protocol {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
if err := nextSession.NextTerminal.SftpClient.Mkdir(remoteDir); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == s.Protocol {
return StorageMkDir(c, remoteDir, s.StorageId)
}
return errors.New("当前协议不支持此操作")
}
func SessionRmEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := sessionRepository.FindById(sessionId)
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
key := c.QueryParam("key")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
if s.Delete != "1" {
return errors.New("禁止操作")
}
// 文件夹或者文件
file := c.FormValue("file")
if "ssh" == s.Protocol {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
sftpClient := tun.Subject.NextTerminal.SftpClient
sftpClient := nextSession.NextTerminal.SftpClient
stat, err := sftpClient.Stat(key)
stat, err := sftpClient.Stat(file)
if err != nil {
return err
}
if stat.IsDir() {
fileInfos, err := sftpClient.ReadDir(key)
fileInfos, err := sftpClient.ReadDir(file)
if err != nil {
return err
}
for i := range fileInfos {
if err := sftpClient.Remove(path.Join(key, fileInfos[i].Name())); err != nil {
if err := sftpClient.Remove(path.Join(file, fileInfos[i].Name())); err != nil {
return err
}
}
if err := sftpClient.RemoveDirectory(key); err != nil {
if err := sftpClient.RemoveDirectory(file); err != nil {
return err
}
} else {
if err := sftpClient.Remove(key); err != nil {
if err := sftpClient.Remove(file); err != nil {
return err
}
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(key, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := propertyRepository.GetDrivePath()
if err != nil {
return err
}
if err := os.RemoveAll(path.Join(drivePath, key)); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == s.Protocol {
return StorageRm(c, file, s.StorageId)
}
return errors.New("当前协议不支持此操作")
@ -563,58 +604,80 @@ func SessionRmEndpoint(c echo.Context) error {
func SessionRenameEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := sessionRepository.FindById(sessionId)
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
if s.Rename != "1" {
return errors.New("禁止操作")
}
oldName := c.QueryParam("oldName")
newName := c.QueryParam("newName")
if "ssh" == session.Protocol {
tun, ok := global.Store.Get(sessionId)
if !ok {
return errors.New("获取sftp客户端失败")
if "ssh" == s.Protocol {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
sftpClient := tun.Subject.NextTerminal.SftpClient
sftpClient := nextSession.NextTerminal.SftpClient
if err := sftpClient.Rename(oldName, newName); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == session.Protocol {
if strings.Contains(oldName, "../") {
SafetyRuleTrigger(c)
return Fail(c, -1, ":) 您的IP已被记录请去向管理员自首。")
}
drivePath, err := propertyRepository.GetDrivePath()
if err != nil {
return err
}
if err := os.Rename(path.Join(drivePath, oldName), path.Join(drivePath, newName)); err != nil {
return err
}
return Success(c, nil)
} else if "rdp" == s.Protocol {
return StorageRename(c, oldName, newName, s.StorageId)
}
return errors.New("当前协议不支持此操作")
}
func SessionRecordingEndpoint(c echo.Context) error {
sessionId := c.Param("id")
session, err := sessionRepository.FindById(sessionId)
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
var recording string
if session.Mode == constant.Naive {
recording = session.Recording
if s.Mode == constant.Naive || s.Mode == constant.Terminal {
recording = s.Recording
} else {
recording = session.Recording + "/recording"
recording = s.Recording + "/recording"
}
log.Debugf("读取录屏文件:%v,是否存在: %v, 是否为文件: %v", recording, utils.FileExists(recording), utils.IsFile(recording))
return c.File(recording)
}
func SessionGetEndpoint(c echo.Context) error {
sessionId := c.Param("id")
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return err
}
return Success(c, s)
}
func SessionStatsEndpoint(c echo.Context) error {
sessionId := c.Param("id")
s, err := sessionRepository.FindByIdAndDecrypt(sessionId)
if err != nil {
return err
}
if "ssh" != s.Protocol {
return Fail(c, -1, "不支持当前协议")
}
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return errors.New("获取会话失败")
}
sshClient := nextSession.NextTerminal.SshClient
stats, err := GetAllStats(sshClient)
if err != nil {
return err
}
return Success(c, stats)
}

View File

@ -1,18 +1,23 @@
package api
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"path"
"strconv"
"time"
"unicode/utf8"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/pkg/log"
"next-terminal/pkg/term"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/global/session"
"next-terminal/server/guacd"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/term"
"next-terminal/server/utils"
"github.com/gorilla/websocket"
@ -27,17 +32,44 @@ var UpGrader = websocket.Upgrader{
}
const (
Connected = "connected"
Data = "data"
Resize = "resize"
Closed = "closed"
Closed = 0
Connected = 1
Data = 2
Resize = 3
Ping = 4
)
type Message struct {
Type string `json:"type"`
Type int `json:"type"`
Content string `json:"content"`
}
func (r Message) ToString() string {
if r.Content != "" {
return strconv.Itoa(r.Type) + r.Content
} else {
return strconv.Itoa(r.Type)
}
}
func NewMessage(_type int, content string) Message {
return Message{Content: content, Type: _type}
}
func ParseMessage(value string) (message Message, err error) {
if value == "" {
return
}
_type, err := strconv.Atoi(value[:1])
if err != nil {
return
}
var content = value[1:]
message = NewMessage(_type, content)
return
}
type WindowSize struct {
Cols int `json:"cols"`
Rows int `json:"rows"`
@ -50,92 +82,73 @@ func SSHEndpoint(c echo.Context) (err error) {
return err
}
defer ws.Close()
sessionId := c.QueryParam("sessionId")
cols, _ := strconv.Atoi(c.QueryParam("cols"))
rows, _ := strconv.Atoi(c.QueryParam("rows"))
session, err := sessionRepository.FindByIdAndDecrypt(sessionId)
s, err := sessionRepository.FindByIdAndDecrypt(sessionId)
if err != nil {
msg := Message{
Type: Closed,
Content: "get sshSession error." + err.Error(),
}
_ = WriteMessage(ws, msg)
return err
return WriteMessage(ws, NewMessage(Closed, "获取会话失败"))
}
user, _ := GetCurrentAccount(c)
if constant.TypeUser == user.Type {
// 检测是否有访问权限
assetIds, err := resourceSharerRepository.FindAssetIdsByUserId(user.ID)
if err != nil {
return err
}
if !utils.Contains(assetIds, session.AssetId) {
msg := Message{
Type: Closed,
Content: "您没有权限访问此资产",
}
return WriteMessage(ws, msg)
}
if err := permissionCheck(c, s.AssetId); err != nil {
return WriteMessage(ws, NewMessage(Closed, err.Error()))
}
var (
username = session.Username
password = session.Password
privateKey = session.PrivateKey
passphrase = session.Passphrase
ip = session.IP
port = session.Port
username = s.Username
password = s.Password
privateKey = s.PrivateKey
passphrase = s.Passphrase
ip = s.IP
port = s.Port
)
recording := ""
propertyMap := propertyRepository.FindAllMap()
if propertyMap[guacd.EnableRecording] == "true" {
recording = path.Join(propertyMap[guacd.RecordingPath], sessionId, "recording.cast")
}
tun := global.Tun{
Protocol: session.Protocol,
Mode: session.Mode,
WebSocket: ws,
}
if session.ConnectionId != "" {
// 监控会话
observable, ok := global.Store.Get(sessionId)
if ok {
observers := append(observable.Observers, tun)
observable.Observers = observers
global.Store.Set(sessionId, observable)
log.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers))
if s.AccessGatewayId != "" && s.AccessGatewayId != "-" {
g, err := accessGatewayService.GetGatewayAndReconnectById(s.AccessGatewayId)
if err != nil {
return WriteMessage(ws, NewMessage(Closed, "获取接入网关失败:"+err.Error()))
}
return err
if !g.Connected {
return WriteMessage(ws, NewMessage(Closed, "接入网关不可用:"+g.Message))
}
exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port)
if err != nil {
return WriteMessage(ws, NewMessage(Closed, "创建隧道失败:"+err.Error()))
}
defer g.CloseSshTunnel(s.ID)
ip = exposedIP
port = exposedPort
}
nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording)
recording := ""
var isRecording = false
property, err := propertyRepository.FindByName(guacd.EnableRecording)
if err == nil && property.Value == "true" {
isRecording = true
}
if isRecording {
recording = path.Join(config.GlobalCfg.Guacd.Recording, sessionId, "recording.cast")
}
var xterm = "xterm-256color"
nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording, xterm, true)
if err != nil {
log.Errorf("创建SSH客户端失败%v", err.Error())
msg := Message{
Type: Closed,
Content: err.Error(),
}
err := WriteMessage(ws, msg)
return WriteMessage(ws, NewMessage(Closed, "创建SSH客户端失败"+err.Error()))
}
if err := nextTerminal.RequestPty(xterm, rows, cols); err != nil {
return err
}
tun.NextTerminal = nextTerminal
var observers []global.Tun
observable := global.Observable{
Subject: &tun,
Observers: observers,
if err := nextTerminal.Shell(); err != nil {
return err
}
global.Store.Set(sessionId, &observable)
sess := model.Session{
ConnectionId: sessionId,
Width: cols,
@ -149,106 +162,209 @@ func SSHEndpoint(c echo.Context) (err error) {
return err
}
msg := Message{
Type: Connected,
Content: "",
if err := WriteMessage(ws, NewMessage(Connected, "")); err != nil {
return err
}
_ = WriteMessage(ws, msg)
quitChan := make(chan bool)
nextSession := &session.Session{
ID: s.ID,
Protocol: s.Protocol,
Mode: s.Mode,
WebSocket: ws,
GuacdTunnel: nil,
NextTerminal: nextTerminal,
Observer: session.NewObserver(s.ID),
}
go nextSession.Observer.Run()
session.GlobalSessionManager.Add <- nextSession
go ReadMessage(nextTerminal, quitChan, ws)
ctx, cancel := context.WithCancel(context.Background())
tick := time.NewTicker(time.Millisecond * time.Duration(60))
defer tick.Stop()
var buf []byte
dataChan := make(chan rune)
go func() {
SshLoop:
for {
select {
case <-ctx.Done():
log.Debugf("WebSocket已关闭即将关闭SSH连接...")
break SshLoop
default:
r, size, err := nextTerminal.StdoutReader.ReadRune()
if err != nil {
log.Debugf("SSH 读取失败,即将退出循环...")
_ = WriteMessage(ws, NewMessage(Closed, ""))
break SshLoop
}
if size > 0 {
dataChan <- r
}
}
}
log.Debugf("SSH 连接已关闭,退出循环。")
}()
go func() {
tickLoop:
for {
select {
case <-ctx.Done():
break tickLoop
case <-tick.C:
if len(buf) > 0 {
s := string(buf)
// 录屏
if isRecording {
_ = nextTerminal.Recorder.WriteData(s)
}
// 监控
if len(nextSession.Observer.All()) > 0 {
obs := nextSession.Observer.All()
for _, ob := range obs {
_ = WriteMessage(ob.WebSocket, NewMessage(Data, s))
}
}
if err := WriteMessage(ws, NewMessage(Data, s)); err != nil {
log.Debugf("WebSocket写入失败即将退出循环...")
cancel()
}
buf = []byte{}
}
case data := <-dataChan:
if data != utf8.RuneError {
p := make([]byte, utf8.RuneLen(data))
utf8.EncodeRune(p, data)
buf = append(buf, p...)
} else {
buf = append(buf, []byte("@")...)
}
}
}
log.Debugf("SSH 连接已关闭,退出定时器循环。")
}()
//var enterKeys []rune
//enterIndex := 0
for {
_, message, err := ws.ReadMessage()
if err != nil {
// web socket会话关闭后主动关闭ssh会话
CloseSessionById(sessionId, Normal, "正常退出")
quitChan <- true
quitChan <- true
log.Debugf("WebSocket已关闭")
CloseSessionById(sessionId, Normal, "用户正常退出")
cancel()
break
}
var msg Message
err = json.Unmarshal(message, &msg)
msg, err := ParseMessage(string(message))
if err != nil {
log.Warnf("解析Json失败: %v, 原始字符串:%v", err, string(message))
log.Warnf("消息解码失败: %v, 原始字符串:%v", err, string(message))
continue
}
switch msg.Type {
case Resize:
var winSize WindowSize
err = json.Unmarshal([]byte(msg.Content), &winSize)
decodeString, err := base64.StdEncoding.DecodeString(msg.Content)
if err != nil {
log.Warnf("解析SSH会话窗口大小失败: %v", err)
log.Warnf("Base64解码失败: %v原始字符串%v", err, msg.Content)
continue
}
var winSize WindowSize
err = json.Unmarshal(decodeString, &winSize)
if err != nil {
log.Warnf("解析SSH会话窗口大小失败: %v原始字符串%v", err, msg.Content)
continue
}
if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil {
log.Warnf("更改SSH会话窗口大小失败: %v", err)
continue
}
_ = sessionRepository.UpdateWindowSizeById(winSize.Rows, winSize.Cols, sessionId)
case Data:
_, err = nextTerminal.Write([]byte(msg.Content))
input := []byte(msg.Content)
//hexInput := hex.EncodeToString(input)
//switch hexInput {
//case "0d": // 回车
// DealCommand(enterKeys)
// // 清空输入的字符
// enterKeys = enterKeys[:0]
// enterIndex = 0
//case "7f": // backspace
// enterIndex--
// if enterIndex < 0 {
// enterIndex = 0
// }
// temp := enterKeys[:enterIndex]
// if len(enterKeys) > enterIndex {
// enterKeys = append(temp, enterKeys[enterIndex+1:]...)
// } else {
// enterKeys = temp
// }
//case "1b5b337e": // del
// temp := enterKeys[:enterIndex]
// if len(enterKeys) > enterIndex {
// enterKeys = append(temp, enterKeys[enterIndex+1:]...)
// } else {
// enterKeys = temp
// }
// enterIndex--
// if enterIndex < 0 {
// enterIndex = 0
// }
//case "1b5b41":
//case "1b5b42":
// break
//case "1b5b43": // ->
// enterIndex++
// if enterIndex > len(enterKeys) {
// enterIndex = len(enterKeys)
// }
//case "1b5b44": // <-
// enterIndex--
// if enterIndex < 0 {
// enterIndex = 0
// }
//default:
// enterKeys = utils.InsertSlice(enterIndex, []rune(msg.Content), enterKeys)
// enterIndex++
//}
_, err := nextTerminal.Write(input)
if err != nil {
log.Debugf("SSH会话写入失败: %v", err)
msg := Message{
Type: Closed,
Content: "the remote connection is closed.",
}
_ = WriteMessage(ws, msg)
CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭")
}
case Ping:
_, _, err := nextTerminal.SshClient.Conn.SendRequest("helloworld1024@foxmail.com", true, nil)
if err != nil {
CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭")
} else {
_ = WriteMessage(ws, NewMessage(Ping, ""))
}
}
}
}
return err
}
func ReadMessage(nextTerminal *term.NextTerminal, quitChan chan bool, ws *websocket.Conn) {
func permissionCheck(c echo.Context, assetId string) error {
user, _ := GetCurrentAccount(c)
if constant.TypeUser == user.Type {
// 检测是否有访问权限
assetIds, err := resourceSharerRepository.FindAssetIdsByUserId(user.ID)
if err != nil {
return err
}
var quit bool
for {
select {
case quit = <-quitChan:
if quit {
return
}
default:
p, n, err := nextTerminal.Read()
if err != nil {
msg := Message{
Type: Closed,
Content: err.Error(),
}
_ = WriteMessage(ws, msg)
}
if n > 0 {
s := string(p)
msg := Message{
Type: Data,
Content: s,
}
_ = WriteMessage(ws, msg)
}
time.Sleep(time.Duration(10) * time.Millisecond)
if !utils.Contains(assetIds, assetId) {
return errors.New("您没有权限访问此资产")
}
}
return nil
}
func WriteMessage(ws *websocket.Conn, msg Message) error {
message, err := json.Marshal(msg)
if err != nil {
return err
}
WriteByteMessage(ws, message)
return err
}
func WriteByteMessage(ws *websocket.Conn, p []byte) {
err := ws.WriteMessage(websocket.TextMessage, p)
if err != nil {
log.Debugf("write: %v", err)
}
message := []byte(msg.ToString())
return ws.WriteMessage(websocket.TextMessage, message)
}
func CreateNextTerminalBySession(session model.Session) (*term.NextTerminal, error) {
@ -260,5 +376,46 @@ func CreateNextTerminalBySession(session model.Session) (*term.NextTerminal, err
ip = session.IP
port = session.Port
)
return term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, 10, 10, "")
return term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, 10, 10, "", "", false)
}
func SshMonitor(c echo.Context) error {
ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil)
if err != nil {
log.Errorf("升级为WebSocket协议失败%v", err.Error())
return err
}
defer ws.Close()
sessionId := c.QueryParam("sessionId")
s, err := sessionRepository.FindById(sessionId)
if err != nil {
return WriteMessage(ws, NewMessage(Closed, "获取会话失败"))
}
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession == nil {
return WriteMessage(ws, NewMessage(Closed, "会话已离线"))
}
obId := utils.UUID()
obSession := &session.Session{
ID: obId,
Protocol: s.Protocol,
Mode: s.Mode,
WebSocket: ws,
}
nextSession.Observer.Add <- obSession
log.Debugf("会话 %v 观察者 %v 进入", sessionId, obId)
for {
_, _, err := ws.ReadMessage()
if err != nil {
log.Debugf("会话 %v 观察者 %v 退出", sessionId, obId)
nextSession.Observer.Del <- obId
break
}
}
return nil
}

View File

@ -0,0 +1,6 @@
package api
func DealCommand(enterKeys []rune) {
println(string(enterKeys))
}

468
server/api/sshd.go Normal file
View File

@ -0,0 +1,468 @@
package api
import (
"encoding/hex"
"errors"
"fmt"
"io"
"path"
"strings"
"time"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/global/cache"
"next-terminal/server/global/session"
"next-terminal/server/guacd"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/term"
"next-terminal/server/totp"
"next-terminal/server/utils"
"github.com/gliderlabs/ssh"
"github.com/manifoldco/promptui"
"gorm.io/gorm"
)
func sessionHandler(sess *ssh.Session) {
defer func() {
(*sess).Close()
}()
username := (*sess).User()
remoteAddr := strings.Split((*sess).RemoteAddr().String(), ":")[0]
user, err := userRepository.FindByUsername(username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
_, _ = io.WriteString(*sess, "您输入的账户或密码不正确.\n")
} else {
_, _ = io.WriteString(*sess, err.Error())
}
return
}
// 判断是否需要进行双因素认证
if user.TOTPSecret != "" && user.TOTPSecret != "-" {
totpUI(sess, user, remoteAddr, username)
} else {
// 保存登录日志
_ = SaveLoginLog(remoteAddr, "terminal", username, true, false, utils.UUID(), "")
mainUI(sess, user)
}
}
func totpUI(sess *ssh.Session, user model.User, remoteAddr string, username string) {
validate := func(input string) error {
if len(input) < 6 {
return errors.New("双因素认证授权码必须为6个数字")
}
return nil
}
prompt := promptui.Prompt{
Label: "请输入双因素认证授权码",
Validate: validate,
Mask: '*',
Stdin: *sess,
Stdout: *sess,
}
var success = false
for i := 0; i < 5; i++ {
result, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
return
}
loginFailCountKey := remoteAddr + username
v, ok := cache.GlobalCache.Get(loginFailCountKey)
if !ok {
v = 1
}
count := v.(int)
if count >= 5 {
_, _ = io.WriteString(*sess, "登录失败次数过多请等待30秒后再试\r\n")
continue
}
if !totp.Validate(result, user.TOTPSecret) {
count++
println(count)
cache.GlobalCache.Set(loginFailCountKey, count, time.Second*time.Duration(30))
// 保存登录日志
_ = SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "双因素认证授权码不正确")
_, _ = io.WriteString(*sess, "您输入的双因素认证授权码不匹配\r\n")
continue
}
success = true
break
}
if success {
// 保存登录日志
_ = SaveLoginLog(remoteAddr, "terminal", username, true, false, utils.UUID(), "")
mainUI(sess, user)
}
}
func mainUI(sess *ssh.Session, user model.User) {
prompt := promptui.Select{
Label: "欢迎使用 Next Terminal请选择您要使用的功能",
Items: []string{"我的资产", "退出系统"},
Stdin: *sess,
Stdout: *sess,
}
MainLoop:
for {
_, result, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
return
}
switch result {
case "我的资产":
AssetUI(sess, user)
case "退出系统":
break MainLoop
}
}
}
func AssetUI(sess *ssh.Session, user model.User) {
assets, err := assetRepository.FindByProtocolAndUser(constant.SSH, user)
if err != nil {
return
}
quitItem := model.Asset{ID: "quit", Name: "返回上级菜单", Description: "这里是返回上级菜单的选项"}
assets = append([]model.Asset{quitItem}, assets...)
templates := &promptui.SelectTemplates{
Label: "{{ . }}?",
Active: "\U0001F336 {{ .Name | cyan }} ({{ .IP | red }}:{{ .Port | red }})",
Inactive: " {{ .Name | cyan }} ({{ .IP | red }}:{{ .Port | red }})",
Selected: "\U0001F336 {{ .Name | red | cyan }}",
Details: `
--------- 详细信息 ----------
{{ "名称:" | faint }} {{ .Name }}
{{ "主机:" | faint }} {{ .IP }}
{{ "端口:" | faint }} {{ .Port }}
{{ "标签:" | faint }} {{ .Tags }}
{{ "备注:" | faint }} {{ .Description }}
`,
}
searcher := func(input string, index int) bool {
asset := assets[index]
name := strings.Replace(strings.ToLower(asset.Name), " ", "", -1)
input = strings.Replace(strings.ToLower(input), " ", "", -1)
return strings.Contains(name, input)
}
prompt := promptui.Select{
Label: "请选择您要访问的资产",
Items: assets,
Templates: templates,
Size: 4,
Searcher: searcher,
Stdin: *sess,
Stdout: *sess,
}
AssetUILoop:
for {
i, _, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
return
}
chooseAssetId := assets[i].ID
switch chooseAssetId {
case "quit":
break AssetUILoop
default:
if err := createSession(sess, assets[i].ID, user.ID); err != nil {
_, _ = io.WriteString(*sess, err.Error()+"\r\n")
return
}
}
}
}
func createSession(sess *ssh.Session, assetId, creator string) (err error) {
asset, err := assetRepository.FindById(assetId)
if err != nil {
return err
}
ClientIP := strings.Split((*sess).RemoteAddr().String(), ":")[0]
s := &model.Session{
ID: utils.UUID(),
AssetId: asset.ID,
Username: asset.Username,
Password: asset.Password,
PrivateKey: asset.PrivateKey,
Passphrase: asset.Passphrase,
Protocol: asset.Protocol,
IP: asset.IP,
Port: asset.Port,
Status: constant.NoConnect,
Creator: creator,
ClientIP: ClientIP,
Mode: constant.Terminal,
Upload: "0",
Download: "0",
Delete: "0",
Rename: "0",
StorageId: "",
AccessGatewayId: asset.AccessGatewayId,
}
if asset.AccountType == "credential" {
credential, err := credentialRepository.FindById(asset.CredentialId)
if err != nil {
return nil
}
if credential.Type == constant.Custom {
s.Username = credential.Username
s.Password = credential.Password
} else {
s.Username = credential.Username
s.PrivateKey = credential.PrivateKey
s.Passphrase = credential.Passphrase
}
}
if err := sessionRepository.Create(s); err != nil {
return err
}
return handleAccessAsset(sess, s.ID)
}
func handleAccessAsset(sess *ssh.Session, sessionId string) (err error) {
s, err := sessionRepository.FindByIdAndDecrypt(sessionId)
if err != nil {
return err
}
var (
username = s.Username
password = s.Password
privateKey = s.PrivateKey
passphrase = s.Passphrase
ip = s.IP
port = s.Port
)
if s.AccessGatewayId != "" && s.AccessGatewayId != "-" {
g, err := accessGatewayService.GetGatewayAndReconnectById(s.AccessGatewayId)
if err != nil {
return errors.New("获取接入网关失败:" + err.Error())
}
if !g.Connected {
return errors.New("接入网关不可用:" + g.Message)
}
exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port)
if err != nil {
return errors.New("开启SSH隧道失败" + err.Error())
}
defer g.CloseSshTunnel(s.ID)
ip = exposedIP
port = exposedPort
}
pty, winCh, isPty := (*sess).Pty()
if !isPty {
return errors.New("No PTY requested.\n")
}
recording := ""
property, err := propertyRepository.FindByName(guacd.EnableRecording)
if err == nil && property.Value == "true" {
recording = path.Join(config.GlobalCfg.Guacd.Recording, sessionId, "recording.cast")
}
nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, pty.Window.Height, pty.Window.Width, recording, pty.Term, false)
if err != nil {
return err
}
sshSession := nextTerminal.SshSession
writer := NewWriter(sessionId, sess, nextTerminal.Recorder)
sshSession.Stdout = writer
sshSession.Stdin = *sess
sshSession.Stderr = *sess
if err := nextTerminal.RequestPty(pty.Term, pty.Window.Height, pty.Window.Width); err != nil {
return err
}
if err := nextTerminal.Shell(); err != nil {
return err
}
go func() {
log.Debugf("开启窗口大小监控...")
for win := range winCh {
_ = sshSession.WindowChange(win.Height, win.Width)
}
log.Debugf("退出窗口大小监控")
// ==== 修改数据库中的会话状态为已断开,修复用户直接关闭窗口时会话状态不正确的问题 ====
CloseSessionById(sessionId, Normal, "用户正常退出")
// ==== 修改数据库中的会话状态为已断开,修复用户直接关闭窗口时会话状态不正确的问题 ====
}()
// ==== 修改数据库中的会话状态为已连接 ====
sessionForUpdate := model.Session{}
sessionForUpdate.ID = sessionId
sessionForUpdate.Status = constant.Connected
sessionForUpdate.Recording = recording
sessionForUpdate.ConnectedTime = utils.NowJsonTime()
if err := sessionRepository.UpdateById(&sessionForUpdate, sessionId); err != nil {
return err
}
// ==== 修改数据库中的会话状态为已连接 ====
nextSession := &session.Session{
ID: s.ID,
Protocol: s.Protocol,
Mode: s.Mode,
NextTerminal: nextTerminal,
Observer: session.NewObserver(s.ID),
}
go nextSession.Observer.Run()
session.GlobalSessionManager.Add <- nextSession
if err := sshSession.Wait(); err != nil {
return err
}
// ==== 修改数据库中的会话状态为已断开 ====
CloseSessionById(sessionId, Normal, "用户正常退出")
// ==== 修改数据库中的会话状态为已断开 ====
return nil
}
func passwordAuth(ctx ssh.Context, pass string) bool {
username := ctx.User()
remoteAddr := strings.Split(ctx.RemoteAddr().String(), ":")[0]
user, err := userRepository.FindByUsername(username)
if err != nil {
// 保存登录日志
_ = SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确")
return false
}
if err := utils.Encoder.Match([]byte(user.Password), []byte(pass)); err != nil {
// 保存登录日志
_ = SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确")
return false
}
return true
}
func Setup() {
ssh.Handle(func(s ssh.Session) {
_, _ = io.WriteString(s, fmt.Sprintf(constant.Banner, constant.Version))
defer func() {
if e, ok := recover().(error); ok {
log.Fatal(e)
}
}()
sessionHandler(&s)
})
fmt.Printf("⇨ sshd server started on %v\n", config.GlobalCfg.Sshd.Addr)
err := ssh.ListenAndServe(
config.GlobalCfg.Sshd.Addr,
nil,
ssh.PasswordAuth(passwordAuth),
ssh.HostKeyFile(config.GlobalCfg.Sshd.Key),
)
log.Fatal(fmt.Sprintf("启动sshd服务失败: %v", err.Error()))
}
func init() {
if config.GlobalCfg.Sshd.Enable {
go Setup()
}
}
type Writer struct {
sessionId string
sess *ssh.Session
recorder *term.Recorder
rz bool
sz bool
}
func NewWriter(sessionId string, sess *ssh.Session, recorder *term.Recorder) *Writer {
return &Writer{sessionId: sessionId, sess: sess, recorder: recorder}
}
func (w *Writer) Write(p []byte) (n int, err error) {
if w.recorder != nil {
s := string(p)
if !w.sz && !w.rz {
// rz的开头字符
hexData := hex.EncodeToString(p)
if strings.Contains(hexData, "727a0d2a2a184230303030303030303030303030300d8a11") {
w.sz = true
} else if strings.Contains(hexData, "727a2077616974696e6720746f20726563656976652e2a2a184230313030303030303233626535300d8a11") {
w.rz = true
}
}
if w.sz {
// sz 会以 OO 结尾
if "OO" == s {
w.sz = false
}
} else if w.rz {
// rz 最后会显示 Received /home/xxx
if strings.Contains(s, "Received") {
w.rz = false
// 把上传的文件名称也显示一下
err := w.recorder.WriteData(s)
if err != nil {
return 0, err
}
sendObData(w.sessionId, s)
}
} else {
err := w.recorder.WriteData(s)
if err != nil {
return 0, err
}
sendObData(w.sessionId, s)
}
}
return (*w.sess).Write(p)
}
func sendObData(sessionId, s string) {
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession != nil {
if nextSession.Observer != nil {
obs := nextSession.Observer.All()
for _, ob := range obs {
_ = WriteMessage(ob.WebSocket, NewMessage(Data, s))
}
}
}
}

384
server/api/stats.go Normal file
View File

@ -0,0 +1,384 @@
package api
import (
"bufio"
"fmt"
"strconv"
"strings"
"time"
"next-terminal/server/utils"
"golang.org/x/crypto/ssh"
)
type FileSystem struct {
MountPoint string `json:"mountPoint"`
Used uint64 `json:"used"`
Free uint64 `json:"free"`
}
type Network struct {
IPv4 string `json:"ipv4"`
IPv6 string `json:"ipv6"`
Rx uint64 `json:"rx"`
Tx uint64 `json:"tx"`
}
type cpuRaw struct {
User uint64 // time spent in user mode
Nice uint64 // time spent in user mode with low priority (nice)
System uint64 // time spent in system mode
Idle uint64 // time spent in the idle task
Iowait uint64 // time spent waiting for I/O to complete (since Linux 2.5.41)
Irq uint64 // time spent servicing interrupts (since 2.6.0-test4)
SoftIrq uint64 // time spent servicing softirqs (since 2.6.0-test4)
Steal uint64 // time spent in other OSes when running in a virtualized environment
Guest uint64 // time spent running a virtual CPU for guest operating systems under the control of the Linux kernel.
Total uint64 // total of all time fields
}
type CPU struct {
User float32 `json:"user"`
Nice float32 `json:"nice"`
System float32 `json:"system"`
Idle float32 `json:"idle"`
IOWait float32 `json:"ioWait"`
Irq float32 `json:"irq"`
SoftIrq float32 `json:"softIrq"`
Steal float32 `json:"steal"`
Guest float32 `json:"guest"`
}
type Stat struct {
Uptime int64 `json:"uptime"`
Hostname string `json:"hostname"`
Load1 string `json:"load1"`
Load5 string `json:"load5"`
Load10 string `json:"load10"`
RunningProcess string `json:"runningProcess"`
TotalProcess string `json:"totalProcess"`
MemTotal uint64 `json:"memTotal"`
MemAvailable uint64 `json:"memAvailable"`
MemFree uint64 `json:"memFree"`
MemBuffers uint64 `json:"memBuffers"`
MemCached uint64 `json:"memCached"`
SwapTotal uint64 `json:"swapTotal"`
SwapFree uint64 `json:"swapFree"`
FileSystems []FileSystem `json:"fileSystems"`
Network map[string]Network `json:"network"`
CPU CPU `json:"cpu"`
}
func GetAllStats(client *ssh.Client) (*Stat, error) {
start := time.Now()
stats := &Stat{}
if err := getUptime(client, stats); err != nil {
return nil, err
}
if err := getHostname(client, stats); err != nil {
return nil, err
}
if err := getLoad(client, stats); err != nil {
return nil, err
}
if err := getMem(client, stats); err != nil {
return nil, err
}
if err := getFileSystems(client, stats); err != nil {
return nil, err
}
if err := getInterfaces(client, stats); err != nil {
return nil, err
}
if err := getInterfaceInfo(client, stats); err != nil {
return nil, err
}
if err := getCPU(client, stats); err != nil {
return nil, err
}
cost := time.Since(start)
fmt.Printf("%s: %v\n", "GetAllStats", cost)
return stats, nil
}
func getHostname(client *ssh.Client, stat *Stat) (err error) {
//defer utils.TimeWatcher("getHostname")
hostname, err := utils.RunCommand(client, "/bin/hostname -f")
if err != nil {
return
}
stat.Hostname = strings.TrimSpace(hostname)
return
}
func getUptime(client *ssh.Client, stat *Stat) (err error) {
//defer utils.TimeWatcher("getUptime")
uptime, err := utils.RunCommand(client, "/bin/cat /proc/uptime")
if err != nil {
return
}
parts := strings.Fields(uptime)
if len(parts) == 2 {
var upSeconds float64
upSeconds, err = strconv.ParseFloat(parts[0], 64)
if err != nil {
return
}
stat.Uptime = int64(upSeconds * 1000)
}
return
}
func getLoad(client *ssh.Client, stat *Stat) (err error) {
//defer utils.TimeWatcher("getLoad")
line, err := utils.RunCommand(client, "/bin/cat /proc/loadavg")
if err != nil {
return
}
parts := strings.Fields(line)
if len(parts) == 5 {
stat.Load1 = parts[0]
stat.Load5 = parts[1]
stat.Load10 = parts[2]
if i := strings.Index(parts[3], "/"); i != -1 {
stat.RunningProcess = parts[3][0:i]
if i+1 < len(parts[3]) {
stat.TotalProcess = parts[3][i+1:]
}
}
}
return
}
func getMem(client *ssh.Client, stat *Stat) (err error) {
//defer utils.TimeWatcher("getMem")
lines, err := utils.RunCommand(client, "/bin/cat /proc/meminfo")
if err != nil {
return
}
scanner := bufio.NewScanner(strings.NewReader(lines))
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
if len(parts) == 3 {
val, err := strconv.ParseUint(parts[1], 10, 64)
if err != nil {
continue
}
val *= 1024
switch parts[0] {
case "MemTotal:":
stat.MemTotal = val
case "MemFree:":
stat.MemFree = val
case "MemAvailable:":
stat.MemAvailable = val
case "Buffers:":
stat.MemBuffers = val
case "Cached:":
stat.MemCached = val
case "SwapTotal:":
stat.SwapTotal = val
case "SwapFree:":
stat.SwapFree = val
}
}
}
return
}
func getFileSystems(client *ssh.Client, stat *Stat) (err error) {
//defer utils.TimeWatcher("getFileSystems")
lines, err := utils.RunCommand(client, "/bin/df -B1")
if err != nil {
return
}
scanner := bufio.NewScanner(strings.NewReader(lines))
flag := 0
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
n := len(parts)
dev := n > 0 && strings.Index(parts[0], "/dev/") == 0
if n == 1 && dev {
flag = 1
} else if (n == 5 && flag == 1) || (n == 6 && dev) {
i := flag
flag = 0
used, err := strconv.ParseUint(parts[2-i], 10, 64)
if err != nil {
continue
}
free, err := strconv.ParseUint(parts[3-i], 10, 64)
if err != nil {
continue
}
stat.FileSystems = append(stat.FileSystems, FileSystem{
parts[5-i], used, free,
})
}
}
return
}
func getInterfaces(client *ssh.Client, stats *Stat) (err error) {
//defer utils.TimeWatcher("getInterfaces")
var lines string
lines, err = utils.RunCommand(client, "/bin/ip -o addr")
if err != nil {
// try /sbin/ip
lines, err = utils.RunCommand(client, "/sbin/ip -o addr")
if err != nil {
return
}
}
if stats.Network == nil {
stats.Network = make(map[string]Network)
}
scanner := bufio.NewScanner(strings.NewReader(lines))
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
if len(parts) >= 4 && (parts[2] == "inet" || parts[2] == "inet6") {
ipv4 := parts[2] == "inet"
intfname := parts[1]
if info, ok := stats.Network[intfname]; ok {
if ipv4 {
info.IPv4 = parts[3]
} else {
info.IPv6 = parts[3]
}
stats.Network[intfname] = info
} else {
info := Network{}
if ipv4 {
info.IPv4 = parts[3]
} else {
info.IPv6 = parts[3]
}
stats.Network[intfname] = info
}
}
}
return
}
func getInterfaceInfo(client *ssh.Client, stats *Stat) (err error) {
//defer utils.TimeWatcher("getInterfaceInfo")
lines, err := utils.RunCommand(client, "/bin/cat /proc/net/dev")
if err != nil {
return
}
if stats.Network == nil {
return
} // should have been here already
scanner := bufio.NewScanner(strings.NewReader(lines))
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
if len(parts) == 17 {
intf := strings.TrimSpace(parts[0])
intf = strings.TrimSuffix(intf, ":")
if info, ok := stats.Network[intf]; ok {
rx, err := strconv.ParseUint(parts[1], 10, 64)
if err != nil {
continue
}
tx, err := strconv.ParseUint(parts[9], 10, 64)
if err != nil {
continue
}
info.Rx = rx
info.Tx = tx
stats.Network[intf] = info
}
}
}
return
}
func parseCPUFields(fields []string, stat *cpuRaw) {
numFields := len(fields)
for i := 1; i < numFields; i++ {
val, err := strconv.ParseUint(fields[i], 10, 64)
if err != nil {
continue
}
stat.Total += val
switch i {
case 1:
stat.User = val
case 2:
stat.Nice = val
case 3:
stat.System = val
case 4:
stat.Idle = val
case 5:
stat.Iowait = val
case 6:
stat.Irq = val
case 7:
stat.SoftIrq = val
case 8:
stat.Steal = val
case 9:
stat.Guest = val
}
}
}
// the CPU stats that were fetched last time round
var preCPU cpuRaw
func getCPU(client *ssh.Client, stats *Stat) (err error) {
//defer utils.TimeWatcher("getCPU")
lines, err := utils.RunCommand(client, "/bin/cat /proc/stat")
if err != nil {
return
}
var (
nowCPU cpuRaw
total float32
)
scanner := bufio.NewScanner(strings.NewReader(lines))
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) > 0 && fields[0] == "cpu" { // changing here if want to get every cpu-core's stats
parseCPUFields(fields, &nowCPU)
break
}
}
if preCPU.Total == 0 { // having no pre raw cpu data
goto END
}
total = float32(nowCPU.Total - preCPU.Total)
stats.CPU.User = float32(nowCPU.User-preCPU.User) / total * 100
stats.CPU.Nice = float32(nowCPU.Nice-preCPU.Nice) / total * 100
stats.CPU.System = float32(nowCPU.System-preCPU.System) / total * 100
stats.CPU.Idle = float32(nowCPU.Idle-preCPU.Idle) / total * 100
stats.CPU.IOWait = float32(nowCPU.Iowait-preCPU.Iowait) / total * 100
stats.CPU.Irq = float32(nowCPU.Irq-preCPU.Irq) / total * 100
stats.CPU.SoftIrq = float32(nowCPU.SoftIrq-preCPU.SoftIrq) / total * 100
stats.CPU.Guest = float32(nowCPU.Guest-preCPU.Guest) / total * 100
END:
preCPU = nowCPU
return
}

353
server/api/storage.go Normal file
View File

@ -0,0 +1,353 @@
package api
import (
"bufio"
"errors"
"io"
"mime/multipart"
"os"
"path"
"strconv"
"strings"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func StoragePagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := storageRepository.Find(pageIndex, pageSize, name, order, field)
if err != nil {
return err
}
drivePath := storageService.GetBaseDrivePath()
for i := range items {
item := items[i]
dirSize, err := utils.DirSize(path.Join(drivePath, item.ID))
if err != nil {
items[i].UsedSize = -1
} else {
items[i].UsedSize = dirSize
}
}
return Success(c, H{
"total": total,
"items": items,
})
}
func StorageCreateEndpoint(c echo.Context) error {
var item model.Storage
if err := c.Bind(&item); err != nil {
return err
}
account, _ := GetCurrentAccount(c)
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
item.Owner = account.ID
// 创建对应的目录文件夹
drivePath := storageService.GetBaseDrivePath()
if err := os.MkdirAll(path.Join(drivePath, item.ID), os.ModePerm); err != nil {
return err
}
if err := storageRepository.Create(&item); err != nil {
return err
}
return Success(c, "")
}
func StorageUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
var item model.Storage
if err := c.Bind(&item); err != nil {
return err
}
drivePath := storageService.GetBaseDrivePath()
dirSize, err := utils.DirSize(path.Join(drivePath, item.ID))
if err != nil {
return err
}
if item.LimitSize > 0 && item.LimitSize < dirSize {
// 不能小于已使用的大小
return errors.New("空间大小不能小于已使用大小")
}
storage, err := storageRepository.FindById(id)
if err != nil {
return err
}
storage.Name = item.Name
storage.LimitSize = item.LimitSize
storage.IsShare = item.IsShare
if err := storageRepository.UpdateById(&storage, id); err != nil {
return err
}
return Success(c, "")
}
func StorageGetEndpoint(c echo.Context) error {
storageId := c.Param("id")
storage, err := storageRepository.FindById(storageId)
if err != nil {
return err
}
structMap := utils.StructToMap(storage)
drivePath := storageService.GetBaseDrivePath()
dirSize, err := utils.DirSize(path.Join(drivePath, storageId))
if err != nil {
structMap["usedSize"] = -1
} else {
structMap["usedSize"] = dirSize
}
return Success(c, structMap)
}
func StorageSharesEndpoint(c echo.Context) error {
storages, err := storageRepository.FindShares()
if err != nil {
return err
}
return Success(c, storages)
}
func StorageDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
split := strings.Split(ids, ",")
for i := range split {
id := split[i]
if err := storageService.DeleteStorageById(id, false); err != nil {
return err
}
}
return Success(c, nil)
}
func PermissionCheck(c echo.Context, id string) error {
storage, err := storageRepository.FindById(id)
if err != nil {
return err
}
account, _ := GetCurrentAccount(c)
if account.Type != constant.TypeAdmin {
if storage.Owner != account.ID {
return errors.New("您没有权限访问此地址 :(")
}
}
return nil
}
func StorageLsEndpoint(c echo.Context) error {
storageId := c.Param("storageId")
if err := PermissionCheck(c, storageId); err != nil {
return err
}
remoteDir := c.FormValue("dir")
return StorageLs(c, remoteDir, storageId)
}
func StorageLs(c echo.Context, remoteDir, storageId string) error {
drivePath := storageService.GetBaseDrivePath()
if strings.Contains(remoteDir, "../") {
return Fail(c, -1, "非法请求 :(")
}
files, err := storageService.Ls(path.Join(drivePath, storageId), remoteDir)
if err != nil {
return err
}
return Success(c, files)
}
func StorageDownloadEndpoint(c echo.Context) error {
storageId := c.Param("storageId")
if err := PermissionCheck(c, storageId); err != nil {
return err
}
remoteFile := c.QueryParam("file")
return StorageDownload(c, remoteFile, storageId)
}
func StorageDownload(c echo.Context, remoteFile, storageId string) error {
drivePath := storageService.GetBaseDrivePath()
if strings.Contains(remoteFile, "../") {
return Fail(c, -1, "非法请求 :(")
}
// 获取带后缀的文件名称
filenameWithSuffix := path.Base(remoteFile)
return c.Attachment(path.Join(path.Join(drivePath, storageId), remoteFile), filenameWithSuffix)
}
func StorageUploadEndpoint(c echo.Context) error {
storageId := c.Param("storageId")
if err := PermissionCheck(c, storageId); err != nil {
return err
}
file, err := c.FormFile("file")
if err != nil {
return err
}
return StorageUpload(c, file, storageId)
}
func StorageUpload(c echo.Context, file *multipart.FileHeader, storageId string) error {
drivePath := storageService.GetBaseDrivePath()
storage, _ := storageRepository.FindById(storageId)
if storage.LimitSize > 0 {
dirSize, err := utils.DirSize(path.Join(drivePath, storageId))
if err != nil {
return err
}
if dirSize+file.Size > storage.LimitSize {
return errors.New("可用空间不足")
}
}
filename := file.Filename
src, err := file.Open()
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
remoteFile := path.Join(remoteDir, filename)
if strings.Contains(remoteDir, "../") {
return Fail(c, -1, "非法请求 :(")
}
if strings.Contains(remoteFile, "../") {
return Fail(c, -1, "非法请求 :(")
}
// 判断文件夹不存在时自动创建
dir := path.Join(path.Join(drivePath, storageId), remoteDir)
if !utils.FileExists(dir) {
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}
}
// Destination
dst, err := os.Create(path.Join(path.Join(drivePath, storageId), remoteFile))
if err != nil {
return err
}
defer dst.Close()
// Copy
if _, err = io.Copy(dst, src); err != nil {
return err
}
return Success(c, nil)
}
func StorageMkDirEndpoint(c echo.Context) error {
storageId := c.Param("storageId")
if err := PermissionCheck(c, storageId); err != nil {
return err
}
remoteDir := c.QueryParam("dir")
return StorageMkDir(c, remoteDir, storageId)
}
func StorageMkDir(c echo.Context, remoteDir, storageId string) error {
drivePath := storageService.GetBaseDrivePath()
if strings.Contains(remoteDir, "../") {
return Fail(c, -1, ":) 非法请求")
}
if err := os.MkdirAll(path.Join(path.Join(drivePath, storageId), remoteDir), os.ModePerm); err != nil {
return err
}
return Success(c, nil)
}
func StorageRmEndpoint(c echo.Context) error {
storageId := c.Param("storageId")
if err := PermissionCheck(c, storageId); err != nil {
return err
}
// 文件夹或者文件
file := c.FormValue("file")
return StorageRm(c, file, storageId)
}
func StorageRm(c echo.Context, file, storageId string) error {
drivePath := storageService.GetBaseDrivePath()
if strings.Contains(file, "../") {
return Fail(c, -1, ":) 非法请求")
}
if err := os.RemoveAll(path.Join(path.Join(drivePath, storageId), file)); err != nil {
return err
}
return Success(c, nil)
}
func StorageRenameEndpoint(c echo.Context) error {
storageId := c.Param("storageId")
if err := PermissionCheck(c, storageId); err != nil {
return err
}
oldName := c.QueryParam("oldName")
newName := c.QueryParam("newName")
return StorageRename(c, oldName, newName, storageId)
}
func StorageRename(c echo.Context, oldName, newName, storageId string) error {
drivePath := storageService.GetBaseDrivePath()
if strings.Contains(oldName, "../") {
return Fail(c, -1, ":) 非法请求")
}
if strings.Contains(newName, "../") {
return Fail(c, -1, ":) 非法请求")
}
if err := os.Rename(path.Join(path.Join(drivePath, storageId), oldName), path.Join(path.Join(drivePath, storageId), newName)); err != nil {
return err
}
return Success(c, nil)
}
func StorageEditEndpoint(c echo.Context) error {
storageId := c.Param("storageId")
if err := PermissionCheck(c, storageId); err != nil {
return err
}
file := c.Param("file")
fileContent := c.Param("fileContent")
return StorageEdit(c, file, fileContent, storageId)
}
func StorageEdit(c echo.Context, file string, fileContent string, storageId string) error {
drivePath := storageService.GetBaseDrivePath()
if strings.Contains(file, "../") {
return Fail(c, -1, ":) 非法请求")
}
realFilePath := path.Join(path.Join(drivePath, storageId), file)
dstFile, err := os.OpenFile(realFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
return err
}
defer dstFile.Close()
write := bufio.NewWriter(dstFile)
if _, err := write.WriteString(fileContent); err != nil {
return err
}
if err := write.Flush(); err != nil {
return err
}
return Success(c, nil)
}

77
server/api/strategy.go Normal file
View File

@ -0,0 +1,77 @@
package api
import (
"strconv"
"strings"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
func StrategyAllEndpoint(c echo.Context) error {
items, err := strategyRepository.FindAll()
if err != nil {
return err
}
return Success(c, items)
}
func StrategyPagingEndpoint(c echo.Context) error {
pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex"))
pageSize, _ := strconv.Atoi(c.QueryParam("pageSize"))
name := c.QueryParam("name")
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := strategyRepository.Find(pageIndex, pageSize, name, order, field)
if err != nil {
return err
}
return Success(c, H{
"total": total,
"items": items,
})
}
func StrategyCreateEndpoint(c echo.Context) error {
var item model.Strategy
if err := c.Bind(&item); err != nil {
return err
}
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := strategyRepository.Create(&item); err != nil {
return err
}
return Success(c, "")
}
func StrategyDeleteEndpoint(c echo.Context) error {
ids := c.Param("id")
split := strings.Split(ids, ",")
for i := range split {
id := split[i]
if err := strategyRepository.DeleteById(id); err != nil {
return err
}
}
return Success(c, nil)
}
func StrategyUpdateEndpoint(c echo.Context) error {
id := c.Param("id")
var item model.Strategy
if err := c.Bind(&item); err != nil {
return err
}
if err := strategyRepository.UpdateById(&item, id); err != nil {
return err
}
return Success(c, "")
}

66
server/api/test/test.go Normal file
View File

@ -0,0 +1,66 @@
package main
import (
"fmt"
"strings"
"github.com/manifoldco/promptui"
)
type pepper struct {
Name string
HeatUnit int
Peppers int
}
func main() {
peppers := []pepper{
{Name: "Bell Pepper", HeatUnit: 0, Peppers: 0},
{Name: "Banana Pepper", HeatUnit: 100, Peppers: 1},
{Name: "Poblano", HeatUnit: 1000, Peppers: 2},
{Name: "Jalapeño", HeatUnit: 3500, Peppers: 3},
{Name: "Aleppo", HeatUnit: 10000, Peppers: 4},
{Name: "Tabasco", HeatUnit: 30000, Peppers: 5},
{Name: "Malagueta", HeatUnit: 50000, Peppers: 6},
{Name: "Habanero", HeatUnit: 100000, Peppers: 7},
{Name: "Red Savina Habanero", HeatUnit: 350000, Peppers: 8},
{Name: "Dragons Breath", HeatUnit: 855000, Peppers: 9},
}
templates := &promptui.SelectTemplates{
Label: "{{ . }}?",
Active: "\U0001F336 {{ .Name | cyan }} ({{ .HeatUnit | red }})",
Inactive: " {{ .Name | cyan }} ({{ .HeatUnit | red }})",
Selected: "\U0001F336 {{ .Name | red | cyan }}",
Details: `
--------- Pepper ----------
{{ "Name:" | faint }} {{ .Name }}/
{{ "Heat Unit:" | faint }} {{ .HeatUnit }}
{{ "Peppers:" | faint }} {{ .Peppers }}`,
}
searcher := func(input string, index int) bool {
pepper := peppers[index]
name := strings.Replace(strings.ToLower(pepper.Name), " ", "", -1)
input = strings.Replace(strings.ToLower(input), " ", "", -1)
return strings.Contains(name, input)
}
prompt := promptui.Select{
Label: "Spicy Level",
Items: peppers,
Templates: templates,
Size: 4,
Searcher: searcher,
}
i, _, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
return
}
fmt.Printf("You choose number %d: %s\n", i+1, peppers[i].Name)
}

View File

@ -1,26 +1,34 @@
package api
import (
"context"
"encoding/base64"
"errors"
"path"
"strconv"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/pkg/log"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/global/session"
"next-terminal/server/guacd"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
)
const (
TunnelClosed int = -1
Normal int = 0
NotFoundSession int = 800
NewTunnelError int = 801
ForcedDisconnect int = 802
TunnelClosed int = -1
Normal int = 0
NotFoundSession int = 800
NewTunnelError int = 801
ForcedDisconnect int = 802
AccessGatewayUnAvailable int = 803
AccessGatewayCreateError int = 804
AccessGatewayConnectError int = 804
)
func TunEndpoint(c echo.Context) error {
@ -44,111 +52,63 @@ func TunEndpoint(c echo.Context) error {
propertyMap := propertyRepository.FindAllMap()
var session model.Session
var s model.Session
if len(connectionId) > 0 {
session, err = sessionRepository.FindByConnectionId(connectionId)
s, err = sessionRepository.FindByConnectionId(connectionId)
if err != nil {
log.Warnf("会话不存在")
return err
}
if session.Status != constant.Connected {
log.Warnf("会话未在线")
if s.Status != constant.Connected {
return errors.New("会话未在线")
}
configuration.ConnectionID = connectionId
sessionId = session.ID
configuration.SetParameter("width", strconv.Itoa(session.Width))
configuration.SetParameter("height", strconv.Itoa(session.Height))
sessionId = s.ID
configuration.SetParameter("width", strconv.Itoa(s.Width))
configuration.SetParameter("height", strconv.Itoa(s.Height))
configuration.SetParameter("dpi", "96")
} else {
configuration.SetParameter("width", width)
configuration.SetParameter("height", height)
configuration.SetParameter("dpi", dpi)
session, err = sessionRepository.FindByIdAndDecrypt(sessionId)
s, err = sessionRepository.FindByIdAndDecrypt(sessionId)
if err != nil {
CloseSessionById(sessionId, NotFoundSession, "会话不存在")
return err
}
if propertyMap[guacd.EnableRecording] == "true" {
configuration.SetParameter(guacd.RecordingPath, path.Join(propertyMap[guacd.RecordingPath], sessionId))
configuration.SetParameter(guacd.CreateRecordingPath, propertyMap[guacd.CreateRecordingPath])
} else {
configuration.SetParameter(guacd.RecordingPath, "")
}
configuration.Protocol = session.Protocol
switch configuration.Protocol {
case "rdp":
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
configuration.SetParameter("security", "any")
configuration.SetParameter("ignore-cert", "true")
configuration.SetParameter("create-drive-path", "true")
configuration.SetParameter("resize-method", "reconnect")
configuration.SetParameter(guacd.EnableDrive, propertyMap[guacd.EnableDrive])
configuration.SetParameter(guacd.DriveName, propertyMap[guacd.DriveName])
configuration.SetParameter(guacd.DrivePath, propertyMap[guacd.DrivePath])
configuration.SetParameter(guacd.EnableWallpaper, propertyMap[guacd.EnableWallpaper])
configuration.SetParameter(guacd.EnableTheming, propertyMap[guacd.EnableTheming])
configuration.SetParameter(guacd.EnableFontSmoothing, propertyMap[guacd.EnableFontSmoothing])
configuration.SetParameter(guacd.EnableFullWindowDrag, propertyMap[guacd.EnableFullWindowDrag])
configuration.SetParameter(guacd.EnableDesktopComposition, propertyMap[guacd.EnableDesktopComposition])
configuration.SetParameter(guacd.EnableMenuAnimations, propertyMap[guacd.EnableMenuAnimations])
configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching])
configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching])
configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching])
case "ssh":
if len(session.PrivateKey) > 0 && session.PrivateKey != "-" {
configuration.SetParameter("username", session.Username)
configuration.SetParameter("private-key", session.PrivateKey)
configuration.SetParameter("passphrase", session.Passphrase)
} else {
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
setConfig(propertyMap, s, configuration)
var (
ip = s.IP
port = s.Port
)
if s.AccessGatewayId != "" && s.AccessGatewayId != "-" {
g, err := accessGatewayService.GetGatewayAndReconnectById(s.AccessGatewayId)
if err != nil {
disconnect(ws, AccessGatewayUnAvailable, "获取接入网关失败:"+err.Error())
return nil
}
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
case "vnc":
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
case "telnet":
configuration.SetParameter("username", session.Username)
configuration.SetParameter("password", session.Password)
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
case "kubernetes":
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
default:
log.WithField("configuration.Protocol", configuration.Protocol).Error("UnSupport Protocol")
return Fail(c, 400, "不支持的协议")
if !g.Connected {
disconnect(ws, AccessGatewayUnAvailable, "接入网关不可用:"+g.Message)
return nil
}
exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port)
if err != nil {
disconnect(ws, AccessGatewayCreateError, "创建SSH隧道失败"+err.Error())
return nil
}
defer g.CloseSshTunnel(s.ID)
ip = exposedIP
port = exposedPort
}
configuration.SetParameter("hostname", session.IP)
configuration.SetParameter("port", strconv.Itoa(session.Port))
configuration.SetParameter("hostname", ip)
configuration.SetParameter("port", strconv.Itoa(port))
// 加载资产配置的属性,优先级比全局配置的高,因此最后加载,覆盖掉全局配置
attributes, _ := assetRepository.FindAttrById(session.AssetId)
attributes, err := assetRepository.FindAssetAttrMapByAssetId(s.AssetId)
if err != nil {
return err
}
if len(attributes) > 0 {
for i := range attributes {
attribute := attributes[i]
configuration.SetParameter(attribute.Name, attribute.Value)
}
setAssetConfig(attributes, s, configuration)
}
}
for name := range configuration.Parameters {
@ -158,94 +118,235 @@ func TunEndpoint(c echo.Context) error {
}
}
addr := propertyMap[guacd.Host] + ":" + propertyMap[guacd.Port]
addr := config.GlobalCfg.Guacd.Hostname + ":" + strconv.Itoa(config.GlobalCfg.Guacd.Port)
log.Debugf("[%v:%v] 创建guacd隧道[%v]", sessionId, connectionId, addr)
tunnel, err := guacd.NewTunnel(addr, configuration)
guacdTunnel, err := guacd.NewTunnel(addr, configuration)
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, NewTunnelError, err.Error())
disconnect(ws, NewTunnelError, err.Error())
}
log.Printf("建立连接失败: %v", err.Error())
log.Printf("[%v:%v] 建立连接失败: %v", sessionId, connectionId, err.Error())
return err
}
tun := global.Tun{
Protocol: session.Protocol,
Mode: session.Mode,
WebSocket: ws,
Tunnel: tunnel,
nextSession := &session.Session{
ID: sessionId,
Protocol: s.Protocol,
Mode: s.Mode,
WebSocket: ws,
GuacdTunnel: guacdTunnel,
}
if len(session.ConnectionId) == 0 {
var observers []global.Tun
observable := global.Observable{
Subject: &tun,
Observers: observers,
if len(s.ConnectionId) == 0 {
if configuration.Protocol == constant.SSH {
nextTerminal, err := CreateNextTerminalBySession(s)
if err != nil {
return err
}
nextSession.NextTerminal = nextTerminal
}
global.Store.Set(sessionId, &observable)
nextSession.Observer = session.NewObserver(sessionId)
session.GlobalSessionManager.Add <- nextSession
go nextSession.Observer.Run()
sess := model.Session{
ConnectionId: tunnel.UUID,
ConnectionId: guacdTunnel.UUID,
Width: intWidth,
Height: intHeight,
Status: constant.Connecting,
Recording: configuration.GetParameter(guacd.RecordingPath),
}
// 创建新会话
log.Debugf("创建新会话 %v", sess.ConnectionId)
log.Debugf("[%v:%v] 创建新会话: %v", sessionId, connectionId, sess.ConnectionId)
if err := sessionRepository.UpdateById(&sess, sessionId); err != nil {
return err
}
} else {
// 监控会话
observable, ok := global.Store.Get(sessionId)
if ok {
observers := append(observable.Observers, tun)
observable.Observers = observers
global.Store.Set(sessionId, observable)
log.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers))
// 监控会话
forObsSession := session.GlobalSessionManager.GetById(sessionId)
if forObsSession == nil {
disconnect(ws, NotFoundSession, "获取会话失败")
return nil
}
nextSession.ID = utils.UUID()
forObsSession.Observer.Add <- nextSession
log.Debugf("[%v:%v] 观察者[%v]加入会话[%v]", sessionId, connectionId, nextSession.ID, s.ConnectionId)
}
ctx, cancel := context.WithCancel(context.Background())
tick := time.NewTicker(time.Millisecond * time.Duration(60))
defer tick.Stop()
var buf []byte
dataChan := make(chan []byte)
go func() {
GuacdLoop:
for {
instruction, err := tunnel.Read()
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, TunnelClosed, "远程连接关闭")
select {
case <-ctx.Done():
log.Debugf("[%v:%v] WebSocket 已关闭,即将关闭 Guacd 连接...", sessionId, connectionId)
break GuacdLoop
default:
instruction, err := guacdTunnel.Read()
if err != nil {
log.Debugf("[%v:%v] Guacd 读取失败,即将退出循环...", sessionId, connectionId)
disconnect(ws, TunnelClosed, "远程连接已关闭")
break GuacdLoop
}
break
}
if len(instruction) == 0 {
continue
}
err = ws.WriteMessage(websocket.TextMessage, instruction)
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, Normal, "正常退出")
if len(instruction) == 0 {
continue
}
break
dataChan <- instruction
}
}
log.Debugf("[%v:%v] Guacd 连接已关闭,退出 Guacd 循环。", sessionId, connectionId)
}()
go func() {
tickLoop:
for {
select {
case <-ctx.Done():
break tickLoop
case <-tick.C:
if len(buf) > 0 {
err = ws.WriteMessage(websocket.TextMessage, buf)
if err != nil {
log.Debugf("[%v:%v] WebSocket写入失败即将关闭Guacd连接...", sessionId, connectionId)
break tickLoop
}
buf = []byte{}
}
case data := <-dataChan:
buf = append(buf, data...)
}
}
log.Debugf("[%v:%v] Guacd连接已关闭退出定时器循环。", sessionId, connectionId)
}()
for {
_, message, err := ws.ReadMessage()
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, Normal, "正常退出")
log.Debugf("[%v:%v] WebSocket已关闭", sessionId, connectionId)
// guacdTunnel.Read() 会阻塞所以要先把guacdTunnel客户端关闭才能退出Guacd循环
_ = guacdTunnel.Close()
if connectionId != "" {
observerId := nextSession.ID
forObsSession := session.GlobalSessionManager.GetById(sessionId)
if forObsSession != nil {
// 移除会话中保存的观察者信息
forObsSession.Observer.Del <- observerId
log.Debugf("[%v:%v] 观察者[%v]退出会话", sessionId, connectionId, observerId)
}
} else {
CloseSessionById(sessionId, Normal, "用户正常退出")
}
cancel()
break
}
_, err = tunnel.WriteAndFlush(message)
_, err = guacdTunnel.WriteAndFlush(message)
if err != nil {
if connectionId == "" {
CloseSessionById(sessionId, TunnelClosed, "远程连接关闭")
}
break
CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭")
}
}
return err
return nil
}
func setAssetConfig(attributes map[string]string, s model.Session, configuration *guacd.Configuration) {
for key, value := range attributes {
if guacd.DrivePath == key {
// 忽略该参数
continue
}
if guacd.EnableDrive == key && value == "true" {
storageId := attributes[guacd.DrivePath]
if storageId == "" || storageId == "-" {
// 默认空间ID和用户ID相同
storageId = s.Creator
}
realPath := path.Join(storageService.GetBaseDrivePath(), storageId)
configuration.SetParameter(guacd.EnableDrive, "true")
configuration.SetParameter(guacd.DriveName, "Next Terminal Filesystem")
configuration.SetParameter(guacd.DrivePath, realPath)
log.Debugf("[%v] 会话 %v:%v 映射目录地址为 %v", s.ID, s.IP, s.Port, realPath)
} else {
configuration.SetParameter(key, value)
}
}
}
func setConfig(propertyMap map[string]string, s model.Session, configuration *guacd.Configuration) {
if propertyMap[guacd.EnableRecording] == "true" {
configuration.SetParameter(guacd.RecordingPath, path.Join(config.GlobalCfg.Guacd.Recording, s.ID))
configuration.SetParameter(guacd.CreateRecordingPath, propertyMap[guacd.CreateRecordingPath])
} else {
configuration.SetParameter(guacd.RecordingPath, "")
}
configuration.Protocol = s.Protocol
switch configuration.Protocol {
case "rdp":
configuration.SetParameter("username", s.Username)
configuration.SetParameter("password", s.Password)
configuration.SetParameter("security", "any")
configuration.SetParameter("ignore-cert", "true")
configuration.SetParameter("create-drive-path", "true")
configuration.SetParameter("resize-method", "reconnect")
configuration.SetParameter(guacd.EnableWallpaper, propertyMap[guacd.EnableWallpaper])
configuration.SetParameter(guacd.EnableTheming, propertyMap[guacd.EnableTheming])
configuration.SetParameter(guacd.EnableFontSmoothing, propertyMap[guacd.EnableFontSmoothing])
configuration.SetParameter(guacd.EnableFullWindowDrag, propertyMap[guacd.EnableFullWindowDrag])
configuration.SetParameter(guacd.EnableDesktopComposition, propertyMap[guacd.EnableDesktopComposition])
configuration.SetParameter(guacd.EnableMenuAnimations, propertyMap[guacd.EnableMenuAnimations])
configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching])
configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching])
configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching])
case "ssh":
if len(s.PrivateKey) > 0 && s.PrivateKey != "-" {
configuration.SetParameter("username", s.Username)
configuration.SetParameter("private-key", s.PrivateKey)
configuration.SetParameter("passphrase", s.Passphrase)
} else {
configuration.SetParameter("username", s.Username)
configuration.SetParameter("password", s.Password)
}
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
case "vnc":
configuration.SetParameter("username", s.Username)
configuration.SetParameter("password", s.Password)
case "telnet":
configuration.SetParameter("username", s.Username)
configuration.SetParameter("password", s.Password)
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
case "kubernetes":
configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize])
configuration.SetParameter(guacd.FontName, propertyMap[guacd.FontName])
configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme])
configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace])
configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType])
default:
}
}
func disconnect(ws *websocket.Conn, code int, reason string) {
// guacd 无法处理中文字符所以进行了base64编码。
encodeReason := base64.StdEncoding.EncodeToString([]byte(reason))
err := guacd.NewInstruction("error", encodeReason, strconv.Itoa(code))
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.String()))
disconnect := guacd.NewInstruction("disconnect")
_ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String()))
}

View File

@ -1,18 +1,20 @@
package api
import (
"errors"
"strconv"
"strings"
"next-terminal/pkg/global"
"next-terminal/pkg/log"
"next-terminal/server/global/cache"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
func UserCreateEndpoint(c echo.Context) error {
func UserCreateEndpoint(c echo.Context) (err error) {
var item model.User
if err := c.Bind(&item); err != nil {
return err
@ -20,7 +22,6 @@ func UserCreateEndpoint(c echo.Context) error {
password := item.Password
var pass []byte
var err error
if pass, err = utils.Encoder.Encode([]byte(password)); err != nil {
return err
}
@ -32,6 +33,10 @@ func UserCreateEndpoint(c echo.Context) error {
if err := userRepository.Create(&item); err != nil {
return err
}
err = storageService.CreateStorageByUser(&item)
if err != nil {
return err
}
if item.Mail != "" {
go mailService.SendMail(item.Mail, "[Next Terminal] 注册通知", "你好,"+item.Nickname+"。管理员为你注册了账号:"+item.Username+" 密码:"+password)
@ -89,16 +94,22 @@ func UserDeleteEndpoint(c echo.Context) error {
if account.ID == userId {
return Fail(c, -1, "不允许删除自身账户")
}
user, err := userRepository.FindById(userId)
if err != nil {
return err
}
// 将用户强制下线
loginLogs, err := loginLogRepository.FindAliveLoginLogsByUserId(userId)
loginLogs, err := loginLogRepository.FindAliveLoginLogsByUsername(user.Username)
if err != nil {
return err
}
for j := range loginLogs {
global.Cache.Delete(loginLogs[j].ID)
if err := userService.Logout(loginLogs[j].ID); err != nil {
log.WithError(err).WithField("id:", loginLogs[j].ID).Error("Cache Deleted Error")
token := loginLogs[j].ID
cacheKey := userService.BuildCacheKeyByToken(token)
cache.GlobalCache.Delete(cacheKey)
if err := userService.Logout(token); err != nil {
log.WithError(err).WithField("id:", token).Error("Cache Deleted Error")
return Fail(c, 500, "强制下线错误")
}
}
@ -107,6 +118,10 @@ func UserDeleteEndpoint(c echo.Context) error {
if err := userRepository.DeleteById(userId); err != nil {
return err
}
// 删除用户的默认磁盘空间
if err := storageService.DeleteStorageById(userId, true); err != nil {
return err
}
}
return Success(c, nil)
@ -125,7 +140,10 @@ func UserGetEndpoint(c echo.Context) error {
func UserChangePasswordEndpoint(c echo.Context) error {
id := c.Param("id")
password := c.QueryParam("password")
password := c.FormValue("password")
if password == "" {
return Fail(c, -1, "请输入密码")
}
user, err := userRepository.FindById(id)
if err != nil {
@ -172,9 +190,11 @@ func ReloadToken() error {
for i := range loginLogs {
loginLog := loginLogs[i]
token := loginLog.ID
user, err := userRepository.FindById(loginLog.UserId)
user, err := userRepository.FindByUsername(loginLog.Username)
if err != nil {
log.Debugf("用户「%v」获取失败忽略", loginLog.UserId)
if errors.Is(gorm.ErrRecordNotFound, err) {
_ = loginLogRepository.DeleteById(token)
}
continue
}
@ -184,13 +204,13 @@ func ReloadToken() error {
User: user,
}
cacheKey := BuildCacheKeyByToken(token)
cacheKey := userService.BuildCacheKeyByToken(token)
if authorization.Remember {
// 记住登录有效期两周
global.Cache.Set(cacheKey, authorization, RememberEffectiveTime)
cache.GlobalCache.Set(cacheKey, authorization, RememberEffectiveTime)
} else {
global.Cache.Set(cacheKey, authorization, NotRememberEffectiveTime)
cache.GlobalCache.Set(cacheKey, authorization, NotRememberEffectiveTime)
}
log.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token)
}

227
server/config/config.go Normal file
View File

@ -0,0 +1,227 @@
package config
import (
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"path"
"path/filepath"
"strings"
"next-terminal/server/utils"
"github.com/mitchellh/go-homedir"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
var GlobalCfg *Config
type Config struct {
Debug bool
Demo bool
Container bool
DB string
Server *Server
Mysql *Mysql
Sqlite *Sqlite
ResetPassword string
ResetTotp string
EncryptionKey string
EncryptionPassword []byte
NewEncryptionKey string
Guacd *Guacd
Sshd *Sshd
}
type Mysql struct {
Hostname string
Port int
Username string
Password string
Database string
}
type Sqlite struct {
File string
}
type Server struct {
Addr string
Cert string
Key string
}
type Guacd struct {
Hostname string
Port int
Recording string
Drive string
}
type Sshd struct {
Enable bool
Addr string
Key string
}
func SetupConfig() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yml")
viper.AddConfigPath("/etc/next-terminal/")
viper.AddConfigPath("$HOME/.next-terminal")
viper.AddConfigPath(".")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
pflag.String("db", "sqlite", "db mode")
pflag.String("sqlite.file", path.Join("/usr/local/next-terminal/data", "sqlite", "next-terminal.db"), "sqlite db file")
pflag.String("mysql.hostname", "127.0.0.1", "mysql hostname")
pflag.Int("mysql.port", 3306, "mysql port")
pflag.String("mysql.username", "mysql", "mysql username")
pflag.String("mysql.password", "mysql", "mysql password")
pflag.String("mysql.database", "next-terminal", "mysql database")
pflag.String("server.addr", "", "server listen addr")
pflag.String("server.cert", "", "tls cert file")
pflag.String("server.key", "", "tls key file")
pflag.String("reset-password", "", "")
pflag.String("encryption-key", "", "")
pflag.String("new-encryption-key", "", "")
pflag.String("guacd.hostname", "127.0.0.1", "")
pflag.Int("guacd.port", 4822, "")
pflag.String("guacd.recording", "/usr/local/next-terminal/data/recording", "")
pflag.String("guacd.drive", "/usr/local/next-terminal/data/drive", "")
pflag.Bool("sshd.enable", false, "true or false")
pflag.String("sshd.addr", "", "sshd server listen addr")
pflag.String("sshd.key", "~/.ssh/id_rsa", "sshd public key filepath")
pflag.Parse()
if err := viper.BindPFlags(pflag.CommandLine); err != nil {
return nil, err
}
if err := viper.ReadInConfig(); err != nil {
return nil, err
}
sshdKey, err := homedir.Expand(viper.GetString("sshd.key"))
if err != nil {
return nil, err
}
guacdRecording, err := homedir.Expand(viper.GetString("guacd.recording"))
if err != nil {
return nil, err
}
guacdDrive, err := homedir.Expand(viper.GetString("guacd.drive"))
if err != nil {
return nil, err
}
var config = &Config{
DB: viper.GetString("db"),
Mysql: &Mysql{
Hostname: viper.GetString("mysql.hostname"),
Port: viper.GetInt("mysql.port"),
Username: viper.GetString("mysql.username"),
Password: viper.GetString("mysql.password"),
Database: viper.GetString("mysql.database"),
},
Sqlite: &Sqlite{
File: viper.GetString("sqlite.file"),
},
Server: &Server{
Addr: viper.GetString("server.addr"),
Cert: viper.GetString("server.cert"),
Key: viper.GetString("server.key"),
},
ResetPassword: viper.GetString("reset-password"),
ResetTotp: viper.GetString("reset-totp"),
Debug: viper.GetBool("debug"),
Demo: viper.GetBool("demo"),
Container: viper.GetBool("container"),
EncryptionKey: viper.GetString("encryption-key"),
NewEncryptionKey: viper.GetString("new-encryption-key"),
Guacd: &Guacd{
Hostname: viper.GetString("guacd.hostname"),
Port: viper.GetInt("guacd.port"),
Recording: guacdRecording,
Drive: guacdDrive,
},
Sshd: &Sshd{
Enable: viper.GetBool("sshd.enable"),
Addr: viper.GetString("sshd.addr"),
Key: sshdKey,
},
}
if config.EncryptionKey == "" {
config.EncryptionKey = "next-terminal"
}
md5Sum := fmt.Sprintf("%x", md5.Sum([]byte(config.EncryptionKey)))
config.EncryptionPassword = []byte(md5Sum)
// 自动创建数据存放目录
if err := utils.MkdirP(config.Guacd.Recording); err != nil {
panic(fmt.Sprintf("创建文件夹 %v 失败: %v", config.Guacd.Recording, err.Error()))
}
if err := utils.MkdirP(config.Guacd.Drive); err != nil {
panic(fmt.Sprintf("创建文件夹 %v 失败: %v", config.Guacd.Drive, err.Error()))
}
if config.DB == "sqlite" {
sqliteDir := filepath.Dir(config.Sqlite.File)
sqliteDir, err := homedir.Expand(sqliteDir)
if err != nil {
return nil, err
}
if err := utils.MkdirP(sqliteDir); err != nil {
panic(fmt.Sprintf("创建文件夹 %v 失败: %v", sqliteDir, err.Error()))
}
}
if config.Sshd.Enable && !utils.FileExists(sshdKey) {
fmt.Printf("检测到本地RSA私钥文件不存在: %v \n", sshdKey)
sshdKeyDir := filepath.Dir(sshdKey)
if !utils.FileExists(sshdKeyDir) {
if err := utils.MkdirP(sshdKeyDir); err != nil {
panic(fmt.Sprintf("创建文件夹 %v 失败: %v", sshdKeyDir, err.Error()))
}
}
// 自动创建 ID_RSA 密钥
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
//使用X509规范,对公钥私钥进行格式化
x509PrivateKey := x509.MarshalPKCS1PrivateKey(privateKey)
block := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509PrivateKey,
}
privateKeyFile, _ := os.Create(sshdKey)
if err := pem.Encode(privateKeyFile, &block); err != nil {
panic(err)
}
_ = privateKeyFile.Close()
fmt.Printf("自动创建RSA私钥文件成功: %v \n", sshdKey)
}
return config, nil
}
func init() {
var err error
GlobalCfg, err = SetupConfig()
if err != nil {
panic(err)
}
}

65
server/constant/const.go Normal file
View File

@ -0,0 +1,65 @@
package constant
import (
"next-terminal/server/guacd"
)
const (
Version = "v1.2.0"
Banner = `
_______ __ ___________ .__ .__
\ \ ____ ___ ____/ |_ \__ ___/__________ _____ |__| ____ _____ | |
/ | \_/ __ \\ \/ /\ __\ | |_/ __ \_ __ \/ \| |/ \\__ \ | |
/ | \ ___/ > < | | | |\ ___/| | \/ Y Y \ | | \/ __ \| |__
\____|__ /\___ >__/\_ \ |__| |____| \___ >__| |__|_| /__|___| (____ /____/
\/ \/ \/ \/ \/ \/ \/ %s
`
)
const Token = "X-Auth-Token"
const (
SSH = "ssh"
RDP = "rdp"
VNC = "vnc"
Telnet = "telnet"
K8s = "kubernetes"
AccessRuleAllow = "allow" // 允许访问
AccessRuleReject = "reject" // 拒绝访问
Custom = "custom" // 密码
PrivateKey = "private-key" // 密钥
JobStatusRunning = "running" // 计划任务运行状态
JobStatusNotRunning = "not-running" // 计划任务未运行状态
FuncCheckAssetStatusJob = "check-asset-status-job" // 检测资产是否在线
FuncShellJob = "shell-job" // 执行Shell脚本
JobModeAll = "all" // 全部资产
JobModeCustom = "custom" // 自定义选择资产
SshMode = "ssh-mode" // ssh模式
MailHost = "mail-host" // 邮件服务器地址
MailPort = "mail-port" // 邮件服务器端口
MailUsername = "mail-username" // 邮件服务账号
MailPassword = "mail-password" // 邮件服务密码
NoConnect = "no_connect" // 会话状态:未连接
Connecting = "connecting" // 会话状态:连接中
Connected = "connected" // 会话状态:已连接
Disconnected = "disconnected" // 会话状态:已断开连接
Guacd = "guacd" // 接入模式guacd
Naive = "naive" // 接入模式:原生
Terminal = "terminal" // 接入模式:终端
TypeUser = "user" // 普通用户
TypeAdmin = "admin" // 管理员
)
var SSHParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, SshMode}
var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs, guacd.EnableDrive, guacd.DrivePath}
var VNCParameterNames = []string{guacd.ColorDepth, guacd.Cursor, guacd.SwapRedBlue, guacd.DestHost, guacd.DestPort}
var TelnetParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.UsernameRegex, guacd.PasswordRegex, guacd.LoginSuccessRegex, guacd.LoginFailureRegex}
var KubernetesParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.Namespace, guacd.Pod, guacd.Container, guacd.UesSSL, guacd.ClientCert, guacd.ClientKey, guacd.CaCert, guacd.IgnoreCert}

13
server/global/cache/cache.go vendored Normal file
View File

@ -0,0 +1,13 @@
package cache
import (
"time"
"github.com/patrickmn/go-cache"
)
var GlobalCache *cache.Cache
func init() {
GlobalCache = cache.New(5*time.Minute, 10*time.Minute)
}

View File

@ -0,0 +1,16 @@
package cron
import "github.com/robfig/cron/v3"
var GlobalCron *cron.Cron
type Job cron.Job
func init() {
GlobalCron = cron.New(cron.WithSeconds())
GlobalCron.Start()
}
func JobId(jobId int) cron.EntryID {
return cron.EntryID(jobId)
}

View File

@ -0,0 +1,129 @@
package gateway
import (
"context"
"errors"
"fmt"
"net"
"os"
"next-terminal/server/config"
"next-terminal/server/utils"
"golang.org/x/crypto/ssh"
)
// Gateway 接入网关
type Gateway struct {
ID string // 接入网关ID
Connected bool // 是否已连接
LocalHost string // 隧道映射到本地的IP地址
SshClient *ssh.Client
Message string // 失败原因
tunnels map[string]*Tunnel
Add chan *Tunnel
Del chan string
exit chan bool
}
func NewGateway(id, localhost string, connected bool, message string, client *ssh.Client) *Gateway {
return &Gateway{
ID: id,
LocalHost: localhost,
Connected: connected,
Message: message,
SshClient: client,
Add: make(chan *Tunnel),
Del: make(chan string),
tunnels: map[string]*Tunnel{},
exit: make(chan bool, 1),
}
}
func (g *Gateway) Run() {
for {
select {
case t := <-g.Add:
g.tunnels[t.ID] = t
go t.Run()
case k := <-g.Del:
if _, ok := g.tunnels[k]; ok {
g.tunnels[k].Close()
delete(g.tunnels, k)
}
case <-g.exit:
return
}
}
}
func (g *Gateway) Close() {
g.exit <- true
if g.SshClient != nil {
_ = g.SshClient.Close()
}
if len(g.tunnels) > 0 {
for _, tunnel := range g.tunnels {
tunnel.Close()
}
}
}
func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) {
if !g.Connected {
return "", 0, errors.New(g.Message)
}
localPort, err := utils.GetAvailablePort()
if err != nil {
return "", 0, err
}
localHost := g.LocalHost
if localHost == "" {
if config.GlobalCfg.Container {
localIp, err := utils.GetLocalIp()
if err != nil {
hostname, err := os.Hostname()
if err != nil {
return "", 0, err
} else {
localHost = hostname
}
} else {
localHost = localIp
}
} else {
localHost = "localhost"
}
}
localAddr := fmt.Sprintf("%s:%d", localHost, localPort)
listener, err := net.Listen("tcp", localAddr)
if err != nil {
return "", 0, err
}
ctx, cancel := context.WithCancel(context.Background())
tunnel := &Tunnel{
ID: id,
LocalHost: g.LocalHost,
LocalPort: localPort,
Gateway: g,
RemoteHost: ip,
RemotePort: port,
ctx: ctx,
cancel: cancel,
listener: listener,
}
g.Add <- tunnel
return tunnel.LocalHost, tunnel.LocalPort, nil
}
func (g Gateway) CloseSshTunnel(id string) {
if g.tunnels[id] != nil {
g.tunnels[id].Close()
}
}

View File

@ -0,0 +1,42 @@
package gateway
type Manager struct {
gateways map[string]*Gateway
Add chan *Gateway
Del chan string
}
func NewManager() *Manager {
return &Manager{
Add: make(chan *Gateway),
Del: make(chan string),
gateways: map[string]*Gateway{},
}
}
func (m *Manager) Run() {
for {
select {
case g := <-m.Add:
m.gateways[g.ID] = g
go g.Run()
case k := <-m.Del:
if _, ok := m.gateways[k]; ok {
m.gateways[k].Close()
delete(m.gateways, k)
}
}
}
}
func (m Manager) GetById(id string) *Gateway {
return m.gateways[id]
}
var GlobalGatewayManager *Manager
func init() {
GlobalGatewayManager = NewManager()
go GlobalGatewayManager.Run()
}

View File

@ -0,0 +1,63 @@
package gateway
import (
"context"
"fmt"
"io"
"net"
"next-terminal/server/log"
)
type Tunnel struct {
ID string // 唯一标识
LocalHost string // 本地监听地址
LocalPort int // 本地端口
RemoteHost string // 远程连接地址
RemotePort int // 远程端口
Gateway *Gateway
ctx context.Context
cancel context.CancelFunc
listener net.Listener
err error
}
func (r *Tunnel) Run() {
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
log.Debugf("等待客户端访问 [%v] ...", localAddr)
localConn, err := r.listener.Accept()
if err != nil {
r.err = err
return
}
log.Debugf("客户端 [%v] 已连接至 [%v]", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
log.Debugf("连接远程主机 [%v] ...", remoteAddr)
remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr)
if err != nil {
log.Debugf("连接远程主机 [%v] 失败", remoteAddr)
r.err = err
return
}
log.Debugf("连接远程主机 [%v] 成功", remoteAddr)
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
log.Debugf("开始转发数据 [%v]->[%v]", localAddr, remoteAddr)
go func() {
<-r.ctx.Done()
_ = r.listener.Close()
_ = localConn.Close()
_ = remoteConn.Close()
log.Debugf("SSH隧道 [%v]-[%v] 已关闭", localAddr, remoteAddr)
}()
}
func (r Tunnel) Close() {
r.cancel()
}
func copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader)
}

View File

@ -0,0 +1,70 @@
package security
import "sort"
type Security struct {
ID string
Rule string
IP string
Priority int64 // 越小优先级越高
}
type Manager struct {
securities map[string]*Security
values []*Security
Add chan *Security
Del chan string
}
func NewManager() *Manager {
return &Manager{
Add: make(chan *Security),
Del: make(chan string),
securities: map[string]*Security{},
}
}
func (m *Manager) Run() {
for {
select {
case s := <-m.Add:
m.securities[s.ID] = s
m.LoadData()
case s := <-m.Del:
if _, ok := m.securities[s]; ok {
delete(m.securities, s)
m.LoadData()
}
}
}
}
func (m *Manager) Clear() {
m.securities = map[string]*Security{}
}
func (m *Manager) LoadData() {
var values []*Security
for _, security := range m.securities {
values = append(values, security)
}
sort.Slice(values, func(i, j int) bool {
// 优先级数字越小代表优先级越高,因此此处用小于号
return values[i].Priority < values[j].Priority
})
m.values = values
}
func (m Manager) Values() []*Security {
return m.values
}
var GlobalSecurityManager *Manager
func init() {
GlobalSecurityManager = NewManager()
go GlobalSecurityManager.Run()
}

View File

@ -0,0 +1,97 @@
package session
import (
"fmt"
"next-terminal/server/guacd"
"next-terminal/server/term"
"github.com/gorilla/websocket"
)
type Session struct {
ID string
Protocol string
Mode string
WebSocket *websocket.Conn
GuacdTunnel *guacd.Tunnel
NextTerminal *term.NextTerminal
Observer *Manager
}
type Manager struct {
id string
sessions map[string]*Session
Add chan *Session
Del chan string
exit chan bool
}
func NewManager() *Manager {
return &Manager{
Add: make(chan *Session),
Del: make(chan string),
sessions: map[string]*Session{},
exit: make(chan bool, 1),
}
}
func NewObserver(id string) *Manager {
return &Manager{
id: id,
Add: make(chan *Session),
Del: make(chan string),
sessions: map[string]*Session{},
exit: make(chan bool, 1),
}
}
func (m *Manager) Run() {
defer fmt.Printf("Session Manager %v End\n", m.id)
fmt.Printf("Session Manager %v Run\n", m.id)
for {
select {
case s := <-m.Add:
m.sessions[s.ID] = s
case k := <-m.Del:
if _, ok := m.sessions[k]; ok {
ss := m.sessions[k]
if ss.GuacdTunnel != nil {
_ = ss.GuacdTunnel.Close()
}
if ss.NextTerminal != nil {
_ = ss.NextTerminal.Close()
}
if ss.WebSocket != nil {
_ = ss.WebSocket.Close()
}
if ss.Observer != nil {
ss.Observer.Close()
}
delete(m.sessions, k)
}
case <-m.exit:
return
}
}
}
func (m *Manager) Close() {
m.exit <- true
}
func (m Manager) GetById(id string) *Session {
return m.sessions[id]
}
func (m Manager) All() map[string]*Session {
return m.sessions
}
var GlobalSessionManager *Manager
func init() {
GlobalSessionManager = NewManager()
go GlobalSessionManager.Run()
}

304
server/guacd/guacd.go Normal file
View File

@ -0,0 +1,304 @@
package guacd
import (
"bufio"
"errors"
"fmt"
"net"
"strings"
"time"
)
const (
Host = "host"
Port = "port"
EnableRecording = "enable-recording"
RecordingPath = "recording-path"
CreateRecordingPath = "create-recording-path"
FontName = "font-name"
FontSize = "font-size"
ColorScheme = "color-scheme"
Backspace = "backspace"
TerminalType = "terminal-type"
EnableDrive = "enable-drive"
DriveName = "drive-name"
DrivePath = "drive-path"
EnableWallpaper = "enable-wallpaper"
EnableTheming = "enable-theming"
EnableFontSmoothing = "enable-font-smoothing"
EnableFullWindowDrag = "enable-full-window-drag"
EnableDesktopComposition = "enable-desktop-composition"
EnableMenuAnimations = "enable-menu-animations"
DisableBitmapCaching = "disable-bitmap-caching"
DisableOffscreenCaching = "disable-offscreen-caching"
DisableGlyphCaching = "disable-glyph-caching"
Domain = "domain"
RemoteApp = "remote-app"
RemoteAppDir = "remote-app-dir"
RemoteAppArgs = "remote-app-args"
ColorDepth = "color-depth"
Cursor = "cursor"
SwapRedBlue = "swap-red-blue"
DestHost = "dest-host"
DestPort = "dest-port"
UsernameRegex = "username-regex"
PasswordRegex = "password-regex"
LoginSuccessRegex = "login-success-regex"
LoginFailureRegex = "login-failure-regex"
Namespace = "namespace"
Pod = "pod"
Container = "container"
UesSSL = "use-ssl"
ClientCert = "client-cert"
ClientKey = "client-key"
CaCert = "ca-cert"
IgnoreCert = "ignore-cert"
)
const Delimiter = ';'
const Version = "VERSION_1_3_0"
type Configuration struct {
ConnectionID string
Protocol string
Parameters map[string]string
}
func NewConfiguration() (config *Configuration) {
config = &Configuration{}
config.Parameters = make(map[string]string)
return config
}
func (opt *Configuration) SetParameter(name, value string) {
opt.Parameters[name] = value
}
func (opt *Configuration) UnSetParameter(name string) {
delete(opt.Parameters, name)
}
func (opt *Configuration) GetParameter(name string) string {
return opt.Parameters[name]
}
type Instruction struct {
Opcode string
Args []string
ProtocolForm string
}
func NewInstruction(opcode string, args ...string) (ret Instruction) {
ret.Opcode = opcode
ret.Args = args
return ret
}
func (opt *Instruction) String() string {
if len(opt.ProtocolForm) > 0 {
return opt.ProtocolForm
}
opt.ProtocolForm = fmt.Sprintf("%d.%s", len(opt.Opcode), opt.Opcode)
for _, value := range opt.Args {
opt.ProtocolForm += fmt.Sprintf(",%d.%s", len(value), value)
}
opt.ProtocolForm += string(Delimiter)
return opt.ProtocolForm
}
func (opt *Instruction) Parse(content string) Instruction {
if strings.LastIndex(content, ";") > 0 {
content = strings.TrimRight(content, ";")
}
messages := strings.Split(content, ",")
var args = make([]string, len(messages))
for i := range messages {
lm := strings.Split(messages[i], ".")
args[i] = lm[1]
}
return NewInstruction(args[0], args[1:]...)
}
type Tunnel struct {
rw *bufio.ReadWriter
conn net.Conn
UUID string
Config *Configuration
IsOpen bool
}
func NewTunnel(address string, config *Configuration) (ret *Tunnel, err error) {
conn, err := net.DialTimeout("tcp", address, 5*time.Second)
if err != nil {
return
}
ret = &Tunnel{}
ret.conn = conn
ret.rw = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
ret.Config = config
selectArg := config.ConnectionID
if selectArg == "" {
selectArg = config.Protocol
}
if err := ret.WriteInstructionAndFlush(NewInstruction("select", selectArg)); err != nil {
_ = conn.Close()
return nil, err
}
args, err := ret.expect("args")
if err != nil {
_ = conn.Close()
return
}
width := config.GetParameter("width")
height := config.GetParameter("height")
dpi := config.GetParameter("dpi")
// send size
if err := ret.WriteInstructionAndFlush(NewInstruction("size", width, height, dpi)); err != nil {
_ = conn.Close()
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("audio", "audio/L8", "audio/L16")); err != nil {
_ = conn.Close()
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("video")); err != nil {
_ = conn.Close()
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("image", "image/jpeg", "image/png", "image/webp")); err != nil {
_ = conn.Close()
return nil, err
}
if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil {
_ = conn.Close()
return nil, err
}
parameters := make([]string, len(args.Args))
for i := range args.Args {
argName := args.Args[i]
if strings.Contains(argName, "VERSION") {
parameters[i] = Version
continue
}
parameters[i] = config.GetParameter(argName)
}
// send connect
if err := ret.WriteInstructionAndFlush(NewInstruction("connect", parameters...)); err != nil {
_ = conn.Close()
return nil, err
}
ready, err := ret.expect("ready")
if err != nil {
return
}
if len(ready.Args) == 0 {
_ = conn.Close()
return nil, errors.New("no connection id received")
}
ret.UUID = ready.Args[0]
ret.IsOpen = true
return ret, nil
}
func (opt *Tunnel) WriteInstructionAndFlush(instruction Instruction) error {
if _, err := opt.WriteAndFlush([]byte(instruction.String())); err != nil {
return err
}
return nil
}
func (opt *Tunnel) WriteInstruction(instruction Instruction) error {
if _, err := opt.Write([]byte(instruction.String())); err != nil {
return err
}
return nil
}
func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) {
//fmt.Printf("-> %v\n", string(p))
nn, err := opt.rw.Write(p)
if err != nil {
return nn, err
}
err = opt.rw.Flush()
if err != nil {
return nn, err
}
return nn, nil
}
func (opt *Tunnel) Write(p []byte) (int, error) {
//fmt.Printf("-> %v \n", string(p))
nn, err := opt.rw.Write(p)
if err != nil {
return nn, err
}
return nn, nil
}
func (opt *Tunnel) Flush() error {
return opt.rw.Flush()
}
func (opt *Tunnel) ReadInstruction() (instruction Instruction, err error) {
msg, err := opt.rw.ReadString(Delimiter)
//fmt.Printf("<- %v \n", msg)
if err != nil {
return instruction, err
}
return instruction.Parse(msg), err
}
func (opt *Tunnel) Read() (p []byte, err error) {
p, err = opt.rw.ReadBytes(Delimiter)
//fmt.Printf("<- %v \n", string(p))
s := string(p)
if s == "rate=44100,channels=2;" {
return make([]byte, 0), nil
}
if s == "rate=22050,channels=2;" {
return make([]byte, 0), nil
}
if s == "5.audio,1.1,31.audio/L16;" {
s += "rate=44100,channels=2;"
}
return []byte(s), err
}
func (opt *Tunnel) expect(opcode string) (instruction Instruction, err error) {
instruction, err = opt.ReadInstruction()
if err != nil {
return instruction, err
}
if opcode != instruction.Opcode {
msg := fmt.Sprintf(`expected "%s" instruction but instead received "%s"`, opcode, instruction.Opcode)
return instruction, errors.New(msg)
}
return instruction, nil
}
func (opt *Tunnel) Close() error {
opt.IsOpen = false
_ = opt.rw.Flush()
return opt.conn.Close()
}

266
server/log/logger.go Normal file
View File

@ -0,0 +1,266 @@
package log
import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"next-terminal/server/config"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
)
type Formatter struct{}
func (s *Formatter) Format(entry *logrus.Entry) ([]byte, error) {
timestamp := time.Now().Local().Format("2006-01-02 15:04:05")
var file string
var l int
if entry.HasCaller() {
file = filepath.Base(entry.Caller.Function)
l = entry.Caller.Line
}
msg := fmt.Sprintf("%s %s [%s:%d]%s\n", timestamp, strings.ToUpper(entry.Level.String()), file, l, entry.Message)
return []byte(msg), nil
}
var stdOut = NewLogger()
// Trace logs a message at level Trace on the standard logger.
func Trace(args ...interface{}) {
stdOut.Trace(args...)
}
// Debug logs a message at level Debug on the standard logger.
func Debug(args ...interface{}) {
stdOut.Debug(args...)
}
// Print logs a message at level Info on the standard logger.
func Print(args ...interface{}) {
stdOut.Print(args...)
}
// Info logs a message at level Info on the standard logger.
func Info(args ...interface{}) {
stdOut.Info(args...)
}
// Warn logs a message at level Warn on the standard logger.
func Warn(args ...interface{}) {
stdOut.Warn(args...)
}
// Warning logs a message at level Warn on the standard logger.
func Warning(args ...interface{}) {
stdOut.Warning(args...)
}
// Error logs a message at level Error on the standard logger.
func Error(args ...interface{}) {
stdOut.Error(args...)
}
// Panic logs a message at level Panic on the standard logger.
func Panic(args ...interface{}) {
stdOut.Panic(args...)
}
// Fatal logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatal(args ...interface{}) {
stdOut.Fatal(args...)
}
// Tracef logs a message at level Trace on the standard logger.
func Tracef(format string, args ...interface{}) {
stdOut.Tracef(format, args...)
}
// Debugf logs a message at level Debug on the standard logger.
func Debugf(format string, args ...interface{}) {
stdOut.Debugf(format, args...)
}
// Printf logs a message at level Info on the standard logger.
func Printf(format string, args ...interface{}) {
stdOut.Printf(format, args...)
}
// Infof logs a message at level Info on the standard logger.
func Infof(format string, args ...interface{}) {
stdOut.Infof(format, args...)
}
// Warnf logs a message at level Warn on the standard logger.
func Warnf(format string, args ...interface{}) {
stdOut.Warnf(format, args...)
}
// Warningf logs a message at level Warn on the standard logger.
func Warningf(format string, args ...interface{}) {
stdOut.Warningf(format, args...)
}
// Errorf logs a message at level Error on the standard logger.
func Errorf(format string, args ...interface{}) {
stdOut.Errorf(format, args...)
}
// Panicf logs a message at level Panic on the standard logger.
func Panicf(format string, args ...interface{}) {
stdOut.Panicf(format, args...)
}
// Fatalf logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatalf(format string, args ...interface{}) {
stdOut.Fatalf(format, args...)
}
// Traceln logs a message at level Trace on the standard logger.
func Traceln(args ...interface{}) {
stdOut.Traceln(args...)
}
// Debugln logs a message at level Debug on the standard logger.
func Debugln(args ...interface{}) {
stdOut.Debugln(args...)
}
// Println logs a message at level Info on the standard logger.
func Println(args ...interface{}) {
stdOut.Println(args...)
}
// Infoln logs a message at level Info on the standard logger.
func Infoln(args ...interface{}) {
stdOut.Infoln(args...)
}
// Warnln logs a message at level Warn on the standard logger.
func Warnln(args ...interface{}) {
stdOut.Warnln(args...)
}
// Warningln logs a message at level Warn on the standard logger.
func Warningln(args ...interface{}) {
stdOut.Warningln(args...)
}
// Errorln logs a message at level Error on the standard logger.
func Errorln(args ...interface{}) {
stdOut.Errorln(args...)
}
// Panicln logs a message at level Panic on the standard logger.
func Panicln(args ...interface{}) {
stdOut.Panicln(args...)
}
// Fatalln logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatalln(args ...interface{}) {
stdOut.Fatalln(args...)
}
// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key.
func WithError(err error) *logrus.Entry {
return stdOut.WithField(logrus.ErrorKey, err)
}
// WithField creates an entry from the standard logger and adds a field to
// it. If you want multiple fields, use `WithFields`.
//
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
// or Panic on the Entry it returns.
func WithField(key string, value interface{}) *logrus.Entry {
return stdOut.WithField(key, value)
}
// Logrus : implement log
type Logrus struct {
*logrus.Logger
}
// GetEchoLogger for e.l
func NewLogger() Logrus {
logFilePath := ""
if dir, err := os.Getwd(); err == nil {
logFilePath = dir + "/logs/"
}
if err := os.MkdirAll(logFilePath, 0755); err != nil {
fmt.Println(err.Error())
}
logFileName := "next-terminal.log"
//日志文件
fileName := path.Join(logFilePath, logFileName)
if _, err := os.Stat(fileName); err != nil {
if _, err := os.Create(fileName); err != nil {
fmt.Println(err.Error())
}
}
//实例化
logger := logrus.New()
//设置输出
logger.SetOutput(io.MultiWriter(&lumberjack.Logger{
Filename: fileName,
MaxSize: 100, // megabytes
MaxBackups: 3,
MaxAge: 7, //days
Compress: true, // disabled by default
}, os.Stdout))
logger.SetReportCaller(true)
//设置日志级别
if config.GlobalCfg.Debug {
logger.SetLevel(logrus.DebugLevel)
} else {
logger.SetLevel(logrus.InfoLevel)
}
//设置日志格式
logger.SetFormatter(new(Formatter))
return Logrus{Logger: logger}
}
func logrusMiddlewareHandler(c echo.Context, next echo.HandlerFunc) error {
l := NewLogger()
req := c.Request()
res := c.Response()
start := time.Now()
if err := next(c); err != nil {
c.Error(err)
}
stop := time.Now()
l.Debugf("%s %s %s %s %s %3d %s %13v %s %s",
c.RealIP(),
req.Host,
req.Method,
req.RequestURI,
req.URL.Path,
res.Status,
strconv.FormatInt(res.Size, 10),
stop.Sub(start).String(),
req.Referer(),
req.UserAgent(),
)
return nil
}
func logger(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
return logrusMiddlewareHandler(c, next)
}
}
// Hook is a function to process log.
func Hook() echo.MiddlewareFunc {
return logger
}

View File

@ -0,0 +1,34 @@
package model
import "next-terminal/server/utils"
// AccessGateway 接入网关
type AccessGateway struct {
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Name string `gorm:"type:varchar(500)" json:"name"`
IP string `gorm:"type:varchar(500)" json:"ip"`
Port int `gorm:"type:int(5)" json:"port"`
Localhost string `gorm:"type:varchar(200)" json:"localhost"` // 隧道映射到本地的地址
AccountType string `gorm:"type:varchar(50)" json:"accountType"`
Username string `gorm:"type:varchar(200)" json:"username"`
Password string `gorm:"type:varchar(500)" json:"password"`
PrivateKey string `gorm:"type:text" json:"privateKey"`
Passphrase string `gorm:"type:varchar(500)" json:"passphrase"`
Created utils.JsonTime `json:"created"`
}
func (r *AccessGateway) TableName() string {
return "access_gateways"
}
type AccessGatewayForPage struct {
ID string `json:"id"`
Name string `json:"name"`
IP string `json:"ip"`
Port int `json:"port"`
AccountType string `json:"accountType"`
Username string `json:"username"`
Created utils.JsonTime `json:"created"`
Connected bool `json:"connected"`
Message string `json:"message"`
}

View File

@ -1,10 +1,10 @@
package model
type AccessSecurity struct {
ID string `json:"id"`
Rule string `json:"rule"`
IP string `json:"ip"`
Source string `json:"source"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Rule string `gorm:"type:varchar(20)" json:"rule"`
IP string `gorm:"type:varchar(500)" json:"ip"`
Source string `gorm:"type:varchar(500)" json:"source"`
Priority int64 `json:"priority"` // 越小优先级越高
}

View File

@ -4,39 +4,41 @@ import (
"next-terminal/server/utils"
)
type AssetProto string
type Asset struct {
ID string `gorm:"primary_key " json:"id"`
Name string `json:"name"`
Protocol string `json:"protocol"`
IP string `json:"ip"`
Port int `json:"port"`
AccountType string `json:"accountType"`
Username string `json:"username"`
Password string `json:"password"`
CredentialId string `gorm:"index" json:"credentialId"`
PrivateKey string `json:"privateKey"`
Passphrase string `json:"passphrase"`
Description string `json:"description"`
Active bool `json:"active"`
Created utils.JsonTime `json:"created"`
Tags string `json:"tags"`
Owner string `gorm:"index" json:"owner"`
Encrypted bool `json:"encrypted"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Name string `gorm:"type:varchar(500)" json:"name"`
Protocol string `gorm:"type:varchar(20)" json:"protocol"`
IP string `gorm:"type:varchar(200)" json:"ip"`
Port int `json:"port"`
AccountType string `gorm:"type:varchar(20)" json:"accountType"`
Username string `gorm:"type:varchar(200)" json:"username"`
Password string `gorm:"type:varchar(500)" json:"password"`
CredentialId string `gorm:"index,type:varchar(36)" json:"credentialId"`
PrivateKey string `gorm:"type:text" json:"privateKey"`
Passphrase string `gorm:"type:varchar(500)" json:"passphrase"`
Description string `json:"description"`
Active bool `json:"active"`
Created utils.JsonTime `json:"created"`
Tags string `json:"tags"`
Owner string `gorm:"index,type:varchar(36)" json:"owner"`
Encrypted bool `json:"encrypted"`
AccessGatewayId string `gorm:"type:varchar(36)" json:"accessGatewayId"`
}
type AssetForPage struct {
ID string `json:"id"`
Name string `json:"name"`
IP string `json:"ip"`
Protocol string `json:"protocol"`
Port int `json:"port"`
Active bool `json:"active"`
Created utils.JsonTime `json:"created"`
Tags string `json:"tags"`
Owner string `json:"owner"`
OwnerName string `json:"ownerName"`
SharerCount int64 `json:"sharerCount"`
SshMode string `json:"sshMode"`
ID string `json:"id"`
Name string `json:"name"`
IP string `json:"ip"`
Protocol string `json:"protocol"`
Port int `json:"port"`
Active bool `json:"active"`
Created utils.JsonTime `json:"created"`
Tags string `json:"tags"`
Owner string `json:"owner"`
OwnerName string `json:"ownerName"`
SshMode string `json:"sshMode"`
}
func (r *Asset) TableName() string {

View File

@ -5,11 +5,11 @@ import (
)
type Command struct {
ID string `gorm:"primary_key" json:"id"`
Name string `json:"name"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Name string `gorm:"type:varchar(500)" json:"name"`
Content string `json:"content"`
Created utils.JsonTime `json:"created"`
Owner string `gorm:"index" json:"owner"`
Owner string `gorm:"index,type:varchar(36)" json:"owner"`
}
type CommandForPage struct {

View File

@ -5,15 +5,15 @@ import (
)
type Credential struct {
ID string `gorm:"primary_key" json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Username string `json:"username"`
Password string `json:"password"`
PrivateKey string `json:"privateKey"`
Passphrase string `json:"passphrase"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Name string `gorm:"type:varchar(500)" json:"name"`
Type string `gorm:"type:varchar(50)" json:"type"`
Username string `gorm:"type:varchar(200)" json:"username"`
Password string `gorm:"type:varchar(500)" json:"password"`
PrivateKey string `gorm:"type:text" json:"privateKey"`
Passphrase string `gorm:"type:varchar(500)" json:"passphrase"`
Created utils.JsonTime `json:"created"`
Owner string `gorm:"index" json:"owner"`
Owner string `gorm:"index,type:varchar(36)" json:"owner"`
Encrypted bool `json:"encrypted"`
}

View File

@ -5,14 +5,14 @@ import (
)
type Job struct {
ID string `gorm:"primary_key" json:"id"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
CronJobId int `json:"cronJobId"`
Name string `json:"name"`
Func string `json:"func"`
Cron string `json:"cron"`
Mode string `json:"mode"`
Name string `gorm:"type:varchar(500)" json:"name"`
Func string `gorm:"type:varchar(200)" json:"func"`
Cron string `gorm:"type:varchar(100)" json:"cron"`
Mode string `gorm:"type:varchar(50)" json:"mode"`
ResourceIds string `json:"resourceIds"`
Status string `json:"status"`
Status string `gorm:"type:varchar(20)" json:"status"`
Metadata string `json:"metadata"`
Created utils.JsonTime `json:"created"`
Updated utils.JsonTime `json:"updated"`

View File

@ -5,24 +5,15 @@ import (
)
type LoginLog struct {
ID string `gorm:"primary_key" json:"id"`
UserId string `gorm:"index" json:"userId"`
ClientIP string `json:"clientIp"`
ClientUserAgent string `json:"clientUserAgent"`
LoginTime utils.JsonTime `json:"loginTime"`
LogoutTime utils.JsonTime `json:"logoutTime"`
Remember bool `json:"remember"`
}
type LoginLogForPage struct {
ID string `json:"id"`
UserId string `json:"userId"`
UserName string `json:"userName"`
ClientIP string `json:"clientIp"`
ClientUserAgent string `json:"clientUserAgent"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Username string `gorm:"index,type:varchar(200)" json:"username"`
ClientIP string `gorm:"type:varchar(200)" json:"clientIp"`
ClientUserAgent string `gorm:"type:varchar(500)" json:"clientUserAgent"`
LoginTime utils.JsonTime `json:"loginTime"`
LogoutTime utils.JsonTime `json:"logoutTime"`
Remember bool `json:"remember"`
State string `gorm:"type:varchar(1)" json:"state"` // 成功 1 失败 0
Reason string `gorm:"type:varchar(500)" json:"reason"`
}
func (r *LoginLog) TableName() string {

View File

@ -1,9 +0,0 @@
package model
type Num struct {
I string `gorm:"primary_key" json:"i"`
}
func (r *Num) TableName() string {
return "nums"
}

View File

@ -1,11 +1,12 @@
package model
type ResourceSharer struct {
ID string `gorm:"primary_key" json:"id"`
ResourceId string `gorm:"index" json:"resourceId"`
ResourceType string `gorm:"index" json:"resourceType"`
UserId string `gorm:"index" json:"userId"`
UserGroupId string `gorm:"index" json:"userGroupId"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
ResourceId string `gorm:"index,type:varchar(36)" json:"resourceId"`
ResourceType string `gorm:"index,type:varchar(36)" json:"resourceType"`
StrategyId string `gorm:"index,type:varchar(36)" json:"strategyId"`
UserId string `gorm:"index,type:varchar(36)" json:"userId"`
UserGroupId string `gorm:"index,type:varchar(36)" json:"userGroupId"`
}
func (r *ResourceSharer) TableName() string {

View File

@ -5,27 +5,35 @@ import (
)
type Session struct {
ID string `gorm:"primary_key" json:"id"`
Protocol string `json:"protocol"`
IP string `json:"ip"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Protocol string `gorm:"type:varchar(20)" json:"protocol"`
IP string `gorm:"type:varchar(200)" json:"ip"`
Port int `json:"port"`
ConnectionId string `json:"connectionId"`
AssetId string `gorm:"index" json:"assetId"`
Username string `json:"username"`
Password string `json:"password"`
Creator string `gorm:"index" json:"creator"`
ClientIP string `json:"clientIp"`
ConnectionId string `gorm:"type:varchar(50)" json:"connectionId"`
AssetId string `gorm:"index,type:varchar(36)" json:"assetId"`
Username string `gorm:"type:varchar(200)" json:"username"`
Password string `gorm:"type:varchar(500)" json:"password"`
Creator string `gorm:"index,type:varchar(36)" json:"creator"`
ClientIP string `gorm:"type:varchar(200)" json:"clientIp"`
Width int `json:"width"`
Height int `json:"height"`
Status string `gorm:"index" json:"status"`
Recording string `json:"recording"`
PrivateKey string `json:"privateKey"`
Passphrase string `json:"passphrase"`
Status string `gorm:"index,type:varchar(20)" json:"status"`
Recording string `gorm:"type:varchar(1000)" json:"recording"`
PrivateKey string `gorm:"type:text" json:"privateKey"`
Passphrase string `gorm:"type:varchar(500)" json:"passphrase"`
Code int `json:"code"`
Message string `json:"message"`
ConnectedTime utils.JsonTime `json:"connectedTime"`
DisconnectedTime utils.JsonTime `json:"disconnectedTime"`
Mode string `json:"mode"`
Mode string `gorm:"type:varchar(10)" json:"mode"`
Upload string `gorm:"type:varchar(1)" json:"upload"` // 1 = true, 0 = false
Download string `gorm:"type:varchar(1)" json:"download"`
Delete string `gorm:"type:varchar(1)" json:"delete"`
Rename string `gorm:"type:varchar(1)" json:"rename"`
Edit string `gorm:"type:varchar(1)" json:"edit"`
CreateDir string `gorm:"type:varchar(1)" json:"createDir"`
StorageId string `gorm:"type:varchar(36)" json:"storageId"`
AccessGatewayId string `gorm:"type:varchar(36)" json:"accessGatewayId"`
}
func (r *Session) TableName() string {
@ -54,3 +62,12 @@ type SessionForPage struct {
Message string `json:"message"`
Mode string `json:"mode"`
}
type SessionForAccess struct {
AssetId string `json:"assetId"`
Protocol string `json:"protocol"`
IP string `json:"ip"`
Port int `json:"port"`
Username string `json:"username"`
AccessCount int64 `json:"accessCount"`
}

29
server/model/storage.go Normal file
View File

@ -0,0 +1,29 @@
package model
import "next-terminal/server/utils"
type Storage struct {
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Name string `gorm:"type:varchar(500)" json:"name"`
IsShare bool `json:"isShare"` // 是否共享
LimitSize int64 `json:"limitSize"` // 大小限制,单位字节
IsDefault bool `json:"isDefault"` // 是否为用户默认的
Owner string `gorm:"index,type:varchar(36)" json:"owner"`
Created utils.JsonTime `json:"created"`
}
func (r *Storage) TableName() string {
return "storages"
}
type StorageForPage struct {
ID string `gorm:"primary_key " json:"id"`
Name string `json:"name"`
IsShare bool `json:"isShare"` // 是否共享
LimitSize int64 `json:"limitSize"` // 大小限制,单位字节
UsedSize int64 `json:"usedSize"`
IsDefault bool `json:"isDefault"` // 是否为用户默认的
Owner string `gorm:"index" json:"owner"`
OwnerName string `json:"ownerName"`
Created utils.JsonTime `json:"created"`
}

19
server/model/strategy.go Normal file
View File

@ -0,0 +1,19 @@
package model
import "next-terminal/server/utils"
type Strategy struct {
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Name string `gorm:"type:varchar(500)" json:"name"`
Upload string `gorm:"type:varchar(1)" json:"upload"` // 1 = true, 0 = false
Download string `gorm:"type:varchar(1)" json:"download"`
Delete string `gorm:"type:varchar(1)" json:"delete"`
Rename string `gorm:"type:varchar(1)" json:"rename"`
Edit string `gorm:"type:varchar(1)" json:"edit"`
CreateDir string `gorm:"type:varchar(1)" json:"createDir"`
Created utils.JsonTime `json:"created"`
}
func (r *Strategy) TableName() string {
return "strategies"
}

View File

@ -5,16 +5,16 @@ import (
)
type User struct {
ID string `gorm:"primary_key" json:"id"`
Username string `gorm:"index" json:"username"`
Password string `json:"password"`
Nickname string `json:"nickname"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Username string `gorm:"index,type:varchar(200)" json:"username"`
Password string `gorm:"type:varchar(500)" json:"password"`
Nickname string `gorm:"type:varchar(500)" json:"nickname"`
TOTPSecret string `json:"-"`
Online bool `json:"online"`
Enabled bool `json:"enabled"`
Created utils.JsonTime `json:"created"`
Type string `json:"type"`
Mail string `json:"mail"`
Type string `gorm:"type:varchar(20)" json:"type"`
Mail string `gorm:"type:varchar(500)" json:"mail"`
}
type UserForPage struct {

View File

@ -5,8 +5,8 @@ import (
)
type UserGroup struct {
ID string `gorm:"primary_key" json:"id"`
Name string `json:"name"`
ID string `gorm:"primary_key,type:varchar(36)" json:"id"`
Name string `gorm:"type:varchar(500)" json:"name"`
Created utils.JsonTime `json:"created"`
}

View File

@ -0,0 +1,85 @@
package repository
import (
"next-terminal/server/model"
"gorm.io/gorm"
)
type AccessGatewayRepository struct {
DB *gorm.DB
}
func NewAccessGatewayRepository(db *gorm.DB) *AccessGatewayRepository {
accessGatewayRepository = &AccessGatewayRepository{DB: db}
return accessGatewayRepository
}
func (r AccessGatewayRepository) Find(pageIndex, pageSize int, ip, name, order, field string) (o []model.AccessGatewayForPage, total int64, err error) {
t := model.AccessGateway{}
db := r.DB.Table(t.TableName())
dbCounter := r.DB.Table(t.TableName())
if len(ip) > 0 {
db = db.Where("ip like ?", "%"+ip+"%")
dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%")
}
if len(name) > 0 {
db = db.Where("name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("name like ?", "%"+name+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "descend" {
order = "desc"
} else {
order = "asc"
}
if field == "ip" {
field = "ip"
} else if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.AccessGatewayForPage, 0)
}
return
}
func (r AccessGatewayRepository) Create(o *model.AccessGateway) error {
return r.DB.Create(o).Error
}
func (r AccessGatewayRepository) UpdateById(o *model.AccessGateway, id string) error {
o.ID = id
return r.DB.Updates(o).Error
}
func (r AccessGatewayRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(model.AccessGateway{}).Error
}
func (r AccessGatewayRepository) FindById(id string) (o model.AccessGateway, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r AccessGatewayRepository) FindAll() (o []model.AccessGateway, err error) {
t := model.AccessGateway{}
db := r.DB.Table(t.TableName())
err = db.Find(&o).Error
if o == nil {
o = make([]model.AccessGateway, 0)
}
return
}

View File

@ -5,8 +5,8 @@ import (
"fmt"
"strings"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
@ -44,11 +44,21 @@ func (r AssetRepository) FindByProtocolAndIds(protocol string, assetIds []string
}
func (r AssetRepository) FindByProtocolAndUser(protocol string, account model.User) (o []model.Asset, err error) {
db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
// 查询用户所在用户组列表
userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(account.ID)
if err != nil {
return nil, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
}
if len(protocol) > 0 {
@ -59,7 +69,7 @@ func (r AssetRepository) FindByProtocolAndUser(protocol string, account model.Us
}
func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags string, account model.User, owner, sharer, userGroupId, ip, order, field string) (o []model.AssetForPage, total int64, err error) {
db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
dbCounter := r.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
if constant.TypeUser == account.Type {
@ -111,7 +121,7 @@ func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags stri
if len(tags) > 0 {
tagArr := strings.Split(tags, ",")
for i := range tagArr {
if global.Config.DB == "sqlite" {
if config.GlobalCfg.DB == "sqlite" {
db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
} else {
@ -189,7 +199,7 @@ func (r AssetRepository) Encrypt(item *model.Asset, password []byte) error {
}
func (r AssetRepository) Create(o *model.Asset) (err error) {
if err := r.Encrypt(o, global.Config.EncryptionPassword); err != nil {
if err := r.Encrypt(o, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err = r.DB.Create(o).Error; err != nil {
@ -245,7 +255,7 @@ func (r AssetRepository) Decrypt(item *model.Asset, password []byte) error {
func (r AssetRepository) FindByIdAndDecrypt(id string) (o model.Asset, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
if err == nil {
err = r.Decrypt(&o, global.Config.EncryptionPassword)
err = r.Decrypt(&o, config.GlobalCfg.EncryptionPassword)
}
return
}
@ -260,8 +270,17 @@ func (r AssetRepository) UpdateActiveById(active bool, id string) error {
return r.DB.Exec(sql, active, id).Error
}
func (r AssetRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(&model.Asset{}).Error
func (r AssetRepository) DeleteById(id string) (err error) {
return r.DB.Transaction(func(tx *gorm.DB) error {
err = tx.Where("id = ?", id).Delete(&model.Asset{}).Error
if err != nil {
return err
}
// 删除资产属性
err = tx.Where("asset_id = ?", id).Delete(&model.AssetAttribute{}).Error
return err
})
}
func (r AssetRepository) Count() (total int64, err error) {
@ -269,6 +288,11 @@ func (r AssetRepository) Count() (total int64, err error) {
return
}
func (r AssetRepository) CountByProtocol(protocol string) (total int64, err error) {
err = r.DB.Find(&model.Asset{}).Where("protocol = ?", protocol).Count(&total).Error
return
}
func (r AssetRepository) CountByUserId(userId string) (total int64, err error) {
db := r.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id")
@ -287,9 +311,27 @@ func (r AssetRepository) CountByUserId(userId string) (total int64, err error) {
return
}
func (r AssetRepository) CountByUserIdAndProtocol(userId, protocol string) (total int64, err error) {
db := r.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id")
db = db.Where("( assets.owner = ? or resource_sharers.user_id = ? ) and assets.protocol = ?", userId, userId, protocol)
// 查询用户所在用户组列表
userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId)
if err != nil {
return 0, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
err = db.Find(&model.Asset{}).Count(&total).Error
return
}
func (r AssetRepository) FindTags() (o []string, err error) {
var assets []model.Asset
err = r.DB.Not("tags = ?", "").Find(&assets).Error
err = r.DB.Not("tags = '' or tags = '-' ").Find(&assets).Error
if err != nil {
return nil, err
}
@ -359,7 +401,7 @@ func (r AssetRepository) FindAttrById(assetId string) (o []model.AssetAttribute,
return o, err
}
func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]interface{}, error) {
func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]string, error) {
asset, err := r.FindById(assetId)
if err != nil {
return nil, err
@ -383,7 +425,7 @@ func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]i
parameterNames = constant.KubernetesParameterNames
}
propertiesMap := propertyRepository.FindAllMap()
var attributeMap = make(map[string]interface{})
var attributeMap = make(map[string]string)
for name := range propertiesMap {
if utils.Contains(parameterNames, name) {
attributeMap[name] = propertiesMap[name]

View File

@ -1,7 +1,7 @@
package repository
import (
"next-terminal/pkg/constant"
"next-terminal/server/constant"
"next-terminal/server/model"
"gorm.io/gorm"
@ -80,3 +80,17 @@ func (r CommandRepository) UpdateById(o *model.Command, id string) error {
func (r CommandRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(&model.Command{}).Error
}
func (r CommandRepository) FindByUser(account model.User) (o []model.CommandForPage, err error) {
db := r.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner)
}
err = db.Order("commands.name asc").Find(&o).Error
if o == nil {
o = make([]model.CommandForPage, 0)
}
return
}

View File

@ -3,8 +3,8 @@ package repository
import (
"encoding/base64"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
@ -69,7 +69,7 @@ func (r CredentialRepository) Find(pageIndex, pageSize int, name, order, field s
}
func (r CredentialRepository) Create(o *model.Credential) (err error) {
if err := r.Encrypt(o, global.Config.EncryptionPassword); err != nil {
if err := r.Encrypt(o, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err = r.DB.Create(o).Error; err != nil {
@ -151,7 +151,7 @@ func (r CredentialRepository) Decrypt(item *model.Credential, password []byte) e
func (r CredentialRepository) FindByIdAndDecrypt(id string) (o model.Credential, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
if err == nil {
err = r.Decrypt(&o, global.Config.EncryptionPassword)
err = r.Decrypt(&o, config.GlobalCfg.EncryptionPassword)
}
return
}

View File

@ -1,5 +1,8 @@
package repository
/**
* 定义了相关模型的持久化层,方便相互之间调用
*/
var (
userRepository *UserRepository
userGroupRepository *UserGroupRepository
@ -9,9 +12,11 @@ var (
propertyRepository *PropertyRepository
commandRepository *CommandRepository
sessionRepository *SessionRepository
numRepository *NumRepository
accessSecurityRepository *AccessSecurityRepository
accessGatewayRepository *AccessGatewayRepository
jobRepository *JobRepository
jobLogRepository *JobLogRepository
loginLogRepository *LoginLogRepository
storageRepository *StorageRepository
strategyRepository *StrategyRepository
)

View File

@ -44,10 +44,10 @@ func (r JobRepository) Find(pageIndex, pageSize int, name, status, order, field
if field == "name" {
field = "name"
} else if field == "created" {
field = "created"
} else {
} else if field == "updated" {
field = "updated"
} else {
field = "created"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
@ -63,6 +63,12 @@ func (r JobRepository) FindByFunc(function string) (o []model.Job, err error) {
return
}
func (r JobRepository) FindAll() (o []model.Job, err error) {
db := r.DB
err = db.Find(&o).Error
return
}
func (r JobRepository) Create(o *model.Job) (err error) {
return r.DB.Create(o).Error
}

View File

@ -1,6 +1,8 @@
package repository
import (
"time"
"next-terminal/server/model"
"gorm.io/gorm"
@ -24,6 +26,20 @@ func (r JobLogRepository) FindByJobId(jobId string) (o []model.JobLog, err error
return
}
func (r JobLogRepository) FindOutTimeLog(dayLimit int) (o []model.JobLog, err error) {
limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour)
err = r.DB.Where("timestamp < ?", limitTime).Find(&o).Error
return
}
func (r JobLogRepository) DeleteByJobId(jobId string) error {
return r.DB.Where("job_id = ?", jobId).Delete(model.JobLog{}).Error
}
func (r JobLogRepository) DeleteByIdIn(ids []string) error {
return r.DB.Where("id in ?", ids).Delete(&model.JobLog{}).Error
}
func (r JobLogRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(&model.JobLog{}).Error
}

View File

@ -1,6 +1,8 @@
package repository
import (
"time"
"next-terminal/server/model"
"gorm.io/gorm"
@ -15,19 +17,24 @@ func NewLoginLogRepository(db *gorm.DB) *LoginLogRepository {
return loginLogRepository
}
func (r LoginLogRepository) Find(pageIndex, pageSize int, userId, clientIp string) (o []model.LoginLogForPage, total int64, err error) {
func (r LoginLogRepository) Find(pageIndex, pageSize int, username, clientIp, state string) (o []model.LoginLog, total int64, err error) {
m := model.LoginLog{}
db := r.DB.Table(m.TableName())
dbCounter := r.DB.Table(m.TableName())
db := r.DB.Table("login_logs").Select("login_logs.id,login_logs.user_id,login_logs.client_ip,login_logs.client_user_agent,login_logs.login_time, login_logs.logout_time, users.nickname as user_name").Joins("left join users on login_logs.user_id = users.id")
dbCounter := r.DB.Table("login_logs").Select("DISTINCT login_logs.id")
if userId != "" {
db = db.Where("login_logs.user_id = ?", userId)
dbCounter = dbCounter.Where("login_logs.user_id = ?", userId)
if username != "" {
db = db.Where("username like ?", "%"+username+"%")
dbCounter = dbCounter.Where("username like ?", "%"+username+"%")
}
if clientIp != "" {
db = db.Where("login_logs.client_ip like ?", "%"+clientIp+"%")
dbCounter = dbCounter.Where("login_logs.client_ip like ?", "%"+clientIp+"%")
db = db.Where("client_ip like ?", "%"+clientIp+"%")
dbCounter = dbCounter.Where("client_ip like ?", "%"+clientIp+"%")
}
if state != "" {
db = db.Where("state = ?", state)
dbCounter = dbCounter.Where("state = ?", state)
}
err = dbCounter.Count(&total).Error
@ -35,20 +42,26 @@ func (r LoginLogRepository) Find(pageIndex, pageSize int, userId, clientIp strin
return nil, 0, err
}
err = db.Order("login_logs.login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
err = db.Order("login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]model.LoginLogForPage, 0)
o = make([]model.LoginLog, 0)
}
return
}
func (r LoginLogRepository) FindAliveLoginLogs() (o []model.LoginLog, err error) {
err = r.DB.Where("logout_time is null").Find(&o).Error
err = r.DB.Where("state = '1' and logout_time is null").Find(&o).Error
return
}
func (r LoginLogRepository) FindAliveLoginLogsByUserId(userId string) (o []model.LoginLog, err error) {
err = r.DB.Where("logout_time is null and user_id = ?", userId).Find(&o).Error
func (r LoginLogRepository) FindAliveLoginLogsByUsername(username string) (o []model.LoginLog, err error) {
err = r.DB.Where("state = '1' and logout_time is null and username = ?", username).Find(&o).Error
return
}
func (r LoginLogRepository) FindOutTimeLog(dayLimit int) (o []model.LoginLog, err error) {
limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour)
err = r.DB.Where("(state = '0' and login_time < ?) or (state = '1' and logout_time < ?) or (state is null and logout_time < ?)", limitTime, limitTime, limitTime).Find(&o).Error
return
}
@ -60,6 +73,10 @@ func (r LoginLogRepository) DeleteByIdIn(ids []string) (err error) {
return r.DB.Where("id in ?", ids).Delete(&model.LoginLog{}).Error
}
func (r LoginLogRepository) DeleteById(id string) (err error) {
return r.DB.Where("id = ?", id).Delete(&model.LoginLog{}).Error
}
func (r LoginLogRepository) FindById(id string) (o model.LoginLog, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return

View File

@ -1,26 +0,0 @@
package repository
import (
"next-terminal/server/model"
"gorm.io/gorm"
)
type NumRepository struct {
DB *gorm.DB
}
func NewNumRepository(db *gorm.DB) *NumRepository {
numRepository = &NumRepository{DB: db}
return numRepository
}
func (r NumRepository) FindAll() (o []model.Num, err error) {
err = r.DB.Find(&o).Error
return
}
func (r NumRepository) Create(o *model.Num) (err error) {
err = r.DB.Create(o).Error
return
}

View File

@ -1,7 +1,6 @@
package repository
import (
"next-terminal/pkg/guacd"
"next-terminal/server/model"
"gorm.io/gorm"
@ -33,6 +32,10 @@ func (r PropertyRepository) UpdateByName(o *model.Property, name string) error {
return r.DB.Updates(o).Error
}
func (r PropertyRepository) DeleteByName(name string) error {
return r.DB.Where("name = ?", name).Delete(model.Property{}).Error
}
func (r PropertyRepository) FindByName(name string) (o model.Property, err error) {
err = r.DB.Where("name = ?", name).First(&o).Error
return
@ -46,19 +49,3 @@ func (r PropertyRepository) FindAllMap() map[string]string {
}
return propertyMap
}
func (r PropertyRepository) GetDrivePath() (string, error) {
property, err := r.FindByName(guacd.DrivePath)
if err != nil {
return "", err
}
return property.Value, nil
}
func (r PropertyRepository) GetRecordingPath() (string, error) {
property, err := r.FindByName(guacd.RecordingPath)
if err != nil {
return "", err
}
return property.Value, nil
}

View File

@ -18,14 +18,6 @@ func NewResourceSharerRepository(db *gorm.DB) *ResourceSharerRepository {
return resourceSharerRepository
}
func (r *ResourceSharerRepository) FindUserIdsByResourceId(resourceId string) (o []string, err error) {
err = r.DB.Table("resource_sharers").Select("user_id").Where("resource_id = ?", resourceId).Find(&o).Error
if o == nil {
o = make([]string, 0)
}
return
}
func (r *ResourceSharerRepository) OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) {
db := r.DB.Begin()
@ -104,7 +96,7 @@ func (r *ResourceSharerRepository) DeleteResourceSharerByResourceId(resourceId s
return r.DB.Where("resource_id = ?", resourceId).Delete(&model.ResourceSharer{}).Error
}
func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, resourceType string, resourceIds []string) error {
func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, strategyId, resourceType string, resourceIds []string) error {
return r.DB.Transaction(func(tx *gorm.DB) (err error) {
for i := range resourceIds {
@ -138,11 +130,13 @@ func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, resou
return echo.NewHTTPError(400, "参数错误")
}
// 保证同一个资产只能分配给一个用户或者组
id := utils.Sign([]string{resourceId, resourceType, userId, userGroupId})
resource := &model.ResourceSharer{
ID: id,
ResourceId: resourceId,
ResourceType: resourceType,
StrategyId: strategyId,
UserId: userId,
UserGroupId: userGroupId,
}
@ -192,3 +186,35 @@ func (r *ResourceSharerRepository) FindAssetIdsByUserId(userId string) (assetIds
return
}
func (r *ResourceSharerRepository) FindByResourceIdAndUserId(assetId, userId string) (resourceSharers []model.ResourceSharer, err error) {
// 查询其他用户授权给该用户的资产
groupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId)
if err != nil {
return
}
db := r.DB.Where("( resource_id = ? and user_id = ? )", assetId, userId)
if len(groupIds) > 0 {
db = db.Or("user_group_id in ?", groupIds)
}
err = db.Find(&resourceSharers).Error
return
}
func (r *ResourceSharerRepository) Find(resourceId, resourceType, userId, userGroupId string) (resourceSharers []model.ResourceSharer, err error) {
db := r.DB
if resourceId != "" {
db = db.Where("resource_id = ?")
}
if resourceType != "" {
db = db.Where("resource_type = ?")
}
if userId != "" {
db = db.Where("user_id = ?")
}
if userGroupId != "" {
db = db.Where("user_group_id = ?")
}
err = db.Find(&resourceSharers).Error
return
}

View File

@ -6,8 +6,8 @@ import (
"path"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
@ -110,7 +110,7 @@ func (r SessionRepository) Decrypt(item *model.Session) error {
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, global.Config.EncryptionPassword)
decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
@ -121,7 +121,7 @@ func (r SessionRepository) Decrypt(item *model.Session) error {
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, global.Config.EncryptionPassword)
decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
@ -132,7 +132,7 @@ func (r SessionRepository) Decrypt(item *model.Session) error {
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, global.Config.EncryptionPassword)
decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
@ -164,12 +164,9 @@ func (r SessionRepository) DeleteById(id string) error {
}
func (r SessionRepository) DeleteByIds(sessionIds []string) error {
drivePath, err := propertyRepository.GetRecordingPath()
if err != nil {
return err
}
recordingPath := config.GlobalCfg.Guacd.Recording
for i := range sessionIds {
if err := os.RemoveAll(path.Join(drivePath, sessionIds[i])); err != nil {
if err := os.RemoveAll(path.Join(recordingPath, sessionIds[i])); err != nil {
return err
}
if err := r.DeleteById(sessionIds[i]); err != nil {
@ -188,35 +185,27 @@ func (r SessionRepository) CountOnlineSession() (total int64, err error) {
return
}
type D struct {
Day string `json:"day"`
Count int `json:"count"`
Protocol string `json:"protocol"`
}
func (r SessionRepository) CountSessionByDay(day int) (results []D, err error) {
today := time.Now().Format("20060102")
sql := "select t1.`day`, count(t2.id) as count\nfrom (\n SELECT @date := DATE_ADD(@date, INTERVAL - 1 DAY) day\n FROM (SELECT @date := DATE_ADD('" + today + "', INTERVAL + 1 DAY) FROM nums) as t0\n LIMIT ?\n )\n as t1\n left join\n (\n select DATE(s.connected_time) as day, s.id\n from sessions as s\n WHERE protocol = ? and DATE(connected_time) <= '" + today + "'\n AND DATE(connected_time) > DATE_SUB('" + today + "', INTERVAL ? DAY)\n ) as t2 on t1.day = t2.day\ngroup by t1.day"
protocols := []string{"rdp", "ssh", "vnc", "telnet"}
for i := range protocols {
var result []D
err = r.DB.Raw(sql, day, protocols[i], day).Scan(&result).Error
if err != nil {
return nil, err
}
for j := range result {
result[j].Protocol = protocols[i]
}
results = append(results, result...)
}
return
}
func (r SessionRepository) EmptyPassword() error {
sql := "update sessions set password = '-',private_key = '-', passphrase = '-' where 1=1"
return r.DB.Exec(sql).Error
}
func (r SessionRepository) CountByStatus(status string) (total int64, err error) {
err = r.DB.Find(&model.Session{}).Where("status = ?", status).Count(&total).Error
return
}
func (r SessionRepository) OverviewAccess(account model.User) (o []model.SessionForAccess, err error) {
db := r.DB
if constant.TypeUser == account.Type {
sql := "SELECT s.asset_id, s.ip, s.port, s.protocol, s.username, count(s.asset_id) AS access_count FROM sessions AS s where s.creator = ? GROUP BY s.asset_id, s.ip, s.port, s.protocol, s.username ORDER BY access_count DESC limit 10"
err = db.Raw(sql, []string{account.ID}).Scan(&o).Error
} else {
sql := "SELECT s.asset_id, s.ip, s.port, s.protocol, s.username, count(s.asset_id) AS access_count FROM sessions AS s GROUP BY s.asset_id, s.ip, s.port, s.protocol, s.username ORDER BY access_count DESC limit 10"
err = db.Raw(sql).Scan(&o).Error
}
if o == nil {
o = make([]model.SessionForAccess, 0)
}
return
}

View File

@ -0,0 +1,87 @@
package repository
import (
"next-terminal/server/model"
"gorm.io/gorm"
)
type StorageRepository struct {
DB *gorm.DB
}
func NewStorageRepository(db *gorm.DB) *StorageRepository {
storageRepository = &StorageRepository{DB: db}
return storageRepository
}
func (r StorageRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.StorageForPage, total int64, err error) {
m := model.Storage{}
db := r.DB.Table(m.TableName()).Select("storages.id,storages.name,storages.is_share,storages.limit_size,storages.is_default,storages.owner,storages.created, users.nickname as owner_name").Joins("left join users on storages.owner = users.id")
dbCounter := r.DB.Table(m.TableName())
if len(name) > 0 {
db = db.Where("name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("name like ?", "%"+name+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "storages.name"
} else {
field = "storages.created"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.StorageForPage, 0)
}
return
}
func (r StorageRepository) FindShares() (o []model.Storage, err error) {
m := model.Storage{}
db := r.DB.Table(m.TableName()).Where("is_share = 1")
err = db.Find(&o).Error
return
}
func (r StorageRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(model.Storage{}).Error
}
func (r StorageRepository) Create(m *model.Storage) error {
return r.DB.Create(m).Error
}
func (r StorageRepository) UpdateById(o *model.Storage, id string) error {
o.ID = id
return r.DB.Updates(o).Error
}
func (r StorageRepository) FindByOwnerIdAndDefault(owner string, isDefault bool) (m model.Storage, err error) {
err = r.DB.Where("owner = ? and is_default = ?", owner, isDefault).First(&m).Error
return
}
func (r StorageRepository) FindById(id string) (m model.Storage, err error) {
err = r.DB.Where("id = ?", id).First(&m).Error
return
}
func (r StorageRepository) FindAll() (o []model.Storage) {
if r.DB.Find(&o).Error != nil {
return nil
}
return
}

View File

@ -0,0 +1,73 @@
package repository
import (
"next-terminal/server/model"
"gorm.io/gorm"
)
type StrategyRepository struct {
DB *gorm.DB
}
func NewStrategyRepository(db *gorm.DB) *StrategyRepository {
strategyRepository = &StrategyRepository{DB: db}
return strategyRepository
}
func (r StrategyRepository) FindAll() (o []model.Strategy, err error) {
err = r.DB.Order("name desc").Find(&o).Error
return
}
func (r StrategyRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.Strategy, total int64, err error) {
m := model.Strategy{}
db := r.DB.Table(m.TableName())
dbCounter := r.DB.Table(m.TableName())
if len(name) > 0 {
db = db.Where("name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("name like ?", "%"+name+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.Strategy, 0)
}
return
}
func (r StrategyRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(model.Strategy{}).Error
}
func (r StrategyRepository) Create(m *model.Strategy) error {
return r.DB.Create(m).Error
}
func (r StrategyRepository) UpdateById(o *model.Strategy, id string) error {
o.ID = id
return r.DB.Updates(o).Error
}
func (r StrategyRepository) FindById(id string) (m model.Strategy, err error) {
err = r.DB.Where("id = ?", id).First(&m).Error
return
}

View File

@ -1,7 +1,7 @@
package repository
import (
"next-terminal/pkg/constant"
"next-terminal/server/constant"
"next-terminal/server/model"
"gorm.io/gorm"
@ -105,9 +105,9 @@ func (r UserRepository) Update(o *model.User) error {
return r.DB.Updates(o).Error
}
func (r UserRepository) UpdateOnline(id string, online bool) error {
sql := "update users set online = ? where id = ?"
return r.DB.Exec(sql, online, id).Error
func (r UserRepository) UpdateOnlineByUsername(username string, online bool) error {
sql := "update users set online = ? where username = ?"
return r.DB.Exec(sql, online, username).Error
}
func (r UserRepository) DeleteById(id string) error {
@ -135,3 +135,8 @@ func (r UserRepository) CountOnlineUser() (total int64, err error) {
err = r.DB.Where("online = ?", true).Find(&model.User{}).Count(&total).Error
return
}
func (r UserRepository) Count() (total int64, err error) {
err = r.DB.Find(&model.User{}).Count(&total).Error
return
}

View File

@ -0,0 +1,65 @@
package service
import (
"next-terminal/server/global/gateway"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/term"
)
type AccessGatewayService struct {
accessGatewayRepository *repository.AccessGatewayRepository
}
func NewAccessGatewayService(accessGatewayRepository *repository.AccessGatewayRepository) *AccessGatewayService {
accessGatewayService = &AccessGatewayService{accessGatewayRepository: accessGatewayRepository}
return accessGatewayService
}
func (r AccessGatewayService) GetGatewayAndReconnectById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
if g == nil || !g.Connected {
accessGateway, err := r.accessGatewayRepository.FindById(accessGatewayId)
if err != nil {
return nil, err
}
g = r.ReConnect(&accessGateway)
}
return g, nil
}
func (r AccessGatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
return g, nil
}
func (r AccessGatewayService) ReConnectAll() error {
gateways, err := r.accessGatewayRepository.FindAll()
if err != nil {
return err
}
for i := range gateways {
r.ReConnect(&gateways[i])
}
return nil
}
func (r AccessGatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway {
log.Debugf("重建接入网关「%v」中...", m.Name)
r.DisconnectById(m.ID)
sshClient, err := term.NewSshClient(m.IP, m.Port, m.Username, m.Password, m.PrivateKey, m.Passphrase)
var g *gateway.Gateway
if err != nil {
g = gateway.NewGateway(m.ID, m.Localhost, false, err.Error(), nil)
} else {
g = gateway.NewGateway(m.ID, m.Localhost, true, "", sshClient)
}
gateway.GlobalGatewayManager.Add <- g
log.Debugf("重建接入网关「%v」完成", m.Name)
return g
}
func (r AccessGatewayService) DisconnectById(accessGatewayId string) {
gateway.GlobalGatewayManager.Del <- accessGatewayId
}

60
server/service/asset.go Normal file
View File

@ -0,0 +1,60 @@
package service
import (
"next-terminal/server/config"
"next-terminal/server/repository"
"next-terminal/server/utils"
)
type AssetService struct {
assetRepository *repository.AssetRepository
}
func NewAssetService(assetRepository *repository.AssetRepository) *AssetService {
return &AssetService{assetRepository: assetRepository}
}
func (r AssetService) Encrypt() error {
items, err := r.assetRepository.FindAll()
if err != nil {
return err
}
for i := range items {
item := items[i]
if item.Encrypted {
continue
}
if err := r.assetRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err := r.assetRepository.UpdateById(&item, item.ID); err != nil {
return err
}
}
return nil
}
func (r AssetService) CheckStatus(accessGatewayId string, ip string, port int) (active bool, err error) {
if accessGatewayId != "" && accessGatewayId != "-" {
g, e1 := accessGatewayService.GetGatewayAndReconnectById(accessGatewayId)
if err != nil {
return false, e1
}
uuid := utils.UUID()
exposedIP, exposedPort, e2 := g.OpenSshTunnel(uuid, ip, port)
if e2 != nil {
return false, e2
}
defer g.CloseSshTunnel(uuid)
if g.Connected {
active, err = utils.Tcping(exposedIP, exposedPort)
} else {
active = false
}
} else {
active, err = utils.Tcping(ip, port)
}
return active, err
}

View File

@ -0,0 +1,34 @@
package service
import (
"next-terminal/server/config"
"next-terminal/server/repository"
)
type CredentialService struct {
credentialRepository *repository.CredentialRepository
}
func NewCredentialService(credentialRepository *repository.CredentialRepository) *CredentialService {
return &CredentialService{credentialRepository: credentialRepository}
}
func (r CredentialService) Encrypt() error {
items, err := r.credentialRepository.FindAll()
if err != nil {
return err
}
for i := range items {
item := items[i]
if item.Encrypted {
continue
}
if err := r.credentialRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err := r.credentialRepository.UpdateById(&item, item.ID); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,5 @@
package service
var (
accessGatewayService *AccessGatewayService
)

374
server/service/job.go Normal file
View File

@ -0,0 +1,374 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"next-terminal/server/constant"
"next-terminal/server/global/cron"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/term"
"next-terminal/server/utils"
"gorm.io/gorm"
)
type JobService struct {
jobRepository *repository.JobRepository
jobLogRepository *repository.JobLogRepository
assetRepository *repository.AssetRepository
credentialRepository *repository.CredentialRepository
assetService *AssetService
}
func NewJobService(jobRepository *repository.JobRepository, jobLogRepository *repository.JobLogRepository, assetRepository *repository.AssetRepository, credentialRepository *repository.CredentialRepository, assetService *AssetService) *JobService {
return &JobService{jobRepository: jobRepository, jobLogRepository: jobLogRepository, assetRepository: assetRepository, credentialRepository: credentialRepository, assetService: assetService}
}
func (r JobService) ChangeStatusById(id, status string) error {
job, err := r.jobRepository.FindById(id)
if err != nil {
return err
}
if status == constant.JobStatusRunning {
j, err := getJob(&job, &r)
if err != nil {
return err
}
entryID, err := cron.GlobalCron.AddJob(job.Cron, j)
if err != nil {
return err
}
log.Debugf("开启计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(cron.GlobalCron.Entries()))
jobForUpdate := model.Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)}
return r.jobRepository.UpdateById(&jobForUpdate)
} else {
cron.GlobalCron.Remove(cron.JobId(job.CronJobId))
log.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(cron.GlobalCron.Entries()))
jobForUpdate := model.Job{ID: id, Status: constant.JobStatusNotRunning}
return r.jobRepository.UpdateById(&jobForUpdate)
}
}
func getJob(j *model.Job, jobService *JobService) (job cron.Job, err error) {
switch j.Func {
case constant.FuncCheckAssetStatusJob:
job = CheckAssetStatusJob{
ID: j.ID,
Mode: j.Mode,
ResourceIds: j.ResourceIds,
Metadata: j.Metadata,
jobService: jobService,
assetService: jobService.assetService,
}
case constant.FuncShellJob:
job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata, jobService: jobService}
default:
return nil, errors.New("未识别的任务")
}
return job, err
}
type CheckAssetStatusJob struct {
ID string
Mode string
ResourceIds string
Metadata string
jobService *JobService
assetService *AssetService
}
func (r CheckAssetStatusJob) Run() {
if r.ID == "" {
return
}
var assets []model.Asset
if r.Mode == constant.JobModeAll {
assets, _ = r.jobService.assetRepository.FindAll()
} else {
assets, _ = r.jobService.assetRepository.FindByIds(strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
msgChan := make(chan string)
for i := range assets {
asset := assets[i]
go func() {
t1 := time.Now()
var (
msg string
ip = asset.IP
port = asset.Port
)
active, err := r.assetService.CheckStatus(asset.AccessGatewayId, ip, port)
elapsed := time.Since(t1)
if err == nil {
msg = fmt.Sprintf("资产「%v」存活状态检测完成存活「%v」耗时「%v」", asset.Name, active, elapsed)
} else {
msg = fmt.Sprintf("资产「%v」存活状态检测完成存活「%v」耗时「%v」原因 %v", asset.Name, active, elapsed, err.Error())
}
_ = r.jobService.assetRepository.UpdateActiveById(active, asset.ID)
log.Infof(msg)
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID)
jobLog := model.JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = r.jobService.jobLogRepository.Create(&jobLog)
}
type ShellJob struct {
ID string
Mode string
ResourceIds string
Metadata string
jobService *JobService
}
type MetadataShell struct {
Shell string
}
func (r ShellJob) Run() {
if r.ID == "" {
return
}
var assets []model.Asset
if r.Mode == constant.JobModeAll {
assets, _ = r.jobService.assetRepository.FindByProtocol("ssh")
} else {
assets, _ = r.jobService.assetRepository.FindByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
var metadataShell MetadataShell
err := json.Unmarshal([]byte(r.Metadata), &metadataShell)
if err != nil {
log.Errorf("JSON数据解析失败 %v", err)
return
}
msgChan := make(chan string)
for i := range assets {
asset, err := r.jobService.assetRepository.FindByIdAndDecrypt(assets[i].ID)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询数据异常「%v」", assets[i].Name, err.Error())
return
}
var (
username = asset.Username
password = asset.Password
privateKey = asset.PrivateKey
passphrase = asset.Passphrase
ip = asset.IP
port = asset.Port
)
if asset.AccountType == "credential" {
credential, err := r.jobService.credentialRepository.FindByIdAndDecrypt(asset.CredentialId)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询授权凭证数据异常「%v」", assets[i].Name, err.Error())
return
}
if credential.Type == constant.Custom {
username = credential.Username
password = credential.Password
} else {
username = credential.Username
privateKey = credential.PrivateKey
passphrase = credential.Passphrase
}
}
go func() {
t1 := time.Now()
result, err := exec(metadataShell.Shell, asset.AccessGatewayId, ip, port, username, password, privateKey, passphrase)
elapsed := time.Since(t1)
var msg string
if err != nil {
if errors.Is(gorm.ErrRecordNotFound, err) {
msg = fmt.Sprintf("资产「%v」Shell执行失败请检查资产所关联接入网关是否存在耗时「%v」", asset.Name, elapsed)
} else {
msg = fmt.Sprintf("资产「%v」Shell执行失败错误内容为「%v」耗时「%v」", asset.Name, err.Error(), elapsed)
}
log.Infof(msg)
} else {
msg = fmt.Sprintf("资产「%v」Shell执行成功返回值「%v」耗时「%v」", asset.Name, result, elapsed)
log.Infof(msg)
}
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID)
jobLog := model.JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = r.jobService.jobLogRepository.Create(&jobLog)
}
func exec(shell, accessGatewayId, ip string, port int, username, password, privateKey, passphrase string) (string, error) {
if accessGatewayId != "" && accessGatewayId != "-" {
g, err := accessGatewayService.GetGatewayAndReconnectById(accessGatewayId)
if err != nil {
return "", err
}
uuid := utils.UUID()
exposedIP, exposedPort, err := g.OpenSshTunnel(uuid, ip, port)
if err != nil {
return "", err
}
defer g.CloseSshTunnel(uuid)
return ExecCommandBySSH(shell, exposedIP, exposedPort, username, password, privateKey, passphrase)
} else {
return ExecCommandBySSH(shell, ip, port, username, password, privateKey, passphrase)
}
}
func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) {
sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase)
if err != nil {
return "", err
}
session, err := sshClient.NewSession()
if err != nil {
return "", err
}
defer session.Close()
//执行远程命令
combo, err := session.CombinedOutput(cmd)
if err != nil {
return "", err
}
return string(combo), nil
}
func (r JobService) ExecJobById(id string) (err error) {
job, err := r.jobRepository.FindById(id)
if err != nil {
return err
}
j, err := getJob(&job, &r)
if err != nil {
return err
}
j.Run()
return nil
}
func (r JobService) InitJob() error {
jobs, _ := r.jobRepository.FindAll()
if len(jobs) == 0 {
job := model.Job{
ID: utils.UUID(),
Name: "资产状态检测",
Func: constant.FuncCheckAssetStatusJob,
Cron: "0 0/10 * * * ?",
Mode: constant.JobModeAll,
Status: constant.JobStatusRunning,
Created: utils.NowJsonTime(),
Updated: utils.NowJsonTime(),
}
if err := r.jobRepository.Create(&job); err != nil {
return err
}
log.Debugf("创建计划任务「%v」cron「%v」", job.Name, job.Cron)
} else {
for i := range jobs {
if jobs[i].Status == constant.JobStatusRunning {
err := r.ChangeStatusById(jobs[i].ID, constant.JobStatusRunning)
if err != nil {
return err
}
log.Debugf("启动计划任务「%v」cron「%v」", jobs[i].Name, jobs[i].Cron)
}
}
}
return nil
}
func (r JobService) Create(o *model.Job) (err error) {
if o.Status == constant.JobStatusRunning {
j, err := getJob(o, &r)
if err != nil {
return err
}
jobId, err := cron.GlobalCron.AddJob(o.Cron, j)
if err != nil {
return err
}
o.CronJobId = int(jobId)
}
return r.jobRepository.Create(o)
}
func (r JobService) DeleteJobById(id string) error {
job, err := r.jobRepository.FindById(id)
if err != nil {
return err
}
if job.Status == constant.JobStatusRunning {
if err := r.ChangeStatusById(id, constant.JobStatusNotRunning); err != nil {
return err
}
}
return r.jobRepository.DeleteJobById(id)
}
func (r JobService) UpdateById(m *model.Job) error {
if err := r.jobRepository.UpdateById(m); err != nil {
return err
}
if err := r.ChangeStatusById(m.ID, constant.JobStatusNotRunning); err != nil {
return err
}
if err := r.ChangeStatusById(m.ID, constant.JobStatusRunning); err != nil {
return err
}
return nil
}

42
server/service/mail.go Normal file
View File

@ -0,0 +1,42 @@
package service
import (
"net/smtp"
"next-terminal/server/constant"
"next-terminal/server/log"
"next-terminal/server/repository"
"github.com/jordan-wright/email"
)
type MailService struct {
propertyRepository *repository.PropertyRepository
}
func NewMailService(propertyRepository *repository.PropertyRepository) *MailService {
return &MailService{propertyRepository: propertyRepository}
}
func (r MailService) SendMail(to, subject, text string) {
propertiesMap := r.propertyRepository.FindAllMap()
host := propertiesMap[constant.MailHost]
port := propertiesMap[constant.MailPort]
username := propertiesMap[constant.MailUsername]
password := propertiesMap[constant.MailPassword]
if host == "" || port == "" || username == "" || password == "" {
log.Debugf("邮箱信息不完整,跳过发送邮件。")
return
}
e := email.NewEmail()
e.From = "Next Terminal <" + username + ">"
e.To = []string{to}
e.Subject = subject
e.Text = []byte(text)
err := e.Send(host+":"+port, smtp.PlainAuth("", username, password, host))
if err != nil {
log.Errorf("邮件发送失败: %v", err.Error())
}
}

180
server/service/property.go Normal file
View File

@ -0,0 +1,180 @@
package service
import (
"next-terminal/server/guacd"
"next-terminal/server/model"
"next-terminal/server/repository"
)
type PropertyService struct {
propertyRepository *repository.PropertyRepository
}
func NewPropertyService(propertyRepository *repository.PropertyRepository) *PropertyService {
return &PropertyService{propertyRepository: propertyRepository}
}
func (r PropertyService) InitProperties() error {
propertyMap := r.propertyRepository.FindAllMap()
if len(propertyMap[guacd.EnableRecording]) == 0 {
property := model.Property{
Name: guacd.EnableRecording,
Value: "true",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.CreateRecordingPath]) == 0 {
property := model.Property{
Name: guacd.CreateRecordingPath,
Value: "true",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.FontName]) == 0 {
property := model.Property{
Name: guacd.FontName,
Value: "menlo",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.FontSize]) == 0 {
property := model.Property{
Name: guacd.FontSize,
Value: "12",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.ColorScheme]) == 0 {
property := model.Property{
Name: guacd.ColorScheme,
Value: "gray-black",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableWallpaper]) == 0 {
property := model.Property{
Name: guacd.EnableWallpaper,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableTheming]) == 0 {
property := model.Property{
Name: guacd.EnableTheming,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableFontSmoothing]) == 0 {
property := model.Property{
Name: guacd.EnableFontSmoothing,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableFullWindowDrag]) == 0 {
property := model.Property{
Name: guacd.EnableFullWindowDrag,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableDesktopComposition]) == 0 {
property := model.Property{
Name: guacd.EnableDesktopComposition,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.EnableMenuAnimations]) == 0 {
property := model.Property{
Name: guacd.EnableMenuAnimations,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DisableBitmapCaching]) == 0 {
property := model.Property{
Name: guacd.DisableBitmapCaching,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DisableOffscreenCaching]) == 0 {
property := model.Property{
Name: guacd.DisableOffscreenCaching,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
if len(propertyMap[guacd.DisableGlyphCaching]) == 0 {
property := model.Property{
Name: guacd.DisableGlyphCaching,
Value: "true",
}
if err := r.propertyRepository.Create(&property); err != nil {
return err
}
}
return nil
}
func (r PropertyService) DeleteDeprecatedProperty() error {
propertyMap := r.propertyRepository.FindAllMap()
if propertyMap[guacd.EnableDrive] != "" {
if err := r.propertyRepository.DeleteByName(guacd.DriveName); err != nil {
return err
}
}
if propertyMap[guacd.DrivePath] != "" {
if err := r.propertyRepository.DeleteByName(guacd.DrivePath); err != nil {
return err
}
}
if propertyMap[guacd.DriveName] != "" {
if err := r.propertyRepository.DeleteByName(guacd.DriveName); err != nil {
return err
}
}
return nil
}

39
server/service/session.go Normal file
View File

@ -0,0 +1,39 @@
package service
import (
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
)
type SessionService struct {
sessionRepository *repository.SessionRepository
}
func NewSessionService(sessionRepository *repository.SessionRepository) *SessionService {
return &SessionService{sessionRepository: sessionRepository}
}
func (r SessionService) FixSessionState() error {
sessions, err := r.sessionRepository.FindByStatus(constant.Connected)
if err != nil {
return err
}
if len(sessions) > 0 {
for i := range sessions {
session := model.Session{
Status: constant.Disconnected,
DisconnectedTime: utils.NowJsonTime(),
}
_ = r.sessionRepository.UpdateById(&session, sessions[i].ID)
}
}
return nil
}
func (r SessionService) EmptyPassword() error {
return r.sessionRepository.EmptyPassword()
}

151
server/service/storage.go Normal file
View File

@ -0,0 +1,151 @@
package service
import (
"errors"
"io/ioutil"
"os"
"path"
"next-terminal/server/config"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
"gorm.io/gorm"
)
type StorageService struct {
storageRepository *repository.StorageRepository
userRepository *repository.UserRepository
propertyRepository *repository.PropertyRepository
}
func NewStorageService(storageRepository *repository.StorageRepository, userRepository *repository.UserRepository, propertyRepository *repository.PropertyRepository) *StorageService {
return &StorageService{storageRepository: storageRepository, userRepository: userRepository, propertyRepository: propertyRepository}
}
func (r StorageService) InitStorages() error {
users := r.userRepository.FindAll()
for i := range users {
userId := users[i].ID
_, err := r.storageRepository.FindByOwnerIdAndDefault(userId, true)
if errors.Is(err, gorm.ErrRecordNotFound) {
err = r.CreateStorageByUser(&users[i])
if err != nil {
return err
}
}
}
drivePath := r.GetBaseDrivePath()
storages := r.storageRepository.FindAll()
for i := 0; i < len(storages); i++ {
storage := storages[i]
// 判断是否为遗留的数据:磁盘空间在,但用户已删除
if storage.IsDefault {
var userExist = false
for j := range users {
if storage.ID == users[j].ID {
userExist = true
break
}
}
if !userExist {
if err := r.DeleteStorageById(storage.ID, true); err != nil {
return err
}
}
}
storageDir := path.Join(drivePath, storage.ID)
if !utils.FileExists(storageDir) {
if err := os.MkdirAll(storageDir, os.ModePerm); err != nil {
return err
}
log.Infof("创建storage:「%v」文件夹: %v", storage.Name, storageDir)
}
}
return nil
}
func (r StorageService) CreateStorageByUser(user *model.User) error {
drivePath := r.GetBaseDrivePath()
storage := model.Storage{
ID: user.ID,
Name: user.Nickname + "的默认空间",
IsShare: false,
IsDefault: true,
LimitSize: -1,
Owner: user.ID,
Created: utils.NowJsonTime(),
}
storageDir := path.Join(drivePath, storage.ID)
if err := os.MkdirAll(storageDir, os.ModePerm); err != nil {
return err
}
log.Infof("创建storage:「%v」文件夹: %v", storage.Name, storageDir)
err := r.storageRepository.Create(&storage)
if err != nil {
return err
}
return nil
}
type File struct {
Name string `json:"name"`
Path string `json:"path"`
IsDir bool `json:"isDir"`
Mode string `json:"mode"`
IsLink bool `json:"isLink"`
ModTime utils.JsonTime `json:"modTime"`
Size int64 `json:"size"`
}
func (r StorageService) Ls(drivePath, remoteDir string) ([]File, error) {
fileInfos, err := ioutil.ReadDir(path.Join(drivePath, remoteDir))
if err != nil {
return nil, err
}
var files = make([]File, 0)
for i := range fileInfos {
file := File{
Name: fileInfos[i].Name(),
Path: path.Join(remoteDir, fileInfos[i].Name()),
IsDir: fileInfos[i].IsDir(),
Mode: fileInfos[i].Mode().String(),
IsLink: fileInfos[i].Mode()&os.ModeSymlink == os.ModeSymlink,
ModTime: utils.NewJsonTime(fileInfos[i].ModTime()),
Size: fileInfos[i].Size(),
}
files = append(files, file)
}
return files, nil
}
func (r StorageService) GetBaseDrivePath() string {
return config.GlobalCfg.Guacd.Drive
}
func (r StorageService) DeleteStorageById(id string, force bool) error {
drivePath := r.GetBaseDrivePath()
storage, err := r.storageRepository.FindById(id)
if err != nil {
return err
}
if !force && storage.IsDefault {
return errors.New("默认空间不能删除")
}
// 删除对应的本地目录
if err := os.RemoveAll(path.Join(drivePath, id)); err != nil {
return err
}
if err := r.storageRepository.DeleteById(id); err != nil {
return err
}
return nil
}

129
server/service/user.go Normal file
View File

@ -0,0 +1,129 @@
package service
import (
"strings"
"next-terminal/server/constant"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
)
type UserService struct {
userRepository *repository.UserRepository
loginLogRepository *repository.LoginLogRepository
}
func NewUserService(userRepository *repository.UserRepository, loginLogRepository *repository.LoginLogRepository) *UserService {
return &UserService{userRepository: userRepository, loginLogRepository: loginLogRepository}
}
func (r UserService) InitUser() (err error) {
users := r.userRepository.FindAll()
if len(users) == 0 {
initPassword := "admin"
var pass []byte
if pass, err = utils.Encoder.Encode([]byte(initPassword)); err != nil {
return err
}
user := model.User{
ID: utils.UUID(),
Username: "admin",
Password: string(pass),
Nickname: "超级管理员",
Type: constant.TypeAdmin,
Created: utils.NowJsonTime(),
}
if err := r.userRepository.Create(&user); err != nil {
return err
}
log.Infof("初始用户创建成功,账号:「%v」密码「%v」", user.Username, initPassword)
} else {
for i := range users {
// 修正默认用户类型为管理员
if users[i].Type == "" {
user := model.User{
Type: constant.TypeAdmin,
ID: users[i].ID,
}
if err := r.userRepository.Update(&user); err != nil {
return err
}
log.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID)
}
}
}
return nil
}
func (r UserService) FixUserOnlineState() error {
// 修正用户登录状态
onlineUsers, err := r.userRepository.FindOnlineUsers()
if err != nil {
return err
}
if len(onlineUsers) > 0 {
for i := range onlineUsers {
logs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(onlineUsers[i].Username)
if err != nil {
return err
}
if len(logs) == 0 {
if err := r.userRepository.UpdateOnlineByUsername(onlineUsers[i].Username, false); err != nil {
return err
}
}
}
}
return nil
}
func (r UserService) Logout(token string) (err error) {
loginLog, err := r.loginLogRepository.FindById(token)
if err != nil {
log.Warnf("登录日志「%v」获取失败", token)
return
}
loginLogForUpdate := &model.LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}
err = r.loginLogRepository.Update(loginLogForUpdate)
if err != nil {
return err
}
loginLogs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(loginLog.Username)
if err != nil {
return
}
if len(loginLogs) == 0 {
err = r.userRepository.UpdateOnlineByUsername(loginLog.Username, false)
}
return
}
func (r UserService) BuildCacheKeyByToken(token string) string {
cacheKey := strings.Join([]string{constant.Token, token}, ":")
return cacheKey
}
func (r UserService) GetTokenFormCacheKey(cacheKey string) string {
token := strings.Split(cacheKey, ":")[1]
return token
}
func (r UserService) OnEvicted(key string, value interface{}) {
if strings.HasPrefix(key, constant.Token) {
token := r.GetTokenFormCacheKey(key)
log.Debugf("用户Token「%v」过期", token)
err := r.Logout(token)
if err != nil {
log.Errorf("退出登录失败 %v", err)
}
}
}

139
server/task/ticker.go Normal file
View File

@ -0,0 +1,139 @@
package task
import (
"strconv"
"time"
"next-terminal/server/constant"
"next-terminal/server/log"
"next-terminal/server/repository"
)
type Ticker struct {
sessionRepository *repository.SessionRepository
propertyRepository *repository.PropertyRepository
loginLogRepository *repository.LoginLogRepository
jobLogRepository *repository.JobLogRepository
}
func NewTicker(sessionRepository *repository.SessionRepository, propertyRepository *repository.PropertyRepository, loginLogRepository *repository.LoginLogRepository, jobLogRepository *repository.JobLogRepository) *Ticker {
return &Ticker{sessionRepository: sessionRepository, propertyRepository: propertyRepository, loginLogRepository: loginLogRepository, jobLogRepository: jobLogRepository}
}
func (t *Ticker) SetupTicker() {
// 每隔一小时删除一次未使用的会话信息
unUsedSessionTicker := time.NewTicker(time.Minute * 60)
go func() {
for range unUsedSessionTicker.C {
sessions, _ := t.sessionRepository.FindByStatusIn([]string{constant.NoConnect, constant.Connecting})
if len(sessions) > 0 {
now := time.Now()
for i := range sessions {
if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 {
_ = t.sessionRepository.DeleteById(sessions[i].ID)
s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port)
log.Infof("会话「%v」ID「%v」超过1小时未打开已删除。", s, sessions[i].ID)
}
}
}
}
}()
// 每隔6小时删除超过时长限制的会话
timeoutSessionTicker := time.NewTicker(time.Hour * 6)
go func() {
for range timeoutSessionTicker.C {
deleteOutTimeSession(t)
deleteOutTimeLoginLog(t)
deleteOutTimeJobLog(t)
}
}()
}
func deleteOutTimeSession(t *Ticker) {
property, err := t.propertyRepository.FindByName("session-saved-limit")
if err != nil {
return
}
if property.Value == "" || property.Value == "-" {
return
}
limit, err := strconv.Atoi(property.Value)
if err != nil {
return
}
sessions, err := t.sessionRepository.FindOutTimeSessions(limit)
if err != nil {
return
}
if len(sessions) > 0 {
var ids []string
for i := range sessions {
ids = append(ids, sessions[i].ID)
}
err := t.sessionRepository.DeleteByIds(ids)
if err != nil {
log.Errorf("删除离线会话失败 %v", err)
}
}
}
func deleteOutTimeLoginLog(t *Ticker) {
property, err := t.propertyRepository.FindByName("login-log-saved-limit")
if err != nil {
return
}
if property.Value == "" || property.Value == "-" {
return
}
limit, err := strconv.Atoi(property.Value)
if err != nil {
log.Errorf("获取删除登录日志保留时常失败 %v", err)
return
}
loginLogs, err := t.loginLogRepository.FindOutTimeLog(limit)
if err != nil {
log.Errorf("获取登录日志失败 %v", err)
return
}
if len(loginLogs) > 0 {
for i := range loginLogs {
err := t.loginLogRepository.DeleteById(loginLogs[i].ID)
if err != nil {
log.Errorf("删除登录日志失败 %v", err)
}
}
}
}
func deleteOutTimeJobLog(t *Ticker) {
property, err := t.propertyRepository.FindByName("cron-log-saved-limit")
if err != nil {
return
}
if property.Value == "" || property.Value == "-" {
return
}
limit, err := strconv.Atoi(property.Value)
if err != nil {
return
}
jobLogs, err := t.jobLogRepository.FindOutTimeLog(limit)
if err != nil {
return
}
if len(jobLogs) > 0 {
for i := range jobLogs {
err := t.jobLogRepository.DeleteById(jobLogs[i].ID)
if err != nil {
log.Errorf("删除计划日志失败 %v", err)
}
}
}
}

View File

@ -0,0 +1,103 @@
package term
import (
"bufio"
"io"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
type NextTerminal struct {
SshClient *ssh.Client
SshSession *ssh.Session
StdinPipe io.WriteCloser
SftpClient *sftp.Client
Recorder *Recorder
StdoutReader *bufio.Reader
}
func NewNextTerminal(ip string, port int, username, password, privateKey, passphrase string, rows, cols int, recording, term string, pipe bool) (*NextTerminal, error) {
sshClient, err := NewSshClient(ip, port, username, password, privateKey, passphrase)
if err != nil {
return nil, err
}
sshSession, err := sshClient.NewSession()
if err != nil {
return nil, err
}
var stdoutReader *bufio.Reader
if pipe {
stdoutPipe, err := sshSession.StdoutPipe()
if err != nil {
return nil, err
}
stdoutReader = bufio.NewReader(stdoutPipe)
}
var stdinPipe io.WriteCloser
if pipe {
stdinPipe, err = sshSession.StdinPipe()
if err != nil {
return nil, err
}
}
var recorder *Recorder
if recording != "" {
recorder, err = NewRecorder(recording, term, rows, cols)
if err != nil {
return nil, err
}
}
terminal := NextTerminal{
SshClient: sshClient,
SshSession: sshSession,
Recorder: recorder,
StdinPipe: stdinPipe,
StdoutReader: stdoutReader,
}
return &terminal, nil
}
func (ret *NextTerminal) Write(p []byte) (int, error) {
return ret.StdinPipe.Write(p)
}
func (ret *NextTerminal) Close() error {
if ret.SshSession != nil {
return ret.SshSession.Close()
}
if ret.SshClient != nil {
return ret.SshClient.Close()
}
if ret.Recorder != nil {
return ret.Close()
}
return nil
}
func (ret *NextTerminal) WindowChange(h int, w int) error {
return ret.SshSession.WindowChange(h, w)
}
func (ret *NextTerminal) RequestPty(term string, h, w int) error {
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
return ret.SshSession.RequestPty(term, h, w, modes)
}
func (ret *NextTerminal) Shell() error {
return ret.SshSession.Shell()
}

113
server/term/recorder.go Normal file
View File

@ -0,0 +1,113 @@
package term
import (
"encoding/json"
"os"
"time"
"next-terminal/server/utils"
)
type Env struct {
Shell string `json:"SHELL"`
Term string `json:"TERM"`
}
type Header struct {
Title string `json:"title"`
Version int `json:"version"`
Height int `json:"height"`
Width int `json:"width"`
Env Env `json:"env"`
Timestamp int `json:"Timestamp"`
}
type Recorder struct {
File *os.File
Timestamp int
}
func (recorder *Recorder) Close() {
if recorder.File != nil {
_ = recorder.File.Close()
}
}
func (recorder *Recorder) WriteHeader(header *Header) (err error) {
var p []byte
if p, err = json.Marshal(header); err != nil {
return
}
if _, err := recorder.File.Write(p); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
recorder.Timestamp = header.Timestamp
return
}
func (recorder *Recorder) WriteData(data string) (err error) {
now := int(time.Now().UnixNano())
delta := float64(now-recorder.Timestamp*1000*1000*1000) / 1000 / 1000 / 1000
row := make([]interface{}, 0)
row = append(row, delta)
row = append(row, "o")
row = append(row, data)
var s []byte
if s, err = json.Marshal(row); err != nil {
return
}
if _, err := recorder.File.Write(s); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
return
}
func NewRecorder(recordingPath, term string, h int, w int) (recorder *Recorder, err error) {
recorder = &Recorder{}
parentDirectory := utils.GetParentDirectory(recordingPath)
if utils.FileExists(parentDirectory) {
if err := os.RemoveAll(parentDirectory); err != nil {
return nil, err
}
}
if err = os.MkdirAll(parentDirectory, 0777); err != nil {
return
}
var file *os.File
file, err = os.Create(recordingPath)
if err != nil {
return nil, err
}
recorder.File = file
header := &Header{
Title: "",
Version: 2,
Height: h,
Width: w,
Env: Env{Shell: "/bin/bash", Term: term},
Timestamp: int(time.Now().Unix()),
}
if err := recorder.WriteHeader(header); err != nil {
return nil, err
}
return recorder, nil
}

54
server/term/ssh.go Normal file
View File

@ -0,0 +1,54 @@
package term
import (
"fmt"
"time"
"golang.org/x/crypto/ssh"
)
func NewSshClient(ip string, port int, username, password, privateKey, passphrase string) (*ssh.Client, error) {
var authMethod ssh.AuthMethod
if username == "-" || username == "" {
username = "root"
}
if password == "-" {
password = ""
}
if privateKey == "-" {
privateKey = ""
}
if passphrase == "-" {
passphrase = ""
}
var err error
if privateKey != "" {
var key ssh.Signer
if len(passphrase) > 0 {
key, err = ssh.ParsePrivateKeyWithPassphrase([]byte(privateKey), []byte(passphrase))
if err != nil {
return nil, err
}
} else {
key, err = ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, err
}
}
authMethod = ssh.PublicKeys(key)
} else {
authMethod = ssh.Password(password)
}
config := &ssh.ClientConfig{
Timeout: 3 * time.Second,
User: username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
addr := fmt.Sprintf("%s:%d", ip, port)
return ssh.Dial("tcp", addr, config)
}

View File

@ -0,0 +1,175 @@
package main
import (
"fmt"
"io"
"next-terminal/server/log"
"os"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
type SSHTerminal struct {
Session *ssh.Session
exitMsg string
stdout io.Reader
stdin io.Writer
stderr io.Reader
}
func main() {
sshConfig := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{
ssh.Password("root"),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
client, err := ssh.Dial("tcp", "172.16.101.32:22", sshConfig)
if err != nil {
log.Error(err)
}
defer client.Close()
err = New(client)
if err != nil {
fmt.Println(err)
}
}
func (t *SSHTerminal) updateTerminalSize() {
go func() {
// SIGWINCH is sent to the process when the window size of the terminal has
// changed.
sigwinchCh := make(chan os.Signal, 1)
//signal.Notify(sigwinchCh, syscall.SIN)
fd := int(os.Stdin.Fd())
termWidth, termHeight, err := terminal.GetSize(fd)
if err != nil {
fmt.Println(err)
}
for {
select {
// The client updated the size of the local PTY. This change needs to occur
// on the server side PTY as well.
case sigwinch := <-sigwinchCh:
if sigwinch == nil {
return
}
currTermWidth, currTermHeight, err := terminal.GetSize(fd)
// Terminal size has not changed, don't do anything.
if currTermHeight == termHeight && currTermWidth == termWidth {
continue
}
err = t.Session.WindowChange(currTermHeight, currTermWidth)
if err != nil {
fmt.Printf("Unable to send window-change reqest: %s.", err)
continue
}
termWidth, termHeight = currTermWidth, currTermHeight
}
}
}()
}
func (t *SSHTerminal) interactiveSession() error {
defer func() {
if t.exitMsg == "" {
log.Info(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822))
} else {
log.Info(os.Stdout, t.exitMsg)
}
}()
fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd)
if err != nil {
return err
}
defer terminal.Restore(fd, state)
termWidth, termHeight, err := terminal.GetSize(fd)
if err != nil {
return err
}
termType := os.Getenv("TERM")
if termType == "" {
termType = "xterm-256color"
}
err = t.Session.RequestPty(termType, termHeight, termWidth, ssh.TerminalModes{})
if err != nil {
return err
}
t.updateTerminalSize()
t.stdin, err = t.Session.StdinPipe()
if err != nil {
return err
}
t.stdout, err = t.Session.StdoutPipe()
if err != nil {
return err
}
t.stderr, err = t.Session.StderrPipe()
go io.Copy(os.Stderr, t.stderr)
go io.Copy(os.Stdout, t.stdout)
go func() {
buf := make([]byte, 128)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
fmt.Println(err)
return
}
if n > 0 {
_, err = t.stdin.Write(buf[:n])
if err != nil {
fmt.Println(err)
t.exitMsg = err.Error()
return
}
}
}
}()
err = t.Session.Shell()
if err != nil {
return err
}
err = t.Session.Wait()
if err != nil {
return err
}
return nil
}
func New(client *ssh.Client) error {
session, err := client.NewSession()
if err != nil {
return err
}
defer session.Close()
s := SSHTerminal{
Session: session,
}
return s.interactiveSession()
}

19
server/totp/totp.go Normal file
View File

@ -0,0 +1,19 @@
package totp
import (
otp_t "github.com/pquerna/otp"
totp_t "github.com/pquerna/otp/totp"
)
type GenerateOpts totp_t.GenerateOpts
func NewTOTP(opt GenerateOpts) (*otp_t.Key, error) {
return totp_t.Generate(totp_t.GenerateOpts(opt))
}
func Validate(code string, secret string) bool {
if secret == "" {
return true
}
return totp_t.Validate(code, secret)
}

View File

@ -0,0 +1 @@
package utils

View File

@ -18,19 +18,19 @@ func TestTcping(t *testing.T) {
localhost6 := "::1"
conn, err := net.Listen("tcp", ":9999")
assert.NoError(t, err)
ip4resfalse := utils.Tcping(localhost4, 22)
ip4resfalse, _ := utils.Tcping(localhost4, 22)
assert.Equal(t, false, ip4resfalse)
ip4res := utils.Tcping(localhost4, 9999)
ip4res, _ := utils.Tcping(localhost4, 9999)
assert.Equal(t, true, ip4res)
ip6res := utils.Tcping(localhost6, 9999)
ip6res, _ := utils.Tcping(localhost6, 9999)
assert.Equal(t, true, ip6res)
ip4resWithBracket := utils.Tcping("["+localhost4+"]", 9999)
ip4resWithBracket, _ := utils.Tcping("["+localhost4+"]", 9999)
assert.Equal(t, true, ip4resWithBracket)
ip6resWithBracket := utils.Tcping("["+localhost6+"]", 9999)
ip6resWithBracket, _ := utils.Tcping("["+localhost6+"]", 9999)
assert.Equal(t, true, ip6resWithBracket)
defer func() {
@ -40,10 +40,11 @@ func TestTcping(t *testing.T) {
func TestAesEncryptCBC(t *testing.T) {
origData := []byte("Hello Next Terminal") // 待加密的数据
key := []byte("qwertyuiopasdfgh") // 加密的密钥
encryptedCBC, err := utils.AesEncryptCBC(origData, key)
md5Sum := fmt.Sprintf("%x", md5.Sum([]byte("next-terminal")))
key := []byte(md5Sum) // 加密的密钥
_, err := utils.AesEncryptCBC(origData, key)
assert.NoError(t, err)
assert.Equal(t, "s2xvMRPfZjmttpt+x0MzG9dsWcf1X+h9nt7waLvXpNM=", base64.StdEncoding.EncodeToString(encryptedCBC))
//assert.Equal(t, "s2xvMRPfZjmttpt+x0MzG9dsWcf1X+h9nt7waLvXpNM=", base64.StdEncoding.EncodeToString(encryptedCBC))
}
func TestAesDecryptCBC(t *testing.T) {
@ -77,3 +78,11 @@ func TestAesDecryptCBCWithAnyKey(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "admin", string(decryptCBC))
}
func TestGetAvailablePort(t *testing.T) {
port, err := utils.GetAvailablePort()
if err != nil {
println(err)
}
println(port)
}

View File

@ -2,30 +2,43 @@ package utils
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"database/sql/driver"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"image"
"image/png"
"io/ioutil"
"net"
"os"
"path/filepath"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"time"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/ssh"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
"github.com/denisbrodbeck/machineid"
"github.com/gofrs/uuid"
errors2 "github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/pbkdf2"
)
type JsonTime struct {
@ -87,7 +100,7 @@ func UUID() string {
return v4.String()
}
func Tcping(ip string, port int) bool {
func Tcping(ip string, port int) (bool, error) {
var (
conn net.Conn
err error
@ -100,13 +113,13 @@ func Tcping(ip string, port int) bool {
} else {
address = fmt.Sprintf("[%s]:%s", ip, strPort)
}
if conn, err = net.DialTimeout("tcp", address, 2*time.Second); err != nil {
return false
if conn, err = net.DialTimeout("tcp", address, 5*time.Second); err != nil {
return false, err
}
defer func() {
_ = conn.Close()
}()
return true
return true, nil
}
func ImageToBase64Encode(img image.Image) (string, error) {
@ -144,6 +157,16 @@ func GetParentDirectory(directory string) string {
return filepath.Dir(directory)
}
func MkdirP(path string) error {
if !FileExists(path) {
if err := os.MkdirAll(path, os.ModePerm); err != nil {
return err
}
fmt.Printf("创建文件夹: %v \n", path)
}
return nil
}
// 去除重复元素
func Distinct(a []string) []string {
result := make([]string, 0, len(a))
@ -231,6 +254,19 @@ func Check(f func() error) {
}
}
func ParseNetReg(line string, reg *regexp.Regexp, shouldLen, index int) (int64, string, error) {
rx1 := reg.FindStringSubmatch(line)
if len(rx1) != shouldLen {
return 0, "", errors.New("find string length error")
}
i64, err := strconv.ParseInt(rx1[index], 10, 64)
total := rx1[2]
if err != nil {
return 0, "", errors2.Wrap(err, "ParseInt error")
}
return i64, total, nil
}
func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
@ -283,3 +319,167 @@ func Pbkdf2(password string) ([]byte, error) {
dk := pbkdf2.Key([]byte(password), salt, 1, 32, sha256.New)
return dk, nil
}
func DeCryptPassword(cryptPassword string, key []byte) (string, error) {
origData, err := base64.StdEncoding.DecodeString(cryptPassword)
if err != nil {
return "", err
}
decryptedCBC, err := AesDecryptCBC(origData, key)
if err != nil {
return "", err
}
return string(decryptedCBC), nil
}
func RegexpFindSubString(text string, reg *regexp.Regexp) (ret string, err error) {
findErr := errors.New("regexp find failed")
res := reg.FindStringSubmatch(text)
if len(res) != 2 {
return "", findErr
}
return res[1], nil
}
func String2int(s string) (int, error) {
i, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
return i, nil
}
func RunCommand(client *ssh.Client, command string) (stdout string, err error) {
session, err := client.NewSession()
if err != nil {
return "", err
}
defer session.Close()
var buf bytes.Buffer
session.Stdout = &buf
err = session.Run(command)
if err != nil {
return "", err
}
stdout = buf.String()
return
}
func TimeWatcher(name string) {
start := time.Now()
defer func() {
cost := time.Since(start)
fmt.Printf("%s: %v\n", name, cost)
}()
}
func DirSize(path string) (int64, error) {
var size int64
err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return err
})
return size, err
}
func Utf8ToGbk(s []byte) ([]byte, error) {
reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewEncoder())
d, e := ioutil.ReadAll(reader)
if e != nil {
return nil, e
}
return d, nil
}
// SignatureRSA rsa私钥签名
func SignatureRSA(plainText []byte, rsaPrivateKey string) (signed []byte, err error) {
// 使用pem对读取的内容解码得到block
block, _ := pem.Decode([]byte(rsaPrivateKey))
//x509将数据解析得到私钥结构体
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
// 创建一个hash对象
h := sha512.New()
_, _ = h.Write(plainText)
// 计算hash值
hashText := h.Sum(nil)
// 使用rsa函数对散列值签名
signed, err = rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA512, hashText)
if err != nil {
return
}
return signed, nil
}
// VerifyRSA rsa签名认证
func VerifyRSA(plainText, signText []byte, rsaPublicKey string) bool {
// pem解码得到block
block, _ := pem.Decode([]byte(rsaPublicKey))
// x509解析得到接口
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return false
}
// 对原始明文进行hash运算得到散列值
hashText := sha512.Sum512(plainText)
// 签名认证
err = rsa.VerifyPKCS1v15(publicKey, crypto.SHA512, hashText[:], signText)
return err == nil
}
func GetMachineId() (string, error) {
return machineid.ID()
}
// GetAvailablePort 获取可用端口
func GetAvailablePort() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return 0, err
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return 0, err
}
defer func(l *net.TCPListener) {
_ = l.Close()
}(l)
return l.Addr().(*net.TCPAddr).Port, nil
}
func InsertSlice(index int, new []rune, src []rune) (ns []rune) {
ns = append(ns, src[:index]...)
ns = append(ns, new...)
ns = append(ns, src[index:]...)
return ns
}
func GetLocalIp() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", err
}
for _, address := range addrs {
// 检查ip地址判断是否回环地址
if ipNet, ok := address.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
return ipNet.IP.String(), nil
}
}
}
return "", errors.New("获取本机IP地址失败")
}