🐶 重构部分用户数据库操作代码

This commit is contained in:
dushixiang 2021-03-18 00:07:30 +08:00
parent e1cd73260a
commit 0150361054
50 changed files with 478 additions and 453 deletions

289
main.go
View File

@ -4,18 +4,19 @@ import (
"bytes"
"fmt"
"io"
"next-terminal/server/repository"
"os"
"strconv"
"strings"
"time"
"next-terminal/pkg/api"
"next-terminal/pkg/config"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/handle"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/api"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/handle"
"next-terminal/server/model"
"next-terminal/server/utils"
nested "github.com/antonfisher/nested-logrus-formatter"
"github.com/labstack/gommon/log"
@ -30,6 +31,11 @@ import (
const Version = "v0.3.3"
var (
db *gorm.DB
userRepository repository.UserRepository
)
func main() {
err := Run()
if err != nil {
@ -64,151 +70,31 @@ func Run() error {
logrus.SetOutput(io.MultiWriter(writer1, writer2, writer3))
global.Config, err = config.SetupConfig()
if err != nil {
return err
}
global.Config = config.SetupConfig()
db = SetupDB()
var logMode logger.Interface
if global.Config.Debug {
logMode = logger.Default.LogMode(logger.Info)
} else {
logMode = logger.Default.LogMode(logger.Silent)
}
fmt.Printf("当前数据库模式为:%v\n", global.Config.DB)
if global.Config.DB == "mysql" {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Hostname,
global.Config.Mysql.Port,
global.Config.Mysql.Database,
)
global.DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logMode,
})
} else {
global.DB, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{
Logger: logMode,
})
}
if err != nil {
logrus.Errorf("连接数据库异常:%v", err.Error())
return err
}
// 初始化 repository
global.DB = db
userRepository = repository.UserRepository{DB: db}
if global.Config.ResetPassword != "" {
user, err := model.FindUserByUsername(global.Config.ResetPassword)
if err != nil {
return err
}
password := "next-terminal"
passwd, err := utils.Encoder.Encode([]byte(password))
if err != nil {
return err
}
u := &model.User{
Password: string(passwd),
}
model.UpdateUserById(u, user.ID)
logrus.Debugf("用户「%v」密码初始化为: %v", user.Username, password)
return nil
return ResetPassword()
}
if err := global.DB.AutoMigrate(&model.User{}); err != nil {
if err := global.DB.AutoMigrate(&model.User{}, &model.Asset{}, &model.AssetAttribute{}, &model.Session{}, &model.Command{},
&model.Credential{}, &model.Property{}, &model.ResourceSharer{}, &model.UserGroup{}, &model.UserGroupMember{},
&model.LoginLog{}, &model.Num{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}); err != nil {
return err
}
users := model.FindAllUser()
if len(users) == 0 {
initPassword := "admin"
var pass []byte
if pass, err = utils.Encoder.Encode([]byte(initPassword)); err != nil {
return err
}
user := model.User{
ID: utils.UUID(),
Username: "admin",
Password: string(pass),
Nickname: "超级管理员",
Type: constant.TypeAdmin,
Created: utils.NowJsonTime(),
}
if err := model.CreateNewUser(&user); err != nil {
return err
}
logrus.Infof("初始用户创建成功,账号:「%v」密码「%v」", user.Username, initPassword)
} else {
for i := range users {
// 修正默认用户类型为管理员
if users[i].Type == "" {
user := model.User{
Type: constant.TypeAdmin,
}
model.UpdateUserById(&user, users[i].ID)
logrus.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID)
}
}
if err := InitDBData(); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.Asset{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.AssetAttribute{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.Session{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.Command{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.Credential{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.Property{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.ResourceSharer{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.UserGroup{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.UserGroupMember{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.LoginLog{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.Num{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.Job{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.JobLog{}); err != nil {
return err
}
if err := global.DB.AutoMigrate(&model.AccessSecurity{}); err != nil {
return err
}
if err := api.ReloadAccessSecurity(); err != nil {
return err
}
if len(model.FindAllTemp()) == 0 {
for i := 0; i <= 30; i++ {
if err := model.CreateNewTemp(&model.Num{I: strconv.Itoa(i)}); err != nil {
return err
}
}
}
// 配置缓存器
global.Cache = cache.New(5*time.Minute, 10*time.Minute)
global.Cache.OnEvicted(func(key string, value interface{}) {
@ -225,6 +111,68 @@ func Run() error {
global.Cron = cron.New(cron.WithSeconds()) //精确到秒
global.Cron.Start()
e := api.SetupRoutes(userRepository)
if err := handle.InitProperties(); err != nil {
return err
}
// 启动定时任务
go handle.RunTicker()
go handle.RunDataFix()
if global.Config.Server.Cert != "" && global.Config.Server.Key != "" {
return e.StartTLS(global.Config.Server.Addr, global.Config.Server.Cert, global.Config.Server.Key)
} else {
return e.Start(global.Config.Server.Addr)
}
}
func InitDBData() (err error) {
users := userRepository.FindAll()
if len(users) == 0 {
initPassword := "admin"
var pass []byte
if pass, err = utils.Encoder.Encode([]byte(initPassword)); err != nil {
return err
}
user := model.User{
ID: utils.UUID(),
Username: "admin",
Password: string(pass),
Nickname: "超级管理员",
Type: constant.TypeAdmin,
Created: utils.NowJsonTime(),
}
if err := userRepository.Create(&user); err != nil {
return err
}
logrus.Infof("初始用户创建成功,账号:「%v」密码「%v」", user.Username, initPassword)
} else {
for i := range users {
// 修正默认用户类型为管理员
if users[i].Type == "" {
user := model.User{
Type: constant.TypeAdmin,
ID: users[i].ID,
}
if err := userRepository.Update(&user); err != nil {
return err
}
logrus.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID)
}
}
}
if len(model.FindAllTemp()) == 0 {
for i := 0; i <= 30; i++ {
if err := model.CreateNewTemp(&model.Num{I: strconv.Itoa(i)}); err != nil {
return err
}
}
}
jobs, err := model.FindJobByFunc(constant.FuncCheckAssetStatusJob)
if err != nil {
return err
@ -264,7 +212,7 @@ func Run() error {
for i := range loginLogs {
loginLog := loginLogs[i]
token := loginLog.ID
user, err := model.FindUserById(loginLog.UserId)
user, err := userRepository.FindById(loginLog.UserId)
if err != nil {
logrus.Debugf("用户「%v」获取失败忽略", loginLog.UserId)
continue
@ -288,7 +236,7 @@ func Run() error {
}
// 修正用户登录状态
onlineUsers, err := model.FindOnlineUsers()
onlineUsers, err := userRepository.FindOnlineUsers()
if err != nil {
return err
}
@ -298,24 +246,67 @@ func Run() error {
return err
}
if len(logs) == 0 {
if err := model.UpdateUserOnline(false, onlineUsers[i].ID); err != nil {
if err := userRepository.UpdateOnline(onlineUsers[i].ID, false); err != nil {
return err
}
}
}
e := api.SetupRoutes()
if err := handle.InitProperties(); err != nil {
return nil
}
func ResetPassword() error {
user, err := userRepository.FindByUsername(global.Config.ResetPassword)
if err != nil {
return err
}
// 启动定时任务
go handle.RunTicker()
go handle.RunDataFix()
password := "next-terminal"
passwd, err := utils.Encoder.Encode([]byte(password))
if err != nil {
return err
}
u := &model.User{
Password: string(passwd),
ID: user.ID,
}
if err := userRepository.Update(u); err != nil {
return err
}
logrus.Debugf("用户「%v」密码初始化为: %v", user.Username, password)
return nil
}
if global.Config.Server.Cert != "" && global.Config.Server.Key != "" {
return e.StartTLS(global.Config.Server.Addr, global.Config.Server.Cert, global.Config.Server.Key)
func SetupDB() *gorm.DB {
var logMode logger.Interface
if global.Config.Debug {
logMode = logger.Default.LogMode(logger.Info)
} else {
return e.Start(global.Config.Server.Addr)
logMode = logger.Default.LogMode(logger.Silent)
}
fmt.Printf("当前数据库模式为:%v\n", global.Config.DB)
var err error
var db *gorm.DB
if global.Config.DB == "mysql" {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
global.Config.Mysql.Username,
global.Config.Mysql.Password,
global.Config.Mysql.Hostname,
global.Config.Mysql.Port,
global.Config.Mysql.Database,
)
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logMode,
})
} else {
db, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{
Logger: logMode,
})
}
if err != nil {
logrus.WithError(err).Panic("连接数据库异常")
}
return db
}

View File

@ -1,150 +0,0 @@
package model
import (
"reflect"
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
)
type User struct {
ID string `gorm:"primary_key" json:"id"`
Username string `gorm:"index" json:"username"`
Password string `json:"password"`
Nickname string `json:"nickname"`
TOTPSecret string `json:"-"`
Online bool `json:"online"`
Enabled bool `json:"enabled"`
Created utils.JsonTime `json:"created"`
Type string `json:"type"`
Mail string `json:"mail"`
}
type UserVo struct {
ID string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
TOTPSecret string `json:"totpSecret"`
Mail string `json:"mail"`
Online bool `json:"online"`
Enabled bool `json:"enabled"`
Created utils.JsonTime `json:"created"`
Type string `json:"type"`
SharerAssetCount int64 `json:"sharerAssetCount"`
}
func (r *User) TableName() string {
return "users"
}
func (r *User) IsEmpty() bool {
return reflect.DeepEqual(r, User{})
}
func FindAllUser() (o []User) {
if global.DB.Find(&o).Error != nil {
return nil
}
return
}
func FindPageUser(pageIndex, pageSize int, username, nickname, mail, order, field string) (o []UserVo, total int64, err error) {
db := global.DB.Table("users").Select("users.id,users.username,users.nickname,users.mail,users.online,users.enabled,users.created,users.type, count(resource_sharers.user_id) as sharer_asset_count, users.totp_secret").Joins("left join resource_sharers on users.id = resource_sharers.user_id and resource_sharers.resource_type = 'asset'").Group("users.id")
dbCounter := global.DB.Table("users")
if len(username) > 0 {
db = db.Where("users.username like ?", "%"+username+"%")
dbCounter = dbCounter.Where("username like ?", "%"+username+"%")
}
if len(nickname) > 0 {
db = db.Where("users.nickname like ?", "%"+nickname+"%")
dbCounter = dbCounter.Where("nickname like ?", "%"+nickname+"%")
}
if len(mail) > 0 {
db = db.Where("users.mail like ?", "%"+mail+"%")
dbCounter = dbCounter.Where("mail like ?", "%"+mail+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "username" {
field = "username"
} else if field == "nickname" {
field = "nickname"
} else {
field = "created"
}
err = db.Order("users." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]UserVo, 0)
}
for i := 0; i < len(o); i++ {
if o[i].TOTPSecret == "" || o[i].TOTPSecret == "-" {
o[i].TOTPSecret = "0"
} else {
o[i].TOTPSecret = "1"
}
}
return
}
func CreateNewUser(o *User) (err error) {
err = global.DB.Create(o).Error
return
}
func FindUserById(id string) (o User, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func FindUserByIdIn(ids []string) (o []User, err error) {
err = global.DB.Where("id in ?", ids).First(&o).Error
return
}
func FindUserByUsername(username string) (o User, err error) {
err = global.DB.Where("username = ?", username).First(&o).Error
return
}
func UpdateUserById(o *User, id string) {
o.ID = id
global.DB.Updates(o)
}
func UpdateUserOnline(online bool, id string) (err error) {
sql := "update users set online = ? where id = ?"
err = global.DB.Exec(sql, online, id).Error
return
}
func FindOnlineUsers() (o []User, err error) {
err = global.DB.Where("online = ?", true).Find(&o).Error
return
}
func DeleteUserById(id string) {
global.DB.Where("id = ?", id).Delete(&User{})
// 删除用户组中的用户关系
global.DB.Where("user_id = ?", id).Delete(&UserGroupMember{})
// 删除用户分享到的资产
global.DB.Where("user_id = ?", id).Delete(&ResourceSharer{})
}
func CountOnlineUser() (total int64, err error) {
err = global.DB.Where("online = ?", true).Find(&User{}).Count(&total).Error
return
}

View File

@ -4,10 +4,10 @@ import (
"strings"
"time"
"next-terminal/pkg/global"
"next-terminal/pkg/model"
"next-terminal/pkg/totp"
"next-terminal/pkg/utils"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/totp"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)
@ -46,7 +46,7 @@ func LoginEndpoint(c echo.Context) error {
return err
}
user, err := model.FindUserByUsername(loginAccount.Username)
user, err := userRepository.FindByUsername(loginAccount.Username)
// 存储登录失败次数信息
loginFailCountKey := loginAccount.Username
@ -116,8 +116,9 @@ func LoginSuccess(c echo.Context, loginAccount LoginAccount, user model.User) (t
}
// 修改登录状态
model.UpdateUserById(&model.User{Online: true}, user.ID)
return token, nil
err = userRepository.Update(&model.User{Online: true, ID: user.ID})
return token, err
}
func BuildCacheKeyByToken(token string) string {
@ -147,7 +148,7 @@ func loginWithTotpEndpoint(c echo.Context) error {
return Fail(c, -1, "登录失败次数过多,请稍后再试")
}
user, err := model.FindUserByUsername(loginAccount.Username)
user, err := userRepository.FindByUsername(loginAccount.Username)
if err != nil {
count++
global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5))
@ -202,9 +203,12 @@ func ConfirmTOTPEndpoint(c echo.Context) error {
u := &model.User{
TOTPSecret: confirmTOTP.Secret,
ID: account.ID,
}
model.UpdateUserById(u, account.ID)
if err := userRepository.Update(u); err != nil {
return err
}
return Success(c, nil)
}
@ -240,8 +244,11 @@ func ResetTOTPEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
u := &model.User{
TOTPSecret: "-",
ID: account.ID,
}
if err := userRepository.Update(u); err != nil {
return err
}
model.UpdateUserById(u, account.ID)
return Success(c, "")
}
@ -266,9 +273,12 @@ func ChangePasswordEndpoint(c echo.Context) error {
}
u := &model.User{
Password: string(passwd),
ID: account.ID,
}
model.UpdateUserById(u, account.ID)
if err := userRepository.Update(u); err != nil {
return err
}
return LogoutEndpoint(c)
}
@ -284,7 +294,7 @@ type AccountInfo struct {
func InfoEndpoint(c echo.Context) error {
account, _ := GetCurrentAccount(c)
user, err := model.FindUserById(account.ID)
user, err := userRepository.FindById(account.ID)
if err != nil {
return err
}

View File

@ -8,9 +8,9 @@ import (
"strconv"
"strings"
"next-terminal/pkg/constant"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)

View File

@ -5,8 +5,8 @@ import (
"strconv"
"strings"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)

View File

@ -5,9 +5,9 @@ import (
"strconv"
"strings"
"next-terminal/pkg/constant"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)

View File

@ -4,8 +4,8 @@ import (
"strconv"
"strings"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)

View File

@ -4,8 +4,8 @@ import (
"strconv"
"strings"
"next-terminal/pkg/global"
"next-terminal/pkg/model"
"next-terminal/server/global"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"

View File

@ -7,9 +7,9 @@ import (
"strings"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)

View File

@ -1,8 +1,8 @@
package api
import (
"next-terminal/pkg/constant"
"next-terminal/pkg/model"
"next-terminal/server/constant"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
)
@ -24,12 +24,12 @@ func OverviewCounterEndPoint(c echo.Context) error {
asset int64
)
if constant.TypeUser == account.Type {
countUser, _ = model.CountOnlineUser()
countUser, _ = userRepository.CountOnlineUser()
countOnlineSession, _ = model.CountOnlineSession()
credential, _ = model.CountCredentialByUserId(account.ID)
asset, _ = model.CountAssetByUserId(account.ID)
} else {
countUser, _ = model.CountOnlineUser()
countUser, _ = userRepository.CountOnlineUser()
countOnlineSession, _ = model.CountOnlineSession()
credential, _ = model.CountCredential()
asset, _ = model.CountAsset()

View File

@ -4,7 +4,7 @@ import (
"errors"
"fmt"
"next-terminal/pkg/model"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
"gorm.io/gorm"

View File

@ -1,7 +1,7 @@
package api
import (
"next-terminal/pkg/model"
"next-terminal/server/model"
"github.com/labstack/echo/v4"
)

View File

@ -2,11 +2,11 @@ package api
import (
"net/http"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/log"
"next-terminal/pkg/model"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
@ -14,7 +14,12 @@ import (
const Token = "X-Auth-Token"
func SetupRoutes() *echo.Echo {
var (
userRepository repository.UserRepository
)
func SetupRoutes(ur repository.UserRepository) *echo.Echo {
userRepository = ur
e := echo.New()
e.HideBanner = true

View File

@ -4,9 +4,9 @@ import (
"strconv"
"strings"
"next-terminal/pkg/global"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)

View File

@ -13,10 +13,10 @@ import (
"strings"
"sync"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"github.com/pkg/sftp"

View File

@ -7,12 +7,12 @@ import (
"strconv"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/pkg/model"
"next-terminal/pkg/term"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"next-terminal/server/model"
"next-terminal/server/term"
"next-terminal/server/utils"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"

View File

@ -5,10 +5,10 @@ import (
"path"
"strconv"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/pkg/model"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"next-terminal/server/model"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"

View File

@ -4,9 +4,9 @@ import (
"strconv"
"strings"
"next-terminal/pkg/global"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
)

View File

@ -4,9 +4,9 @@ import (
"strconv"
"strings"
"next-terminal/pkg/global"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"github.com/sirupsen/logrus"
@ -29,7 +29,7 @@ func UserCreateEndpoint(c echo.Context) error {
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := model.CreateNewUser(&item); err != nil {
if err := userRepository.Create(&item); err != nil {
return err
}
@ -49,7 +49,7 @@ func UserPagingEndpoint(c echo.Context) error {
order := c.QueryParam("order")
field := c.QueryParam("field")
items, total, err := model.FindPageUser(pageIndex, pageSize, username, nickname, mail, order, field)
items, total, err := userRepository.Find(pageIndex, pageSize, username, nickname, mail, order, field)
if err != nil {
return err
}
@ -67,8 +67,11 @@ func UserUpdateEndpoint(c echo.Context) error {
if err := c.Bind(&item); err != nil {
return err
}
item.ID = id
model.UpdateUserById(&item, id)
if err := userRepository.Update(&item); err != nil {
return err
}
return Success(c, nil)
}
@ -100,7 +103,9 @@ func UserDeleteEndpoint(c echo.Context) error {
}
// 删除用户
model.DeleteUserById(userId)
if err := userRepository.DeleteById(userId); err != nil {
return err
}
}
return Success(c, nil)
@ -109,7 +114,7 @@ func UserDeleteEndpoint(c echo.Context) error {
func UserGetEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := model.FindUserById(id)
item, err := userRepository.FindById(id)
if err != nil {
return err
}
@ -121,7 +126,7 @@ func UserChangePasswordEndpoint(c echo.Context) error {
id := c.Param("id")
password := c.QueryParam("password")
user, err := model.FindUserById(id)
user, err := userRepository.FindById(id)
if err != nil {
return err
}
@ -132,8 +137,11 @@ func UserChangePasswordEndpoint(c echo.Context) error {
}
u := &model.User{
Password: string(passwd),
ID: id,
}
if err := userRepository.Update(u); err != nil {
return err
}
model.UpdateUserById(u, id)
if user.Mail != "" {
go model.SendMail(user.Mail, "[Next Terminal] 密码修改通知", "你好,"+user.Nickname+"。管理员已将你的密码修改为:"+password)
@ -146,7 +154,10 @@ func UserResetTotpEndpoint(c echo.Context) error {
id := c.Param("id")
u := &model.User{
TOTPSecret: "-",
ID: id,
}
if err := userRepository.Update(u); err != nil {
return err
}
model.UpdateUserById(u, id)
return Success(c, "")
}

View File

@ -36,7 +36,7 @@ type Server struct {
Key string
}
func SetupConfig() (*Config, error) {
func SetupConfig() *Config {
viper.SetConfigName("config")
viper.SetConfigType("yml")
@ -85,5 +85,5 @@ func SetupConfig() (*Config, error) {
Demo: viper.GetBool("demo"),
}
return config, nil
return config
}

View File

@ -1,7 +1,7 @@
package global
import (
"next-terminal/pkg/config"
"next-terminal/server/config"
"github.com/patrickmn/go-cache"
"github.com/robfig/cron/v3"

View File

@ -4,8 +4,8 @@ import (
"strconv"
"sync"
"next-terminal/pkg/guacd"
"next-terminal/pkg/term"
"next-terminal/server/guacd"
"next-terminal/server/term"
"github.com/gorilla/websocket"
)

View File

@ -5,10 +5,10 @@ import (
"strconv"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/guacd"
"next-terminal/pkg/model"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/guacd"
"next-terminal/server/model"
"next-terminal/server/utils"
"github.com/sirupsen/logrus"
)

View File

@ -1,7 +1,7 @@
package model
import (
"next-terminal/pkg/global"
"next-terminal/server/global"
)
type AccessSecurity struct {

View File

@ -3,9 +3,9 @@ package model
import (
"strings"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Asset struct {

View File

@ -3,10 +3,10 @@ package model
import (
"fmt"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"gorm.io/gorm"

View File

@ -1,9 +1,9 @@
package model
import (
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Command struct {

View File

@ -1,9 +1,9 @@
package model
import (
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Credential struct {

View File

@ -7,10 +7,10 @@ import (
"strings"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/term"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/term"
"next-terminal/server/utils"
"github.com/robfig/cron/v3"
"github.com/sirupsen/logrus"

View File

@ -1,8 +1,8 @@
package model
import (
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type JobLog struct {

View File

@ -1,10 +1,8 @@
package model
import (
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"github.com/sirupsen/logrus"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type LoginLog struct {
@ -83,25 +81,26 @@ func FindLoginLogById(id string) (o LoginLog, err error) {
}
func Logout(token string) (err error) {
loginLog, err := FindLoginLogById(token)
if err != nil {
logrus.Warnf("登录日志「%v」获取失败", token)
return
}
err = global.DB.Updates(&LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}).Error
if err != nil {
return err
}
loginLogs, err := FindAliveLoginLogsByUserId(loginLog.UserId)
if err != nil {
return
}
if len(loginLogs) == 0 {
err = UpdateUserOnline(false, loginLog.UserId)
}
//
//loginLog, err := FindLoginLogById(token)
//if err != nil {
// logrus.Warnf("登录日志「%v」获取失败", token)
// return
//}
//
//err = global.DB.Updates(&LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}).Error
//if err != nil {
// return err
//}
//
//loginLogs, err := FindAliveLoginLogsByUserId(loginLog.UserId)
//if err != nil {
// return
//}
//
//if len(loginLogs) == 0 {
// // TODO
// err = UpdateUserOnline(false, loginLog.UserId)
//}
return
}

View File

@ -1,7 +1,7 @@
package model
import (
"next-terminal/pkg/global"
"next-terminal/server/global"
)
type Num struct {

View File

@ -3,9 +3,9 @@ package model
import (
"net/smtp"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/guacd"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/guacd"
"github.com/jordan-wright/email"
"github.com/sirupsen/logrus"

View File

@ -1,8 +1,8 @@
package model
import (
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/global"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"

View File

@ -5,9 +5,9 @@ import (
"path"
"time"
"next-terminal/pkg/constant"
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/utils"
)
type Session struct {

View File

@ -1,6 +1,6 @@
package model
import "next-terminal/pkg/global"
import "next-terminal/server/global"
type UserAttribute struct {
Id string `gorm:"index" json:"id"`

View File

@ -1,6 +1,6 @@
package model
import "next-terminal/pkg/global"
import "next-terminal/server/global"
type UserGroupMember struct {
ID string `gorm:"primary_key" json:"name"`

View File

@ -1,8 +1,8 @@
package model
import (
"next-terminal/pkg/global"
"next-terminal/pkg/utils"
"next-terminal/server/global"
"next-terminal/server/utils"
"gorm.io/gorm"
)
@ -75,23 +75,24 @@ func CreateNewUserGroup(o *UserGroup, members []string) (err error) {
}
func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error {
for i := range userIds {
userId := userIds[i]
_, err := FindUserById(userId)
if err != nil {
return err
}
userGroupMember := UserGroupMember{
ID: utils.Sign([]string{userGroupId, userId}),
UserId: userId,
UserGroupId: userGroupId,
}
err = tx.Create(&userGroupMember).Error
if err != nil {
return err
}
}
//for i := range userIds {
// userId := userIds[i]
// // TODO
// _, err := FindUserById(userId)
// if err != nil {
// return err
// }
//
// userGroupMember := UserGroupMember{
// ID: utils.Sign([]string{userGroupId, userId}),
// UserId: userId,
// UserGroupId: userGroupId,
// }
// err = tx.Create(&userGroupMember).Error
// if err != nil {
// return err
// }
//}
return nil
}

35
server/model/user.go Normal file
View File

@ -0,0 +1,35 @@
package model
import (
"next-terminal/server/utils"
)
type User struct {
ID string `gorm:"primary_key" json:"id"`
Username string `gorm:"index" json:"username"`
Password string `json:"password"`
Nickname string `json:"nickname"`
TOTPSecret string `json:"-"`
Online bool `json:"online"`
Enabled bool `json:"enabled"`
Created utils.JsonTime `json:"created"`
Type string `json:"type"`
Mail string `json:"mail"`
}
type UserVo struct {
ID string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
TOTPSecret string `json:"totpSecret"`
Mail string `json:"mail"`
Online bool `json:"online"`
Enabled bool `json:"enabled"`
Created utils.JsonTime `json:"created"`
Type string `json:"type"`
SharerAssetCount int64 `json:"sharerAssetCount"`
}
func (r *User) TableName() string {
return "users"
}

123
server/repository/user.go Normal file
View File

@ -0,0 +1,123 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/model"
)
type UserRepository struct {
DB *gorm.DB
}
func (r UserRepository) FindAll() (o []model.User) {
if r.DB.Find(&o).Error != nil {
return nil
}
return
}
func (r UserRepository) Find(pageIndex, pageSize int, username, nickname, mail, order, field string) (o []model.UserVo, total int64, err error) {
db := r.DB.Table("users").Select("users.id,users.username,users.nickname,users.mail,users.online,users.enabled,users.created,users.type, count(resource_sharers.user_id) as sharer_asset_count, users.totp_secret").Joins("left join resource_sharers on users.id = resource_sharers.user_id and resource_sharers.resource_type = 'asset'").Group("users.id")
dbCounter := r.DB.Table("users")
if len(username) > 0 {
db = db.Where("users.username like ?", "%"+username+"%")
dbCounter = dbCounter.Where("username like ?", "%"+username+"%")
}
if len(nickname) > 0 {
db = db.Where("users.nickname like ?", "%"+nickname+"%")
dbCounter = dbCounter.Where("nickname like ?", "%"+nickname+"%")
}
if len(mail) > 0 {
db = db.Where("users.mail like ?", "%"+mail+"%")
dbCounter = dbCounter.Where("mail like ?", "%"+mail+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "username" {
field = "username"
} else if field == "nickname" {
field = "nickname"
} else {
field = "created"
}
err = db.Order("users." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.UserVo, 0)
}
for i := 0; i < len(o); i++ {
if o[i].TOTPSecret == "" || o[i].TOTPSecret == "-" {
o[i].TOTPSecret = "0"
} else {
o[i].TOTPSecret = "1"
}
}
return
}
func (r UserRepository) FindById(id string) (o model.User, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r UserRepository) FindByUsername(username string) (o model.User, err error) {
err = r.DB.Where("username = ?", username).First(&o).Error
return
}
func (r UserRepository) FindOnlineUsers() (o []model.User, err error) {
err = r.DB.Where("online = ?", true).Find(&o).Error
return
}
func (r UserRepository) Create(o *model.User) error {
return r.DB.Create(o).Error
}
func (r UserRepository) Update(o *model.User) error {
return r.DB.Updates(o).Error
}
func (r UserRepository) UpdateOnline(id string, online bool) error {
sql := "update users set online = ? where id = ?"
return r.DB.Exec(sql, online, id).Error
}
func (r UserRepository) DeleteById(id string) error {
return r.DB.Transaction(func(tx *gorm.DB) (err error) {
// 删除用户
err = tx.Where("id = ?", id).Delete(&model.User{}).Error
if err != nil {
return err
}
// 删除用户组中的用户关系
err = tx.Where("user_id = ?", id).Delete(&model.UserGroupMember{}).Error
if err != nil {
return err
}
// 删除用户分享到的资产
err = tx.Where("user_id = ?", id).Delete(&model.ResourceSharer{}).Error
if err != nil {
return err
}
return nil
})
}
func (r UserRepository) CountOnlineUser() (total int64, err error) {
err = r.DB.Where("online = ?", true).Find(&model.User{}).Count(&total).Error
return
}

View File

@ -5,7 +5,7 @@ import (
"os"
"time"
"next-terminal/pkg/utils"
"next-terminal/server/utils"
)
type Env struct {