From 25b8381a4f41a2cdef7e35faec700d814e670ce1 Mon Sep 17 00:00:00 2001 From: dushixiang Date: Thu, 18 Mar 2021 23:36:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E9=87=8D=E6=9E=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=93=8D=E4=BD=9C=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 256 +------------ server/api/account.go | 4 +- server/api/api.go | 75 ++++ server/api/asset.go | 46 +-- server/api/login-log.go | 10 +- server/api/routes.go | 181 ++++++---- server/api/security.go | 12 +- server/api/ticker.go | 62 ++++ server/config/db.go | 53 +++ server/constant/const.go | 8 + server/global/global.go | 3 - server/model/access_security.go | 70 ---- server/model/asset.go | 200 +---------- server/model/asset_attribute.go | 119 ------- server/model/command.go | 68 ---- server/model/credential.go | 94 ----- server/model/job.go | 336 +----------------- server/model/job_log.go | 30 -- server/model/login_log.go | 76 ---- server/model/num.go | 16 - server/model/property.go | 81 ----- server/model/resource-sharer.go | 185 ---------- server/model/session.go | 154 -------- server/model/user-attribute.go | 23 -- server/model/user-group-member.go | 18 - server/model/user-group.go | 127 +------ server/repository/access_security.go | 81 +++++ server/repository/asset.go | 302 ++++++++++++++++ server/repository/command.go | 82 +++++ server/repository/credential.go | 108 ++++++ server/repository/definitions.go | 17 + server/repository/job.go | 108 ++++++ server/repository/job_log.go | 28 ++ server/repository/login_log.go | 69 ++++ server/repository/num.go | 26 ++ server/repository/property.go | 90 +++++ server/repository/resource_sharer.go | 193 ++++++++++ server/repository/session.go | 167 +++++++++ server/repository/user.go | 5 + server/repository/user_group.go | 140 ++++++++ server/service/job.go | 324 +++++++++++++++++ .../{handle/runner.go => service/property.go} | 125 ++----- server/service/session.go | 32 ++ server/service/user.go | 104 ++++++ 44 files changed, 2292 insertions(+), 2016 deletions(-) create mode 100644 server/api/api.go create mode 100644 server/api/ticker.go create mode 100644 server/config/db.go delete mode 100644 server/model/asset_attribute.go delete mode 100644 server/model/job_log.go delete mode 100644 server/model/user-attribute.go delete mode 100644 server/model/user-group-member.go create mode 100644 server/repository/access_security.go create mode 100644 server/repository/asset.go create mode 100644 server/repository/command.go create mode 100644 server/repository/credential.go create mode 100644 server/repository/definitions.go create mode 100644 server/repository/job.go create mode 100644 server/repository/job_log.go create mode 100644 server/repository/login_log.go create mode 100644 server/repository/num.go create mode 100644 server/repository/property.go create mode 100644 server/repository/resource_sharer.go create mode 100644 server/repository/session.go create mode 100644 server/repository/user_group.go create mode 100644 server/service/job.go rename server/{handle/runner.go => service/property.go} (53%) create mode 100644 server/service/session.go create mode 100644 server/service/user.go diff --git a/main.go b/main.go index c7e4568..6231d76 100644 --- a/main.go +++ b/main.go @@ -3,37 +3,22 @@ package main import ( "bytes" "fmt" - "io" - "next-terminal/server/repository" - "os" - "strconv" - "strings" - "time" - - "next-terminal/server/api" - "next-terminal/server/config" - "next-terminal/server/constant" - "next-terminal/server/global" - "next-terminal/server/handle" - "next-terminal/server/model" - "next-terminal/server/utils" - nested "github.com/antonfisher/nested-logrus-formatter" "github.com/labstack/gommon/log" - "github.com/patrickmn/go-cache" "github.com/robfig/cron/v3" "github.com/sirupsen/logrus" - "gorm.io/driver/mysql" - "gorm.io/driver/sqlite" "gorm.io/gorm" - "gorm.io/gorm/logger" + "io" + "next-terminal/server/api" + "next-terminal/server/config" + "next-terminal/server/global" + "os" ) const Version = "v0.3.3" var ( - db *gorm.DB - userRepository repository.UserRepository + db *gorm.DB ) func main() { @@ -71,52 +56,25 @@ func Run() error { logrus.SetOutput(io.MultiWriter(writer1, writer2, writer3)) global.Config = config.SetupConfig() - db = SetupDB() + db = config.SetupDB() - // 初始化 repository - global.DB = db - userRepository = repository.UserRepository{DB: db} - - if global.Config.ResetPassword != "" { - return ResetPassword() - } - - if err := global.DB.AutoMigrate(&model.User{}, &model.Asset{}, &model.AssetAttribute{}, &model.Session{}, &model.Command{}, - &model.Credential{}, &model.Property{}, &model.ResourceSharer{}, &model.UserGroup{}, &model.UserGroupMember{}, - &model.LoginLog{}, &model.Num{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}); err != nil { - return err - } - - if err := InitDBData(); err != nil { - return err - } + //if global.Config.ResetPassword != "" { + // return ResetPassword() + //} if err := api.ReloadAccessSecurity(); err != nil { return err } - // 配置缓存器 - global.Cache = cache.New(5*time.Minute, 10*time.Minute) - global.Cache.OnEvicted(func(key string, value interface{}) { - if strings.HasPrefix(key, api.Token) { - token := api.GetTokenFormCacheKey(key) - logrus.Debugf("用户Token「%v」过期", token) - err := model.Logout(token) - if err != nil { - logrus.Errorf("退出登录失败 %v", err) - } - } - }) global.Store = global.NewStore() global.Cron = cron.New(cron.WithSeconds()) //精确到秒 global.Cron.Start() - e := api.SetupRoutes(userRepository) - if err := handle.InitProperties(); err != nil { - return err - } - // 启动定时任务 - go handle.RunTicker() + e := api.SetupRoutes(db) + global.Cache = api.SetupCache() + + api.SetupTicker() + go handle.RunDataFix() if global.Config.Server.Cert != "" && global.Config.Server.Key != "" { @@ -126,187 +84,3 @@ func Run() error { } } - -func InitDBData() (err error) { - users := userRepository.FindAll() - - if len(users) == 0 { - initPassword := "admin" - var pass []byte - if pass, err = utils.Encoder.Encode([]byte(initPassword)); err != nil { - return err - } - - user := model.User{ - ID: utils.UUID(), - Username: "admin", - Password: string(pass), - Nickname: "超级管理员", - Type: constant.TypeAdmin, - Created: utils.NowJsonTime(), - } - if err := userRepository.Create(&user); err != nil { - return err - } - logrus.Infof("初始用户创建成功,账号:「%v」密码:「%v」", user.Username, initPassword) - } else { - for i := range users { - // 修正默认用户类型为管理员 - if users[i].Type == "" { - user := model.User{ - Type: constant.TypeAdmin, - ID: users[i].ID, - } - if err := userRepository.Update(&user); err != nil { - return err - } - logrus.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID) - } - } - } - - if len(model.FindAllTemp()) == 0 { - for i := 0; i <= 30; i++ { - if err := model.CreateNewTemp(&model.Num{I: strconv.Itoa(i)}); err != nil { - return err - } - } - } - - jobs, err := model.FindJobByFunc(constant.FuncCheckAssetStatusJob) - if err != nil { - return err - } - if len(jobs) == 0 { - job := model.Job{ - ID: utils.UUID(), - Name: "资产状态检测", - Func: constant.FuncCheckAssetStatusJob, - Cron: "0 0 0/1 * * ?", - Mode: constant.JobModeAll, - Status: constant.JobStatusRunning, - Created: utils.NowJsonTime(), - Updated: utils.NowJsonTime(), - } - if err := model.CreateNewJob(&job); err != nil { - return err - } - logrus.Debugf("创建计划任务「%v」cron「%v」", job.Name, job.Cron) - } else { - for i := range jobs { - if jobs[i].Status == constant.JobStatusRunning { - err := model.ChangeJobStatusById(jobs[i].ID, constant.JobStatusRunning) - if err != nil { - return err - } - logrus.Debugf("启动计划任务「%v」cron「%v」", jobs[i].Name, jobs[i].Cron) - } - } - } - - loginLogs, err := model.FindAliveLoginLogs() - if err != nil { - return err - } - - for i := range loginLogs { - loginLog := loginLogs[i] - token := loginLog.ID - user, err := userRepository.FindById(loginLog.UserId) - if err != nil { - logrus.Debugf("用户「%v」获取失败,忽略", loginLog.UserId) - continue - } - - authorization := api.Authorization{ - Token: token, - Remember: loginLog.Remember, - User: user, - } - - cacheKey := api.BuildCacheKeyByToken(token) - - if authorization.Remember { - // 记住登录有效期两周 - global.Cache.Set(cacheKey, authorization, api.RememberEffectiveTime) - } else { - global.Cache.Set(cacheKey, authorization, api.NotRememberEffectiveTime) - } - logrus.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token) - } - - // 修正用户登录状态 - onlineUsers, err := userRepository.FindOnlineUsers() - if err != nil { - return err - } - for i := range onlineUsers { - logs, err := model.FindAliveLoginLogsByUserId(onlineUsers[i].ID) - if err != nil { - return err - } - if len(logs) == 0 { - if err := userRepository.UpdateOnline(onlineUsers[i].ID, false); err != nil { - return err - } - } - } - - return nil -} - -func ResetPassword() error { - user, err := userRepository.FindByUsername(global.Config.ResetPassword) - if err != nil { - return err - } - password := "next-terminal" - passwd, err := utils.Encoder.Encode([]byte(password)) - if err != nil { - return err - } - u := &model.User{ - Password: string(passwd), - ID: user.ID, - } - if err := userRepository.Update(u); err != nil { - return err - } - logrus.Debugf("用户「%v」密码初始化为: %v", user.Username, password) - return nil -} - -func SetupDB() *gorm.DB { - - var logMode logger.Interface - if global.Config.Debug { - logMode = logger.Default.LogMode(logger.Info) - } else { - logMode = logger.Default.LogMode(logger.Silent) - } - - fmt.Printf("当前数据库模式为:%v\n", global.Config.DB) - var err error - var db *gorm.DB - if global.Config.DB == "mysql" { - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", - global.Config.Mysql.Username, - global.Config.Mysql.Password, - global.Config.Mysql.Hostname, - global.Config.Mysql.Port, - global.Config.Mysql.Database, - ) - db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ - Logger: logMode, - }) - } else { - db, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{ - Logger: logMode, - }) - } - - if err != nil { - logrus.WithError(err).Panic("连接数据库异常") - } - return db -} diff --git a/server/api/account.go b/server/api/account.go index 3836d9c..237faa6 100644 --- a/server/api/account.go +++ b/server/api/account.go @@ -111,7 +111,7 @@ func LoginSuccess(c echo.Context, loginAccount LoginAccount, user model.User) (t Remember: authorization.Remember, } - if model.CreateNewLoginLog(&loginLog) != nil { + if loginLogRepository.Create(&loginLog) != nil { return "", err } @@ -179,7 +179,7 @@ func LogoutEndpoint(c echo.Context) error { token := GetToken(c) cacheKey := BuildCacheKeyByToken(token) global.Cache.Delete(cacheKey) - err := model.Logout(token) + err := userService.Logout(token) if err != nil { return err } diff --git a/server/api/api.go b/server/api/api.go new file mode 100644 index 0000000..0c4db84 --- /dev/null +++ b/server/api/api.go @@ -0,0 +1,75 @@ +package api + +import ( + "github.com/labstack/echo/v4" + "next-terminal/server/constant" + "next-terminal/server/global" + "next-terminal/server/model" +) + +type H map[string]interface{} + +func Fail(c echo.Context, code int, message string) error { + return c.JSON(200, H{ + "code": code, + "message": message, + }) +} + +func FailWithData(c echo.Context, code int, message string, data interface{}) error { + return c.JSON(200, H{ + "code": code, + "message": message, + "data": data, + }) +} + +func Success(c echo.Context, data interface{}) error { + return c.JSON(200, H{ + "code": 1, + "message": "success", + "data": data, + }) +} + +func NotFound(c echo.Context, message string) error { + return c.JSON(200, H{ + "code": -1, + "message": message, + }) +} + +func GetToken(c echo.Context) string { + token := c.Request().Header.Get(Token) + if len(token) > 0 { + return token + } + return c.QueryParam(Token) +} + +func GetCurrentAccount(c echo.Context) (model.User, bool) { + token := GetToken(c) + cacheKey := BuildCacheKeyByToken(token) + get, b := global.Cache.Get(cacheKey) + if b { + return get.(Authorization).User, true + } + return model.User{}, false +} + +func HasPermission(c echo.Context, owner string) bool { + // 检测是否登录 + account, found := GetCurrentAccount(c) + if !found { + return false + } + // 检测是否为管理人员 + if constant.TypeAdmin == account.Type { + return true + } + // 检测是否为所有者 + if owner == account.ID { + return true + } + return false +} diff --git a/server/api/asset.go b/server/api/asset.go index f160cfe..56ff2d2 100644 --- a/server/api/asset.go +++ b/server/api/asset.go @@ -32,18 +32,18 @@ func AssetCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := model.CreateNewAsset(&item); err != nil { + if err := assetRepository.Create(&item); err != nil { return err } - if err := model.UpdateAssetAttributes(item.ID, item.Protocol, m); err != nil { + if err := assetRepository.UpdateAttributes(item.ID, item.Protocol, m); err != nil { return err } // 创建后自动检测资产是否存活 go func() { active := utils.Tcping(item.IP, item.Port) - model.UpdateAssetActiveById(active, item.ID) + assetRepository.UpdateActiveById(active, item.ID) }() return Success(c, item) @@ -98,7 +98,7 @@ func AssetImportEndpoint(c echo.Context) error { Owner: account.ID, } - err := model.CreateNewAsset(&asset) + err := assetRepository.Create(&asset) if err != nil { errorCount++ m[strconv.Itoa(i)] = err.Error() @@ -107,7 +107,7 @@ func AssetImportEndpoint(c echo.Context) error { // 创建后自动检测资产是否存活 go func() { active := utils.Tcping(asset.IP, asset.Port) - model.UpdateAssetActiveById(active, asset.ID) + assetRepository.UpdateActiveById(active, asset.ID) }() } } @@ -135,7 +135,7 @@ func AssetPagingEndpoint(c echo.Context) error { field := c.QueryParam("field") account, _ := GetCurrentAccount(c) - items, total, err := model.FindPageAsset(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) + items, total, err := assetRepository.Find(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) if err != nil { return err } @@ -149,7 +149,7 @@ func AssetPagingEndpoint(c echo.Context) error { func AssetAllEndpoint(c echo.Context) error { protocol := c.QueryParam("protocol") account, _ := GetCurrentAccount(c) - items, _ := model.FindAssetByConditions(protocol, account) + items, _ := assetRepository.FindByProtocolAndUser(protocol, account) return Success(c, items) } @@ -199,8 +199,10 @@ func AssetUpdateEndpoint(c echo.Context) error { item.Description = "-" } - model.UpdateAssetById(&item, id) - if err := model.UpdateAssetAttributes(id, item.Protocol, m); err != nil { + if err := assetRepository.UpdateById(&item, id); err != nil { + return err + } + if err := assetRepository.UpdateAttributes(id, item.Protocol, m); err != nil { return err } @@ -214,7 +216,7 @@ func AssetGetAttributeEndpoint(c echo.Context) error { return err } - attributeMap, err := model.FindAssetAttrMapByAssetId(assetId) + attributeMap, err := assetRepository.FindAssetAttrMapByAssetId(assetId) if err != nil { return err } @@ -229,7 +231,7 @@ func AssetUpdateAttributeEndpoint(c echo.Context) error { assetId := c.Param("id") protocol := c.QueryParam("protocol") - err := model.UpdateAssetAttributes(assetId, protocol, m) + err := assetRepository.UpdateAttributes(assetId, protocol, m) if err != nil { return err } @@ -243,11 +245,11 @@ func AssetDeleteEndpoint(c echo.Context) error { if err := PreCheckAssetPermission(c, split[i]); err != nil { return err } - if err := model.DeleteAssetById(split[i]); err != nil { + if err := assetRepository.DeleteById(split[i]); err != nil { return err } // 删除资产与用户的关系 - if err := model.DeleteResourceSharerByResourceId(split[i]); err != nil { + if err := resourceSharerRepository.DeleteResourceSharerByResourceId(split[i]); err != nil { return err } } @@ -262,10 +264,10 @@ func AssetGetEndpoint(c echo.Context) (err error) { } var item model.Asset - if item, err = model.FindAssetById(id); err != nil { + if item, err = assetRepository.FindById(id); err != nil { return err } - attributeMap, err := model.FindAssetAttrMapByAssetId(id) + attributeMap, err := assetRepository.FindAssetAttrMapByAssetId(id) if err != nil { return err } @@ -281,19 +283,21 @@ func AssetTcpingEndpoint(c echo.Context) (err error) { id := c.Param("id") var item model.Asset - if item, err = model.FindAssetById(id); err != nil { + if item, err = assetRepository.FindById(id); err != nil { return err } active := utils.Tcping(item.IP, item.Port) - model.UpdateAssetActiveById(active, item.ID) + if err := assetRepository.UpdateActiveById(active, item.ID); err != nil { + return err + } return Success(c, active) } func AssetTagsEndpoint(c echo.Context) (err error) { var items []string - if items, err = model.FindAssetTags(); err != nil { + if items, err = assetRepository.FindTags(); err != nil { return err } return Success(c, items) @@ -307,12 +311,14 @@ func AssetChangeOwnerEndpoint(c echo.Context) (err error) { } owner := c.QueryParam("owner") - model.UpdateAssetById(&model.Asset{Owner: owner}, id) + if err := assetRepository.UpdateById(&model.Asset{Owner: owner}, id); err != nil { + return err + } return Success(c, "") } func PreCheckAssetPermission(c echo.Context, id string) error { - item, err := model.FindAssetById(id) + item, err := assetRepository.FindById(id) if err != nil { return err } diff --git a/server/api/login-log.go b/server/api/login-log.go index 4c278ff..ec3e6df 100644 --- a/server/api/login-log.go +++ b/server/api/login-log.go @@ -4,11 +4,9 @@ import ( "strconv" "strings" - "next-terminal/server/global" - "next-terminal/server/model" - "github.com/labstack/echo/v4" "github.com/sirupsen/logrus" + "next-terminal/server/global" ) func LoginLogPagingEndpoint(c echo.Context) error { @@ -17,7 +15,7 @@ func LoginLogPagingEndpoint(c echo.Context) error { userId := c.QueryParam("userId") clientIp := c.QueryParam("clientIp") - items, total, err := model.FindPageLoginLog(pageIndex, pageSize, userId, clientIp) + items, total, err := loginLogRepository.Find(pageIndex, pageSize, userId, clientIp) if err != nil { return err @@ -35,11 +33,11 @@ func LoginLogDeleteEndpoint(c echo.Context) error { for i := range split { token := split[i] global.Cache.Delete(token) - if err := model.Logout(token); err != nil { + if err := userService.Logout(token); err != nil { logrus.WithError(err).Error("Cache Delete Failed") } } - if err := model.DeleteLoginLogByIdIn(split); err != nil { + if err := loginLogRepository.DeleteByIdIn(split); err != nil { return err } diff --git a/server/api/routes.go b/server/api/routes.go index 5249ffa..e907648 100644 --- a/server/api/routes.go +++ b/server/api/routes.go @@ -1,12 +1,19 @@ package api import ( + "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" + "gorm.io/gorm" "net/http" - "next-terminal/server/constant" "next-terminal/server/global" "next-terminal/server/log" "next-terminal/server/model" "next-terminal/server/repository" + "next-terminal/server/service" + "next-terminal/server/utils" + "strconv" + "strings" + "time" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" @@ -15,11 +22,32 @@ import ( const Token = "X-Auth-Token" var ( - userRepository repository.UserRepository + userRepository *repository.UserRepository + userGroupRepository *repository.UserGroupRepository + resourceSharerRepository *repository.ResourceSharerRepository + assetRepository *repository.AssetRepository + credentialRepository *repository.CredentialRepository + propertyRepository *repository.PropertyRepository + commandRepository *repository.CommandRepository + sessionRepository *repository.SessionRepository + numRepository *repository.NumRepository + accessSecurityRepository *repository.AccessSecurityRepository + jobRepository *repository.JobRepository + jobLogRepository *repository.JobLogRepository + loginLogRepository *repository.LoginLogRepository + + jobService *service.JobService + propertyService *service.PropertyService + userService *service.UserService + sessionService *service.SessionService ) -func SetupRoutes(ur repository.UserRepository) *echo.Echo { - userRepository = ur +func SetupRoutes(db *gorm.DB) *echo.Echo { + + InitRepository(db) + InitService() + + InitDBData() e := echo.New() e.HideBanner = true @@ -178,69 +206,88 @@ func SetupRoutes(ur repository.UserRepository) *echo.Echo { return e } -type H map[string]interface{} +func InitRepository(db *gorm.DB) { + userRepository = repository.NewUserRepository(db) + userGroupRepository = repository.NewUserGroupRepository(db) + resourceSharerRepository = repository.NewResourceSharerRepository(db) + assetRepository = repository.NewAssetRepository(db) + credentialRepository = repository.NewCredentialRepository(db) + propertyRepository = repository.NewPropertyRepository(db) + commandRepository = repository.NewCommandRepository(db) + sessionRepository = repository.NewSessionRepository(db) + numRepository = repository.NewNumRepository(db) + accessSecurityRepository = repository.NewAccessSecurityRepository(db) + jobRepository = repository.NewJobRepository(db) + jobLogRepository = repository.NewJobLogRepository(db) + loginLogRepository = repository.NewLoginLogRepository(db) +} -func Fail(c echo.Context, code int, message string) error { - return c.JSON(200, H{ - "code": code, - "message": message, +func InitService() { + jobService = service.NewJobService(jobRepository, jobLogRepository, assetRepository, credentialRepository) + propertyService = service.NewPropertyService(propertyRepository) + userService = service.NewUserService(userRepository) + sessionService = service.NewSessionService(sessionRepository) +} + +func InitDBData() (err error) { + if err := propertyService.InitProperties(); err != nil { + return err + } + if err := userService.InitUser(); err != nil { + return err + } + if err := userService.FixedOnlineState(); err != nil { + return err + } + if err := jobService.InitJob(); err != nil { + return err + } + + sessionService.Fix() + nums, _ := numRepository.FindAll() + if nums == nil || len(nums) == 0 { + for i := 0; i <= 30; i++ { + if err := numRepository.Create(&model.Num{I: strconv.Itoa(i)}); err != nil { + return err + } + } + } + return nil +} + +func ResetPassword() error { + user, err := userRepository.FindByUsername(global.Config.ResetPassword) + if err != nil { + return err + } + password := "next-terminal" + passwd, err := utils.Encoder.Encode([]byte(password)) + if err != nil { + return err + } + u := &model.User{ + Password: string(passwd), + ID: user.ID, + } + if err := userRepository.Update(u); err != nil { + return err + } + logrus.Debugf("用户「%v」密码初始化为: %v", user.Username, password) + return nil +} + +func SetupCache() *cache.Cache { + // 配置缓存器 + mCache := cache.New(5*time.Minute, 10*time.Minute) + mCache.OnEvicted(func(key string, value interface{}) { + if strings.HasPrefix(key, Token) { + token := GetTokenFormCacheKey(key) + logrus.Debugf("用户Token「%v」过期", token) + err := userService.Logout(token) + if err != nil { + logrus.Errorf("退出登录失败 %v", err) + } + } }) -} - -func FailWithData(c echo.Context, code int, message string, data interface{}) error { - return c.JSON(200, H{ - "code": code, - "message": message, - "data": data, - }) -} - -func Success(c echo.Context, data interface{}) error { - return c.JSON(200, H{ - "code": 1, - "message": "success", - "data": data, - }) -} - -func NotFound(c echo.Context, message string) error { - return c.JSON(200, H{ - "code": -1, - "message": message, - }) -} - -func GetToken(c echo.Context) string { - token := c.Request().Header.Get(Token) - if len(token) > 0 { - return token - } - return c.QueryParam(Token) -} - -func GetCurrentAccount(c echo.Context) (model.User, bool) { - token := GetToken(c) - cacheKey := BuildCacheKeyByToken(token) - get, b := global.Cache.Get(cacheKey) - if b { - return get.(Authorization).User, true - } - return model.User{}, false -} - -func HasPermission(c echo.Context, owner string) bool { - // 检测是否登录 - account, found := GetCurrentAccount(c) - if !found { - return false - } - // 检测是否为管理人员 - if constant.TypeAdmin == account.Type { - return true - } - // 检测是否为所有者 - if owner == account.ID { - return true - } - return false + return mCache } diff --git a/server/api/security.go b/server/api/security.go index fd80b29..a879713 100644 --- a/server/api/security.go +++ b/server/api/security.go @@ -20,7 +20,7 @@ func SecurityCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Source = "管理员添加" - if err := model.CreateNewSecurity(&item); err != nil { + if err := accessSecurityRepository.Create(&item); err != nil { return err } // 更新内存中的安全规则 @@ -31,7 +31,7 @@ func SecurityCreateEndpoint(c echo.Context) error { } func ReloadAccessSecurity() error { - rules, err := model.FindAllAccessSecurities() + rules, err := accessSecurityRepository.FindAllAccessSecurities() if err != nil { return err } @@ -58,7 +58,7 @@ func SecurityPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := model.FindPageSecurity(pageIndex, pageSize, ip, rule, order, field) + items, total, err := accessSecurityRepository.Find(pageIndex, pageSize, ip, rule, order, field) if err != nil { return err } @@ -77,7 +77,7 @@ func SecurityUpdateEndpoint(c echo.Context) error { return err } - if err := model.UpdateSecurityById(&item, id); err != nil { + if err := accessSecurityRepository.UpdateById(&item, id); err != nil { return err } // 更新内存中的安全规则 @@ -93,7 +93,7 @@ func SecurityDeleteEndpoint(c echo.Context) error { split := strings.Split(ids, ",") for i := range split { jobId := split[i] - if err := model.DeleteSecurityById(jobId); err != nil { + if err := accessSecurityRepository.DeleteById(jobId); err != nil { return err } } @@ -107,7 +107,7 @@ func SecurityDeleteEndpoint(c echo.Context) error { func SecurityGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := model.FindSecurityById(id) + item, err := accessSecurityRepository.FindById(id) if err != nil { return err } diff --git a/server/api/ticker.go b/server/api/ticker.go new file mode 100644 index 0000000..fa9d2f8 --- /dev/null +++ b/server/api/ticker.go @@ -0,0 +1,62 @@ +package api + +import ( + "github.com/sirupsen/logrus" + "next-terminal/server/constant" + "strconv" + "time" +) + +func SetupTicker() { + + // 每隔一小时删除一次未使用的会话信息 + unUsedSessionTicker := time.NewTicker(time.Minute * 60) + go func() { + for range unUsedSessionTicker.C { + sessions, _ := sessionRepository.FindByStatusIn([]string{constant.NoConnect, constant.Connecting}) + if len(sessions) > 0 { + now := time.Now() + for i := range sessions { + if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 { + _ = sessionRepository.DeleteById(sessions[i].ID) + s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port) + logrus.Infof("会话「%v」ID「%v」超过1小时未打开,已删除。", s, sessions[i].ID) + } + } + } + } + }() + + // 每日凌晨删除超过时长限制的会话 + timeoutSessionTicker := time.NewTicker(time.Hour * 24) + go func() { + for range timeoutSessionTicker.C { + property, err := propertyRepository.FindByName("session-saved-limit") + if err != nil { + return + } + if property.Value == "" || property.Value == "-" { + return + } + limit, err := strconv.Atoi(property.Value) + if err != nil { + return + } + sessions, err := sessionRepository.FindOutTimeSessions(limit) + if err != nil { + return + } + + if len(sessions) > 0 { + var sessionIds []string + for i := range sessions { + sessionIds = append(sessionIds, sessions[i].ID) + } + err := sessionRepository.DeleteByIds(sessionIds) + if err != nil { + logrus.Errorf("删除离线会话失败 %v", err) + } + } + } + }() +} diff --git a/server/config/db.go b/server/config/db.go new file mode 100644 index 0000000..3ef9428 --- /dev/null +++ b/server/config/db.go @@ -0,0 +1,53 @@ +package config + +import ( + "fmt" + "github.com/sirupsen/logrus" + "gorm.io/driver/mysql" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "next-terminal/server/global" + "next-terminal/server/model" +) + +func SetupDB() *gorm.DB { + + var logMode logger.Interface + if global.Config.Debug { + logMode = logger.Default.LogMode(logger.Info) + } else { + logMode = logger.Default.LogMode(logger.Silent) + } + + fmt.Printf("当前数据库模式为:%v\n", global.Config.DB) + var err error + var db *gorm.DB + if global.Config.DB == "mysql" { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", + global.Config.Mysql.Username, + global.Config.Mysql.Password, + global.Config.Mysql.Hostname, + global.Config.Mysql.Port, + global.Config.Mysql.Database, + ) + db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logMode, + }) + } else { + db, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{ + Logger: logMode, + }) + } + + if err != nil { + logrus.WithError(err).Panic("连接数据库异常") + } + + if err := db.AutoMigrate(&model.User{}, &model.Asset{}, &model.AssetAttribute{}, &model.Session{}, &model.Command{}, + &model.Credential{}, &model.Property{}, &model.ResourceSharer{}, &model.UserGroup{}, &model.UserGroupMember{}, + &model.LoginLog{}, &model.Num{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}); err != nil { + logrus.WithError(err).Panic("初始化数据库表结构异常") + } + return db +} diff --git a/server/constant/const.go b/server/constant/const.go index cdba074..0865990 100644 --- a/server/constant/const.go +++ b/server/constant/const.go @@ -1,5 +1,7 @@ package constant +import "next-terminal/server/guacd" + const ( AccessRuleAllow = "allow" // 允许访问 AccessRuleReject = "reject" // 拒绝访问 @@ -31,3 +33,9 @@ const ( TypeUser = "user" // 普通用户 TypeAdmin = "admin" // 管理员 ) + +var SSHParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, SshMode} +var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs} +var VNCParameterNames = []string{guacd.ColorDepth, guacd.Cursor, guacd.SwapRedBlue, guacd.DestHost, guacd.DestPort} +var TelnetParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.UsernameRegex, guacd.PasswordRegex, guacd.LoginSuccessRegex, guacd.LoginFailureRegex} +var KubernetesParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.Namespace, guacd.Pod, guacd.Container, guacd.UesSSL, guacd.ClientCert, guacd.ClientKey, guacd.CaCert, guacd.IgnoreCert} diff --git a/server/global/global.go b/server/global/global.go index e4ea8bb..1903e6f 100644 --- a/server/global/global.go +++ b/server/global/global.go @@ -5,11 +5,8 @@ import ( "github.com/patrickmn/go-cache" "github.com/robfig/cron/v3" - "gorm.io/gorm" ) -var DB *gorm.DB - var Cache *cache.Cache var Config *config.Config diff --git a/server/model/access_security.go b/server/model/access_security.go index db351cb..c111d12 100644 --- a/server/model/access_security.go +++ b/server/model/access_security.go @@ -1,9 +1,5 @@ package model -import ( - "next-terminal/server/global" -) - type AccessSecurity struct { ID string `json:"id"` Rule string `json:"rule"` @@ -15,69 +11,3 @@ type AccessSecurity struct { func (r *AccessSecurity) TableName() string { return "access_securities" } - -func FindAllAccessSecurities() (o []AccessSecurity, err error) { - db := global.DB - err = db.Order("priority asc").Find(&o).Error - return -} - -func FindPageSecurity(pageIndex, pageSize int, ip, rule, order, field string) (o []AccessSecurity, total int64, err error) { - t := AccessSecurity{} - db := global.DB.Table(t.TableName()) - dbCounter := global.DB.Table(t.TableName()) - - if len(ip) > 0 { - db = db.Where("ip like ?", "%"+ip+"%") - dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") - } - - if len(rule) > 0 { - db = db.Where("rule = ?", rule) - dbCounter = dbCounter.Where("rule = ?", rule) - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "descend" { - order = "desc" - } else { - order = "asc" - } - - if field == "ip" { - field = "ip" - } else if field == "rule" { - field = "rule" - } else { - field = "priority" - } - - err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]AccessSecurity, 0) - } - return -} - -func CreateNewSecurity(o *AccessSecurity) error { - return global.DB.Create(o).Error -} - -func UpdateSecurityById(o *AccessSecurity, id string) error { - o.ID = id - return global.DB.Updates(o).Error -} - -func DeleteSecurityById(id string) error { - - return global.DB.Where("id = ?", id).Delete(AccessSecurity{}).Error -} - -func FindSecurityById(id string) (o *AccessSecurity, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} diff --git a/server/model/asset.go b/server/model/asset.go index 6e37ca1..8b8a7b3 100644 --- a/server/model/asset.go +++ b/server/model/asset.go @@ -1,10 +1,6 @@ package model import ( - "strings" - - "next-terminal/server/constant" - "next-terminal/server/global" "next-terminal/server/utils" ) @@ -45,195 +41,13 @@ func (r *Asset) TableName() string { return "assets" } -func FindAllAsset() (o []Asset, err error) { - err = global.DB.Find(&o).Error - return +type AssetAttribute struct { + Id string `gorm:"index" json:"id"` + AssetId string `gorm:"index" json:"assetId"` + Name string `gorm:"index" json:"name"` + Value string `json:"value"` } -func FindAssetByIds(assetIds []string) (o []Asset, err error) { - err = global.DB.Where("id in ?", assetIds).Find(&o).Error - return -} - -func FindAssetByProtocol(protocol string) (o []Asset, err error) { - err = global.DB.Where("protocol = ?", protocol).Find(&o).Error - return -} - -func FindAssetByProtocolAndIds(protocol string, assetIds []string) (o []Asset, err error) { - err = global.DB.Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error - return -} - -func FindAssetByConditions(protocol string, account User) (o []Asset, err error) { - db := global.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) - } - - if len(protocol) > 0 { - db = db.Where("assets.protocol = ?", protocol) - } - err = db.Find(&o).Error - return -} - -func FindPageAsset(pageIndex, pageSize int, name, protocol, tags string, account User, owner, sharer, userGroupId, ip, order, field string) (o []AssetVo, total int64, err error) { - db := global.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") - dbCounter := global.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) - dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) - - // 查询用户所在用户组列表 - userGroupIds, err := FindUserGroupIdsByUserId(account.ID) - if err != nil { - return nil, 0, err - } - - if len(userGroupIds) > 0 { - db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) - dbCounter = dbCounter.Or("resource_sharers.user_group_id in ?", userGroupIds) - } - } else { - if len(owner) > 0 { - db = db.Where("assets.owner = ?", owner) - dbCounter = dbCounter.Where("assets.owner = ?", owner) - } - if len(sharer) > 0 { - db = db.Where("resource_sharers.user_id = ?", sharer) - dbCounter = dbCounter.Where("resource_sharers.user_id = ?", sharer) - } - - if len(userGroupId) > 0 { - db = db.Where("resource_sharers.user_group_id = ?", userGroupId) - dbCounter = dbCounter.Where("resource_sharers.user_group_id = ?", userGroupId) - } - } - - if len(name) > 0 { - db = db.Where("assets.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("assets.name like ?", "%"+name+"%") - } - - if len(ip) > 0 { - db = db.Where("assets.ip like ?", "%"+ip+"%") - dbCounter = dbCounter.Where("assets.ip like ?", "%"+ip+"%") - } - - if len(protocol) > 0 { - db = db.Where("assets.protocol = ?", protocol) - dbCounter = dbCounter.Where("assets.protocol = ?", protocol) - } - - if len(tags) > 0 { - tagArr := strings.Split(tags, ",") - for i := range tagArr { - if global.Config.DB == "sqlite" { - db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") - dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") - } else { - db = db.Where("find_in_set(?, assets.tags)", tagArr[i]) - dbCounter = dbCounter.Where("find_in_set(?, assets.tags)", tagArr[i]) - } - } - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("assets." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - - if o == nil { - o = make([]AssetVo, 0) - } - return -} - -func CreateNewAsset(o *Asset) (err error) { - if err = global.DB.Create(o).Error; err != nil { - return err - } - return nil -} - -func FindAssetById(id string) (o Asset, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func UpdateAssetById(o *Asset, id string) { - o.ID = id - global.DB.Updates(o) -} - -func UpdateAssetActiveById(active bool, id string) { - sql := "update assets set active = ? where id = ?" - global.DB.Exec(sql, active, id) -} - -func DeleteAssetById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Asset{}).Error -} - -func CountAsset() (total int64, err error) { - err = global.DB.Find(&Asset{}).Count(&total).Error - return -} - -func CountAssetByUserId(userId string) (total int64, err error) { - db := global.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") - - db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId) - - // 查询用户所在用户组列表 - userGroupIds, err := FindUserGroupIdsByUserId(userId) - if err != nil { - return 0, err - } - - if len(userGroupIds) > 0 { - db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) - } - err = db.Find(&Asset{}).Count(&total).Error - return -} - -func FindAssetTags() (o []string, err error) { - var assets []Asset - err = global.DB.Not("tags = ?", "").Find(&assets).Error - if err != nil { - return nil, err - } - - o = make([]string, 0) - - for i := range assets { - if len(assets[i].Tags) == 0 { - continue - } - split := strings.Split(assets[i].Tags, ",") - - o = append(o, split...) - } - - return utils.Distinct(o), nil +func (r *AssetAttribute) TableName() string { + return "asset_attributes" } diff --git a/server/model/asset_attribute.go b/server/model/asset_attribute.go deleted file mode 100644 index 21db88c..0000000 --- a/server/model/asset_attribute.go +++ /dev/null @@ -1,119 +0,0 @@ -package model - -import ( - "fmt" - - "next-terminal/server/constant" - "next-terminal/server/global" - "next-terminal/server/guacd" - "next-terminal/server/utils" - - "github.com/labstack/echo/v4" - "gorm.io/gorm" -) - -type AssetAttribute struct { - Id string `gorm:"index" json:"id"` - AssetId string `gorm:"index" json:"assetId"` - Name string `gorm:"index" json:"name"` - Value string `json:"value"` -} - -func (r *AssetAttribute) TableName() string { - return "asset_attributes" -} - -var SSHParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, constant.SshMode} -var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs} -var VNCParameterNames = []string{guacd.ColorDepth, guacd.Cursor, guacd.SwapRedBlue, guacd.DestHost, guacd.DestPort} -var TelnetParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.UsernameRegex, guacd.PasswordRegex, guacd.LoginSuccessRegex, guacd.LoginFailureRegex} -var KubernetesParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.Namespace, guacd.Pod, guacd.Container, guacd.UesSSL, guacd.ClientCert, guacd.ClientKey, guacd.CaCert, guacd.IgnoreCert} - -func UpdateAssetAttributes(assetId, protocol string, m echo.Map) error { - var data []AssetAttribute - var parameterNames []string - switch protocol { - case "ssh": - parameterNames = SSHParameterNames - case "rdp": - parameterNames = RDPParameterNames - case "vnc": - parameterNames = VNCParameterNames - case "telnet": - parameterNames = TelnetParameterNames - case "kubernetes": - parameterNames = KubernetesParameterNames - - } - - for i := range parameterNames { - name := parameterNames[i] - if m[name] != nil && m[name] != "" { - data = append(data, genAttribute(assetId, name, m)) - } - } - - return global.DB.Transaction(func(tx *gorm.DB) error { - err := tx.Where("asset_id = ?", assetId).Delete(&AssetAttribute{}).Error - if err != nil { - return err - } - return tx.CreateInBatches(&data, len(data)).Error - }) -} - -func genAttribute(assetId, name string, m echo.Map) AssetAttribute { - value := fmt.Sprintf("%v", m[name]) - attribute := AssetAttribute{ - Id: utils.Sign([]string{assetId, name}), - AssetId: assetId, - Name: name, - Value: value, - } - return attribute -} - -func FindAssetAttributeByAssetId(assetId string) (o []AssetAttribute, err error) { - err = global.DB.Where("asset_id = ?", assetId).Find(&o).Error - if o == nil { - o = make([]AssetAttribute, 0) - } - return o, err -} - -func FindAssetAttrMapByAssetId(assetId string) (map[string]interface{}, error) { - asset, err := FindAssetById(assetId) - if err != nil { - return nil, err - } - attributes, err := FindAssetAttributeByAssetId(assetId) - if err != nil { - return nil, err - } - - var parameterNames []string - switch asset.Protocol { - case "ssh": - parameterNames = SSHParameterNames - case "rdp": - parameterNames = RDPParameterNames - case "vnc": - parameterNames = VNCParameterNames - case "telnet": - parameterNames = TelnetParameterNames - case "kubernetes": - parameterNames = KubernetesParameterNames - } - propertiesMap := FindAllPropertiesMap() - var attributeMap = make(map[string]interface{}) - for name := range propertiesMap { - if utils.Contains(parameterNames, name) { - attributeMap[name] = propertiesMap[name] - } - } - - for i := range attributes { - attributeMap[attributes[i].Name] = attributes[i].Value - } - return attributeMap, nil -} diff --git a/server/model/command.go b/server/model/command.go index 22d82f6..b085ea3 100644 --- a/server/model/command.go +++ b/server/model/command.go @@ -1,8 +1,6 @@ package model import ( - "next-terminal/server/constant" - "next-terminal/server/global" "next-terminal/server/utils" ) @@ -27,69 +25,3 @@ type CommandVo struct { func (r *Command) TableName() string { return "commands" } - -func FindPageCommand(pageIndex, pageSize int, name, content, order, field string, account User) (o []CommandVo, total int64, err error) { - - db := global.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") - dbCounter := global.DB.Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) - dbCounter = dbCounter.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) - } - - if len(name) > 0 { - db = db.Where("commands.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("commands.name like ?", "%"+name+"%") - } - - if len(content) > 0 { - db = db.Where("commands.content like ?", "%"+content+"%") - dbCounter = dbCounter.Where("commands.content like ?", "%"+content+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("commands." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - if o == nil { - o = make([]CommandVo, 0) - } - return -} - -func CreateNewCommand(o *Command) (err error) { - if err = global.DB.Create(o).Error; err != nil { - return err - } - return nil -} - -func FindCommandById(id string) (o Command, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func UpdateCommandById(o *Command, id string) { - o.ID = id - global.DB.Updates(o) -} - -func DeleteCommandById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Command{}).Error -} diff --git a/server/model/credential.go b/server/model/credential.go index 4a5570a..ddebdb0 100644 --- a/server/model/credential.go +++ b/server/model/credential.go @@ -1,8 +1,6 @@ package model import ( - "next-terminal/server/constant" - "next-terminal/server/global" "next-terminal/server/utils" ) @@ -37,95 +35,3 @@ type CredentialSimpleVo struct { ID string `json:"id"` Name string `json:"name"` } - -func FindAllCredential(account User) (o []CredentialSimpleVo, err error) { - db := global.DB.Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") - if account.Type == constant.TypeUser { - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", account.ID, account.ID) - } - err = db.Find(&o).Error - return -} - -func FindPageCredential(pageIndex, pageSize int, name, order, field string, account User) (o []CredentialVo, total int64, err error) { - db := global.DB.Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") - dbCounter := global.DB.Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) - dbCounter = dbCounter.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) - } - - if len(name) > 0 { - db = db.Where("credentials.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("credentials.name like ?", "%"+name+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("credentials." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - if o == nil { - o = make([]CredentialVo, 0) - } - return -} - -func CreateNewCredential(o *Credential) (err error) { - if err = global.DB.Create(o).Error; err != nil { - return err - } - return nil -} - -func FindCredentialById(id string) (o Credential, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func UpdateCredentialById(o *Credential, id string) { - o.ID = id - global.DB.Updates(o) -} - -func DeleteCredentialById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Credential{}).Error -} - -func CountCredential() (total int64, err error) { - err = global.DB.Find(&Credential{}).Count(&total).Error - return -} - -func CountCredentialByUserId(userId string) (total int64, err error) { - db := global.DB.Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") - - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId) - - // 查询用户所在用户组列表 - userGroupIds, err := FindUserGroupIdsByUserId(userId) - if err != nil { - return 0, err - } - - if len(userGroupIds) > 0 { - db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) - } - err = db.Find(&Credential{}).Count(&total).Error - return -} diff --git a/server/model/job.go b/server/model/job.go index 03d21c6..35a337f 100644 --- a/server/model/job.go +++ b/server/model/job.go @@ -1,19 +1,7 @@ package model import ( - "encoding/json" - "errors" - "fmt" - "strings" - "time" - - "next-terminal/server/constant" - "next-terminal/server/global" - "next-terminal/server/term" "next-terminal/server/utils" - - "github.com/robfig/cron/v3" - "github.com/sirupsen/logrus" ) type Job struct { @@ -34,323 +22,13 @@ func (r *Job) TableName() string { return "jobs" } -func FindPageJob(pageIndex, pageSize int, name, status, order, field string) (o []Job, total int64, err error) { - job := Job{} - db := global.DB.Table(job.TableName()) - dbCounter := global.DB.Table(job.TableName()) - - if len(name) > 0 { - db = db.Where("name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("name like ?", "%"+name+"%") - } - - if len(status) > 0 { - db = db.Where("status = ?", status) - dbCounter = dbCounter.Where("status = ?", status) - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else if field == "created" { - field = "created" - } else { - field = "updated" - } - - err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]Job, 0) - } - return +type JobLog struct { + ID string `json:"id"` + Timestamp utils.JsonTime `json:"timestamp"` + JobId string `json:"jobId"` + Message string `json:"message"` } -func FindJobByFunc(function string) (o []Job, err error) { - db := global.DB - err = db.Where("func = ?", function).Find(&o).Error - return -} - -func CreateNewJob(o *Job) (err error) { - - if o.Status == constant.JobStatusRunning { - j, err := getJob(o) - if err != nil { - return err - } - jobId, err := global.Cron.AddJob(o.Cron, j) - if err != nil { - return err - } - o.CronJobId = int(jobId) - } - - return global.DB.Create(o).Error -} - -func UpdateJobById(o *Job, id string) (err error) { - if o.Status == constant.JobStatusRunning { - return errors.New("请先停止定时任务后再修改") - } - - o.ID = id - return global.DB.Updates(o).Error -} - -func UpdateJonUpdatedById(id string) (err error) { - err = global.DB.Updates(Job{ID: id, Updated: utils.NowJsonTime()}).Error - return -} - -func ChangeJobStatusById(id, status string) (err error) { - var job Job - err = global.DB.Where("id = ?", id).First(&job).Error - if err != nil { - return err - } - if status == constant.JobStatusRunning { - j, err := getJob(&job) - if err != nil { - return err - } - entryID, err := global.Cron.AddJob(job.Cron, j) - if err != nil { - return err - } - logrus.Debugf("开启计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) - - return global.DB.Updates(Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)}).Error - } else { - global.Cron.Remove(cron.EntryID(job.CronJobId)) - logrus.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) - return global.DB.Updates(Job{ID: id, Status: constant.JobStatusNotRunning}).Error - } -} - -func ExecJobById(id string) (err error) { - job, err := FindJobById(id) - if err != nil { - return err - } - j, err := getJob(&job) - if err != nil { - return err - } - j.Run() - return nil -} - -func FindJobById(id string) (o Job, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func DeleteJobById(id string) error { - job, err := FindJobById(id) - if err != nil { - return err - } - if job.Status == constant.JobStatusRunning { - if err := ChangeJobStatusById(id, constant.JobStatusNotRunning); err != nil { - return err - } - } - return global.DB.Where("id = ?", id).Delete(Job{}).Error -} - -func getJob(j *Job) (job cron.Job, err error) { - switch j.Func { - case constant.FuncCheckAssetStatusJob: - job = CheckAssetStatusJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata} - case constant.FuncShellJob: - job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata} - default: - return nil, errors.New("未识别的任务") - } - return job, err -} - -type CheckAssetStatusJob struct { - ID string - Mode string - ResourceIds string - Metadata string -} - -func (r CheckAssetStatusJob) Run() { - if r.ID == "" { - return - } - - var assets []Asset - if r.Mode == constant.JobModeAll { - assets, _ = FindAllAsset() - } else { - assets, _ = FindAssetByIds(strings.Split(r.ResourceIds, ",")) - } - - if len(assets) == 0 { - return - } - - msgChan := make(chan string) - for i := range assets { - asset := assets[i] - go func() { - t1 := time.Now() - active := utils.Tcping(asset.IP, asset.Port) - elapsed := time.Since(t1) - msg := fmt.Sprintf("资产「%v」存活状态检测完成,存活「%v」,耗时「%v」", asset.Name, active, elapsed) - - UpdateAssetActiveById(active, asset.ID) - logrus.Infof(msg) - msgChan <- msg - }() - } - - var message = "" - for i := 0; i < len(assets); i++ { - message += <-msgChan + "\n" - } - - _ = UpdateJonUpdatedById(r.ID) - jobLog := JobLog{ - ID: utils.UUID(), - JobId: r.ID, - Timestamp: utils.NowJsonTime(), - Message: message, - } - - _ = CreateNewJobLog(&jobLog) -} - -type ShellJob struct { - ID string - Mode string - ResourceIds string - Metadata string -} - -type MetadataShell struct { - Shell string -} - -func (r ShellJob) Run() { - if r.ID == "" { - return - } - - var assets []Asset - if r.Mode == constant.JobModeAll { - assets, _ = FindAssetByProtocol("ssh") - } else { - assets, _ = FindAssetByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ",")) - } - - if len(assets) == 0 { - return - } - - var metadataShell MetadataShell - err := json.Unmarshal([]byte(r.Metadata), &metadataShell) - if err != nil { - logrus.Errorf("JSON数据解析失败 %v", err) - return - } - - msgChan := make(chan string) - for i := range assets { - asset, err := FindAssetById(assets[i].ID) - if err != nil { - msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询数据异常「%v」", assets[i].Name, err.Error()) - return - } - - var ( - username = asset.Username - password = asset.Password - privateKey = asset.PrivateKey - passphrase = asset.Passphrase - ip = asset.IP - port = asset.Port - ) - - if asset.AccountType == "credential" { - credential, err := FindCredentialById(asset.CredentialId) - if err != nil { - msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询授权凭证数据异常「%v」", assets[i].Name, err.Error()) - return - } - - if credential.Type == constant.Custom { - username = credential.Username - password = credential.Password - } else { - username = credential.Username - privateKey = credential.PrivateKey - passphrase = credential.Passphrase - } - } - - go func() { - - t1 := time.Now() - result, err := ExecCommandBySSH(metadataShell.Shell, ip, port, username, password, privateKey, passphrase) - elapsed := time.Since(t1) - var msg string - if err != nil { - msg = fmt.Sprintf("资产「%v」Shell执行失败,返回值「%v」,耗时「%v」", asset.Name, err.Error(), elapsed) - logrus.Infof(msg) - } else { - msg = fmt.Sprintf("资产「%v」Shell执行成功,返回值「%v」,耗时「%v」", asset.Name, result, elapsed) - logrus.Infof(msg) - } - - msgChan <- msg - }() - } - - var message = "" - for i := 0; i < len(assets); i++ { - message += <-msgChan + "\n" - } - - _ = UpdateJonUpdatedById(r.ID) - jobLog := JobLog{ - ID: utils.UUID(), - JobId: r.ID, - Timestamp: utils.NowJsonTime(), - Message: message, - } - - _ = CreateNewJobLog(&jobLog) -} - -func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) { - sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase) - if err != nil { - return "", err - } - - session, err := sshClient.NewSession() - if err != nil { - return "", err - } - defer session.Close() - //执行远程命令 - combo, err := session.CombinedOutput(cmd) - if err != nil { - return "", err - } - return string(combo), nil +func (r *JobLog) TableName() string { + return "job_logs" } diff --git a/server/model/job_log.go b/server/model/job_log.go deleted file mode 100644 index 9946cc1..0000000 --- a/server/model/job_log.go +++ /dev/null @@ -1,30 +0,0 @@ -package model - -import ( - "next-terminal/server/global" - "next-terminal/server/utils" -) - -type JobLog struct { - ID string `json:"id"` - Timestamp utils.JsonTime `json:"timestamp"` - JobId string `json:"jobId"` - Message string `json:"message"` -} - -func (r *JobLog) TableName() string { - return "job_logs" -} - -func CreateNewJobLog(o *JobLog) error { - return global.DB.Create(o).Error -} - -func FindJobLogs(jobId string) (o []JobLog, err error) { - err = global.DB.Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error - return -} - -func DeleteJobLogByJobId(jobId string) error { - return global.DB.Where("job_id = ?", jobId).Delete(JobLog{}).Error -} diff --git a/server/model/login_log.go b/server/model/login_log.go index 8aaf95f..fcf7856 100644 --- a/server/model/login_log.go +++ b/server/model/login_log.go @@ -1,7 +1,6 @@ package model import ( - "next-terminal/server/global" "next-terminal/server/utils" ) @@ -29,78 +28,3 @@ type LoginLogVo struct { func (r *LoginLog) TableName() string { return "login_logs" } - -func FindPageLoginLog(pageIndex, pageSize int, userId, clientIp string) (o []LoginLogVo, total int64, err error) { - - db := global.DB.Table("login_logs").Select("login_logs.id,login_logs.user_id,login_logs.client_ip,login_logs.client_user_agent,login_logs.login_time, login_logs.logout_time, users.nickname as user_name").Joins("left join users on login_logs.user_id = users.id") - dbCounter := global.DB.Table("login_logs").Select("DISTINCT login_logs.id") - - if userId != "" { - db = db.Where("login_logs.user_id = ?", userId) - dbCounter = dbCounter.Where("login_logs.user_id = ?", userId) - } - - if clientIp != "" { - db = db.Where("login_logs.client_ip like ?", "%"+clientIp+"%") - dbCounter = dbCounter.Where("login_logs.client_ip like ?", "%"+clientIp+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - err = db.Order("login_logs.login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - if o == nil { - o = make([]LoginLogVo, 0) - } - return -} - -func FindAliveLoginLogs() (o []LoginLog, err error) { - err = global.DB.Where("logout_time is null").Find(&o).Error - return -} - -func FindAliveLoginLogsByUserId(userId string) (o []LoginLog, err error) { - err = global.DB.Where("logout_time is null and user_id = ?", userId).Find(&o).Error - return -} - -func CreateNewLoginLog(o *LoginLog) (err error) { - return global.DB.Create(o).Error -} - -func DeleteLoginLogByIdIn(ids []string) (err error) { - return global.DB.Where("id in ?", ids).Delete(&LoginLog{}).Error -} - -func FindLoginLogById(id string) (o LoginLog, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func Logout(token string) (err error) { - // - //loginLog, err := FindLoginLogById(token) - //if err != nil { - // logrus.Warnf("登录日志「%v」获取失败", token) - // return - //} - // - //err = global.DB.Updates(&LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}).Error - //if err != nil { - // return err - //} - // - //loginLogs, err := FindAliveLoginLogsByUserId(loginLog.UserId) - //if err != nil { - // return - //} - // - //if len(loginLogs) == 0 { - // // TODO - // err = UpdateUserOnline(false, loginLog.UserId) - //} - return -} diff --git a/server/model/num.go b/server/model/num.go index 78481fb..fa7bd58 100644 --- a/server/model/num.go +++ b/server/model/num.go @@ -1,9 +1,5 @@ package model -import ( - "next-terminal/server/global" -) - type Num struct { I string `gorm:"primary_key" json:"i"` } @@ -11,15 +7,3 @@ type Num struct { func (r *Num) TableName() string { return "nums" } - -func FindAllTemp() (o []Num) { - if global.DB.Find(&o).Error != nil { - return nil - } - return -} - -func CreateNewTemp(o *Num) (err error) { - err = global.DB.Create(o).Error - return -} diff --git a/server/model/property.go b/server/model/property.go index 0bd7d93..7f2aae8 100644 --- a/server/model/property.go +++ b/server/model/property.go @@ -1,16 +1,5 @@ package model -import ( - "net/smtp" - - "next-terminal/server/constant" - "next-terminal/server/global" - "next-terminal/server/guacd" - - "github.com/jordan-wright/email" - "github.com/sirupsen/logrus" -) - type Property struct { Name string `gorm:"primary_key" json:"name"` Value string `json:"value"` @@ -19,73 +8,3 @@ type Property struct { func (r *Property) TableName() string { return "properties" } - -func FindAllProperties() (o []Property) { - if global.DB.Find(&o).Error != nil { - return nil - } - return -} - -func CreateNewProperty(o *Property) (err error) { - err = global.DB.Create(o).Error - return -} - -func UpdatePropertyByName(o *Property, name string) { - o.Name = name - global.DB.Updates(o) -} - -func FindPropertyByName(name string) (o Property, err error) { - err = global.DB.Where("name = ?", name).First(&o).Error - return -} - -func FindAllPropertiesMap() map[string]string { - properties := FindAllProperties() - propertyMap := make(map[string]string) - for i := range properties { - propertyMap[properties[i].Name] = properties[i].Value - } - return propertyMap -} - -func GetDrivePath() (string, error) { - property, err := FindPropertyByName(guacd.DrivePath) - if err != nil { - return "", err - } - return property.Value, nil -} - -func GetRecordingPath() (string, error) { - property, err := FindPropertyByName(guacd.RecordingPath) - if err != nil { - return "", err - } - return property.Value, nil -} - -func SendMail(to, subject, text string) { - propertiesMap := FindAllPropertiesMap() - host := propertiesMap[constant.MailHost] - port := propertiesMap[constant.MailPort] - username := propertiesMap[constant.MailUsername] - password := propertiesMap[constant.MailPassword] - - if host == "" || port == "" || username == "" || password == "" { - logrus.Debugf("邮箱信息不完整,跳过发送邮件。") - return - } - - e := email.NewEmail() - e.From = "Next Terminal <" + username + ">" - e.To = []string{to} - e.Subject = subject - e.Text = []byte(text) - err := e.Send(host+":"+port, smtp.PlainAuth("", username, password, host)) - if err != nil { - logrus.Errorf("邮件发送失败: %v", err.Error()) - } -} diff --git a/server/model/resource-sharer.go b/server/model/resource-sharer.go index 43dcea7..8a9591b 100644 --- a/server/model/resource-sharer.go +++ b/server/model/resource-sharer.go @@ -1,14 +1,5 @@ package model -import ( - "next-terminal/server/global" - "next-terminal/server/utils" - - "github.com/labstack/echo/v4" - "github.com/pkg/errors" - "gorm.io/gorm" -) - type ResourceSharer struct { ID string `gorm:"primary_key" json:"id"` ResourceId string `gorm:"index" json:"resourceId"` @@ -20,179 +11,3 @@ type ResourceSharer struct { func (r *ResourceSharer) TableName() string { return "resource_sharers" } - -func FindUserIdsByResourceId(resourceId string) (r []string, err error) { - db := global.DB - err = db.Table("resource_sharers").Select("user_id").Where("resource_id = ?", resourceId).Find(&r).Error - if r == nil { - r = make([]string, 0) - } - return -} - -func OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) { - db := global.DB.Begin() - - var owner string - // 检查资产是否存在 - switch resourceType { - case "asset": - resource := Asset{} - err = db.Where("id = ?", resourceId).First(&resource).Error - owner = resource.Owner - case "command": - resource := Command{} - err = db.Where("id = ?", resourceId).First(&resource).Error - owner = resource.Owner - case "credential": - resource := Credential{} - err = db.Where("id = ?", resourceId).First(&resource).Error - owner = resource.Owner - } - - if err == gorm.ErrRecordNotFound { - return echo.NewHTTPError(404, "资源「"+resourceId+"」不存在") - } - - for i := range userIds { - if owner == userIds[i] { - return echo.NewHTTPError(400, "参数错误") - } - } - - db.Where("resource_id = ?", resourceId).Delete(&ResourceSharer{}) - - for i := range userIds { - userId := userIds[i] - if len(userId) == 0 { - continue - } - id := utils.Sign([]string{resourceId, resourceType, userId}) - resource := &ResourceSharer{ - ID: id, - ResourceId: resourceId, - ResourceType: resourceType, - UserId: userId, - } - err = db.Create(resource).Error - if err != nil { - return err - } - } - db.Commit() - return nil -} - -func DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceType string, resourceIds []string) error { - db := global.DB - if userGroupId != "" { - db = db.Where("user_group_id = ?", userGroupId) - } - - if userId != "" { - db = db.Where("user_id = ?", userId) - } - - if resourceType != "" { - db = db.Where("resource_type = ?", resourceType) - } - - if resourceIds != nil { - db = db.Where("resource_id in ?", resourceIds) - } - - return db.Delete(&ResourceSharer{}).Error -} - -func DeleteResourceSharerByResourceId(resourceId string) error { - return global.DB.Where("resource_id = ?", resourceId).Delete(&ResourceSharer{}).Error -} - -func AddSharerResources(userGroupId, userId, resourceType string, resourceIds []string) error { - return global.DB.Transaction(func(tx *gorm.DB) (err error) { - - for i := range resourceIds { - resourceId := resourceIds[i] - - var owner string - // 检查资产是否存在 - switch resourceType { - case "asset": - resource := Asset{} - if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { - return errors.Wrap(err, "find asset fail") - } - owner = resource.Owner - case "command": - resource := Command{} - if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { - return errors.Wrap(err, "find command fail") - } - owner = resource.Owner - case "credential": - resource := Credential{} - if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { - return errors.Wrap(err, "find credential fail") - - } - owner = resource.Owner - } - - if owner == userId { - return echo.NewHTTPError(400, "参数错误") - } - - id := utils.Sign([]string{resourceId, resourceType, userId, userGroupId}) - resource := &ResourceSharer{ - ID: id, - ResourceId: resourceId, - ResourceType: resourceType, - UserId: userId, - UserGroupId: userGroupId, - } - err = tx.Create(resource).Error - if err != nil { - return err - } - } - return nil - }) -} - -func FindAssetIdsByUserId(userId string) (assetIds []string, err error) { - // 查询当前用户创建的资产 - var ownerAssetIds, sharerAssetIds []string - asset := Asset{} - err = global.DB.Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error - if err != nil { - return nil, err - } - - // 查询其他用户授权给该用户的资产 - groupIds, err := FindUserGroupIdsByUserId(userId) - if err != nil { - return nil, err - } - - db := global.DB.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId) - if len(groupIds) > 0 { - db = db.Or("user_group_id in ?", groupIds) - } - err = db.Find(&sharerAssetIds).Error - if err != nil { - return nil, err - } - - // 合并查询到的资产ID - assetIds = make([]string, 0) - - if ownerAssetIds != nil { - assetIds = append(assetIds, ownerAssetIds...) - } - - if sharerAssetIds != nil { - assetIds = append(assetIds, sharerAssetIds...) - } - - return -} diff --git a/server/model/session.go b/server/model/session.go index 55c8c33..dd3a639 100644 --- a/server/model/session.go +++ b/server/model/session.go @@ -1,12 +1,6 @@ package model import ( - "os" - "path" - "time" - - "next-terminal/server/constant" - "next-terminal/server/global" "next-terminal/server/utils" ) @@ -60,151 +54,3 @@ type SessionVo struct { Message string `json:"message"` Mode string `json:"mode"` } - -func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []SessionVo, total int64, err error) { - - db := global.DB - var params []interface{} - - params = append(params, status) - - itemSql := "SELECT s.id,s.mode, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? " - countSql := "select count(*) from sessions as s where s.status = ? " - - if len(userId) > 0 { - itemSql += " and s.creator = ?" - countSql += " and s.creator = ?" - params = append(params, userId) - } - - if len(clientIp) > 0 { - itemSql += " and s.client_ip like ?" - countSql += " and s.client_ip like ?" - params = append(params, "%"+clientIp+"%") - } - - if len(assetId) > 0 { - itemSql += " and s.asset_id = ?" - countSql += " and s.asset_id = ?" - params = append(params, assetId) - } - - if len(protocol) > 0 { - itemSql += " and s.protocol = ?" - countSql += " and s.protocol = ?" - params = append(params, protocol) - } - - params = append(params, (pageIndex-1)*pageSize, pageSize) - itemSql += " order by s.connected_time desc LIMIT ?, ?" - - db.Raw(countSql, params...).Scan(&total) - - err = db.Raw(itemSql, params...).Scan(&results).Error - - if results == nil { - results = make([]SessionVo, 0) - } - return -} - -func FindSessionByStatus(status string) (o []Session, err error) { - err = global.DB.Where("status = ?", status).Find(&o).Error - return -} - -func FindSessionByStatusIn(statuses []string) (o []Session, err error) { - err = global.DB.Where("status in ?", statuses).Find(&o).Error - return -} - -func FindOutTimeSessions(dayLimit int) (o []Session, err error) { - limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour) - err = global.DB.Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error - return -} - -func CreateNewSession(o *Session) (err error) { - err = global.DB.Create(o).Error - return -} - -func FindSessionById(id string) (o Session, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func FindSessionByConnectionId(connectionId string) (o Session, err error) { - err = global.DB.Where("connection_id = ?", connectionId).First(&o).Error - return -} - -func UpdateSessionById(o *Session, id string) error { - o.ID = id - return global.DB.Updates(o).Error -} - -func UpdateSessionWindowSizeById(width, height int, id string) error { - session := Session{} - session.Width = width - session.Height = height - - return UpdateSessionById(&session, id) -} - -func DeleteSessionById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Session{}).Error -} - -func DeleteSessionByIds(sessionIds []string) error { - drivePath, err := GetRecordingPath() - if err != nil { - return err - } - for i := range sessionIds { - if err := os.RemoveAll(path.Join(drivePath, sessionIds[i])); err != nil { - return err - } - if err := DeleteSessionById(sessionIds[i]); err != nil { - return err - } - } - return nil -} - -func DeleteSessionByStatus(status string) { - global.DB.Where("status = ?", status).Delete(&Session{}) -} - -func CountOnlineSession() (total int64, err error) { - err = global.DB.Where("status = ?", constant.Connected).Find(&Session{}).Count(&total).Error - return -} - -type D struct { - Day string `json:"day"` - Count int `json:"count"` - Protocol string `json:"protocol"` -} - -func CountSessionByDay(day int) (results []D, err error) { - - today := time.Now().Format("20060102") - sql := "select t1.`day`, count(t2.id) as count\nfrom (\n SELECT @date := DATE_ADD(@date, INTERVAL - 1 DAY) day\n FROM (SELECT @date := DATE_ADD('" + today + "', INTERVAL + 1 DAY) FROM nums) as t0\n LIMIT ?\n )\n as t1\n left join\n (\n select DATE(s.connected_time) as day, s.id\n from sessions as s\n WHERE protocol = ? and DATE(connected_time) <= '" + today + "'\n AND DATE(connected_time) > DATE_SUB('" + today + "', INTERVAL ? DAY)\n ) as t2 on t1.day = t2.day\ngroup by t1.day" - - protocols := []string{"rdp", "ssh", "vnc", "telnet"} - - for i := range protocols { - var result []D - err = global.DB.Raw(sql, day, protocols[i], day).Scan(&result).Error - if err != nil { - return nil, err - } - for j := range result { - result[j].Protocol = protocols[i] - } - results = append(results, result...) - } - - return -} diff --git a/server/model/user-attribute.go b/server/model/user-attribute.go deleted file mode 100644 index 3fa870e..0000000 --- a/server/model/user-attribute.go +++ /dev/null @@ -1,23 +0,0 @@ -package model - -import "next-terminal/server/global" - -type UserAttribute struct { - Id string `gorm:"index" json:"id"` - UserId string `gorm:"index" json:"userId"` - Name string `gorm:"index" json:"name"` - Value string `json:"value"` -} - -func (r *UserAttribute) TableName() string { - return "user_attributes" -} - -func CreateUserAttribute(o *UserAttribute) error { - return global.DB.Create(o).Error -} - -func FindUserAttributeByUserId(userId string) (o []UserAttribute, err error) { - err = global.DB.Where("user_id = ?", userId).Find(&o).Error - return o, err -} diff --git a/server/model/user-group-member.go b/server/model/user-group-member.go deleted file mode 100644 index 5cfadf6..0000000 --- a/server/model/user-group-member.go +++ /dev/null @@ -1,18 +0,0 @@ -package model - -import "next-terminal/server/global" - -type UserGroupMember struct { - ID string `gorm:"primary_key" json:"name"` - UserId string `gorm:"index" json:"userId"` - UserGroupId string `gorm:"index" json:"userGroupId"` -} - -func (r *UserGroupMember) TableName() string { - return "user_group_members" -} - -func FindUserGroupMembersByUserGroupId(id string) (o []string, err error) { - err = global.DB.Table("user_group_members").Select("user_id").Where("user_group_id = ?", id).Find(&o).Error - return -} diff --git a/server/model/user-group.go b/server/model/user-group.go index 8f2797e..cdf2592 100644 --- a/server/model/user-group.go +++ b/server/model/user-group.go @@ -1,137 +1,26 @@ package model import ( - "next-terminal/server/global" "next-terminal/server/utils" - - "gorm.io/gorm" ) type UserGroup struct { - ID string `gorm:"primary_key" json:"id"` - Name string `json:"name"` - Created utils.JsonTime `json:"created"` -} - -type UserGroupVo struct { - ID string `json:"id"` + ID string `gorm:"primary_key" json:"id"` Name string `json:"name"` Created utils.JsonTime `json:"created"` - AssetCount int64 `json:"assetCount"` + AssetCount int64 `gorm:"-" json:"assetCount"` } func (r *UserGroup) TableName() string { return "user_groups" } -func FindPageUserGroup(pageIndex, pageSize int, name, order, field string) (o []UserGroupVo, total int64, err error) { - db := global.DB.Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id") - dbCounter := global.DB.Table("user_groups") - if len(name) > 0 { - db = db.Where("user_groups.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("name like ?", "%"+name+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("user_groups." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]UserGroupVo, 0) - } - return +type UserGroupMember struct { + ID string `gorm:"primary_key" json:"name"` + UserId string `gorm:"index" json:"userId"` + UserGroupId string `gorm:"index" json:"userGroupId"` } -func CreateNewUserGroup(o *UserGroup, members []string) (err error) { - return global.DB.Transaction(func(tx *gorm.DB) error { - err = tx.Create(o).Error - if err != nil { - return err - } - - if members != nil { - userGroupId := o.ID - err = AddUserGroupMembers(tx, members, userGroupId) - if err != nil { - return err - } - } - return err - }) -} - -func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error { - //for i := range userIds { - // userId := userIds[i] - // // TODO - // _, err := FindUserById(userId) - // if err != nil { - // return err - // } - // - // userGroupMember := UserGroupMember{ - // ID: utils.Sign([]string{userGroupId, userId}), - // UserId: userId, - // UserGroupId: userGroupId, - // } - // err = tx.Create(&userGroupMember).Error - // if err != nil { - // return err - // } - //} - return nil -} - -func FindUserGroupById(id string) (o UserGroup, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func FindUserGroupIdsByUserId(userId string) (o []string, err error) { - // 先查询用户所在的用户 - err = global.DB.Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error - return -} - -func UpdateUserGroupById(o *UserGroup, members []string, id string) error { - return global.DB.Transaction(func(tx *gorm.DB) error { - o.ID = id - err := tx.Updates(o).Error - if err != nil { - return err - } - - err = tx.Where("user_group_id = ?", id).Delete(&UserGroupMember{}).Error - if err != nil { - return err - } - if members != nil { - userGroupId := o.ID - err = AddUserGroupMembers(tx, members, userGroupId) - if err != nil { - return err - } - } - return err - }) - -} - -func DeleteUserGroupById(id string) { - global.DB.Where("id = ?", id).Delete(&UserGroup{}) - global.DB.Where("user_group_id = ?", id).Delete(&UserGroupMember{}) +func (r *UserGroupMember) TableName() string { + return "user_group_members" } diff --git a/server/repository/access_security.go b/server/repository/access_security.go new file mode 100644 index 0000000..469675f --- /dev/null +++ b/server/repository/access_security.go @@ -0,0 +1,81 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/global" + "next-terminal/server/model" +) + +type AccessSecurityRepository struct { + DB *gorm.DB +} + +func NewAccessSecurityRepository(db *gorm.DB) *AccessSecurityRepository { + accessSecurityRepository = &AccessSecurityRepository{DB: db} + return accessSecurityRepository +} + +func (r AccessSecurityRepository) FindAllAccessSecurities() (o []model.AccessSecurity, err error) { + db := r.DB + err = db.Order("priority asc").Find(&o).Error + return +} + +func (r AccessSecurityRepository) Find(pageIndex, pageSize int, ip, rule, order, field string) (o []model.AccessSecurity, total int64, err error) { + t := model.AccessSecurity{} + db := r.DB.Table(t.TableName()) + dbCounter := r.DB.Table(t.TableName()) + + if len(ip) > 0 { + db = db.Where("ip like ?", "%"+ip+"%") + dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") + } + + if len(rule) > 0 { + db = db.Where("rule = ?", rule) + dbCounter = dbCounter.Where("rule = ?", rule) + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "descend" { + order = "desc" + } else { + order = "asc" + } + + if field == "ip" { + field = "ip" + } else if field == "rule" { + field = "rule" + } else { + field = "priority" + } + + err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.AccessSecurity, 0) + } + return +} + +func (r AccessSecurityRepository) Create(o *model.AccessSecurity) error { + return global.DB.Create(o).Error +} + +func (r AccessSecurityRepository) UpdateById(o *model.AccessSecurity, id string) error { + o.ID = id + return global.DB.Updates(o).Error +} + +func (r AccessSecurityRepository) DeleteById(id string) error { + return global.DB.Where("id = ?", id).Delete(model.AccessSecurity{}).Error +} + +func (r AccessSecurityRepository) FindById(id string) (o *model.AccessSecurity, err error) { + err = global.DB.Where("id = ?", id).First(&o).Error + return +} diff --git a/server/repository/asset.go b/server/repository/asset.go new file mode 100644 index 0000000..e0b383c --- /dev/null +++ b/server/repository/asset.go @@ -0,0 +1,302 @@ +package repository + +import ( + "fmt" + "github.com/labstack/echo/v4" + "gorm.io/gorm" + "next-terminal/server/constant" + "next-terminal/server/global" + "next-terminal/server/model" + "next-terminal/server/utils" + "strings" +) + +type AssetRepository struct { + DB *gorm.DB +} + +func NewAssetRepository(db *gorm.DB) *AssetRepository { + assetRepository = &AssetRepository{DB: db} + return assetRepository +} + +func (r AssetRepository) FindAll() (o []model.Asset, err error) { + err = r.DB.Find(&o).Error + return +} + +func (r AssetRepository) FindByIds(assetIds []string) (o []model.Asset, err error) { + err = r.DB.Where("id in ?", assetIds).Find(&o).Error + return +} + +func (r AssetRepository) FindByProtocol(protocol string) (o []model.Asset, err error) { + err = r.DB.Where("protocol = ?", protocol).Find(&o).Error + return +} + +func (r AssetRepository) FindByProtocolAndIds(protocol string, assetIds []string) (o []model.Asset, err error) { + err = r.DB.Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error + return +} + +func (r AssetRepository) FindByProtocolAndUser(protocol string, account model.User) (o []model.Asset, err error) { + db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) + } + + if len(protocol) > 0 { + db = db.Where("assets.protocol = ?", protocol) + } + err = db.Find(&o).Error + return +} + +func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags string, account model.User, owner, sharer, userGroupId, ip, order, field string) (o []model.AssetVo, total int64, err error) { + db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") + dbCounter := r.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) + dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) + + // 查询用户所在用户组列表 + userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(account.ID) + if err != nil { + return nil, 0, err + } + + if len(userGroupIds) > 0 { + db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) + dbCounter = dbCounter.Or("resource_sharers.user_group_id in ?", userGroupIds) + } + } else { + if len(owner) > 0 { + db = db.Where("assets.owner = ?", owner) + dbCounter = dbCounter.Where("assets.owner = ?", owner) + } + if len(sharer) > 0 { + db = db.Where("resource_sharers.user_id = ?", sharer) + dbCounter = dbCounter.Where("resource_sharers.user_id = ?", sharer) + } + + if len(userGroupId) > 0 { + db = db.Where("resource_sharers.user_group_id = ?", userGroupId) + dbCounter = dbCounter.Where("resource_sharers.user_group_id = ?", userGroupId) + } + } + + if len(name) > 0 { + db = db.Where("assets.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("assets.name like ?", "%"+name+"%") + } + + if len(ip) > 0 { + db = db.Where("assets.ip like ?", "%"+ip+"%") + dbCounter = dbCounter.Where("assets.ip like ?", "%"+ip+"%") + } + + if len(protocol) > 0 { + db = db.Where("assets.protocol = ?", protocol) + dbCounter = dbCounter.Where("assets.protocol = ?", protocol) + } + + if len(tags) > 0 { + tagArr := strings.Split(tags, ",") + for i := range tagArr { + if global.Config.DB == "sqlite" { + db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") + dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") + } else { + db = db.Where("find_in_set(?, assets.tags)", tagArr[i]) + dbCounter = dbCounter.Where("find_in_set(?, assets.tags)", tagArr[i]) + } + } + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("assets." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + + if o == nil { + o = make([]model.AssetVo, 0) + } + return +} + +func (r AssetRepository) Create(o *model.Asset) (err error) { + if err = r.DB.Create(o).Error; err != nil { + return err + } + return nil +} + +func (r AssetRepository) FindById(id string) (o model.Asset, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r AssetRepository) UpdateById(o *model.Asset, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r AssetRepository) UpdateActiveById(active bool, id string) error { + sql := "update assets set active = ? where id = ?" + return r.DB.Exec(sql, active, id).Error +} + +func (r AssetRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Asset{}).Error +} + +func (r AssetRepository) CountAsset() (total int64, err error) { + err = r.DB.Find(&model.Asset{}).Count(&total).Error + return +} + +func (r AssetRepository) CountByUserId(userId string) (total int64, err error) { + db := r.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") + + db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId) + + // 查询用户所在用户组列表 + userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + if err != nil { + return 0, err + } + + if len(userGroupIds) > 0 { + db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) + } + err = db.Find(&model.Asset{}).Count(&total).Error + return +} + +func (r AssetRepository) FindTags() (o []string, err error) { + var assets []model.Asset + err = r.DB.Not("tags = ?", "").Find(&assets).Error + if err != nil { + return nil, err + } + + o = make([]string, 0) + + for i := range assets { + if len(assets[i].Tags) == 0 { + continue + } + split := strings.Split(assets[i].Tags, ",") + + o = append(o, split...) + } + + return utils.Distinct(o), nil +} + +func (r AssetRepository) UpdateAttributes(assetId, protocol string, m echo.Map) error { + var data []model.AssetAttribute + var parameterNames []string + switch protocol { + case "ssh": + parameterNames = constant.SSHParameterNames + case "rdp": + parameterNames = constant.RDPParameterNames + case "vnc": + parameterNames = constant.VNCParameterNames + case "telnet": + parameterNames = constant.TelnetParameterNames + case "kubernetes": + parameterNames = constant.KubernetesParameterNames + } + + for i := range parameterNames { + name := parameterNames[i] + if m[name] != nil && m[name] != "" { + data = append(data, genAttribute(assetId, name, m)) + } + } + + return r.DB.Transaction(func(tx *gorm.DB) error { + err := tx.Where("asset_id = ?", assetId).Delete(&model.AssetAttribute{}).Error + if err != nil { + return err + } + return tx.CreateInBatches(&data, len(data)).Error + }) +} + +func genAttribute(assetId, name string, m echo.Map) model.AssetAttribute { + value := fmt.Sprintf("%v", m[name]) + attribute := model.AssetAttribute{ + Id: utils.Sign([]string{assetId, name}), + AssetId: assetId, + Name: name, + Value: value, + } + return attribute +} + +func (r AssetRepository) FindAttrById(assetId string) (o []model.AssetAttribute, err error) { + err = r.DB.Where("asset_id = ?", assetId).Find(&o).Error + if o == nil { + o = make([]model.AssetAttribute, 0) + } + return o, err +} + +func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]interface{}, error) { + asset, err := r.FindById(assetId) + if err != nil { + return nil, err + } + attributes, err := r.FindAttrById(assetId) + if err != nil { + return nil, err + } + + var parameterNames []string + switch asset.Protocol { + case "ssh": + parameterNames = constant.SSHParameterNames + case "rdp": + parameterNames = constant.RDPParameterNames + case "vnc": + parameterNames = constant.VNCParameterNames + case "telnet": + parameterNames = constant.TelnetParameterNames + case "kubernetes": + parameterNames = constant.KubernetesParameterNames + } + propertiesMap := propertyRepository.FindAllMap() + var attributeMap = make(map[string]interface{}) + for name := range propertiesMap { + if utils.Contains(parameterNames, name) { + attributeMap[name] = propertiesMap[name] + } + } + + for i := range attributes { + attributeMap[attributes[i].Name] = attributes[i].Value + } + return attributeMap, nil +} diff --git a/server/repository/command.go b/server/repository/command.go new file mode 100644 index 0000000..a47112c --- /dev/null +++ b/server/repository/command.go @@ -0,0 +1,82 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/constant" + "next-terminal/server/global" + "next-terminal/server/model" +) + +type CommandRepository struct { + DB *gorm.DB +} + +func NewCommandRepository(db *gorm.DB) *CommandRepository { + commandRepository = &CommandRepository{DB: db} + return commandRepository +} + +func (r CommandRepository) Find(pageIndex, pageSize int, name, content, order, field string, account model.User) (o []model.CommandVo, total int64, err error) { + db := r.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") + dbCounter := r.DB.Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) + dbCounter = dbCounter.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) + } + + if len(name) > 0 { + db = db.Where("commands.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("commands.name like ?", "%"+name+"%") + } + + if len(content) > 0 { + db = db.Where("commands.content like ?", "%"+content+"%") + dbCounter = dbCounter.Where("commands.content like ?", "%"+content+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("commands." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + if o == nil { + o = make([]model.CommandVo, 0) + } + return +} + +func (r CommandRepository) Create(o *model.Command) (err error) { + if err = r.DB.Create(o).Error; err != nil { + return err + } + return nil +} + +func (r CommandRepository) FindById(id string) (o model.Command, err error) { + err = global.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r CommandRepository) UpdateById(o *model.Command, id string) { + o.ID = id + global.DB.Updates(o) +} + +func (r CommandRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Command{}).Error +} diff --git a/server/repository/credential.go b/server/repository/credential.go new file mode 100644 index 0000000..25d5915 --- /dev/null +++ b/server/repository/credential.go @@ -0,0 +1,108 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/constant" + "next-terminal/server/model" +) + +type CredentialRepository struct { + DB *gorm.DB +} + +func NewCredentialRepository(db *gorm.DB) *CredentialRepository { + credentialRepository = &CredentialRepository{DB: db} + return credentialRepository +} + +func (r CredentialRepository) FindAllByUser(account model.User) (o []model.CredentialSimpleVo, err error) { + db := r.DB.Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") + if account.Type == constant.TypeUser { + db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", account.ID, account.ID) + } + err = db.Find(&o).Error + return +} + +func (r CredentialRepository) Find(pageIndex, pageSize int, name, order, field string, account model.User) (o []model.CredentialVo, total int64, err error) { + db := r.DB.Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") + dbCounter := r.DB.Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) + dbCounter = dbCounter.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) + } + + if len(name) > 0 { + db = db.Where("credentials.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("credentials.name like ?", "%"+name+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("credentials." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + if o == nil { + o = make([]model.CredentialVo, 0) + } + return +} + +func (r CredentialRepository) Create(o *model.Credential) (err error) { + if err = r.DB.Create(o).Error; err != nil { + return err + } + return nil +} + +func (r CredentialRepository) FindById(id string) (o model.Credential, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r CredentialRepository) UpdateById(o *model.Credential, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r CredentialRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Credential{}).Error +} + +func (r CredentialRepository) Count() (total int64, err error) { + err = r.DB.Find(&model.Credential{}).Count(&total).Error + return +} + +func (r CredentialRepository) CountByUserId(userId string) (total int64, err error) { + db := r.DB.Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") + + db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId) + + // 查询用户所在用户组列表 + userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + if err != nil { + return 0, err + } + + if len(userGroupIds) > 0 { + db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) + } + err = db.Find(&model.Credential{}).Count(&total).Error + return +} diff --git a/server/repository/definitions.go b/server/repository/definitions.go new file mode 100644 index 0000000..56686fb --- /dev/null +++ b/server/repository/definitions.go @@ -0,0 +1,17 @@ +package repository + +var ( + userRepository *UserRepository + userGroupRepository *UserGroupRepository + resourceSharerRepository *ResourceSharerRepository + assetRepository *AssetRepository + credentialRepository *CredentialRepository + propertyRepository *PropertyRepository + commandRepository *CommandRepository + sessionRepository *SessionRepository + numRepository *NumRepository + accessSecurityRepository *AccessSecurityRepository + jobRepository *JobRepository + jobLogRepository *JobLogRepository + loginLogRepository *LoginLogRepository +) diff --git a/server/repository/job.go b/server/repository/job.go new file mode 100644 index 0000000..08086cb --- /dev/null +++ b/server/repository/job.go @@ -0,0 +1,108 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/global" + "next-terminal/server/model" + "next-terminal/server/utils" +) + +type JobRepository struct { + DB *gorm.DB +} + +func NewJobRepository(db *gorm.DB) *JobRepository { + jobRepository = &JobRepository{DB: db} + return jobRepository +} + +func (r JobRepository) Find(pageIndex, pageSize int, name, status, order, field string) (o []model.Job, total int64, err error) { + job := model.Job{} + db := r.DB.Table(job.TableName()) + dbCounter := r.DB.Table(job.TableName()) + + if len(name) > 0 { + db = db.Where("name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("name like ?", "%"+name+"%") + } + + if len(status) > 0 { + db = db.Where("status = ?", status) + dbCounter = dbCounter.Where("status = ?", status) + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else if field == "created" { + field = "created" + } else { + field = "updated" + } + + err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.Job, 0) + } + return +} + +func (r JobRepository) FindByFunc(function string) (o []model.Job, err error) { + db := r.DB + err = db.Where("func = ?", function).Find(&o).Error + return +} + +func (r JobRepository) Create(o *model.Job) (err error) { + // + //if o.Status == constant.JobStatusRunning { + // j, err := getJob(o) + // if err != nil { + // return err + // } + // jobId, err := global.Cron.AddJob(o.Cron, j) + // if err != nil { + // return err + // } + // o.CronJobId = int(jobId) + //} + + return r.DB.Create(o).Error +} + +func (r JobRepository) UpdateById(o *model.Job) (err error) { + return r.DB.Updates(o).Error +} + +func (r JobRepository) UpdateLastUpdatedById(id string) (err error) { + err = r.DB.Updates(model.Job{ID: id, Updated: utils.NowJsonTime()}).Error + return +} + +func (r JobRepository) FindById(id string) (o model.Job, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r JobRepository) DeleteJobById(id string) error { + //job, err := r.FindById(id) + //if err != nil { + // return err + //} + //if job.Status == constant.JobStatusRunning { + // if err := r.ChangeStatusById(id, constant.JobStatusNotRunning); err != nil { + // return err + // } + //} + return global.DB.Where("id = ?", id).Delete(model.Job{}).Error +} diff --git a/server/repository/job_log.go b/server/repository/job_log.go new file mode 100644 index 0000000..10ca91d --- /dev/null +++ b/server/repository/job_log.go @@ -0,0 +1,28 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/model" +) + +type JobLogRepository struct { + DB *gorm.DB +} + +func NewJobLogRepository(db *gorm.DB) *JobLogRepository { + jobLogRepository = &JobLogRepository{DB: db} + return jobLogRepository +} + +func (r JobLogRepository) Create(o *model.JobLog) error { + return r.DB.Create(o).Error +} + +func (r JobLogRepository) FindByJobId(jobId string) (o []model.JobLog, err error) { + err = r.DB.Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error + return +} + +func (r JobLogRepository) DeleteByJobId(jobId string) error { + return r.DB.Where("job_id = ?", jobId).Delete(model.JobLog{}).Error +} diff --git a/server/repository/login_log.go b/server/repository/login_log.go new file mode 100644 index 0000000..ebae58c --- /dev/null +++ b/server/repository/login_log.go @@ -0,0 +1,69 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/model" +) + +type LoginLogRepository struct { + DB *gorm.DB +} + +func NewLoginLogRepository(db *gorm.DB) *LoginLogRepository { + loginLogRepository = &LoginLogRepository{DB: db} + return loginLogRepository +} + +func (r LoginLogRepository) Find(pageIndex, pageSize int, userId, clientIp string) (o []model.LoginLogVo, total int64, err error) { + + db := r.DB.Table("login_logs").Select("login_logs.id,login_logs.user_id,login_logs.client_ip,login_logs.client_user_agent,login_logs.login_time, login_logs.logout_time, users.nickname as user_name").Joins("left join users on login_logs.user_id = users.id") + dbCounter := r.DB.Table("login_logs").Select("DISTINCT login_logs.id") + + if userId != "" { + db = db.Where("login_logs.user_id = ?", userId) + dbCounter = dbCounter.Where("login_logs.user_id = ?", userId) + } + + if clientIp != "" { + db = db.Where("login_logs.client_ip like ?", "%"+clientIp+"%") + dbCounter = dbCounter.Where("login_logs.client_ip like ?", "%"+clientIp+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + err = db.Order("login_logs.login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + if o == nil { + o = make([]model.LoginLogVo, 0) + } + return +} + +func (r LoginLogRepository) FindAliveLoginLogs() (o []model.LoginLog, err error) { + err = r.DB.Where("logout_time is null").Find(&o).Error + return +} + +func (r LoginLogRepository) FindAliveLoginLogsByUserId(userId string) (o []model.LoginLog, err error) { + err = r.DB.Where("logout_time is null and user_id = ?", userId).Find(&o).Error + return +} + +func (r LoginLogRepository) Create(o *model.LoginLog) (err error) { + return r.DB.Create(o).Error +} + +func (r LoginLogRepository) DeleteByIdIn(ids []string) (err error) { + return r.DB.Where("id in ?", ids).Delete(&model.LoginLog{}).Error +} + +func (r LoginLogRepository) FindById(id string) (o model.LoginLog, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r LoginLogRepository) Update(o *model.LoginLog) error { + return r.DB.Updates(o).Error +} diff --git a/server/repository/num.go b/server/repository/num.go new file mode 100644 index 0000000..7309127 --- /dev/null +++ b/server/repository/num.go @@ -0,0 +1,26 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/global" + "next-terminal/server/model" +) + +type NumRepository struct { + DB *gorm.DB +} + +func NewNumRepository(db *gorm.DB) *NumRepository { + numRepository = &NumRepository{DB: db} + return numRepository +} + +func (r NumRepository) FindAll() (o []model.Num, err error) { + err = r.DB.Find(&o).Error + return +} + +func (r NumRepository) Create(o *model.Num) (err error) { + err = global.DB.Create(o).Error + return +} diff --git a/server/repository/property.go b/server/repository/property.go new file mode 100644 index 0000000..d8320af --- /dev/null +++ b/server/repository/property.go @@ -0,0 +1,90 @@ +package repository + +import ( + "github.com/jordan-wright/email" + "github.com/sirupsen/logrus" + "gorm.io/gorm" + "net/smtp" + "next-terminal/server/constant" + "next-terminal/server/guacd" + "next-terminal/server/model" +) + +type PropertyRepository struct { + DB *gorm.DB +} + +func NewPropertyRepository(db *gorm.DB) *PropertyRepository { + propertyRepository = &PropertyRepository{DB: db} + return propertyRepository +} + +func (r PropertyRepository) FindAll() (o []model.Property) { + if r.DB.Find(&o).Error != nil { + return nil + } + return +} + +func (r PropertyRepository) Create(o *model.Property) (err error) { + err = r.DB.Create(o).Error + return +} + +func (r PropertyRepository) UpdatePropertyByName(o *model.Property, name string) error { + o.Name = name + return r.DB.Updates(o).Error +} + +func (r PropertyRepository) FindByName(name string) (o model.Property, err error) { + err = r.DB.Where("name = ?", name).First(&o).Error + return +} + +func (r PropertyRepository) FindAllMap() map[string]string { + properties := r.FindAll() + propertyMap := make(map[string]string) + for i := range properties { + propertyMap[properties[i].Name] = properties[i].Value + } + return propertyMap +} + +func (r PropertyRepository) GetDrivePath() (string, error) { + property, err := r.FindByName(guacd.DrivePath) + if err != nil { + return "", err + } + return property.Value, nil +} + +func (r PropertyRepository) GetRecordingPath() (string, error) { + property, err := r.FindByName(guacd.RecordingPath) + if err != nil { + return "", err + } + return property.Value, nil +} + +func (r PropertyRepository) SendMail(to, subject, text string) { + propertiesMap := r.FindAllMap() + host := propertiesMap[constant.MailHost] + port := propertiesMap[constant.MailPort] + username := propertiesMap[constant.MailUsername] + password := propertiesMap[constant.MailPassword] + + if host == "" || port == "" || username == "" || password == "" { + logrus.Debugf("邮箱信息不完整,跳过发送邮件。") + return + } + + e := email.NewEmail() + e.From = "Next Terminal <" + username + ">" + e.To = []string{to} + e.Subject = subject + e.Text = []byte(text) + err := e.Send(host+":"+port, smtp.PlainAuth("", username, password, host)) + if err != nil { + logrus.Errorf("邮件发送失败: %v", err.Error()) + } +} diff --git a/server/repository/resource_sharer.go b/server/repository/resource_sharer.go new file mode 100644 index 0000000..ee3e56c --- /dev/null +++ b/server/repository/resource_sharer.go @@ -0,0 +1,193 @@ +package repository + +import ( + "github.com/labstack/echo/v4" + "github.com/pkg/errors" + "gorm.io/gorm" + "next-terminal/server/model" + "next-terminal/server/utils" +) + +type ResourceSharerRepository struct { + DB *gorm.DB +} + +func NewResourceSharerRepository(db *gorm.DB) *ResourceSharerRepository { + resourceSharerRepository = &ResourceSharerRepository{DB: db} + return resourceSharerRepository +} + +func (r *ResourceSharerRepository) FindUserIdsByResourceId(resourceId string) (o []string, err error) { + err = r.DB.Table("resource_sharers").Select("user_id").Where("resource_id = ?", resourceId).Find(&o).Error + if o == nil { + o = make([]string, 0) + } + return +} + +func (r *ResourceSharerRepository) OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) { + db := r.DB.Begin() + + var owner string + // 检查资产是否存在 + switch resourceType { + case "asset": + resource := model.Asset{} + err = db.Where("id = ?", resourceId).First(&resource).Error + owner = resource.Owner + case "command": + resource := model.Command{} + err = db.Where("id = ?", resourceId).First(&resource).Error + owner = resource.Owner + case "credential": + resource := model.Credential{} + err = db.Where("id = ?", resourceId).First(&resource).Error + owner = resource.Owner + } + + if err == gorm.ErrRecordNotFound { + return echo.NewHTTPError(404, "资源「"+resourceId+"」不存在") + } + + for i := range userIds { + if owner == userIds[i] { + return echo.NewHTTPError(400, "参数错误") + } + } + + db.Where("resource_id = ?", resourceId).Delete(&ResourceSharerRepository{}) + + for i := range userIds { + userId := userIds[i] + if len(userId) == 0 { + continue + } + id := utils.Sign([]string{resourceId, resourceType, userId}) + resource := &model.ResourceSharer{ + ID: id, + ResourceId: resourceId, + ResourceType: resourceType, + UserId: userId, + } + err = db.Create(resource).Error + if err != nil { + return err + } + } + db.Commit() + return nil +} + +func (r *ResourceSharerRepository) DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceType string, resourceIds []string) error { + db := r.DB + if userGroupId != "" { + db = db.Where("user_group_id = ?", userGroupId) + } + + if userId != "" { + db = db.Where("user_id = ?", userId) + } + + if resourceType != "" { + db = db.Where("resource_type = ?", resourceType) + } + + if resourceIds != nil { + db = db.Where("resource_id in ?", resourceIds) + } + + return db.Delete(&ResourceSharerRepository{}).Error +} + +func (r *ResourceSharerRepository) DeleteResourceSharerByResourceId(resourceId string) error { + return r.DB.Where("resource_id = ?", resourceId).Delete(&ResourceSharerRepository{}).Error +} + +func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, resourceType string, resourceIds []string) error { + return r.DB.Transaction(func(tx *gorm.DB) (err error) { + + for i := range resourceIds { + resourceId := resourceIds[i] + + var owner string + // 检查资产是否存在 + switch resourceType { + case "asset": + resource := model.Asset{} + if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { + return errors.Wrap(err, "find asset fail") + } + owner = resource.Owner + case "command": + resource := model.Command{} + if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { + return errors.Wrap(err, "find command fail") + } + owner = resource.Owner + case "credential": + resource := model.Credential{} + if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { + return errors.Wrap(err, "find credential fail") + + } + owner = resource.Owner + } + + if owner == userId { + return echo.NewHTTPError(400, "参数错误") + } + + id := utils.Sign([]string{resourceId, resourceType, userId, userGroupId}) + resource := &model.ResourceSharer{ + ID: id, + ResourceId: resourceId, + ResourceType: resourceType, + UserId: userId, + UserGroupId: userGroupId, + } + err = tx.Create(resource).Error + if err != nil { + return err + } + } + return nil + }) +} + +func (r *ResourceSharerRepository) FindAssetIdsByUserId(userId string) (assetIds []string, err error) { + // 查询当前用户创建的资产 + var ownerAssetIds, sharerAssetIds []string + asset := model.Asset{} + err = r.DB.Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error + if err != nil { + return nil, err + } + + // 查询其他用户授权给该用户的资产 + groupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + if err != nil { + return nil, err + } + + db := r.DB.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId) + if len(groupIds) > 0 { + db = db.Or("user_group_id in ?", groupIds) + } + err = db.Find(&sharerAssetIds).Error + if err != nil { + return nil, err + } + + // 合并查询到的资产ID + assetIds = make([]string, 0) + + if ownerAssetIds != nil { + assetIds = append(assetIds, ownerAssetIds...) + } + + if sharerAssetIds != nil { + assetIds = append(assetIds, sharerAssetIds...) + } + + return +} diff --git a/server/repository/session.go b/server/repository/session.go new file mode 100644 index 0000000..fc3e8cf --- /dev/null +++ b/server/repository/session.go @@ -0,0 +1,167 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/constant" + "next-terminal/server/model" + "os" + "path" + "time" +) + +type SessionRepository struct { + DB *gorm.DB +} + +func NewSessionRepository(db *gorm.DB) *SessionRepository { + sessionRepository = &SessionRepository{DB: db} + return sessionRepository +} + +func (r SessionRepository) Find(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []model.SessionVo, total int64, err error) { + + db := r.DB + var params []interface{} + + params = append(params, status) + + itemSql := "SELECT s.id,s.mode, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? " + countSql := "select count(*) from sessions as s where s.status = ? " + + if len(userId) > 0 { + itemSql += " and s.creator = ?" + countSql += " and s.creator = ?" + params = append(params, userId) + } + + if len(clientIp) > 0 { + itemSql += " and s.client_ip like ?" + countSql += " and s.client_ip like ?" + params = append(params, "%"+clientIp+"%") + } + + if len(assetId) > 0 { + itemSql += " and s.asset_id = ?" + countSql += " and s.asset_id = ?" + params = append(params, assetId) + } + + if len(protocol) > 0 { + itemSql += " and s.protocol = ?" + countSql += " and s.protocol = ?" + params = append(params, protocol) + } + + params = append(params, (pageIndex-1)*pageSize, pageSize) + itemSql += " order by s.connected_time desc LIMIT ?, ?" + + db.Raw(countSql, params...).Scan(&total) + + err = db.Raw(itemSql, params...).Scan(&results).Error + + if results == nil { + results = make([]model.SessionVo, 0) + } + return +} + +func (r SessionRepository) FindByStatus(status string) (o []model.Session, err error) { + err = r.DB.Where("status = ?", status).Find(&o).Error + return +} + +func (r SessionRepository) FindByStatusIn(statuses []string) (o []model.Session, err error) { + err = r.DB.Where("status in ?", statuses).Find(&o).Error + return +} + +func (r SessionRepository) FindOutTimeSessions(dayLimit int) (o []model.Session, err error) { + limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour) + err = r.DB.Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error + return +} + +func (r SessionRepository) Create(o *model.Session) (err error) { + err = r.DB.Create(o).Error + return +} + +func (r SessionRepository) FindById(id string) (o model.Session, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r SessionRepository) FindByConnectionId(connectionId string) (o model.Session, err error) { + err = r.DB.Where("connection_id = ?", connectionId).First(&o).Error + return +} + +func (r SessionRepository) UpdateById(o *model.Session, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r SessionRepository) UpdateWindowSizeById(width, height int, id string) error { + session := model.Session{} + session.Width = width + session.Height = height + + return r.UpdateById(&session, id) +} + +func (r SessionRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Session{}).Error +} + +func (r SessionRepository) DeleteByIds(sessionIds []string) error { + drivePath, err := propertyRepository.GetRecordingPath() + if err != nil { + return err + } + for i := range sessionIds { + if err := os.RemoveAll(path.Join(drivePath, sessionIds[i])); err != nil { + return err + } + if err := r.DeleteById(sessionIds[i]); err != nil { + return err + } + } + return nil +} + +func (r SessionRepository) DeleteByStatus(status string) error { + return r.DB.Where("status = ?", status).Delete(&model.Session{}).Error +} + +func (r SessionRepository) CountOnlineSession() (total int64, err error) { + err = r.DB.Where("status = ?", constant.Connected).Find(&model.Session{}).Count(&total).Error + return +} + +type D struct { + Day string `json:"day"` + Count int `json:"count"` + Protocol string `json:"protocol"` +} + +func (r SessionRepository) CountSessionByDay(day int) (results []D, err error) { + + today := time.Now().Format("20060102") + sql := "select t1.`day`, count(t2.id) as count\nfrom (\n SELECT @date := DATE_ADD(@date, INTERVAL - 1 DAY) day\n FROM (SELECT @date := DATE_ADD('" + today + "', INTERVAL + 1 DAY) FROM nums) as t0\n LIMIT ?\n )\n as t1\n left join\n (\n select DATE(s.connected_time) as day, s.id\n from sessions as s\n WHERE protocol = ? and DATE(connected_time) <= '" + today + "'\n AND DATE(connected_time) > DATE_SUB('" + today + "', INTERVAL ? DAY)\n ) as t2 on t1.day = t2.day\ngroup by t1.day" + + protocols := []string{"rdp", "ssh", "vnc", "telnet"} + + for i := range protocols { + var result []D + err = r.DB.Raw(sql, day, protocols[i], day).Scan(&result).Error + if err != nil { + return nil, err + } + for j := range result { + result[j].Protocol = protocols[i] + } + results = append(results, result...) + } + + return +} diff --git a/server/repository/user.go b/server/repository/user.go index b133875..7391644 100644 --- a/server/repository/user.go +++ b/server/repository/user.go @@ -9,6 +9,11 @@ type UserRepository struct { DB *gorm.DB } +func NewUserRepository(db *gorm.DB) *UserRepository { + userRepository = &UserRepository{DB: db} + return userRepository +} + func (r UserRepository) FindAll() (o []model.User) { if r.DB.Find(&o).Error != nil { return nil diff --git a/server/repository/user_group.go b/server/repository/user_group.go new file mode 100644 index 0000000..64323c1 --- /dev/null +++ b/server/repository/user_group.go @@ -0,0 +1,140 @@ +package repository + +import ( + "gorm.io/gorm" + "next-terminal/server/global" + "next-terminal/server/model" + "next-terminal/server/utils" +) + +type UserGroupRepository struct { + DB *gorm.DB +} + +func NewUserGroupRepository(db *gorm.DB) *UserGroupRepository { + userGroupRepository = &UserGroupRepository{DB: db} + return userGroupRepository +} + +func (r UserGroupRepository) FindAll() (o []model.UserGroup) { + if r.DB.Find(&o).Error != nil { + return nil + } + return +} + +func (r UserGroupRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.UserGroup, total int64, err error) { + db := r.DB.Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id") + dbCounter := r.DB.Table("user_groups") + if len(name) > 0 { + db = db.Where("user_groups.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("name like ?", "%"+name+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("user_groups." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.UserGroup, 0) + } + return +} + +func (r UserGroupRepository) FindById(id string) (o model.UserGroup, err error) { + err = global.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r UserGroupRepository) FindUserGroupIdsByUserId(userId string) (o []string, err error) { + // 先查询用户所在的用户 + err = r.DB.Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error + return +} + +func (r UserGroupRepository) FindUserGroupMembersByUserGroupId(userGroupId string) (o []string, err error) { + err = r.DB.Table("user_group_members").Select("user_id").Where("user_group_id = ?", userGroupId).Find(&o).Error + return +} + +func (r UserGroupRepository) Create(o *model.UserGroup, members []string) (err error) { + return r.DB.Transaction(func(tx *gorm.DB) error { + err = tx.Create(o).Error + if err != nil { + return err + } + + if members != nil { + userGroupId := o.ID + err = AddUserGroupMembers(tx, members, userGroupId) + if err != nil { + return err + } + } + return err + }) +} + +func (r UserGroupRepository) Update(o *model.UserGroup, members []string, id string) error { + return r.DB.Transaction(func(tx *gorm.DB) error { + o.ID = id + err := tx.Updates(o).Error + if err != nil { + return err + } + + err = tx.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{}).Error + if err != nil { + return err + } + if members != nil { + userGroupId := o.ID + err = AddUserGroupMembers(tx, members, userGroupId) + if err != nil { + return err + } + } + return err + }) +} + +func (r UserGroupRepository) DeleteById(id string) { + r.DB.Where("id = ?", id).Delete(&model.UserGroup{}) + r.DB.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{}) +} + +func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error { + userRepository := NewUserRepository(tx) + for i := range userIds { + userId := userIds[i] + _, err := userRepository.FindById(userId) + if err != nil { + return err + } + + userGroupMember := model.UserGroupMember{ + ID: utils.Sign([]string{userGroupId, userId}), + UserId: userId, + UserGroupId: userGroupId, + } + err = tx.Create(&userGroupMember).Error + if err != nil { + return err + } + } + return nil +} diff --git a/server/service/job.go b/server/service/job.go new file mode 100644 index 0000000..731f5a9 --- /dev/null +++ b/server/service/job.go @@ -0,0 +1,324 @@ +package service + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/robfig/cron/v3" + "github.com/sirupsen/logrus" + "next-terminal/server/api" + "next-terminal/server/constant" + "next-terminal/server/global" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/term" + "next-terminal/server/utils" + "strings" + "time" +) + +type JobService struct { + jobRepository *repository.JobRepository + jobLogRepository *repository.JobLogRepository + assetRepository *repository.AssetRepository + credentialRepository *repository.CredentialRepository +} + +func NewJobService(jobRepository *repository.JobRepository, jobLogRepository *repository.JobLogRepository, assetRepository *repository.AssetRepository, credentialRepository *repository.CredentialRepository) *JobService { + return &JobService{jobRepository: jobRepository, jobLogRepository: jobLogRepository, assetRepository: assetRepository, credentialRepository: credentialRepository} +} + +func (r JobService) ChangeJobStatusById(id, status string) error { + job, err := r.jobRepository.FindById(id) + if err != nil { + return err + } + if status == constant.JobStatusRunning { + j, err := getJob(&job, &r) + if err != nil { + return err + } + entryID, err := global.Cron.AddJob(job.Cron, j) + if err != nil { + return err + } + logrus.Debugf("开启计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) + + jobForUpdate := model.Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)} + + return r.jobRepository.UpdateById(&jobForUpdate) + } else { + global.Cron.Remove(cron.EntryID(job.CronJobId)) + logrus.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) + jobForUpdate := model.Job{ID: id, Status: constant.JobStatusNotRunning} + return r.jobRepository.UpdateById(&jobForUpdate) + } +} + +func getJob(j *model.Job, jobService *JobService) (job cron.Job, err error) { + switch j.Func { + case constant.FuncCheckAssetStatusJob: + job = CheckAssetStatusJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata, jobService: jobService} + case constant.FuncShellJob: + job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata, jobService: jobService} + default: + return nil, errors.New("未识别的任务") + } + return job, err +} + +type CheckAssetStatusJob struct { + ID string + Mode string + ResourceIds string + Metadata string + jobService *JobService +} + +func (r CheckAssetStatusJob) Run() { + if r.ID == "" { + return + } + + var assets []model.Asset + if r.Mode == constant.JobModeAll { + assets, _ = r.jobService.assetRepository.FindAll() + } else { + assets, _ = r.jobService.assetRepository.FindByIds(strings.Split(r.ResourceIds, ",")) + } + + if len(assets) == 0 { + return + } + + msgChan := make(chan string) + for i := range assets { + asset := assets[i] + go func() { + t1 := time.Now() + active := utils.Tcping(asset.IP, asset.Port) + elapsed := time.Since(t1) + msg := fmt.Sprintf("资产「%v」存活状态检测完成,存活「%v」,耗时「%v」", asset.Name, active, elapsed) + + _ = r.jobService.assetRepository.UpdateActiveById(active, asset.ID) + logrus.Infof(msg) + msgChan <- msg + }() + } + + var message = "" + for i := 0; i < len(assets); i++ { + message += <-msgChan + "\n" + } + + _ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID) + jobLog := model.JobLog{ + ID: utils.UUID(), + JobId: r.ID, + Timestamp: utils.NowJsonTime(), + Message: message, + } + + _ = r.jobService.jobLogRepository.Create(&jobLog) +} + +type ShellJob struct { + ID string + Mode string + ResourceIds string + Metadata string + jobService *JobService +} + +type MetadataShell struct { + Shell string +} + +func (r ShellJob) Run() { + if r.ID == "" { + return + } + + var assets []model.Asset + if r.Mode == constant.JobModeAll { + assets, _ = r.jobService.assetRepository.FindByProtocol("ssh") + } else { + assets, _ = r.jobService.assetRepository.FindByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ",")) + } + + if len(assets) == 0 { + return + } + + var metadataShell MetadataShell + err := json.Unmarshal([]byte(r.Metadata), &metadataShell) + if err != nil { + logrus.Errorf("JSON数据解析失败 %v", err) + return + } + + msgChan := make(chan string) + for i := range assets { + asset, err := r.jobService.assetRepository.FindById(assets[i].ID) + if err != nil { + msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询数据异常「%v」", assets[i].Name, err.Error()) + return + } + + var ( + username = asset.Username + password = asset.Password + privateKey = asset.PrivateKey + passphrase = asset.Passphrase + ip = asset.IP + port = asset.Port + ) + + if asset.AccountType == "credential" { + credential, err := r.jobService.credentialRepository.FindById(asset.CredentialId) + if err != nil { + msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询授权凭证数据异常「%v」", assets[i].Name, err.Error()) + return + } + + if credential.Type == constant.Custom { + username = credential.Username + password = credential.Password + } else { + username = credential.Username + privateKey = credential.PrivateKey + passphrase = credential.Passphrase + } + } + + go func() { + + t1 := time.Now() + result, err := ExecCommandBySSH(metadataShell.Shell, ip, port, username, password, privateKey, passphrase) + elapsed := time.Since(t1) + var msg string + if err != nil { + msg = fmt.Sprintf("资产「%v」Shell执行失败,返回值「%v」,耗时「%v」", asset.Name, err.Error(), elapsed) + logrus.Infof(msg) + } else { + msg = fmt.Sprintf("资产「%v」Shell执行成功,返回值「%v」,耗时「%v」", asset.Name, result, elapsed) + logrus.Infof(msg) + } + + msgChan <- msg + }() + } + + var message = "" + for i := 0; i < len(assets); i++ { + message += <-msgChan + "\n" + } + + _ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID) + jobLog := model.JobLog{ + ID: utils.UUID(), + JobId: r.ID, + Timestamp: utils.NowJsonTime(), + Message: message, + } + + _ = r.jobService.jobLogRepository.Create(&jobLog) +} + +func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) { + sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase) + if err != nil { + return "", err + } + + session, err := sshClient.NewSession() + if err != nil { + return "", err + } + defer session.Close() + //执行远程命令 + combo, err := session.CombinedOutput(cmd) + if err != nil { + return "", err + } + return string(combo), nil +} + +func (r JobService) ExecJobById(id string) (err error) { + job, err := r.jobRepository.FindById(id) + if err != nil { + return err + } + j, err := getJob(&job, &r) + if err != nil { + return err + } + j.Run() + return nil +} + +func (r JobService) InitJob() error { + jobs, _ := r.jobRepository.FindByFunc(constant.FuncCheckAssetStatusJob) + if jobs == nil || len(jobs) == 0 { + job := model.Job{ + ID: utils.UUID(), + Name: "资产状态检测", + Func: constant.FuncCheckAssetStatusJob, + Cron: "0 0 0/1 * * ?", + Mode: constant.JobModeAll, + Status: constant.JobStatusRunning, + Created: utils.NowJsonTime(), + Updated: utils.NowJsonTime(), + } + if err := r.jobRepository.Create(&job); err != nil { + return err + } + logrus.Debugf("创建计划任务「%v」cron「%v」", job.Name, job.Cron) + } else { + for i := range jobs { + if jobs[i].Status == constant.JobStatusRunning { + err := r.ChangeJobStatusById(jobs[i].ID, constant.JobStatusRunning) + if err != nil { + return err + } + logrus.Debugf("启动计划任务「%v」cron「%v」", jobs[i].Name, jobs[i].Cron) + } + } + } + return nil +} + +// TODO 可能存在循环引用 +func (r UserService) ReloadToken() error { + loginLogs, err := r.loginLogRepository.FindAliveLoginLogs() + if err != nil { + return err + } + + for i := range loginLogs { + loginLog := loginLogs[i] + token := loginLog.ID + user, err := r.userRepository.FindById(loginLog.UserId) + if err != nil { + logrus.Debugf("用户「%v」获取失败,忽略", loginLog.UserId) + continue + } + + authorization := api.Authorization{ + Token: token, + Remember: loginLog.Remember, + User: user, + } + + cacheKey := api.BuildCacheKeyByToken(token) + + if authorization.Remember { + // 记住登录有效期两周 + global.Cache.Set(cacheKey, authorization, api.RememberEffectiveTime) + } else { + global.Cache.Set(cacheKey, authorization, api.NotRememberEffectiveTime) + } + logrus.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token) + } + return nil +} diff --git a/server/handle/runner.go b/server/service/property.go similarity index 53% rename from server/handle/runner.go rename to server/service/property.go index 1afcb76..62f5119 100644 --- a/server/handle/runner.go +++ b/server/service/property.go @@ -1,97 +1,30 @@ -package handle +package service import ( - "os" - "strconv" - "time" - - "next-terminal/server/constant" "next-terminal/server/guacd" "next-terminal/server/model" + "next-terminal/server/repository" "next-terminal/server/utils" - - "github.com/sirupsen/logrus" + "os" ) -func RunTicker() { - - // 每隔一小时删除一次未使用的会话信息 - unUsedSessionTicker := time.NewTicker(time.Minute * 60) - go func() { - for range unUsedSessionTicker.C { - sessions, _ := model.FindSessionByStatusIn([]string{constant.NoConnect, constant.Connecting}) - if len(sessions) > 0 { - now := time.Now() - for i := range sessions { - if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 { - _ = model.DeleteSessionById(sessions[i].ID) - s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port) - logrus.Infof("会话「%v」ID「%v」超过1小时未打开,已删除。", s, sessions[i].ID) - } - } - } - } - }() - - // 每日凌晨删除超过时长限制的会话 - timeoutSessionTicker := time.NewTicker(time.Hour * 24) - go func() { - for range timeoutSessionTicker.C { - property, err := model.FindPropertyByName("session-saved-limit") - if err != nil { - return - } - if property.Value == "" || property.Value == "-" { - return - } - limit, err := strconv.Atoi(property.Value) - if err != nil { - return - } - sessions, err := model.FindOutTimeSessions(limit) - if err != nil { - return - } - - if len(sessions) > 0 { - var sessionIds []string - for i := range sessions { - sessionIds = append(sessionIds, sessions[i].ID) - } - err := model.DeleteSessionByIds(sessionIds) - if err != nil { - logrus.Errorf("删除离线会话失败 %v", err) - } - } - } - }() +type PropertyService struct { + propertyRepository *repository.PropertyRepository } -func RunDataFix() { - sessions, _ := model.FindSessionByStatus(constant.Connected) - if sessions == nil { - return - } - - for i := range sessions { - session := model.Session{ - Status: constant.Disconnected, - DisconnectedTime: utils.NowJsonTime(), - } - - _ = model.UpdateSessionById(&session, sessions[i].ID) - } +func NewPropertyService(propertyRepository *repository.PropertyRepository) *PropertyService { + return &PropertyService{propertyRepository: propertyRepository} } -func InitProperties() error { - propertyMap := model.FindAllPropertiesMap() +func (r PropertyService) InitProperties() error { + propertyMap := r.propertyRepository.FindAllMap() if len(propertyMap[guacd.Host]) == 0 { property := model.Property{ Name: guacd.Host, Value: "127.0.0.1", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -101,7 +34,7 @@ func InitProperties() error { Name: guacd.Port, Value: "4822", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -111,7 +44,7 @@ func InitProperties() error { Name: guacd.EnableRecording, Value: "true", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -127,7 +60,7 @@ func InitProperties() error { return err } } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -137,7 +70,7 @@ func InitProperties() error { Name: guacd.CreateRecordingPath, Value: "true", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -147,7 +80,7 @@ func InitProperties() error { Name: guacd.DriveName, Value: "File-System", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -165,7 +98,7 @@ func InitProperties() error { return err } } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -175,7 +108,7 @@ func InitProperties() error { Name: guacd.FontName, Value: "menlo", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -185,7 +118,7 @@ func InitProperties() error { Name: guacd.FontSize, Value: "12", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -195,7 +128,7 @@ func InitProperties() error { Name: guacd.ColorScheme, Value: "gray-black", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -205,7 +138,7 @@ func InitProperties() error { Name: guacd.EnableDrive, Value: "true", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -215,7 +148,7 @@ func InitProperties() error { Name: guacd.EnableWallpaper, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -225,7 +158,7 @@ func InitProperties() error { Name: guacd.EnableTheming, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -235,7 +168,7 @@ func InitProperties() error { Name: guacd.EnableFontSmoothing, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -245,7 +178,7 @@ func InitProperties() error { Name: guacd.EnableFullWindowDrag, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -255,7 +188,7 @@ func InitProperties() error { Name: guacd.EnableDesktopComposition, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -265,7 +198,7 @@ func InitProperties() error { Name: guacd.EnableMenuAnimations, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -275,7 +208,7 @@ func InitProperties() error { Name: guacd.DisableBitmapCaching, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -285,7 +218,7 @@ func InitProperties() error { Name: guacd.DisableOffscreenCaching, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -295,7 +228,7 @@ func InitProperties() error { Name: guacd.DisableGlyphCaching, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } diff --git a/server/service/session.go b/server/service/session.go new file mode 100644 index 0000000..672db5a --- /dev/null +++ b/server/service/session.go @@ -0,0 +1,32 @@ +package service + +import ( + "next-terminal/server/constant" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" +) + +type SessionService struct { + sessionRepository *repository.SessionRepository +} + +func NewSessionService(sessionRepository *repository.SessionRepository) *SessionService { + return &SessionService{sessionRepository: sessionRepository} +} + +func (r SessionService) Fix() { + sessions, _ := r.sessionRepository.FindByStatus(constant.Connected) + if sessions == nil { + return + } + + for i := range sessions { + session := model.Session{ + Status: constant.Disconnected, + DisconnectedTime: utils.NowJsonTime(), + } + + _ = r.sessionRepository.UpdateById(&session, sessions[i].ID) + } +} diff --git a/server/service/user.go b/server/service/user.go new file mode 100644 index 0000000..25a4ed1 --- /dev/null +++ b/server/service/user.go @@ -0,0 +1,104 @@ +package service + +import ( + "github.com/sirupsen/logrus" + "next-terminal/server/constant" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" +) + +type UserService struct { + userRepository *repository.UserRepository + loginLogRepository *repository.LoginLogRepository +} + +func NewUserService(userRepository *repository.UserRepository) *UserService { + return &UserService{userRepository: userRepository} +} + +func (r UserService) InitUser() (err error) { + + users := r.userRepository.FindAll() + + if len(users) == 0 { + initPassword := "admin" + var pass []byte + if pass, err = utils.Encoder.Encode([]byte(initPassword)); err != nil { + return err + } + + user := model.User{ + ID: utils.UUID(), + Username: "admin", + Password: string(pass), + Nickname: "超级管理员", + Type: constant.TypeAdmin, + Created: utils.NowJsonTime(), + } + if err := r.userRepository.Create(&user); err != nil { + return err + } + logrus.Infof("初始用户创建成功,账号:「%v」密码:「%v」", user.Username, initPassword) + } else { + for i := range users { + // 修正默认用户类型为管理员 + if users[i].Type == "" { + user := model.User{ + Type: constant.TypeAdmin, + ID: users[i].ID, + } + if err := r.userRepository.Update(&user); err != nil { + return err + } + logrus.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID) + } + } + } + return nil +} + +func (r UserService) FixedOnlineState() error { + // 修正用户登录状态 + onlineUsers, err := r.userRepository.FindOnlineUsers() + if err != nil { + return err + } + for i := range onlineUsers { + logs, err := r.loginLogRepository.FindAliveLoginLogsByUserId(onlineUsers[i].ID) + if err != nil { + return err + } + if len(logs) == 0 { + if err := r.userRepository.UpdateOnline(onlineUsers[i].ID, false); err != nil { + return err + } + } + } + return nil +} + +func (r UserService) Logout(token string) (err error) { + + loginLog, err := r.loginLogRepository.FindById(token) + if err != nil { + logrus.Warnf("登录日志「%v」获取失败", token) + return + } + + loginLogForUpdate := &model.LoginLog{LogoutTime: utils.NowJsonTime(), ID: token} + err = r.loginLogRepository.Update(loginLogForUpdate) + if err != nil { + return err + } + + loginLogs, err := r.loginLogRepository.FindAliveLoginLogsByUserId(loginLog.UserId) + if err != nil { + return + } + + if len(loginLogs) == 0 { + err = r.userRepository.UpdateOnline(loginLog.UserId, false) + } + return +}