diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 0000000..8459891 --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,22 @@ +name: golangci-lint +on: + push: + tags: + - v* + branches: + - master + - be-* + - dev + pull_request: +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version + version: latest + args: --timeout=5m diff --git a/.gitignore b/.gitignore index 6f4c7f7..520d8be 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,26 @@ web/build *.db .DS_Store .eslintcache -.env \ No newline at end of file +.env + +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# ignore ide config +.idea +.vscode + +# vim +*.swp + + +# playground +playground/data +playground/drive +playground/recording + +/log diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..9afa411 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,23 @@ +linters-settings: + goimports: + local-prefixes: next-terminal + +linters: + disable-all: true + enable: + - deadcode + - errcheck + - gofmt + - goimports + - gosimple + - govet + - ineffassign + - staticcheck + - structcheck + - typecheck + - unused + - varcheck + +run: + skip-files: + - test_ssh.go diff --git a/config.yml b/config.yml index 77d97dd..e538301 100644 --- a/config.yml +++ b/config.yml @@ -2,10 +2,10 @@ debug: true demo: false db: mysql mysql: - hostname: 172.16.101.32 + hostname: localhost port: 3306 - username: root - password: mysql + username: next-terminal + password: next-terminal database: next-terminal sqlite: file: 'next-terminal.db' diff --git a/docs/screenshot.md b/docs/screenshot.md index 43a232a..ae20567 100644 --- a/docs/screenshot.md +++ b/docs/screenshot.md @@ -1,23 +1,69 @@ -资源占用截图 +### 资源占用 + +未使用时资源占用非常小  -资产管理 +### 控制面板 + +更方便的概览系统信息 + + + +### 资产管理 + +支持多种RDP、SSH、Telnet、VNC,Kubernetes等多种协议的资产  -rdp +#### rdp  -vnc +#### vnc  -ssh +#### ssh  +### 授权凭证 + +极为方便的复用资产认证信息 + + + 批量执行命令 - \ No newline at end of file + + +### 在线监控 + +实时监控用户的操作,并可以随时断开该会话 + + + +### 离线回放 + +详细的数据回放,定位任何一个可疑操作 + + + +### 计划任务 + +自定义计划任务 + + + +### 访问安全 + +黑白名单访问控制,支持ip、cidr及连续IP + + + +### 用户组授权 + +灵活的授权策略 + + \ No newline at end of file diff --git a/go.mod b/go.mod index db596f1..3e283fa 100644 --- a/go.mod +++ b/go.mod @@ -10,13 +10,16 @@ require ( github.com/labstack/echo/v4 v4.1.17 github.com/labstack/gommon v0.3.0 github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.12.0 github.com/pquerna/otp v1.3.0 github.com/robfig/cron/v3 v3.0.1 github.com/sirupsen/logrus v1.4.2 github.com/spf13/pflag v1.0.3 github.com/spf13/viper v1.7.1 + github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a + gopkg.in/natefinch/lumberjack.v2 v2.0.0 gorm.io/driver/mysql v1.0.3 gorm.io/driver/sqlite v1.1.4 gorm.io/gorm v1.20.7 diff --git a/go.sum b/go.sum index ede144e..44a95a0 100644 --- a/go.sum +++ b/go.sum @@ -353,6 +353,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= +gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/main.go b/main.go index 939dd7a..597d7d6 100644 --- a/main.go +++ b/main.go @@ -1,32 +1,18 @@ package main import ( - "bytes" "fmt" - nested "github.com/antonfisher/nested-logrus-formatter" - "github.com/labstack/gommon/log" - "github.com/patrickmn/go-cache" - "github.com/robfig/cron/v3" - "github.com/sirupsen/logrus" - "gorm.io/driver/mysql" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" - "io" - "next-terminal/pkg/api" + "next-terminal/pkg/config" - "next-terminal/pkg/constant" "next-terminal/pkg/global" - "next-terminal/pkg/handle" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" - "os" - "strconv" - "strings" - "time" + "next-terminal/pkg/task" + "next-terminal/server/api" + "next-terminal/server/repository" + + "github.com/labstack/gommon/log" ) -const Version = "v0.3.3" +const Version = "v0.3.4" func main() { err := Run() @@ -45,270 +31,20 @@ func Run() error { \____|__ /\___ >__/\_ \ |__| |____| \___ >__| |__|_| /__|___| (____ /____/ \/ \/ \/ \/ \/ \/ \/ ` + Version + "\n\n") - var err error - //logrus.SetReportCaller(true) - logrus.SetLevel(logrus.DebugLevel) - logrus.SetFormatter(&nested.Formatter{ - HideKeys: true, - FieldsOrder: []string{"component", "category"}, - }) + // 为了兼容之前调用global包的代码 后期预期会改为调用pgk/config + global.Config = config.GlobalCfg - writer1 := &bytes.Buffer{} - writer2 := os.Stdout - writer3, err := os.OpenFile("next-terminal.log", os.O_WRONLY|os.O_CREATE, 0755) - if err != nil { - log.Fatalf("create file log.txt failed: %v", err) - } - - logrus.SetOutput(io.MultiWriter(writer1, writer2, writer3)) - - global.Config, err = config.SetupConfig() - if err != nil { - return err - } - - var logMode logger.Interface - if global.Config.Debug { - logMode = logger.Default.LogMode(logger.Info) - } else { - logMode = logger.Default.LogMode(logger.Silent) - } - - fmt.Printf("当前数据库模式为:%v\n", global.Config.DB) - if global.Config.DB == "mysql" { - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", - global.Config.Mysql.Username, - global.Config.Mysql.Password, - global.Config.Mysql.Hostname, - global.Config.Mysql.Port, - global.Config.Mysql.Database, - ) - global.DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ - Logger: logMode, - }) - } else { - global.DB, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{ - Logger: logMode, - }) - } - - if err != nil { - logrus.Errorf("连接数据库异常:%v", err.Error()) - return err - } + global.Cache = api.SetupCache() + db := api.SetupDB() + e := api.SetupRoutes(db) if global.Config.ResetPassword != "" { - user, err := model.FindUserByUsername(global.Config.ResetPassword) - if err != nil { - return err - } - password := "next-terminal" - passwd, err := utils.Encoder.Encode([]byte(password)) - if err != nil { - return err - } - u := &model.User{ - Password: string(passwd), - } - model.UpdateUserById(u, user.ID) - logrus.Debugf("用户「%v」密码初始化为: %v", user.Username, password) - return nil + return api.ResetPassword() } - - if err := global.DB.AutoMigrate(&model.User{}); err != nil { - return err - } - - users := model.FindAllUser() - if len(users) == 0 { - - initPassword := "admin" - var pass []byte - if pass, err = utils.Encoder.Encode([]byte(initPassword)); err != nil { - return err - } - - user := model.User{ - ID: utils.UUID(), - Username: "admin", - Password: string(pass), - Nickname: "超级管理员", - Type: constant.TypeAdmin, - Created: utils.NowJsonTime(), - } - if err := model.CreateNewUser(&user); err != nil { - return err - } - logrus.Infof("初始用户创建成功,账号:「%v」密码:「%v」", user.Username, initPassword) - } else { - for i := range users { - // 修正默认用户类型为管理员 - if users[i].Type == "" { - user := model.User{ - Type: constant.TypeAdmin, - } - model.UpdateUserById(&user, users[i].ID) - logrus.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID) - } - } - } - - if err := global.DB.AutoMigrate(&model.Asset{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.AssetAttribute{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.Session{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.Command{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.Credential{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.Property{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.ResourceSharer{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.UserGroup{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.UserGroupMember{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.LoginLog{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.Num{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.Job{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.JobLog{}); err != nil { - return err - } - if err := global.DB.AutoMigrate(&model.AccessSecurity{}); err != nil { - return err - } - if err := api.ReloadAccessSecurity(); err != nil { - return err - } - - if len(model.FindAllTemp()) == 0 { - for i := 0; i <= 30; i++ { - if err := model.CreateNewTemp(&model.Num{I: strconv.Itoa(i)}); err != nil { - return err - } - } - } - - // 配置缓存器 - global.Cache = cache.New(5*time.Minute, 10*time.Minute) - global.Cache.OnEvicted(func(key string, value interface{}) { - if strings.HasPrefix(key, api.Token) { - token := api.GetTokenFormCacheKey(key) - logrus.Debugf("用户Token「%v」过期", token) - err := model.Logout(token) - if err != nil { - logrus.Errorf("退出登录失败 %v", err) - } - } - }) - global.Store = global.NewStore() - global.Cron = cron.New(cron.WithSeconds()) //精确到秒 - global.Cron.Start() - - jobs, err := model.FindJobByFunc(constant.FuncCheckAssetStatusJob) - if err != nil { - return err - } - if jobs == nil || len(jobs) == 0 { - job := model.Job{ - ID: utils.UUID(), - Name: "资产状态检测", - Func: constant.FuncCheckAssetStatusJob, - Cron: "0 0 0/1 * * ?", - Mode: constant.JobModeAll, - Status: constant.JobStatusRunning, - Created: utils.NowJsonTime(), - Updated: utils.NowJsonTime(), - } - if err := model.CreateNewJob(&job); err != nil { - return err - } - logrus.Debugf("创建计划任务「%v」cron「%v」", job.Name, job.Cron) - } else { - for i := range jobs { - if jobs[i].Status == constant.JobStatusRunning { - err := model.ChangeJobStatusById(jobs[i].ID, constant.JobStatusRunning) - if err != nil { - return err - } - logrus.Debugf("启动计划任务「%v」cron「%v」", jobs[i].Name, jobs[i].Cron) - } - } - } - - loginLogs, err := model.FindAliveLoginLogs() - if err != nil { - return err - } - - for i := range loginLogs { - loginLog := loginLogs[i] - token := loginLog.ID - user, err := model.FindUserById(loginLog.UserId) - if err != nil { - logrus.Debugf("用户「%v」获取失败,忽略", loginLog.UserId) - continue - } - - authorization := api.Authorization{ - Token: token, - Remember: loginLog.Remember, - User: user, - } - - cacheKey := api.BuildCacheKeyByToken(token) - - if authorization.Remember { - // 记住登录有效期两周 - global.Cache.Set(cacheKey, authorization, api.RememberEffectiveTime) - } else { - global.Cache.Set(cacheKey, authorization, api.NotRememberEffectiveTime) - } - logrus.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token) - } - - // 修正用户登录状态 - onlineUsers, err := model.FindOnlineUsers() - if err != nil { - return err - } - for i := range onlineUsers { - logs, err := model.FindAliveLoginLogsByUserId(onlineUsers[i].ID) - if err != nil { - return err - } - if len(logs) == 0 { - if err := model.UpdateUserOnline(false, onlineUsers[i].ID); err != nil { - return err - } - } - } - - e := api.SetupRoutes() - if err := handle.InitProperties(); err != nil { - return err - } - // 启动定时任务 - go handle.RunTicker() - go handle.RunDataFix() + sessionRepo := repository.NewSessionRepository(db) + propertyRepo := repository.NewPropertyRepository(db) + ticker := task.NewTicker(sessionRepo, propertyRepo) + ticker.SetupTicker() if global.Config.Server.Cert != "" && global.Config.Server.Key != "" { return e.StartTLS(global.Config.Server.Addr, global.Config.Server.Cert, global.Config.Server.Key) diff --git a/pkg/config/config.go b/pkg/config/config.go index ec37f6a..5fe0868 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,12 +1,15 @@ package config import ( - "github.com/spf13/pflag" "strings" + "github.com/spf13/pflag" + "github.com/spf13/viper" ) +var GlobalCfg *Config + type Config struct { Debug bool Demo bool @@ -35,7 +38,7 @@ type Server struct { Key string } -func SetupConfig() (*Config, error) { +func SetupConfig() *Config { viper.SetConfigName("config") viper.SetConfigType("yml") @@ -83,6 +86,10 @@ func SetupConfig() (*Config, error) { Debug: viper.GetBool("debug"), Demo: viper.GetBool("demo"), } - - return config, nil + GlobalCfg = config + return config +} + +func init() { + GlobalCfg = SetupConfig() } diff --git a/pkg/constant/const.go b/pkg/constant/const.go index cdba074..2a2406e 100644 --- a/pkg/constant/const.go +++ b/pkg/constant/const.go @@ -1,5 +1,7 @@ package constant +import "next-terminal/pkg/guacd" + const ( AccessRuleAllow = "allow" // 允许访问 AccessRuleReject = "reject" // 拒绝访问 @@ -31,3 +33,9 @@ const ( TypeUser = "user" // 普通用户 TypeAdmin = "admin" // 管理员 ) + +var SSHParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, SshMode} +var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs} +var VNCParameterNames = []string{guacd.ColorDepth, guacd.Cursor, guacd.SwapRedBlue, guacd.DestHost, guacd.DestPort} +var TelnetParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.UsernameRegex, guacd.PasswordRegex, guacd.LoginSuccessRegex, guacd.LoginFailureRegex} +var KubernetesParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.Namespace, guacd.Pod, guacd.Container, guacd.UesSSL, guacd.ClientCert, guacd.ClientKey, guacd.CaCert, guacd.IgnoreCert} diff --git a/pkg/global/global.go b/pkg/global/global.go index 95d9ee0..77019f4 100644 --- a/pkg/global/global.go +++ b/pkg/global/global.go @@ -1,14 +1,12 @@ package global import ( + "next-terminal/pkg/config" + "github.com/patrickmn/go-cache" "github.com/robfig/cron/v3" - "gorm.io/gorm" - "next-terminal/pkg/config" ) -var DB *gorm.DB - var Cache *cache.Cache var Config *config.Config @@ -23,3 +21,8 @@ type Security struct { } var Securities []*Security + +func init() { + Cron = cron.New(cron.WithSeconds()) + Cron.Start() +} diff --git a/pkg/global/store.go b/pkg/global/store.go index 6e91a7b..d5a19b8 100644 --- a/pkg/global/store.go +++ b/pkg/global/store.go @@ -1,11 +1,13 @@ package global import ( - "github.com/gorilla/websocket" - "next-terminal/pkg/guacd" - "next-terminal/pkg/term" "strconv" "sync" + + "next-terminal/pkg/guacd" + "next-terminal/pkg/term" + + "github.com/gorilla/websocket" ) type Tun struct { @@ -67,3 +69,7 @@ func NewStore() *TunStore { store := TunStore{sync.Map{}} return &store } + +func init() { + Store = NewStore() +} diff --git a/pkg/log/logger.go b/pkg/log/logger.go index ed3c485..aba71db 100644 --- a/pkg/log/logger.go +++ b/pkg/log/logger.go @@ -1,155 +1,235 @@ package log import ( + "fmt" "io" + "os" + "path" + "path/filepath" "strconv" + "strings" "time" + "next-terminal/pkg/config" + "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" "github.com/sirupsen/logrus" + "gopkg.in/natefinch/lumberjack.v2" ) -// Logrus : implement Logger +type Formatter struct{} + +func (s *Formatter) Format(entry *logrus.Entry) ([]byte, error) { + timestamp := time.Now().Local().Format("2006-01-02 15:04:05") + var file string + var l int + if entry.HasCaller() { + file = filepath.Base(entry.Caller.Function) + l = entry.Caller.Line + } + + msg := fmt.Sprintf("%s %s [%s:%d]%s\n", timestamp, strings.ToUpper(entry.Level.String()), file, l, entry.Message) + return []byte(msg), nil +} + +var stdOut = NewLogger() + +// Trace logs a message at level Trace on the standard logger. +func Trace(args ...interface{}) { + stdOut.Trace(args...) +} + +// Debug logs a message at level Debug on the standard logger. +func Debug(args ...interface{}) { + stdOut.Debug(args...) +} + +// Print logs a message at level Info on the standard logger. +func Print(args ...interface{}) { + stdOut.Print(args...) +} + +// Info logs a message at level Info on the standard logger. +func Info(args ...interface{}) { + stdOut.Info(args...) +} + +// Warn logs a message at level Warn on the standard logger. +func Warn(args ...interface{}) { + stdOut.Warn(args...) +} + +// Warning logs a message at level Warn on the standard logger. +func Warning(args ...interface{}) { + stdOut.Warning(args...) +} + +// Error logs a message at level Error on the standard logger. +func Error(args ...interface{}) { + stdOut.Error(args...) +} + +// Panic logs a message at level Panic on the standard logger. +func Panic(args ...interface{}) { + stdOut.Panic(args...) +} + +// Fatal logs a message at level Fatal on the standard logger then the process will exit with status set to 1. +func Fatal(args ...interface{}) { + stdOut.Fatal(args...) +} + +// Tracef logs a message at level Trace on the standard logger. +func Tracef(format string, args ...interface{}) { + stdOut.Tracef(format, args...) +} + +// Debugf logs a message at level Debug on the standard logger. +func Debugf(format string, args ...interface{}) { + stdOut.Debugf(format, args...) +} + +// Printf logs a message at level Info on the standard logger. +func Printf(format string, args ...interface{}) { + stdOut.Printf(format, args...) +} + +// Infof logs a message at level Info on the standard logger. +func Infof(format string, args ...interface{}) { + stdOut.Infof(format, args...) +} + +// Warnf logs a message at level Warn on the standard logger. +func Warnf(format string, args ...interface{}) { + stdOut.Warnf(format, args...) +} + +// Warningf logs a message at level Warn on the standard logger. +func Warningf(format string, args ...interface{}) { + stdOut.Warningf(format, args...) +} + +// Errorf logs a message at level Error on the standard logger. +func Errorf(format string, args ...interface{}) { + stdOut.Errorf(format, args...) +} + +// Panicf logs a message at level Panic on the standard logger. +func Panicf(format string, args ...interface{}) { + stdOut.Panicf(format, args...) +} + +// Fatalf logs a message at level Fatal on the standard logger then the process will exit with status set to 1. +func Fatalf(format string, args ...interface{}) { + stdOut.Fatalf(format, args...) +} + +// Traceln logs a message at level Trace on the standard logger. +func Traceln(args ...interface{}) { + stdOut.Traceln(args...) +} + +// Debugln logs a message at level Debug on the standard logger. +func Debugln(args ...interface{}) { + stdOut.Debugln(args...) +} + +// Println logs a message at level Info on the standard logger. +func Println(args ...interface{}) { + stdOut.Println(args...) +} + +// Infoln logs a message at level Info on the standard logger. +func Infoln(args ...interface{}) { + stdOut.Infoln(args...) +} + +// Warnln logs a message at level Warn on the standard logger. +func Warnln(args ...interface{}) { + stdOut.Warnln(args...) +} + +// Warningln logs a message at level Warn on the standard logger. +func Warningln(args ...interface{}) { + stdOut.Warningln(args...) +} + +// Errorln logs a message at level Error on the standard logger. +func Errorln(args ...interface{}) { + stdOut.Errorln(args...) +} + +// Panicln logs a message at level Panic on the standard logger. +func Panicln(args ...interface{}) { + stdOut.Panicln(args...) +} + +// Fatalln logs a message at level Fatal on the standard logger then the process will exit with status set to 1. +func Fatalln(args ...interface{}) { + stdOut.Fatalln(args...) +} + +// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key. +func WithError(err error) *logrus.Entry { + return stdOut.WithField(logrus.ErrorKey, err) +} + +// WithField creates an entry from the standard logger and adds a field to +// it. If you want multiple fields, use `WithFields`. +// +// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal +// or Panic on the Entry it returns. +func WithField(key string, value interface{}) *logrus.Entry { + return stdOut.WithField(key, value) +} + +// Logrus : implement log type Logrus struct { *logrus.Logger } -// Logger ... -var Logger = logrus.New() - -// GetEchoLogger for e.Logger -func GetEchoLogger() Logrus { - return Logrus{Logger} -} - -// Level returns logger level -func (l Logrus) Level() log.Lvl { - switch l.Logger.Level { - case logrus.DebugLevel: - return log.DEBUG - case logrus.WarnLevel: - return log.WARN - case logrus.ErrorLevel: - return log.ERROR - case logrus.InfoLevel: - return log.INFO - default: - l.Panic("Invalid level") +// GetEchoLogger for e.l +func NewLogger() Logrus { + logFilePath := "" + if dir, err := os.Getwd(); err == nil { + logFilePath = dir + "/logs/" + } + if err := os.MkdirAll(logFilePath, 0755); err != nil { + fmt.Println(err.Error()) + } + logFileName := "next-terminal.log" + //日志文件 + fileName := path.Join(logFilePath, logFileName) + if _, err := os.Stat(fileName); err != nil { + if _, err := os.Create(fileName); err != nil { + fmt.Println(err.Error()) + } } - return log.OFF -} - -// SetHeader is a stub to satisfy interface -// It's controlled by Logger -func (l Logrus) SetHeader(_ string) {} - -// SetPrefix It's controlled by Logger -func (l Logrus) SetPrefix(s string) {} - -// Prefix It's controlled by Logger -func (l Logrus) Prefix() string { - return "" -} - -// SetLevel set level to logger from given log.Lvl -func (l Logrus) SetLevel(lvl log.Lvl) { - switch lvl { - case log.DEBUG: - Logger.SetLevel(logrus.DebugLevel) - case log.WARN: - Logger.SetLevel(logrus.WarnLevel) - case log.ERROR: - Logger.SetLevel(logrus.ErrorLevel) - case log.INFO: - Logger.SetLevel(logrus.InfoLevel) - default: - l.Panic("Invalid level") + //实例化 + logger := logrus.New() + //设置输出 + logger.SetOutput(io.MultiWriter(&lumberjack.Logger{ + Filename: fileName, + MaxSize: 100, // megabytes + MaxBackups: 3, + MaxAge: 7, //days + Compress: true, // disabled by default + }, os.Stdout)) + logger.SetReportCaller(true) + //设置日志级别 + if config.GlobalCfg.Debug { + logger.SetLevel(logrus.DebugLevel) + } else { + logger.SetLevel(logrus.InfoLevel) } -} - -// Output logger output func -func (l Logrus) Output() io.Writer { - return l.Out -} - -// SetOutput change output, default os.Stdout -func (l Logrus) SetOutput(w io.Writer) { - Logger.SetOutput(w) -} - -// Printj print json log -func (l Logrus) Printj(j log.JSON) { - Logger.WithFields(logrus.Fields(j)).Print() -} - -// Debugj debug json log -func (l Logrus) Debugj(j log.JSON) { - Logger.WithFields(logrus.Fields(j)).Debug() -} - -// Infoj info json log -func (l Logrus) Infoj(j log.JSON) { - Logger.WithFields(logrus.Fields(j)).Info() -} - -// Warnj warning json log -func (l Logrus) Warnj(j log.JSON) { - Logger.WithFields(logrus.Fields(j)).Warn() -} - -// Errorj error json log -func (l Logrus) Errorj(j log.JSON) { - Logger.WithFields(logrus.Fields(j)).Error() -} - -// Fatalj fatal json log -func (l Logrus) Fatalj(j log.JSON) { - Logger.WithFields(logrus.Fields(j)).Fatal() -} - -// Panicj panic json log -func (l Logrus) Panicj(j log.JSON) { - Logger.WithFields(logrus.Fields(j)).Panic() -} - -// Print string log -func (l Logrus) Print(i ...interface{}) { - Logger.Print(i[0].(string)) -} - -// Debug string log -func (l Logrus) Debug(i ...interface{}) { - Logger.Debug(i[0].(string)) -} - -// Info string log -func (l Logrus) Info(i ...interface{}) { - Logger.Info(i[0].(string)) -} - -// Warn string log -func (l Logrus) Warn(i ...interface{}) { - Logger.Warn(i[0].(string)) -} - -// Error string log -func (l Logrus) Error(i ...interface{}) { - Logger.Error(i[0].(string)) -} - -// Fatal string log -func (l Logrus) Fatal(i ...interface{}) { - Logger.Fatal(i[0].(string)) -} - -// Panic string log -func (l Logrus) Panic(i ...interface{}) { - Logger.Panic(i[0].(string)) + //设置日志格式 + logger.SetFormatter(new(Formatter)) + return Logrus{Logger: logger} } func logrusMiddlewareHandler(c echo.Context, next echo.HandlerFunc) error { + l := NewLogger() req := c.Request() res := c.Response() start := time.Now() @@ -158,25 +238,18 @@ func logrusMiddlewareHandler(c echo.Context, next echo.HandlerFunc) error { } stop := time.Now() - p := req.URL.Path - - bytesIn := req.Header.Get(echo.HeaderContentLength) - - Logger.WithFields(map[string]interface{}{ - "time_rfc3339": time.Now().Format(time.RFC3339), - "remote_ip": c.RealIP(), - "host": req.Host, - "uri": req.RequestURI, - "method": req.Method, - "path": p, - "referer": req.Referer(), - "user_agent": req.UserAgent(), - "status": res.Status, - "latency": strconv.FormatInt(stop.Sub(start).Nanoseconds()/1000, 10), - "latency_human": stop.Sub(start).String(), - "bytes_in": bytesIn, - "bytes_out": strconv.FormatInt(res.Size, 10), - }).Info("Handled request") + l.Debugf("%s %s %s %s %s %3d %s %13v %s %s", + c.RealIP(), + req.Host, + req.Method, + req.RequestURI, + req.URL.Path, + res.Status, + strconv.FormatInt(res.Size, 10), + stop.Sub(start).String(), + req.Referer(), + req.UserAgent(), + ) return nil } diff --git a/pkg/model/access_security.go b/pkg/model/access_security.go deleted file mode 100644 index 1453bc6..0000000 --- a/pkg/model/access_security.go +++ /dev/null @@ -1,83 +0,0 @@ -package model - -import ( - "next-terminal/pkg/global" -) - -type AccessSecurity struct { - ID string `json:"id"` - Rule string `json:"rule"` - IP string `json:"ip"` - Source string `json:"source"` - Priority int64 `json:"priority"` // 越小优先级越高 -} - -func (r *AccessSecurity) TableName() string { - return "access_securities" -} - -func FindAllAccessSecurities() (o []AccessSecurity, err error) { - db := global.DB - err = db.Order("priority asc").Find(&o).Error - return -} - -func FindPageSecurity(pageIndex, pageSize int, ip, rule, order, field string) (o []AccessSecurity, total int64, err error) { - t := AccessSecurity{} - db := global.DB.Table(t.TableName()) - dbCounter := global.DB.Table(t.TableName()) - - if len(ip) > 0 { - db = db.Where("ip like ?", "%"+ip+"%") - dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") - } - - if len(rule) > 0 { - db = db.Where("rule = ?", rule) - dbCounter = dbCounter.Where("rule = ?", rule) - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "descend" { - order = "desc" - } else { - order = "asc" - } - - if field == "ip" { - field = "ip" - } else if field == "rule" { - field = "rule" - } else { - field = "priority" - } - - err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]AccessSecurity, 0) - } - return -} - -func CreateNewSecurity(o *AccessSecurity) error { - return global.DB.Create(o).Error -} - -func UpdateSecurityById(o *AccessSecurity, id string) error { - o.ID = id - return global.DB.Updates(o).Error -} - -func DeleteSecurityById(id string) error { - - return global.DB.Where("id = ?", id).Delete(AccessSecurity{}).Error -} - -func FindSecurityById(id string) (o *AccessSecurity, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} diff --git a/pkg/model/asset.go b/pkg/model/asset.go deleted file mode 100644 index a119510..0000000 --- a/pkg/model/asset.go +++ /dev/null @@ -1,238 +0,0 @@ -package model - -import ( - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" - "strings" -) - -type Asset struct { - ID string `gorm:"primary_key " json:"id"` - Name string `json:"name"` - Protocol string `json:"protocol"` - IP string `json:"ip"` - Port int `json:"port"` - AccountType string `json:"accountType"` - Username string `json:"username"` - Password string `json:"password"` - CredentialId string `gorm:"index" json:"credentialId"` - PrivateKey string `json:"privateKey"` - Passphrase string `json:"passphrase"` - Description string `json:"description"` - Active bool `json:"active"` - Created utils.JsonTime `json:"created"` - Tags string `json:"tags"` - Owner string `gorm:"index" json:"owner"` -} - -type AssetVo struct { - ID string `json:"id"` - Name string `json:"name"` - IP string `json:"ip"` - Protocol string `json:"protocol"` - Port int `json:"port"` - Active bool `json:"active"` - Created utils.JsonTime `json:"created"` - Tags string `json:"tags"` - Owner string `json:"owner"` - OwnerName string `json:"ownerName"` - SharerCount int64 `json:"sharerCount"` -} - -func (r *Asset) TableName() string { - return "assets" -} - -func FindAllAsset() (o []Asset, err error) { - err = global.DB.Find(&o).Error - return -} - -func FindAssetByIds(assetIds []string) (o []Asset, err error) { - err = global.DB.Where("id in ?", assetIds).Find(&o).Error - return -} - -func FindAssetByProtocol(protocol string) (o []Asset, err error) { - err = global.DB.Where("protocol = ?", protocol).Find(&o).Error - return -} - -func FindAssetByProtocolAndIds(protocol string, assetIds []string) (o []Asset, err error) { - err = global.DB.Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error - return -} - -func FindAssetByConditions(protocol string, account User) (o []Asset, err error) { - db := global.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) - } - - if len(protocol) > 0 { - db = db.Where("assets.protocol = ?", protocol) - } - err = db.Find(&o).Error - return -} - -func FindPageAsset(pageIndex, pageSize int, name, protocol, tags string, account User, owner, sharer, userGroupId, ip, order, field string) (o []AssetVo, total int64, err error) { - db := global.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") - dbCounter := global.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) - dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) - - // 查询用户所在用户组列表 - userGroupIds, err := FindUserGroupIdsByUserId(account.ID) - if err != nil { - return nil, 0, err - } - - if userGroupIds != nil && len(userGroupIds) > 0 { - db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) - dbCounter = dbCounter.Or("resource_sharers.user_group_id in ?", userGroupIds) - } - } else { - if len(owner) > 0 { - db = db.Where("assets.owner = ?", owner) - dbCounter = dbCounter.Where("assets.owner = ?", owner) - } - if len(sharer) > 0 { - db = db.Where("resource_sharers.user_id = ?", sharer) - dbCounter = dbCounter.Where("resource_sharers.user_id = ?", sharer) - } - - if len(userGroupId) > 0 { - db = db.Where("resource_sharers.user_group_id = ?", userGroupId) - dbCounter = dbCounter.Where("resource_sharers.user_group_id = ?", userGroupId) - } - } - - if len(name) > 0 { - db = db.Where("assets.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("assets.name like ?", "%"+name+"%") - } - - if len(ip) > 0 { - db = db.Where("assets.ip like ?", "%"+ip+"%") - dbCounter = dbCounter.Where("assets.ip like ?", "%"+ip+"%") - } - - if len(protocol) > 0 { - db = db.Where("assets.protocol = ?", protocol) - dbCounter = dbCounter.Where("assets.protocol = ?", protocol) - } - - if len(tags) > 0 { - tagArr := strings.Split(tags, ",") - for i := range tagArr { - if global.Config.DB == "sqlite" { - db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") - dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") - } else { - db = db.Where("find_in_set(?, assets.tags)", tagArr[i]) - dbCounter = dbCounter.Where("find_in_set(?, assets.tags)", tagArr[i]) - } - } - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("assets." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - - if o == nil { - o = make([]AssetVo, 0) - } - return -} - -func CreateNewAsset(o *Asset) (err error) { - if err = global.DB.Create(o).Error; err != nil { - return err - } - return nil -} - -func FindAssetById(id string) (o Asset, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func UpdateAssetById(o *Asset, id string) { - o.ID = id - global.DB.Updates(o) -} - -func UpdateAssetActiveById(active bool, id string) { - sql := "update assets set active = ? where id = ?" - global.DB.Exec(sql, active, id) -} - -func DeleteAssetById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Asset{}).Error -} - -func CountAsset() (total int64, err error) { - err = global.DB.Find(&Asset{}).Count(&total).Error - return -} - -func CountAssetByUserId(userId string) (total int64, err error) { - db := global.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") - - db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId) - - // 查询用户所在用户组列表 - userGroupIds, err := FindUserGroupIdsByUserId(userId) - if err != nil { - return 0, err - } - - if userGroupIds != nil && len(userGroupIds) > 0 { - db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) - } - err = db.Find(&Asset{}).Count(&total).Error - return -} - -func FindAssetTags() (o []string, err error) { - var assets []Asset - err = global.DB.Not("tags = ?", "").Find(&assets).Error - if err != nil { - return nil, err - } - - o = make([]string, 0) - - for i := range assets { - if len(assets[i].Tags) == 0 { - continue - } - split := strings.Split(assets[i].Tags, ",") - - o = append(o, split...) - } - - return utils.Distinct(o), nil -} diff --git a/pkg/model/asset_attribute.go b/pkg/model/asset_attribute.go deleted file mode 100644 index de65465..0000000 --- a/pkg/model/asset_attribute.go +++ /dev/null @@ -1,117 +0,0 @@ -package model - -import ( - "fmt" - "github.com/labstack/echo/v4" - "gorm.io/gorm" - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/guacd" - "next-terminal/pkg/utils" -) - -type AssetAttribute struct { - Id string `gorm:"index" json:"id"` - AssetId string `gorm:"index" json:"assetId"` - Name string `gorm:"index" json:"name"` - Value string `json:"value"` -} - -func (r *AssetAttribute) TableName() string { - return "asset_attributes" -} - -var SSHParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, constant.SshMode} -var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs} -var VNCParameterNames = []string{guacd.ColorDepth, guacd.Cursor, guacd.SwapRedBlue, guacd.DestHost, guacd.DestPort} -var TelnetParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.UsernameRegex, guacd.PasswordRegex, guacd.LoginSuccessRegex, guacd.LoginFailureRegex} -var KubernetesParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.Namespace, guacd.Pod, guacd.Container, guacd.UesSSL, guacd.ClientCert, guacd.ClientKey, guacd.CaCert, guacd.IgnoreCert} - -func UpdateAssetAttributes(assetId, protocol string, m echo.Map) error { - var data []AssetAttribute - var parameterNames []string - switch protocol { - case "ssh": - parameterNames = SSHParameterNames - case "rdp": - parameterNames = RDPParameterNames - case "vnc": - parameterNames = VNCParameterNames - case "telnet": - parameterNames = TelnetParameterNames - case "kubernetes": - parameterNames = KubernetesParameterNames - - } - - for i := range parameterNames { - name := parameterNames[i] - if m[name] != nil && m[name] != "" { - data = append(data, genAttribute(assetId, name, m)) - } - } - - return global.DB.Transaction(func(tx *gorm.DB) error { - err := tx.Where("asset_id = ?", assetId).Delete(&AssetAttribute{}).Error - if err != nil { - return err - } - return tx.CreateInBatches(&data, len(data)).Error - }) -} - -func genAttribute(assetId, name string, m echo.Map) AssetAttribute { - value := fmt.Sprintf("%v", m[name]) - attribute := AssetAttribute{ - Id: utils.Sign([]string{assetId, name}), - AssetId: assetId, - Name: name, - Value: value, - } - return attribute -} - -func FindAssetAttributeByAssetId(assetId string) (o []AssetAttribute, err error) { - err = global.DB.Where("asset_id = ?", assetId).Find(&o).Error - if o == nil { - o = make([]AssetAttribute, 0) - } - return o, err -} - -func FindAssetAttrMapByAssetId(assetId string) (map[string]interface{}, error) { - asset, err := FindAssetById(assetId) - if err != nil { - return nil, err - } - attributes, err := FindAssetAttributeByAssetId(assetId) - if err != nil { - return nil, err - } - - var parameterNames []string - switch asset.Protocol { - case "ssh": - parameterNames = SSHParameterNames - case "rdp": - parameterNames = RDPParameterNames - case "vnc": - parameterNames = VNCParameterNames - case "telnet": - parameterNames = TelnetParameterNames - case "kubernetes": - parameterNames = KubernetesParameterNames - } - propertiesMap := FindAllPropertiesMap() - var attributeMap = make(map[string]interface{}) - for name := range propertiesMap { - if utils.Contains(parameterNames, name) { - attributeMap[name] = propertiesMap[name] - } - } - - for i := range attributes { - attributeMap[attributes[i].Name] = attributes[i].Value - } - return attributeMap, nil -} diff --git a/pkg/model/command.go b/pkg/model/command.go deleted file mode 100644 index be2390b..0000000 --- a/pkg/model/command.go +++ /dev/null @@ -1,95 +0,0 @@ -package model - -import ( - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" -) - -type Command struct { - ID string `gorm:"primary_key" json:"id"` - Name string `json:"name"` - Content string `json:"content"` - Created utils.JsonTime `json:"created"` - Owner string `gorm:"index" json:"owner"` -} - -type CommandVo struct { - ID string `gorm:"primary_key" json:"id"` - Name string `json:"name"` - Content string `json:"content"` - Created utils.JsonTime `json:"created"` - Owner string `json:"owner"` - OwnerName string `json:"ownerName"` - SharerCount int64 `json:"sharerCount"` -} - -func (r *Command) TableName() string { - return "commands" -} - -func FindPageCommand(pageIndex, pageSize int, name, content, order, field string, account User) (o []CommandVo, total int64, err error) { - - db := global.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") - dbCounter := global.DB.Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) - dbCounter = dbCounter.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) - } - - if len(name) > 0 { - db = db.Where("commands.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("commands.name like ?", "%"+name+"%") - } - - if len(content) > 0 { - db = db.Where("commands.content like ?", "%"+content+"%") - dbCounter = dbCounter.Where("commands.content like ?", "%"+content+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("commands." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - if o == nil { - o = make([]CommandVo, 0) - } - return -} - -func CreateNewCommand(o *Command) (err error) { - if err = global.DB.Create(o).Error; err != nil { - return err - } - return nil -} - -func FindCommandById(id string) (o Command, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func UpdateCommandById(o *Command, id string) { - o.ID = id - global.DB.Updates(o) -} - -func DeleteCommandById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Command{}).Error -} diff --git a/pkg/model/credential.go b/pkg/model/credential.go deleted file mode 100644 index c333cd7..0000000 --- a/pkg/model/credential.go +++ /dev/null @@ -1,131 +0,0 @@ -package model - -import ( - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" -) - -type Credential struct { - ID string `gorm:"primary_key" json:"id"` - Name string `json:"name"` - Type string `json:"type"` - Username string `json:"username"` - Password string `json:"password"` - PrivateKey string `json:"privateKey"` - Passphrase string `json:"passphrase"` - Created utils.JsonTime `json:"created"` - Owner string `gorm:"index" json:"owner"` -} - -func (r *Credential) TableName() string { - return "credentials" -} - -type CredentialVo struct { - ID string `json:"id"` - Name string `json:"name"` - Type string `json:"type"` - Username string `json:"username"` - Created utils.JsonTime `json:"created"` - Owner string `json:"owner"` - OwnerName string `json:"ownerName"` - SharerCount int64 `json:"sharerCount"` -} - -type CredentialSimpleVo struct { - ID string `json:"id"` - Name string `json:"name"` -} - -func FindAllCredential(account User) (o []CredentialSimpleVo, err error) { - db := global.DB.Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") - if account.Type == constant.TypeUser { - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", account.ID, account.ID) - } - err = db.Find(&o).Error - return -} - -func FindPageCredential(pageIndex, pageSize int, name, order, field string, account User) (o []CredentialVo, total int64, err error) { - db := global.DB.Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") - dbCounter := global.DB.Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") - - if constant.TypeUser == account.Type { - owner := account.ID - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) - dbCounter = dbCounter.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) - } - - if len(name) > 0 { - db = db.Where("credentials.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("credentials.name like ?", "%"+name+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("credentials." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - if o == nil { - o = make([]CredentialVo, 0) - } - return -} - -func CreateNewCredential(o *Credential) (err error) { - if err = global.DB.Create(o).Error; err != nil { - return err - } - return nil -} - -func FindCredentialById(id string) (o Credential, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func UpdateCredentialById(o *Credential, id string) { - o.ID = id - global.DB.Updates(o) -} - -func DeleteCredentialById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Credential{}).Error -} - -func CountCredential() (total int64, err error) { - err = global.DB.Find(&Credential{}).Count(&total).Error - return -} - -func CountCredentialByUserId(userId string) (total int64, err error) { - db := global.DB.Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") - - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId) - - // 查询用户所在用户组列表 - userGroupIds, err := FindUserGroupIdsByUserId(userId) - if err != nil { - return 0, err - } - - if userGroupIds != nil && len(userGroupIds) > 0 { - db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) - } - err = db.Find(&Credential{}).Count(&total).Error - return -} diff --git a/pkg/model/job_log.go b/pkg/model/job_log.go deleted file mode 100644 index fb7f923..0000000 --- a/pkg/model/job_log.go +++ /dev/null @@ -1,30 +0,0 @@ -package model - -import ( - "next-terminal/pkg/global" - "next-terminal/pkg/utils" -) - -type JobLog struct { - ID string `json:"id"` - Timestamp utils.JsonTime `json:"timestamp"` - JobId string `json:"jobId"` - Message string `json:"message"` -} - -func (r *JobLog) TableName() string { - return "job_logs" -} - -func CreateNewJobLog(o *JobLog) error { - return global.DB.Create(o).Error -} - -func FindJobLogs(jobId string) (o []JobLog, err error) { - err = global.DB.Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error - return -} - -func DeleteJobLogByJobId(jobId string) error { - return global.DB.Where("job_id = ?", jobId).Delete(JobLog{}).Error -} diff --git a/pkg/model/login_log.go b/pkg/model/login_log.go deleted file mode 100644 index 44f845a..0000000 --- a/pkg/model/login_log.go +++ /dev/null @@ -1,106 +0,0 @@ -package model - -import ( - "github.com/sirupsen/logrus" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" -) - -type LoginLog struct { - ID string `gorm:"primary_key" json:"id"` - UserId string `gorm:"index" json:"userId"` - ClientIP string `json:"clientIp"` - ClientUserAgent string `json:"clientUserAgent"` - LoginTime utils.JsonTime `json:"loginTime"` - LogoutTime utils.JsonTime `json:"logoutTime"` - Remember bool `json:"remember"` -} - -type LoginLogVo struct { - ID string `json:"id"` - UserId string `json:"userId"` - UserName string `json:"userName"` - ClientIP string `json:"clientIp"` - ClientUserAgent string `json:"clientUserAgent"` - LoginTime utils.JsonTime `json:"loginTime"` - LogoutTime utils.JsonTime `json:"logoutTime"` - Remember bool `json:"remember"` -} - -func (r *LoginLog) TableName() string { - return "login_logs" -} - -func FindPageLoginLog(pageIndex, pageSize int, userId, clientIp string) (o []LoginLogVo, total int64, err error) { - - db := global.DB.Table("login_logs").Select("login_logs.id,login_logs.user_id,login_logs.client_ip,login_logs.client_user_agent,login_logs.login_time, login_logs.logout_time, users.nickname as user_name").Joins("left join users on login_logs.user_id = users.id") - dbCounter := global.DB.Table("login_logs").Select("DISTINCT login_logs.id") - - if userId != "" { - db = db.Where("login_logs.user_id = ?", userId) - dbCounter = dbCounter.Where("login_logs.user_id = ?", userId) - } - - if clientIp != "" { - db = db.Where("login_logs.client_ip like ?", "%"+clientIp+"%") - dbCounter = dbCounter.Where("login_logs.client_ip like ?", "%"+clientIp+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - err = db.Order("login_logs.login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error - if o == nil { - o = make([]LoginLogVo, 0) - } - return -} - -func FindAliveLoginLogs() (o []LoginLog, err error) { - err = global.DB.Where("logout_time is null").Find(&o).Error - return -} - -func FindAliveLoginLogsByUserId(userId string) (o []LoginLog, err error) { - err = global.DB.Where("logout_time is null and user_id = ?", userId).Find(&o).Error - return -} - -func CreateNewLoginLog(o *LoginLog) (err error) { - return global.DB.Create(o).Error -} - -func DeleteLoginLogByIdIn(ids []string) (err error) { - return global.DB.Where("id in ?", ids).Delete(&LoginLog{}).Error -} - -func FindLoginLogById(id string) (o LoginLog, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func Logout(token string) (err error) { - - loginLog, err := FindLoginLogById(token) - if err != nil { - logrus.Warnf("登录日志「%v」获取失败", token) - return - } - - err = global.DB.Updates(&LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}).Error - if err != nil { - return err - } - - loginLogs, err := FindAliveLoginLogsByUserId(loginLog.UserId) - if err != nil { - return - } - - if len(loginLogs) == 0 { - err = UpdateUserOnline(false, loginLog.UserId) - } - return -} diff --git a/pkg/model/num.go b/pkg/model/num.go deleted file mode 100644 index 653c246..0000000 --- a/pkg/model/num.go +++ /dev/null @@ -1,25 +0,0 @@ -package model - -import ( - "next-terminal/pkg/global" -) - -type Num struct { - I string `gorm:"primary_key" json:"i"` -} - -func (r *Num) TableName() string { - return "nums" -} - -func FindAllTemp() (o []Num) { - if global.DB.Find(&o).Error != nil { - return nil - } - return -} - -func CreateNewTemp(o *Num) (err error) { - err = global.DB.Create(o).Error - return -} diff --git a/pkg/model/property.go b/pkg/model/property.go deleted file mode 100644 index 00febff..0000000 --- a/pkg/model/property.go +++ /dev/null @@ -1,89 +0,0 @@ -package model - -import ( - "github.com/jordan-wright/email" - "github.com/sirupsen/logrus" - "net/smtp" - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/guacd" -) - -type Property struct { - Name string `gorm:"primary_key" json:"name"` - Value string `json:"value"` -} - -func (r *Property) TableName() string { - return "properties" -} - -func FindAllProperties() (o []Property) { - if global.DB.Find(&o).Error != nil { - return nil - } - return -} - -func CreateNewProperty(o *Property) (err error) { - err = global.DB.Create(o).Error - return -} - -func UpdatePropertyByName(o *Property, name string) { - o.Name = name - global.DB.Updates(o) -} - -func FindPropertyByName(name string) (o Property, err error) { - err = global.DB.Where("name = ?", name).First(&o).Error - return -} - -func FindAllPropertiesMap() map[string]string { - properties := FindAllProperties() - propertyMap := make(map[string]string) - for i := range properties { - propertyMap[properties[i].Name] = properties[i].Value - } - return propertyMap -} - -func GetDrivePath() (string, error) { - property, err := FindPropertyByName(guacd.DrivePath) - if err != nil { - return "", err - } - return property.Value, nil -} - -func GetRecordingPath() (string, error) { - property, err := FindPropertyByName(guacd.RecordingPath) - if err != nil { - return "", err - } - return property.Value, nil -} - -func SendMail(to, subject, text string) { - propertiesMap := FindAllPropertiesMap() - host := propertiesMap[constant.MailHost] - port := propertiesMap[constant.MailPort] - username := propertiesMap[constant.MailUsername] - password := propertiesMap[constant.MailPassword] - - if host == "" || port == "" || username == "" || password == "" { - logrus.Debugf("邮箱信息不完整,跳过发送邮件。") - return - } - - e := email.NewEmail() - e.From = "Next Terminal <" + username + ">" - e.To = []string{to} - e.Subject = subject - e.Text = []byte(text) - err := e.Send(host+":"+port, smtp.PlainAuth("", username, password, host)) - if err != nil { - logrus.Errorf("邮件发送失败: %v", err.Error()) - } -} diff --git a/pkg/model/session.go b/pkg/model/session.go deleted file mode 100644 index cd7b296..0000000 --- a/pkg/model/session.go +++ /dev/null @@ -1,209 +0,0 @@ -package model - -import ( - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" - "os" - "path" - "time" -) - -type Session struct { - ID string `gorm:"primary_key" json:"id"` - Protocol string `json:"protocol"` - IP string `json:"ip"` - Port int `json:"port"` - ConnectionId string `json:"connectionId"` - AssetId string `gorm:"index" json:"assetId"` - Username string `json:"username"` - Password string `json:"password"` - Creator string `gorm:"index" json:"creator"` - ClientIP string `json:"clientIp"` - Width int `json:"width"` - Height int `json:"height"` - Status string `gorm:"index" json:"status"` - Recording string `json:"recording"` - PrivateKey string `json:"privateKey"` - Passphrase string `json:"passphrase"` - Code int `json:"code"` - Message string `json:"message"` - ConnectedTime utils.JsonTime `json:"connectedTime"` - DisconnectedTime utils.JsonTime `json:"disconnectedTime"` - Mode string `json:"mode"` -} - -func (r *Session) TableName() string { - return "sessions" -} - -type SessionVo struct { - ID string `json:"id"` - Protocol string `json:"protocol"` - IP string `json:"ip"` - Port int `json:"port"` - Username string `json:"username"` - ConnectionId string `json:"connectionId"` - AssetId string `json:"assetId"` - Creator string `json:"creator"` - ClientIP string `json:"clientIp"` - Width int `json:"width"` - Height int `json:"height"` - Status string `json:"status"` - Recording string `json:"recording"` - ConnectedTime utils.JsonTime `json:"connectedTime"` - DisconnectedTime utils.JsonTime `json:"disconnectedTime"` - AssetName string `json:"assetName"` - CreatorName string `json:"creatorName"` - Code int `json:"code"` - Message string `json:"message"` - Mode string `json:"mode"` -} - -func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []SessionVo, total int64, err error) { - - db := global.DB - var params []interface{} - - params = append(params, status) - - itemSql := "SELECT s.id,s.mode, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? " - countSql := "select count(*) from sessions as s where s.status = ? " - - if len(userId) > 0 { - itemSql += " and s.creator = ?" - countSql += " and s.creator = ?" - params = append(params, userId) - } - - if len(clientIp) > 0 { - itemSql += " and s.client_ip like ?" - countSql += " and s.client_ip like ?" - params = append(params, "%"+clientIp+"%") - } - - if len(assetId) > 0 { - itemSql += " and s.asset_id = ?" - countSql += " and s.asset_id = ?" - params = append(params, assetId) - } - - if len(protocol) > 0 { - itemSql += " and s.protocol = ?" - countSql += " and s.protocol = ?" - params = append(params, protocol) - } - - params = append(params, (pageIndex-1)*pageSize, pageSize) - itemSql += " order by s.connected_time desc LIMIT ?, ?" - - db.Raw(countSql, params...).Scan(&total) - - err = db.Raw(itemSql, params...).Scan(&results).Error - - if results == nil { - results = make([]SessionVo, 0) - } - return -} - -func FindSessionByStatus(status string) (o []Session, err error) { - err = global.DB.Where("status = ?", status).Find(&o).Error - return -} - -func FindSessionByStatusIn(statuses []string) (o []Session, err error) { - err = global.DB.Where("status in ?", statuses).Find(&o).Error - return -} - -func FindOutTimeSessions(dayLimit int) (o []Session, err error) { - limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour) - err = global.DB.Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error - return -} - -func CreateNewSession(o *Session) (err error) { - err = global.DB.Create(o).Error - return -} - -func FindSessionById(id string) (o Session, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func FindSessionByConnectionId(connectionId string) (o Session, err error) { - err = global.DB.Where("connection_id = ?", connectionId).First(&o).Error - return -} - -func UpdateSessionById(o *Session, id string) error { - o.ID = id - return global.DB.Updates(o).Error -} - -func UpdateSessionWindowSizeById(width, height int, id string) error { - session := Session{} - session.Width = width - session.Height = height - - return UpdateSessionById(&session, id) -} - -func DeleteSessionById(id string) error { - return global.DB.Where("id = ?", id).Delete(&Session{}).Error -} - -func DeleteSessionByIds(sessionIds []string) error { - drivePath, err := GetRecordingPath() - if err != nil { - return err - } - for i := range sessionIds { - if err := os.RemoveAll(path.Join(drivePath, sessionIds[i])); err != nil { - return err - } - if err := DeleteSessionById(sessionIds[i]); err != nil { - return err - } - } - return nil -} - -func DeleteSessionByStatus(status string) { - global.DB.Where("status = ?", status).Delete(&Session{}) -} - -func CountOnlineSession() (total int64, err error) { - err = global.DB.Where("status = ?", constant.Connected).Find(&Session{}).Count(&total).Error - return -} - -type D struct { - Day string `json:"day"` - Count int `json:"count"` - Protocol string `json:"protocol"` -} - -func CountSessionByDay(day int) (results []D, err error) { - - today := time.Now().Format("20060102") - sql := "select t1.`day`, count(t2.id) as count\nfrom (\n SELECT @date := DATE_ADD(@date, INTERVAL - 1 DAY) day\n FROM (SELECT @date := DATE_ADD('" + today + "', INTERVAL + 1 DAY) FROM nums) as t0\n LIMIT ?\n )\n as t1\n left join\n (\n select DATE(s.connected_time) as day, s.id\n from sessions as s\n WHERE protocol = ? and DATE(connected_time) <= '" + today + "'\n AND DATE(connected_time) > DATE_SUB('" + today + "', INTERVAL ? DAY)\n ) as t2 on t1.day = t2.day\ngroup by t1.day" - - protocols := []string{"rdp", "ssh", "vnc", "telnet"} - - for i := range protocols { - var result []D - err = global.DB.Raw(sql, day, protocols[i], day).Scan(&result).Error - if err != nil { - return nil, err - } - for j := range result { - result[j].Protocol = protocols[i] - } - results = append(results, result...) - } - - return -} diff --git a/pkg/model/user-attribute.go b/pkg/model/user-attribute.go deleted file mode 100644 index c97a8f4..0000000 --- a/pkg/model/user-attribute.go +++ /dev/null @@ -1,23 +0,0 @@ -package model - -import "next-terminal/pkg/global" - -type UserAttribute struct { - Id string `gorm:"index" json:"id"` - UserId string `gorm:"index" json:"userId"` - Name string `gorm:"index" json:"name"` - Value string `json:"value"` -} - -func (r *UserAttribute) TableName() string { - return "user_attributes" -} - -func CreateUserAttribute(o *UserAttribute) error { - return global.DB.Create(o).Error -} - -func FindUserAttributeByUserId(userId string) (o []UserAttribute, err error) { - err = global.DB.Where("user_id = ?", userId).Find(&o).Error - return o, err -} diff --git a/pkg/model/user-group-member.go b/pkg/model/user-group-member.go deleted file mode 100644 index 88ea276..0000000 --- a/pkg/model/user-group-member.go +++ /dev/null @@ -1,18 +0,0 @@ -package model - -import "next-terminal/pkg/global" - -type UserGroupMember struct { - ID string `gorm:"primary_key" json:"name"` - UserId string `gorm:"index" json:"userId"` - UserGroupId string `gorm:"index" json:"userGroupId"` -} - -func (r *UserGroupMember) TableName() string { - return "user_group_members" -} - -func FindUserGroupMembersByUserGroupId(id string) (o []string, err error) { - err = global.DB.Table("user_group_members").Select("user_id").Where("user_group_id = ?", id).Find(&o).Error - return -} diff --git a/pkg/model/user-group.go b/pkg/model/user-group.go deleted file mode 100644 index 8a33eab..0000000 --- a/pkg/model/user-group.go +++ /dev/null @@ -1,135 +0,0 @@ -package model - -import ( - "gorm.io/gorm" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" -) - -type UserGroup struct { - ID string `gorm:"primary_key" json:"id"` - Name string `json:"name"` - Created utils.JsonTime `json:"created"` -} - -type UserGroupVo struct { - ID string `json:"id"` - Name string `json:"name"` - Created utils.JsonTime `json:"created"` - AssetCount int64 `json:"assetCount"` -} - -func (r *UserGroup) TableName() string { - return "user_groups" -} - -func FindPageUserGroup(pageIndex, pageSize int, name, order, field string) (o []UserGroupVo, total int64, err error) { - db := global.DB.Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id") - dbCounter := global.DB.Table("user_groups") - if len(name) > 0 { - db = db.Where("user_groups.name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("name like ?", "%"+name+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order("user_groups." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]UserGroupVo, 0) - } - return -} - -func CreateNewUserGroup(o *UserGroup, members []string) (err error) { - return global.DB.Transaction(func(tx *gorm.DB) error { - err = tx.Create(o).Error - if err != nil { - return err - } - - if members != nil { - userGroupId := o.ID - err = AddUserGroupMembers(tx, members, userGroupId) - if err != nil { - return err - } - } - return err - }) -} - -func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error { - for i := range userIds { - userId := userIds[i] - _, err := FindUserById(userId) - if err != nil { - return err - } - - userGroupMember := UserGroupMember{ - ID: utils.Sign([]string{userGroupId, userId}), - UserId: userId, - UserGroupId: userGroupId, - } - err = tx.Create(&userGroupMember).Error - if err != nil { - return err - } - } - return nil -} - -func FindUserGroupById(id string) (o UserGroup, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func FindUserGroupIdsByUserId(userId string) (o []string, err error) { - // 先查询用户所在的用户 - err = global.DB.Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error - return -} - -func UpdateUserGroupById(o *UserGroup, members []string, id string) error { - return global.DB.Transaction(func(tx *gorm.DB) error { - o.ID = id - err := tx.Updates(o).Error - if err != nil { - return err - } - - err = tx.Where("user_group_id = ?", id).Delete(&UserGroupMember{}).Error - if err != nil { - return err - } - if members != nil { - userGroupId := o.ID - err = AddUserGroupMembers(tx, members, userGroupId) - if err != nil { - return err - } - } - return err - }) - -} - -func DeleteUserGroupById(id string) { - global.DB.Where("id = ?", id).Delete(&UserGroup{}) - global.DB.Where("user_group_id = ?", id).Delete(&UserGroupMember{}) -} diff --git a/pkg/model/user.go b/pkg/model/user.go deleted file mode 100644 index 7ed9028..0000000 --- a/pkg/model/user.go +++ /dev/null @@ -1,149 +0,0 @@ -package model - -import ( - "next-terminal/pkg/global" - "next-terminal/pkg/utils" - "reflect" -) - -type User struct { - ID string `gorm:"primary_key" json:"id"` - Username string `gorm:"index" json:"username"` - Password string `json:"password"` - Nickname string `json:"nickname"` - TOTPSecret string `json:"-"` - Online bool `json:"online"` - Enabled bool `json:"enabled"` - Created utils.JsonTime `json:"created"` - Type string `json:"type"` - Mail string `json:"mail"` -} - -type UserVo struct { - ID string `json:"id"` - Username string `json:"username"` - Nickname string `json:"nickname"` - TOTPSecret string `json:"totpSecret"` - Mail string `json:"mail"` - Online bool `json:"online"` - Enabled bool `json:"enabled"` - Created utils.JsonTime `json:"created"` - Type string `json:"type"` - SharerAssetCount int64 `json:"sharerAssetCount"` -} - -func (r *User) TableName() string { - return "users" -} - -func (r *User) IsEmpty() bool { - return reflect.DeepEqual(r, User{}) -} - -func FindAllUser() (o []User) { - if global.DB.Find(&o).Error != nil { - return nil - } - return -} - -func FindPageUser(pageIndex, pageSize int, username, nickname, mail, order, field string) (o []UserVo, total int64, err error) { - db := global.DB.Table("users").Select("users.id,users.username,users.nickname,users.mail,users.online,users.enabled,users.created,users.type, count(resource_sharers.user_id) as sharer_asset_count, users.totp_secret").Joins("left join resource_sharers on users.id = resource_sharers.user_id and resource_sharers.resource_type = 'asset'").Group("users.id") - dbCounter := global.DB.Table("users") - if len(username) > 0 { - db = db.Where("users.username like ?", "%"+username+"%") - dbCounter = dbCounter.Where("username like ?", "%"+username+"%") - } - - if len(nickname) > 0 { - db = db.Where("users.nickname like ?", "%"+nickname+"%") - dbCounter = dbCounter.Where("nickname like ?", "%"+nickname+"%") - } - - if len(mail) > 0 { - db = db.Where("users.mail like ?", "%"+mail+"%") - dbCounter = dbCounter.Where("mail like ?", "%"+mail+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "username" { - field = "username" - } else if field == "nickname" { - field = "nickname" - } else { - field = "created" - } - - err = db.Order("users." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]UserVo, 0) - } - - for i := 0; i < len(o); i++ { - if o[i].TOTPSecret == "" || o[i].TOTPSecret == "-" { - o[i].TOTPSecret = "0" - } else { - o[i].TOTPSecret = "1" - } - } - return -} - -func CreateNewUser(o *User) (err error) { - err = global.DB.Create(o).Error - return -} - -func FindUserById(id string) (o User, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func FindUserByIdIn(ids []string) (o []User, err error) { - err = global.DB.Where("id in ?", ids).First(&o).Error - return -} - -func FindUserByUsername(username string) (o User, err error) { - err = global.DB.Where("username = ?", username).First(&o).Error - return -} - -func UpdateUserById(o *User, id string) { - o.ID = id - global.DB.Updates(o) -} - -func UpdateUserOnline(online bool, id string) (err error) { - sql := "update users set online = ? where id = ?" - err = global.DB.Exec(sql, online, id).Error - return -} - -func FindOnlineUsers() (o []User, err error) { - err = global.DB.Where("online = ?", true).Find(&o).Error - return -} - -func DeleteUserById(id string) { - global.DB.Where("id = ?", id).Delete(&User{}) - // 删除用户组中的用户关系 - global.DB.Where("user_id = ?", id).Delete(&UserGroupMember{}) - // 删除用户分享到的资产 - global.DB.Where("user_id = ?", id).Delete(&ResourceSharer{}) -} - -func CountOnlineUser() (total int64, err error) { - err = global.DB.Where("online = ?", true).Find(&User{}).Count(&total).Error - return -} diff --git a/pkg/model/job.go b/pkg/service/job.go similarity index 52% rename from pkg/model/job.go rename to pkg/service/job.go index 7f76888..d923ebc 100644 --- a/pkg/model/job.go +++ b/pkg/service/job.go @@ -1,123 +1,41 @@ -package model +package service import ( "encoding/json" "errors" "fmt" - "github.com/robfig/cron/v3" - "github.com/sirupsen/logrus" - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/term" - "next-terminal/pkg/utils" "strings" "time" + + "next-terminal/pkg/constant" + "next-terminal/pkg/global" + "next-terminal/pkg/log" + "next-terminal/pkg/term" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" + + "github.com/robfig/cron/v3" ) -type Job struct { - ID string `gorm:"primary_key" json:"id"` - CronJobId int `json:"cronJobId"` - Name string `json:"name"` - Func string `json:"func"` - Cron string `json:"cron"` - Mode string `json:"mode"` - ResourceIds string `json:"resourceIds"` - Status string `json:"status"` - Metadata string `json:"metadata"` - Created utils.JsonTime `json:"created"` - Updated utils.JsonTime `json:"updated"` +type JobService struct { + jobRepository *repository.JobRepository + jobLogRepository *repository.JobLogRepository + assetRepository *repository.AssetRepository + credentialRepository *repository.CredentialRepository } -func (r *Job) TableName() string { - return "jobs" +func NewJobService(jobRepository *repository.JobRepository, jobLogRepository *repository.JobLogRepository, assetRepository *repository.AssetRepository, credentialRepository *repository.CredentialRepository) *JobService { + return &JobService{jobRepository: jobRepository, jobLogRepository: jobLogRepository, assetRepository: assetRepository, credentialRepository: credentialRepository} } -func FindPageJob(pageIndex, pageSize int, name, status, order, field string) (o []Job, total int64, err error) { - job := Job{} - db := global.DB.Table(job.TableName()) - dbCounter := global.DB.Table(job.TableName()) - - if len(name) > 0 { - db = db.Where("name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("name like ?", "%"+name+"%") - } - - if len(status) > 0 { - db = db.Where("status = ?", status) - dbCounter = dbCounter.Where("status = ?", status) - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "ascend" { - order = "asc" - } else { - order = "desc" - } - - if field == "name" { - field = "name" - } else if field == "created" { - field = "created" - } else { - field = "updated" - } - - err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]Job, 0) - } - return -} - -func FindJobByFunc(function string) (o []Job, err error) { - db := global.DB - err = db.Where("func = ?", function).Find(&o).Error - return -} - -func CreateNewJob(o *Job) (err error) { - - if o.Status == constant.JobStatusRunning { - j, err := getJob(o) - if err != nil { - return err - } - jobId, err := global.Cron.AddJob(o.Cron, j) - if err != nil { - return err - } - o.CronJobId = int(jobId) - } - - return global.DB.Create(o).Error -} - -func UpdateJobById(o *Job, id string) (err error) { - if o.Status == constant.JobStatusRunning { - return errors.New("请先停止定时任务后再修改") - } - - o.ID = id - return global.DB.Updates(o).Error -} - -func UpdateJonUpdatedById(id string) (err error) { - err = global.DB.Updates(Job{ID: id, Updated: utils.NowJsonTime()}).Error - return -} - -func ChangeJobStatusById(id, status string) (err error) { - var job Job - err = global.DB.Where("id = ?", id).First(&job).Error +func (r JobService) ChangeStatusById(id, status string) error { + job, err := r.jobRepository.FindById(id) if err != nil { return err } if status == constant.JobStatusRunning { - j, err := getJob(&job) + j, err := getJob(&job, &r) if err != nil { return err } @@ -125,53 +43,25 @@ func ChangeJobStatusById(id, status string) (err error) { if err != nil { return err } - logrus.Debugf("开启计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) + log.Debugf("开启计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) - return global.DB.Updates(Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)}).Error + jobForUpdate := model.Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)} + + return r.jobRepository.UpdateById(&jobForUpdate) } else { global.Cron.Remove(cron.EntryID(job.CronJobId)) - logrus.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) - return global.DB.Updates(Job{ID: id, Status: constant.JobStatusNotRunning}).Error + log.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(global.Cron.Entries())) + jobForUpdate := model.Job{ID: id, Status: constant.JobStatusNotRunning} + return r.jobRepository.UpdateById(&jobForUpdate) } } -func ExecJobById(id string) (err error) { - job, err := FindJobById(id) - if err != nil { - return err - } - j, err := getJob(&job) - if err != nil { - return err - } - j.Run() - return nil -} - -func FindJobById(id string) (o Job, err error) { - err = global.DB.Where("id = ?", id).First(&o).Error - return -} - -func DeleteJobById(id string) error { - job, err := FindJobById(id) - if err != nil { - return err - } - if job.Status == constant.JobStatusRunning { - if err := ChangeJobStatusById(id, constant.JobStatusNotRunning); err != nil { - return err - } - } - return global.DB.Where("id = ?", id).Delete(Job{}).Error -} - -func getJob(j *Job) (job cron.Job, err error) { +func getJob(j *model.Job, jobService *JobService) (job cron.Job, err error) { switch j.Func { case constant.FuncCheckAssetStatusJob: - job = CheckAssetStatusJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata} + job = CheckAssetStatusJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata, jobService: jobService} case constant.FuncShellJob: - job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata} + job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata, jobService: jobService} default: return nil, errors.New("未识别的任务") } @@ -183,6 +73,7 @@ type CheckAssetStatusJob struct { Mode string ResourceIds string Metadata string + jobService *JobService } func (r CheckAssetStatusJob) Run() { @@ -190,14 +81,14 @@ func (r CheckAssetStatusJob) Run() { return } - var assets []Asset + var assets []model.Asset if r.Mode == constant.JobModeAll { - assets, _ = FindAllAsset() + assets, _ = r.jobService.assetRepository.FindAll() } else { - assets, _ = FindAssetByIds(strings.Split(r.ResourceIds, ",")) + assets, _ = r.jobService.assetRepository.FindByIds(strings.Split(r.ResourceIds, ",")) } - if assets == nil || len(assets) == 0 { + if len(assets) == 0 { return } @@ -210,8 +101,8 @@ func (r CheckAssetStatusJob) Run() { elapsed := time.Since(t1) msg := fmt.Sprintf("资产「%v」存活状态检测完成,存活「%v」,耗时「%v」", asset.Name, active, elapsed) - UpdateAssetActiveById(active, asset.ID) - logrus.Infof(msg) + _ = r.jobService.assetRepository.UpdateActiveById(active, asset.ID) + log.Infof(msg) msgChan <- msg }() } @@ -221,15 +112,15 @@ func (r CheckAssetStatusJob) Run() { message += <-msgChan + "\n" } - _ = UpdateJonUpdatedById(r.ID) - jobLog := JobLog{ + _ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID) + jobLog := model.JobLog{ ID: utils.UUID(), JobId: r.ID, Timestamp: utils.NowJsonTime(), Message: message, } - _ = CreateNewJobLog(&jobLog) + _ = r.jobService.jobLogRepository.Create(&jobLog) } type ShellJob struct { @@ -237,6 +128,7 @@ type ShellJob struct { Mode string ResourceIds string Metadata string + jobService *JobService } type MetadataShell struct { @@ -248,27 +140,27 @@ func (r ShellJob) Run() { return } - var assets []Asset + var assets []model.Asset if r.Mode == constant.JobModeAll { - assets, _ = FindAssetByProtocol("ssh") + assets, _ = r.jobService.assetRepository.FindByProtocol("ssh") } else { - assets, _ = FindAssetByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ",")) + assets, _ = r.jobService.assetRepository.FindByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ",")) } - if assets == nil || len(assets) == 0 { + if len(assets) == 0 { return } var metadataShell MetadataShell err := json.Unmarshal([]byte(r.Metadata), &metadataShell) if err != nil { - logrus.Errorf("JSON数据解析失败 %v", err) + log.Errorf("JSON数据解析失败 %v", err) return } msgChan := make(chan string) for i := range assets { - asset, err := FindAssetById(assets[i].ID) + asset, err := r.jobService.assetRepository.FindById(assets[i].ID) if err != nil { msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询数据异常「%v」", assets[i].Name, err.Error()) return @@ -284,7 +176,7 @@ func (r ShellJob) Run() { ) if asset.AccountType == "credential" { - credential, err := FindCredentialById(asset.CredentialId) + credential, err := r.jobService.credentialRepository.FindById(asset.CredentialId) if err != nil { msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询授权凭证数据异常「%v」", assets[i].Name, err.Error()) return @@ -308,10 +200,10 @@ func (r ShellJob) Run() { var msg string if err != nil { msg = fmt.Sprintf("资产「%v」Shell执行失败,返回值「%v」,耗时「%v」", asset.Name, err.Error(), elapsed) - logrus.Infof(msg) + log.Infof(msg) } else { msg = fmt.Sprintf("资产「%v」Shell执行成功,返回值「%v」,耗时「%v」", asset.Name, result, elapsed) - logrus.Infof(msg) + log.Infof(msg) } msgChan <- msg @@ -323,15 +215,15 @@ func (r ShellJob) Run() { message += <-msgChan + "\n" } - _ = UpdateJonUpdatedById(r.ID) - jobLog := JobLog{ + _ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID) + jobLog := model.JobLog{ ID: utils.UUID(), JobId: r.ID, Timestamp: utils.NowJsonTime(), Message: message, } - _ = CreateNewJobLog(&jobLog) + _ = r.jobService.jobLogRepository.Create(&jobLog) } func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) { @@ -352,3 +244,77 @@ func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, } return string(combo), nil } + +func (r JobService) ExecJobById(id string) (err error) { + job, err := r.jobRepository.FindById(id) + if err != nil { + return err + } + j, err := getJob(&job, &r) + if err != nil { + return err + } + j.Run() + return nil +} + +func (r JobService) InitJob() error { + jobs, _ := r.jobRepository.FindByFunc(constant.FuncCheckAssetStatusJob) + if len(jobs) == 0 { + job := model.Job{ + ID: utils.UUID(), + Name: "资产状态检测", + Func: constant.FuncCheckAssetStatusJob, + Cron: "0 0 0/1 * * ?", + Mode: constant.JobModeAll, + Status: constant.JobStatusRunning, + Created: utils.NowJsonTime(), + Updated: utils.NowJsonTime(), + } + if err := r.jobRepository.Create(&job); err != nil { + return err + } + log.Debugf("创建计划任务「%v」cron「%v」", job.Name, job.Cron) + } else { + for i := range jobs { + if jobs[i].Status == constant.JobStatusRunning { + err := r.ChangeStatusById(jobs[i].ID, constant.JobStatusRunning) + if err != nil { + return err + } + log.Debugf("启动计划任务「%v」cron「%v」", jobs[i].Name, jobs[i].Cron) + } + } + } + return nil +} + +func (r JobService) Create(o *model.Job) (err error) { + + if o.Status == constant.JobStatusRunning { + j, err := getJob(o, &r) + if err != nil { + return err + } + jobId, err := global.Cron.AddJob(o.Cron, j) + if err != nil { + return err + } + o.CronJobId = int(jobId) + } + + return r.jobRepository.Create(o) +} + +func (r JobService) DeleteJobById(id string) error { + job, err := r.jobRepository.FindById(id) + if err != nil { + return err + } + if job.Status == constant.JobStatusRunning { + if err := r.ChangeStatusById(id, constant.JobStatusNotRunning); err != nil { + return err + } + } + return r.jobRepository.DeleteJobById(id) +} diff --git a/pkg/service/mail.go b/pkg/service/mail.go new file mode 100644 index 0000000..1c1a7b5 --- /dev/null +++ b/pkg/service/mail.go @@ -0,0 +1,42 @@ +package service + +import ( + "net/smtp" + + "next-terminal/pkg/constant" + "next-terminal/pkg/log" + "next-terminal/server/repository" + + "github.com/jordan-wright/email" +) + +type MailService struct { + propertyRepository *repository.PropertyRepository +} + +func NewMailService(propertyRepository *repository.PropertyRepository) *MailService { + return &MailService{propertyRepository: propertyRepository} +} + +func (r MailService) SendMail(to, subject, text string) { + propertiesMap := r.propertyRepository.FindAllMap() + host := propertiesMap[constant.MailHost] + port := propertiesMap[constant.MailPort] + username := propertiesMap[constant.MailUsername] + password := propertiesMap[constant.MailPassword] + + if host == "" || port == "" || username == "" || password == "" { + log.Debugf("邮箱信息不完整,跳过发送邮件。") + return + } + + e := email.NewEmail() + e.From = "Next Terminal <" + username + ">" + e.To = []string{to} + e.Subject = subject + e.Text = []byte(text) + err := e.Send(host+":"+port, smtp.PlainAuth("", username, password, host)) + if err != nil { + log.Errorf("邮件发送失败: %v", err.Error()) + } +} diff --git a/pkg/service/num.go b/pkg/service/num.go new file mode 100644 index 0000000..04cf7ba --- /dev/null +++ b/pkg/service/num.go @@ -0,0 +1,31 @@ +package service + +import ( + "strconv" + + "next-terminal/server/model" + "next-terminal/server/repository" +) + +type NumService struct { + numRepository *repository.NumRepository +} + +func NewNumService(numRepository *repository.NumRepository) *NumService { + return &NumService{numRepository: numRepository} +} + +func (r NumService) InitNums() error { + nums, err := r.numRepository.FindAll() + if err != nil { + return err + } + if len(nums) == 0 { + for i := 0; i <= 30; i++ { + if err := r.numRepository.Create(&model.Num{I: strconv.Itoa(i)}); err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/handle/runner.go b/pkg/service/property.go similarity index 52% rename from pkg/handle/runner.go rename to pkg/service/property.go index f9ae0c9..8894fb6 100644 --- a/pkg/handle/runner.go +++ b/pkg/service/property.go @@ -1,95 +1,31 @@ -package handle +package service import ( - "github.com/sirupsen/logrus" - "next-terminal/pkg/constant" - "next-terminal/pkg/guacd" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "os" - "strconv" - "time" + + "next-terminal/pkg/guacd" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" ) -func RunTicker() { - - // 每隔一小时删除一次未使用的会话信息 - unUsedSessionTicker := time.NewTicker(time.Minute * 60) - go func() { - for range unUsedSessionTicker.C { - sessions, _ := model.FindSessionByStatusIn([]string{constant.NoConnect, constant.Connecting}) - if sessions != nil && len(sessions) > 0 { - now := time.Now() - for i := range sessions { - if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 { - _ = model.DeleteSessionById(sessions[i].ID) - s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port) - logrus.Infof("会话「%v」ID「%v」超过1小时未打开,已删除。", s, sessions[i].ID) - } - } - } - } - }() - - // 每日凌晨删除超过时长限制的会话 - timeoutSessionTicker := time.NewTicker(time.Hour * 24) - go func() { - for range timeoutSessionTicker.C { - property, err := model.FindPropertyByName("session-saved-limit") - if err != nil { - return - } - if property.Value == "" || property.Value == "-" { - return - } - limit, err := strconv.Atoi(property.Value) - if err != nil { - return - } - sessions, err := model.FindOutTimeSessions(limit) - if err != nil { - return - } - - if sessions != nil && len(sessions) > 0 { - var sessionIds []string - for i := range sessions { - sessionIds = append(sessionIds, sessions[i].ID) - } - err := model.DeleteSessionByIds(sessionIds) - if err != nil { - logrus.Errorf("删除离线会话失败 %v", err) - } - } - } - }() +type PropertyService struct { + propertyRepository *repository.PropertyRepository } -func RunDataFix() { - sessions, _ := model.FindSessionByStatus(constant.Connected) - if sessions == nil { - return - } - - for i := range sessions { - session := model.Session{ - Status: constant.Disconnected, - DisconnectedTime: utils.NowJsonTime(), - } - - _ = model.UpdateSessionById(&session, sessions[i].ID) - } +func NewPropertyService(propertyRepository *repository.PropertyRepository) *PropertyService { + return &PropertyService{propertyRepository: propertyRepository} } -func InitProperties() error { - propertyMap := model.FindAllPropertiesMap() +func (r PropertyService) InitProperties() error { + propertyMap := r.propertyRepository.FindAllMap() if len(propertyMap[guacd.Host]) == 0 { property := model.Property{ Name: guacd.Host, Value: "127.0.0.1", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -99,7 +35,7 @@ func InitProperties() error { Name: guacd.Port, Value: "4822", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -109,7 +45,7 @@ func InitProperties() error { Name: guacd.EnableRecording, Value: "true", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -125,7 +61,7 @@ func InitProperties() error { return err } } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -135,7 +71,7 @@ func InitProperties() error { Name: guacd.CreateRecordingPath, Value: "true", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -145,7 +81,7 @@ func InitProperties() error { Name: guacd.DriveName, Value: "File-System", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -163,7 +99,7 @@ func InitProperties() error { return err } } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -173,7 +109,7 @@ func InitProperties() error { Name: guacd.FontName, Value: "menlo", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -183,7 +119,7 @@ func InitProperties() error { Name: guacd.FontSize, Value: "12", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -193,7 +129,7 @@ func InitProperties() error { Name: guacd.ColorScheme, Value: "gray-black", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -203,7 +139,7 @@ func InitProperties() error { Name: guacd.EnableDrive, Value: "true", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -213,7 +149,7 @@ func InitProperties() error { Name: guacd.EnableWallpaper, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -223,7 +159,7 @@ func InitProperties() error { Name: guacd.EnableTheming, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -233,7 +169,7 @@ func InitProperties() error { Name: guacd.EnableFontSmoothing, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -243,7 +179,7 @@ func InitProperties() error { Name: guacd.EnableFullWindowDrag, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -253,7 +189,7 @@ func InitProperties() error { Name: guacd.EnableDesktopComposition, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -263,7 +199,7 @@ func InitProperties() error { Name: guacd.EnableMenuAnimations, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -273,7 +209,7 @@ func InitProperties() error { Name: guacd.DisableBitmapCaching, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -283,7 +219,7 @@ func InitProperties() error { Name: guacd.DisableOffscreenCaching, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } @@ -293,7 +229,7 @@ func InitProperties() error { Name: guacd.DisableGlyphCaching, Value: "false", } - if err := model.CreateNewProperty(&property); err != nil { + if err := r.propertyRepository.Create(&property); err != nil { return err } } diff --git a/pkg/service/session.go b/pkg/service/session.go new file mode 100644 index 0000000..ae6b267 --- /dev/null +++ b/pkg/service/session.go @@ -0,0 +1,35 @@ +package service + +import ( + "next-terminal/pkg/constant" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" +) + +type SessionService struct { + sessionRepository *repository.SessionRepository +} + +func NewSessionService(sessionRepository *repository.SessionRepository) *SessionService { + return &SessionService{sessionRepository: sessionRepository} +} + +func (r SessionService) FixSessionState() error { + sessions, err := r.sessionRepository.FindByStatus(constant.Connected) + if err != nil { + return err + } + + if len(sessions) > 0 { + for i := range sessions { + session := model.Session{ + Status: constant.Disconnected, + DisconnectedTime: utils.NowJsonTime(), + } + + _ = r.sessionRepository.UpdateById(&session, sessions[i].ID) + } + } + return nil +} diff --git a/pkg/service/user.go b/pkg/service/user.go new file mode 100644 index 0000000..b57f67c --- /dev/null +++ b/pkg/service/user.go @@ -0,0 +1,106 @@ +package service + +import ( + "next-terminal/pkg/constant" + "next-terminal/pkg/log" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" +) + +type UserService struct { + userRepository *repository.UserRepository + loginLogRepository *repository.LoginLogRepository +} + +func NewUserService(userRepository *repository.UserRepository, loginLogRepository *repository.LoginLogRepository) *UserService { + return &UserService{userRepository: userRepository, loginLogRepository: loginLogRepository} +} + +func (r UserService) InitUser() (err error) { + + users := r.userRepository.FindAll() + + if len(users) == 0 { + initPassword := "admin" + var pass []byte + if pass, err = utils.Encoder.Encode([]byte(initPassword)); err != nil { + return err + } + + user := model.User{ + ID: utils.UUID(), + Username: "admin", + Password: string(pass), + Nickname: "超级管理员", + Type: constant.TypeAdmin, + Created: utils.NowJsonTime(), + } + if err := r.userRepository.Create(&user); err != nil { + return err + } + log.Infof("初始用户创建成功,账号:「%v」密码:「%v」", user.Username, initPassword) + } else { + for i := range users { + // 修正默认用户类型为管理员 + if users[i].Type == "" { + user := model.User{ + Type: constant.TypeAdmin, + ID: users[i].ID, + } + if err := r.userRepository.Update(&user); err != nil { + return err + } + log.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID) + } + } + } + return nil +} + +func (r UserService) FixedUserOnlineState() error { + // 修正用户登录状态 + onlineUsers, err := r.userRepository.FindOnlineUsers() + if err != nil { + return err + } + if len(onlineUsers) > 0 { + for i := range onlineUsers { + logs, err := r.loginLogRepository.FindAliveLoginLogsByUserId(onlineUsers[i].ID) + if err != nil { + return err + } + if len(logs) == 0 { + if err := r.userRepository.UpdateOnline(onlineUsers[i].ID, false); err != nil { + return err + } + } + } + } + return nil +} + +func (r UserService) Logout(token string) (err error) { + + loginLog, err := r.loginLogRepository.FindById(token) + if err != nil { + log.Warnf("登录日志「%v」获取失败", token) + return + } + + loginLogForUpdate := &model.LoginLog{LogoutTime: utils.NowJsonTime(), ID: token} + err = r.loginLogRepository.Update(loginLogForUpdate) + if err != nil { + return err + } + + loginLogs, err := r.loginLogRepository.FindAliveLoginLogsByUserId(loginLog.UserId) + if err != nil { + return + } + + if len(loginLogs) == 0 { + err = r.userRepository.UpdateOnline(loginLog.UserId, false) + } + return +} diff --git a/pkg/task/ticker.go b/pkg/task/ticker.go new file mode 100644 index 0000000..c7dab7c --- /dev/null +++ b/pkg/task/ticker.go @@ -0,0 +1,73 @@ +package task + +import ( + "strconv" + "time" + + "next-terminal/pkg/constant" + "next-terminal/pkg/log" + "next-terminal/server/repository" +) + +type Ticker struct { + sessionRepository *repository.SessionRepository + propertyRepository *repository.PropertyRepository +} + +func NewTicker(sessionRepository *repository.SessionRepository, propertyRepository *repository.PropertyRepository) *Ticker { + return &Ticker{sessionRepository: sessionRepository, propertyRepository: propertyRepository} +} + +func (t *Ticker) SetupTicker() { + + // 每隔一小时删除一次未使用的会话信息 + unUsedSessionTicker := time.NewTicker(time.Minute * 60) + go func() { + for range unUsedSessionTicker.C { + sessions, _ := t.sessionRepository.FindByStatusIn([]string{constant.NoConnect, constant.Connecting}) + if len(sessions) > 0 { + now := time.Now() + for i := range sessions { + if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 { + _ = t.sessionRepository.DeleteById(sessions[i].ID) + s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port) + log.Infof("会话「%v」ID「%v」超过1小时未打开,已删除。", s, sessions[i].ID) + } + } + } + } + }() + + // 每日凌晨删除超过时长限制的会话 + timeoutSessionTicker := time.NewTicker(time.Hour * 24) + go func() { + for range timeoutSessionTicker.C { + property, err := t.propertyRepository.FindByName("session-saved-limit") + if err != nil { + return + } + if property.Value == "" || property.Value == "-" { + return + } + limit, err := strconv.Atoi(property.Value) + if err != nil { + return + } + sessions, err := t.sessionRepository.FindOutTimeSessions(limit) + if err != nil { + return + } + + if len(sessions) > 0 { + var sessionIds []string + for i := range sessions { + sessionIds = append(sessionIds, sessions[i].ID) + } + err := t.sessionRepository.DeleteByIds(sessionIds) + if err != nil { + log.Errorf("删除离线会话失败 %v", err) + } + } + } + }() +} diff --git a/pkg/term/next_terminal.go b/pkg/term/next_terminal.go index e7e1212..4970698 100644 --- a/pkg/term/next_terminal.go +++ b/pkg/term/next_terminal.go @@ -1,9 +1,10 @@ package term import ( + "io" + "github.com/pkg/sftp" "golang.org/x/crypto/ssh" - "io" ) type NextTerminal struct { diff --git a/pkg/term/recording.go b/pkg/term/recording.go index cd35efd..020269f 100644 --- a/pkg/term/recording.go +++ b/pkg/term/recording.go @@ -2,9 +2,10 @@ package term import ( "encoding/json" - "next-terminal/pkg/utils" "os" "time" + + "next-terminal/server/utils" ) type Env struct { diff --git a/pkg/term/ssh.go b/pkg/term/ssh.go index f27b6d2..3ac5aa8 100644 --- a/pkg/term/ssh.go +++ b/pkg/term/ssh.go @@ -2,8 +2,9 @@ package term import ( "fmt" - "golang.org/x/crypto/ssh" "time" + + "golang.org/x/crypto/ssh" ) func NewSshClient(ip string, port int, username, password, privateKey, passphrase string) (*ssh.Client, error) { diff --git a/pkg/term/test/test_ssh.go b/pkg/term/test/test_ssh.go index fabefdd..533eb3e 100644 --- a/pkg/term/test/test_ssh.go +++ b/pkg/term/test/test_ssh.go @@ -6,6 +6,8 @@ import ( "os" "time" + "next-terminal/pkg/log" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" ) @@ -29,7 +31,7 @@ func main() { client, err := ssh.Dial("tcp", "172.16.101.32:22", sshConfig) if err != nil { - fmt.Println(err) + log.Error(err) } defer client.Close() @@ -68,7 +70,7 @@ func (t *SSHTerminal) updateTerminalSize() { continue } - t.Session.WindowChange(currTermHeight, currTermWidth) + err = t.Session.WindowChange(currTermHeight, currTermWidth) if err != nil { fmt.Printf("Unable to send window-change reqest: %s.", err) continue @@ -86,9 +88,9 @@ func (t *SSHTerminal) interactiveSession() error { defer func() { if t.exitMsg == "" { - fmt.Fprintln(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822)) + log.Info(os.Stdout, "the connection was closed on the remote side on ", time.Now().Format(time.RFC822)) } else { - fmt.Fprintln(os.Stdout, t.exitMsg) + log.Info(os.Stdout, t.exitMsg) } }() diff --git a/playground/docker-compose.yml b/playground/docker-compose.yml new file mode 100644 index 0000000..392f982 --- /dev/null +++ b/playground/docker-compose.yml @@ -0,0 +1,49 @@ +version: '3.3' +services: + mysql: + image: mysql:8.0 + container_name: mysql + environment: + MYSQL_DATABASE: next-terminal + MYSQL_USER: next-terminal + MYSQL_PASSWORD: next-terminal + MYSQL_ROOT_PASSWORD: next-terminal + volumes: + - ./data/mysql_data:/var/lib/mysql + ports: + - "3306:3306" + restart: + always + networks: + next-terminal: + ipv4_address: 172.77.77.2 + +# next-terminal: +# container_name: next-terminal +# image: "dushixiang/next-terminal:latest" +# environment: +# DB: "mysql" +# MYSQL_HOSTNAME: "mysql" +# MYSQL_PORT: 3306 +# MYSQL_USERNAME: "next-terminal" +# MYSQL_PASSWORD: "next-terminal" +# MYSQL_DATABASE: "next-terminal" +# ports: +# - "8088:8088" +# volumes: +# - ./drive:/usr/local/next-terminal/drive +# - ./recording:/usr/local/next-terminal/recording +# depends_on: +# - mysql +# networks: +# next-terminal: +# ipv4_address: 172.77.77.3 +# restart: +# always + +networks: + next-terminal: + ipam: + driver: default + config: + - subnet: "172.77.77.0/24" diff --git a/screenshot/access.png b/screenshot/access.png new file mode 100644 index 0000000..2befe3c Binary files /dev/null and b/screenshot/access.png differ diff --git a/screenshot/assets.png b/screenshot/assets.png index a124450..7d49195 100644 Binary files a/screenshot/assets.png and b/screenshot/assets.png differ diff --git a/screenshot/command.png b/screenshot/command.png index b082707..2bca616 100644 Binary files a/screenshot/command.png and b/screenshot/command.png differ diff --git a/screenshot/credential.png b/screenshot/credential.png new file mode 100644 index 0000000..9441bef Binary files /dev/null and b/screenshot/credential.png differ diff --git a/screenshot/cron.png b/screenshot/cron.png new file mode 100644 index 0000000..2a86f2f Binary files /dev/null and b/screenshot/cron.png differ diff --git a/screenshot/dashboard.png b/screenshot/dashboard.png new file mode 100644 index 0000000..4afe1bb Binary files /dev/null and b/screenshot/dashboard.png differ diff --git a/screenshot/offline_session.png b/screenshot/offline_session.png new file mode 100644 index 0000000..8252406 Binary files /dev/null and b/screenshot/offline_session.png differ diff --git a/screenshot/online_session.png b/screenshot/online_session.png new file mode 100644 index 0000000..b5dc452 Binary files /dev/null and b/screenshot/online_session.png differ diff --git a/screenshot/rdp.png b/screenshot/rdp.png index 90b55ef..b142f27 100644 Binary files a/screenshot/rdp.png and b/screenshot/rdp.png differ diff --git a/screenshot/ssh.png b/screenshot/ssh.png index cac04f5..57aceb9 100644 Binary files a/screenshot/ssh.png and b/screenshot/ssh.png differ diff --git a/screenshot/user_group.png b/screenshot/user_group.png new file mode 100644 index 0000000..c1699c6 Binary files /dev/null and b/screenshot/user_group.png differ diff --git a/pkg/api/account.go b/server/api/account.go similarity index 89% rename from pkg/api/account.go rename to server/api/account.go index 05c7b27..edc8b7d 100644 --- a/pkg/api/account.go +++ b/server/api/account.go @@ -1,13 +1,14 @@ package api import ( - "next-terminal/pkg/global" - "next-terminal/pkg/model" - "next-terminal/pkg/totp" - "next-terminal/pkg/utils" "strings" "time" + "next-terminal/pkg/global" + "next-terminal/pkg/totp" + "next-terminal/server/model" + "next-terminal/server/utils" + "github.com/labstack/echo/v4" ) @@ -39,13 +40,18 @@ type Authorization struct { User model.User } +// +//type UserServer struct { +// repository.UserRepository +//} + func LoginEndpoint(c echo.Context) error { var loginAccount LoginAccount if err := c.Bind(&loginAccount); err != nil { return err } - user, err := model.FindUserByUsername(loginAccount.Username) + user, err := userRepository.FindByUsername(loginAccount.Username) // 存储登录失败次数信息 loginFailCountKey := loginAccount.Username @@ -110,13 +116,14 @@ func LoginSuccess(c echo.Context, loginAccount LoginAccount, user model.User) (t Remember: authorization.Remember, } - if model.CreateNewLoginLog(&loginLog) != nil { + if loginLogRepository.Create(&loginLog) != nil { return "", err } // 修改登录状态 - model.UpdateUserById(&model.User{Online: true}, user.ID) - return token, nil + err = userRepository.Update(&model.User{Online: true, ID: user.ID}) + + return token, err } func BuildCacheKeyByToken(token string) string { @@ -146,7 +153,7 @@ func loginWithTotpEndpoint(c echo.Context) error { return Fail(c, -1, "登录失败次数过多,请稍后再试") } - user, err := model.FindUserByUsername(loginAccount.Username) + user, err := userRepository.FindByUsername(loginAccount.Username) if err != nil { count++ global.Cache.Set(loginFailCountKey, count, time.Minute*time.Duration(5)) @@ -177,7 +184,7 @@ func LogoutEndpoint(c echo.Context) error { token := GetToken(c) cacheKey := BuildCacheKeyByToken(token) global.Cache.Delete(cacheKey) - err := model.Logout(token) + err := userService.Logout(token) if err != nil { return err } @@ -201,9 +208,12 @@ func ConfirmTOTPEndpoint(c echo.Context) error { u := &model.User{ TOTPSecret: confirmTOTP.Secret, + ID: account.ID, } - model.UpdateUserById(u, account.ID) + if err := userRepository.Update(u); err != nil { + return err + } return Success(c, nil) } @@ -239,8 +249,11 @@ func ResetTOTPEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) u := &model.User{ TOTPSecret: "-", + ID: account.ID, + } + if err := userRepository.Update(u); err != nil { + return err } - model.UpdateUserById(u, account.ID) return Success(c, "") } @@ -265,9 +278,12 @@ func ChangePasswordEndpoint(c echo.Context) error { } u := &model.User{ Password: string(passwd), + ID: account.ID, } - model.UpdateUserById(u, account.ID) + if err := userRepository.Update(u); err != nil { + return err + } return LogoutEndpoint(c) } @@ -283,7 +299,7 @@ type AccountInfo struct { func InfoEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) - user, err := model.FindUserById(account.ID) + user, err := userRepository.FindById(account.ID) if err != nil { return err } diff --git a/server/api/api.go b/server/api/api.go new file mode 100644 index 0000000..358a906 --- /dev/null +++ b/server/api/api.go @@ -0,0 +1,76 @@ +package api + +import ( + "next-terminal/pkg/constant" + "next-terminal/pkg/global" + "next-terminal/server/model" + + "github.com/labstack/echo/v4" +) + +type H map[string]interface{} + +func Fail(c echo.Context, code int, message string) error { + return c.JSON(200, H{ + "code": code, + "message": message, + }) +} + +func FailWithData(c echo.Context, code int, message string, data interface{}) error { + return c.JSON(200, H{ + "code": code, + "message": message, + "data": data, + }) +} + +func Success(c echo.Context, data interface{}) error { + return c.JSON(200, H{ + "code": 1, + "message": "success", + "data": data, + }) +} + +func NotFound(c echo.Context, message string) error { + return c.JSON(200, H{ + "code": -1, + "message": message, + }) +} + +func GetToken(c echo.Context) string { + token := c.Request().Header.Get(Token) + if len(token) > 0 { + return token + } + return c.QueryParam(Token) +} + +func GetCurrentAccount(c echo.Context) (model.User, bool) { + token := GetToken(c) + cacheKey := BuildCacheKeyByToken(token) + get, b := global.Cache.Get(cacheKey) + if b { + return get.(Authorization).User, true + } + return model.User{}, false +} + +func HasPermission(c echo.Context, owner string) bool { + // 检测是否登录 + account, found := GetCurrentAccount(c) + if !found { + return false + } + // 检测是否为管理人员 + if constant.TypeAdmin == account.Type { + return true + } + // 检测是否为所有者 + if owner == account.ID { + return true + } + return false +} diff --git a/pkg/api/asset.go b/server/api/asset.go similarity index 79% rename from pkg/api/asset.go rename to server/api/asset.go index 0c8b77b..2e4e470 100644 --- a/pkg/api/asset.go +++ b/server/api/asset.go @@ -5,12 +5,14 @@ import ( "encoding/csv" "encoding/json" "errors" - "github.com/labstack/echo/v4" - "next-terminal/pkg/constant" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "strconv" "strings" + + "next-terminal/pkg/constant" + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) func AssetCreateEndpoint(c echo.Context) error { @@ -30,18 +32,18 @@ func AssetCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := model.CreateNewAsset(&item); err != nil { + if err := assetRepository.Create(&item); err != nil { return err } - if err := model.UpdateAssetAttributes(item.ID, item.Protocol, m); err != nil { + if err := assetRepository.UpdateAttributes(item.ID, item.Protocol, m); err != nil { return err } // 创建后自动检测资产是否存活 go func() { active := utils.Tcping(item.IP, item.Port) - model.UpdateAssetActiveById(active, item.ID) + _ = assetRepository.UpdateActiveById(active, item.ID) }() return Success(c, item) @@ -96,7 +98,7 @@ func AssetImportEndpoint(c echo.Context) error { Owner: account.ID, } - err := model.CreateNewAsset(&asset) + err := assetRepository.Create(&asset) if err != nil { errorCount++ m[strconv.Itoa(i)] = err.Error() @@ -105,7 +107,7 @@ func AssetImportEndpoint(c echo.Context) error { // 创建后自动检测资产是否存活 go func() { active := utils.Tcping(asset.IP, asset.Port) - model.UpdateAssetActiveById(active, asset.ID) + _ = assetRepository.UpdateActiveById(active, asset.ID) }() } } @@ -133,7 +135,7 @@ func AssetPagingEndpoint(c echo.Context) error { field := c.QueryParam("field") account, _ := GetCurrentAccount(c) - items, total, err := model.FindPageAsset(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) + items, total, err := assetRepository.Find(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) if err != nil { return err } @@ -147,7 +149,7 @@ func AssetPagingEndpoint(c echo.Context) error { func AssetAllEndpoint(c echo.Context) error { protocol := c.QueryParam("protocol") account, _ := GetCurrentAccount(c) - items, _ := model.FindAssetByConditions(protocol, account) + items, _ := assetRepository.FindByProtocolAndUser(protocol, account) return Success(c, items) } @@ -197,8 +199,10 @@ func AssetUpdateEndpoint(c echo.Context) error { item.Description = "-" } - model.UpdateAssetById(&item, id) - if err := model.UpdateAssetAttributes(id, item.Protocol, m); err != nil { + if err := assetRepository.UpdateById(&item, id); err != nil { + return err + } + if err := assetRepository.UpdateAttributes(id, item.Protocol, m); err != nil { return err } @@ -212,7 +216,7 @@ func AssetGetAttributeEndpoint(c echo.Context) error { return err } - attributeMap, err := model.FindAssetAttrMapByAssetId(assetId) + attributeMap, err := assetRepository.FindAssetAttrMapByAssetId(assetId) if err != nil { return err } @@ -227,7 +231,7 @@ func AssetUpdateAttributeEndpoint(c echo.Context) error { assetId := c.Param("id") protocol := c.QueryParam("protocol") - err := model.UpdateAssetAttributes(assetId, protocol, m) + err := assetRepository.UpdateAttributes(assetId, protocol, m) if err != nil { return err } @@ -241,11 +245,11 @@ func AssetDeleteEndpoint(c echo.Context) error { if err := PreCheckAssetPermission(c, split[i]); err != nil { return err } - if err := model.DeleteAssetById(split[i]); err != nil { + if err := assetRepository.DeleteById(split[i]); err != nil { return err } // 删除资产与用户的关系 - if err := model.DeleteResourceSharerByResourceId(split[i]); err != nil { + if err := resourceSharerRepository.DeleteResourceSharerByResourceId(split[i]); err != nil { return err } } @@ -260,10 +264,10 @@ func AssetGetEndpoint(c echo.Context) (err error) { } var item model.Asset - if item, err = model.FindAssetById(id); err != nil { + if item, err = assetRepository.FindById(id); err != nil { return err } - attributeMap, err := model.FindAssetAttrMapByAssetId(id) + attributeMap, err := assetRepository.FindAssetAttrMapByAssetId(id) if err != nil { return err } @@ -279,19 +283,21 @@ func AssetTcpingEndpoint(c echo.Context) (err error) { id := c.Param("id") var item model.Asset - if item, err = model.FindAssetById(id); err != nil { + if item, err = assetRepository.FindById(id); err != nil { return err } active := utils.Tcping(item.IP, item.Port) - model.UpdateAssetActiveById(active, item.ID) + if err := assetRepository.UpdateActiveById(active, item.ID); err != nil { + return err + } return Success(c, active) } func AssetTagsEndpoint(c echo.Context) (err error) { var items []string - if items, err = model.FindAssetTags(); err != nil { + if items, err = assetRepository.FindTags(); err != nil { return err } return Success(c, items) @@ -305,12 +311,14 @@ func AssetChangeOwnerEndpoint(c echo.Context) (err error) { } owner := c.QueryParam("owner") - model.UpdateAssetById(&model.Asset{Owner: owner}, id) + if err := assetRepository.UpdateById(&model.Asset{Owner: owner}, id); err != nil { + return err + } return Success(c, "") } func PreCheckAssetPermission(c echo.Context, id string) error { - item, err := model.FindAssetById(id) + item, err := assetRepository.FindById(id) if err != nil { return err } diff --git a/pkg/api/command.go b/server/api/command.go similarity index 74% rename from pkg/api/command.go rename to server/api/command.go index 78ba421..afffffc 100644 --- a/pkg/api/command.go +++ b/server/api/command.go @@ -2,11 +2,13 @@ package api import ( "errors" - "github.com/labstack/echo/v4" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "strconv" "strings" + + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) func CommandCreateEndpoint(c echo.Context) error { @@ -20,7 +22,7 @@ func CommandCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := model.CreateNewCommand(&item); err != nil { + if err := commandRepository.Create(&item); err != nil { return err } @@ -37,7 +39,7 @@ func CommandPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := model.FindPageCommand(pageIndex, pageSize, name, content, order, field, account) + items, total, err := commandRepository.Find(pageIndex, pageSize, name, content, order, field, account) if err != nil { return err } @@ -59,7 +61,9 @@ func CommandUpdateEndpoint(c echo.Context) error { return err } - model.UpdateCommandById(&item, id) + if err := commandRepository.UpdateById(&item, id); err != nil { + return err + } return Success(c, nil) } @@ -71,11 +75,11 @@ func CommandDeleteEndpoint(c echo.Context) error { if err := PreCheckCommandPermission(c, split[i]); err != nil { return err } - if err := model.DeleteCommandById(split[i]); err != nil { + if err := commandRepository.DeleteById(split[i]); err != nil { return err } // 删除资产与用户的关系 - if err := model.DeleteResourceSharerByResourceId(split[i]); err != nil { + if err := resourceSharerRepository.DeleteResourceSharerByResourceId(split[i]); err != nil { return err } } @@ -90,7 +94,7 @@ func CommandGetEndpoint(c echo.Context) (err error) { } var item model.Command - if item, err = model.FindCommandById(id); err != nil { + if item, err = commandRepository.FindById(id); err != nil { return err } return Success(c, item) @@ -104,12 +108,14 @@ func CommandChangeOwnerEndpoint(c echo.Context) (err error) { } owner := c.QueryParam("owner") - model.UpdateCommandById(&model.Command{Owner: owner}, id) + if err := commandRepository.UpdateById(&model.Command{Owner: owner}, id); err != nil { + return err + } return Success(c, "") } func PreCheckCommandPermission(c echo.Context, id string) error { - item, err := model.FindCommandById(id) + item, err := commandRepository.FindById(id) if err != nil { return err } diff --git a/pkg/api/credential.go b/server/api/credential.go similarity index 81% rename from pkg/api/credential.go rename to server/api/credential.go index f2ee42b..25caf1a 100644 --- a/pkg/api/credential.go +++ b/server/api/credential.go @@ -2,17 +2,19 @@ package api import ( "errors" - "github.com/labstack/echo/v4" - "next-terminal/pkg/constant" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "strconv" "strings" + + "next-terminal/pkg/constant" + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) func CredentialAllEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) - items, _ := model.FindAllCredential(account) + items, _ := credentialRepository.FindByUser(account) return Success(c, items) } func CredentialCreateEndpoint(c echo.Context) error { @@ -51,7 +53,7 @@ func CredentialCreateEndpoint(c echo.Context) error { return Fail(c, -1, "类型错误") } - if err := model.CreateNewCredential(&item); err != nil { + if err := credentialRepository.Create(&item); err != nil { return err } @@ -67,7 +69,7 @@ func CredentialPagingEndpoint(c echo.Context) error { field := c.QueryParam("field") account, _ := GetCurrentAccount(c) - items, total, err := model.FindPageCredential(pageIndex, pageSize, name, order, field, account) + items, total, err := credentialRepository.Find(pageIndex, pageSize, name, order, field, account) if err != nil { return err } @@ -115,7 +117,9 @@ func CredentialUpdateEndpoint(c echo.Context) error { return Fail(c, -1, "类型错误") } - model.UpdateCredentialById(&item, id) + if err := credentialRepository.UpdateById(&item, id); err != nil { + return err + } return Success(c, nil) } @@ -127,11 +131,11 @@ func CredentialDeleteEndpoint(c echo.Context) error { if err := PreCheckCredentialPermission(c, split[i]); err != nil { return err } - if err := model.DeleteCredentialById(split[i]); err != nil { + if err := credentialRepository.DeleteById(split[i]); err != nil { return err } // 删除资产与用户的关系 - if err := model.DeleteResourceSharerByResourceId(split[i]); err != nil { + if err := resourceSharerRepository.DeleteResourceSharerByResourceId(split[i]); err != nil { return err } } @@ -145,7 +149,7 @@ func CredentialGetEndpoint(c echo.Context) error { return err } - item, err := model.FindCredentialById(id) + item, err := credentialRepository.FindById(id) if err != nil { return err } @@ -165,12 +169,14 @@ func CredentialChangeOwnerEndpoint(c echo.Context) error { } owner := c.QueryParam("owner") - model.UpdateCredentialById(&model.Credential{Owner: owner}, id) + if err := credentialRepository.UpdateById(&model.Credential{Owner: owner}, id); err != nil { + return err + } return Success(c, "") } func PreCheckCredentialPermission(c echo.Context, id string) error { - item, err := model.FindCredentialById(id) + item, err := credentialRepository.FindById(id) if err != nil { return err } diff --git a/pkg/api/job.go b/server/api/job.go similarity index 74% rename from pkg/api/job.go rename to server/api/job.go index 7a406af..6f908f6 100644 --- a/pkg/api/job.go +++ b/server/api/job.go @@ -1,11 +1,13 @@ package api import ( - "github.com/labstack/echo/v4" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "strconv" "strings" + + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) func JobCreateEndpoint(c echo.Context) error { @@ -17,7 +19,7 @@ func JobCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := model.CreateNewJob(&item); err != nil { + if err := jobService.Create(&item); err != nil { return err } return Success(c, "") @@ -32,7 +34,7 @@ func JobPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := model.FindPageJob(pageIndex, pageSize, name, status, order, field) + items, total, err := jobRepository.Find(pageIndex, pageSize, name, status, order, field) if err != nil { return err } @@ -50,8 +52,8 @@ func JobUpdateEndpoint(c echo.Context) error { if err := c.Bind(&item); err != nil { return err } - - if err := model.UpdateJobById(&item, id); err != nil { + item.ID = id + if err := jobRepository.UpdateById(&item); err != nil { return err } @@ -61,7 +63,7 @@ func JobUpdateEndpoint(c echo.Context) error { func JobChangeStatusEndpoint(c echo.Context) error { id := c.Param("id") status := c.QueryParam("status") - if err := model.ChangeJobStatusById(id, status); err != nil { + if err := jobService.ChangeStatusById(id, status); err != nil { return err } return Success(c, "") @@ -69,7 +71,7 @@ func JobChangeStatusEndpoint(c echo.Context) error { func JobExecEndpoint(c echo.Context) error { id := c.Param("id") - if err := model.ExecJobById(id); err != nil { + if err := jobService.ExecJobById(id); err != nil { return err } return Success(c, "") @@ -81,7 +83,7 @@ func JobDeleteEndpoint(c echo.Context) error { split := strings.Split(ids, ",") for i := range split { jobId := split[i] - if err := model.DeleteJobById(jobId); err != nil { + if err := jobRepository.DeleteJobById(jobId); err != nil { return err } } @@ -92,7 +94,7 @@ func JobDeleteEndpoint(c echo.Context) error { func JobGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := model.FindJobById(id) + item, err := jobRepository.FindById(id) if err != nil { return err } @@ -103,7 +105,7 @@ func JobGetEndpoint(c echo.Context) error { func JobGetLogsEndpoint(c echo.Context) error { id := c.Param("id") - items, err := model.FindJobLogs(id) + items, err := jobLogRepository.FindByJobId(id) if err != nil { return err } @@ -113,7 +115,7 @@ func JobGetLogsEndpoint(c echo.Context) error { func JobDeleteLogsEndpoint(c echo.Context) error { id := c.Param("id") - if err := model.DeleteJobLogByJobId(id); err != nil { + if err := jobLogRepository.DeleteByJobId(id); err != nil { return err } return Success(c, "") diff --git a/pkg/api/login-log.go b/server/api/login-log.go similarity index 70% rename from pkg/api/login-log.go rename to server/api/login-log.go index a50c39e..822f701 100644 --- a/pkg/api/login-log.go +++ b/server/api/login-log.go @@ -1,11 +1,13 @@ package api import ( - "github.com/labstack/echo/v4" - "next-terminal/pkg/global" - "next-terminal/pkg/model" "strconv" "strings" + + "next-terminal/pkg/global" + "next-terminal/pkg/log" + + "github.com/labstack/echo/v4" ) func LoginLogPagingEndpoint(c echo.Context) error { @@ -14,7 +16,7 @@ func LoginLogPagingEndpoint(c echo.Context) error { userId := c.QueryParam("userId") clientIp := c.QueryParam("clientIp") - items, total, err := model.FindPageLoginLog(pageIndex, pageSize, userId, clientIp) + items, total, err := loginLogRepository.Find(pageIndex, pageSize, userId, clientIp) if err != nil { return err @@ -32,9 +34,11 @@ func LoginLogDeleteEndpoint(c echo.Context) error { for i := range split { token := split[i] global.Cache.Delete(token) - model.Logout(token) + if err := userService.Logout(token); err != nil { + log.WithError(err).Error("Cache Delete Failed") + } } - if err := model.DeleteLoginLogByIdIn(split); err != nil { + if err := loginLogRepository.DeleteByIdIn(split); err != nil { return err } diff --git a/pkg/api/middleware.go b/server/api/middleware.go similarity index 99% rename from pkg/api/middleware.go rename to server/api/middleware.go index 55d4635..431f170 100644 --- a/pkg/api/middleware.go +++ b/server/api/middleware.go @@ -2,14 +2,16 @@ package api import ( "fmt" - "github.com/labstack/echo/v4" "net" - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" "regexp" "strings" "time" + + "next-terminal/pkg/constant" + "next-terminal/pkg/global" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) func ErrorHandler(next echo.HandlerFunc) echo.HandlerFunc { diff --git a/server/api/monitor.go b/server/api/monitor.go new file mode 100644 index 0000000..8a068ee --- /dev/null +++ b/server/api/monitor.go @@ -0,0 +1,17 @@ +package api + +import ( + "github.com/labstack/echo/v4" +) + +// todo 监控 +func MonitorEndpoint(c echo.Context) (err error) { + //ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) + //if err != nil { + // log.Errorf("升级为WebSocket协议失败:%v", err.Error()) + // return err + //} + + return + +} diff --git a/pkg/api/overview.go b/server/api/overview.go similarity index 58% rename from pkg/api/overview.go rename to server/api/overview.go index 174a6a7..ed4fb76 100644 --- a/pkg/api/overview.go +++ b/server/api/overview.go @@ -1,9 +1,10 @@ package api import ( - "github.com/labstack/echo/v4" "next-terminal/pkg/constant" - "next-terminal/pkg/model" + "next-terminal/server/repository" + + "github.com/labstack/echo/v4" ) type Counter struct { @@ -23,15 +24,15 @@ func OverviewCounterEndPoint(c echo.Context) error { asset int64 ) if constant.TypeUser == account.Type { - countUser, _ = model.CountOnlineUser() - countOnlineSession, _ = model.CountOnlineSession() - credential, _ = model.CountCredentialByUserId(account.ID) - asset, _ = model.CountAssetByUserId(account.ID) + countUser, _ = userRepository.CountOnlineUser() + countOnlineSession, _ = sessionRepository.CountOnlineSession() + credential, _ = credentialRepository.CountByUserId(account.ID) + asset, _ = assetRepository.CountByUserId(account.ID) } else { - countUser, _ = model.CountOnlineUser() - countOnlineSession, _ = model.CountOnlineSession() - credential, _ = model.CountCredential() - asset, _ = model.CountAsset() + countUser, _ = userRepository.CountOnlineUser() + countOnlineSession, _ = sessionRepository.CountOnlineSession() + credential, _ = credentialRepository.Count() + asset, _ = assetRepository.Count() } counter := Counter{ User: countUser, @@ -45,11 +46,11 @@ func OverviewCounterEndPoint(c echo.Context) error { func OverviewSessionPoint(c echo.Context) (err error) { d := c.QueryParam("d") - var results []model.D + var results []repository.D if d == "m" { - results, err = model.CountSessionByDay(30) + results, err = sessionRepository.CountSessionByDay(30) } else { - results, err = model.CountSessionByDay(7) + results, err = sessionRepository.CountSessionByDay(7) } if err != nil { return err diff --git a/pkg/api/property.go b/server/api/property.go similarity index 68% rename from pkg/api/property.go rename to server/api/property.go index a7bfaba..f5e6310 100644 --- a/pkg/api/property.go +++ b/server/api/property.go @@ -3,13 +3,15 @@ package api import ( "errors" "fmt" + + "next-terminal/server/model" + "github.com/labstack/echo/v4" "gorm.io/gorm" - "next-terminal/pkg/model" ) func PropertyGetEndpoint(c echo.Context) error { - properties := model.FindAllPropertiesMap() + properties := propertyRepository.FindAllMap() return Success(c, properties) } @@ -30,13 +32,15 @@ func PropertyUpdateEndpoint(c echo.Context) error { Value: value, } - _, err := model.FindPropertyByName(key) + _, err := propertyRepository.FindByName(key) if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { - if err := model.CreateNewProperty(&property); err != nil { + if err := propertyRepository.Create(&property); err != nil { return err } } else { - model.UpdatePropertyByName(&property, key) + if err := propertyRepository.UpdateByName(&property, key); err != nil { + return err + } } } return Success(c, nil) diff --git a/pkg/api/resource-sharer.go b/server/api/resource-sharer.go similarity index 68% rename from pkg/api/resource-sharer.go rename to server/api/resource-sharer.go index da4551a..77a5201 100644 --- a/pkg/api/resource-sharer.go +++ b/server/api/resource-sharer.go @@ -2,7 +2,6 @@ package api import ( "github.com/labstack/echo/v4" - "next-terminal/pkg/model" ) type RU struct { @@ -20,7 +19,7 @@ type UR struct { func RSGetSharersEndPoint(c echo.Context) error { resourceId := c.QueryParam("resourceId") - userIds, err := model.FindUserIdsByResourceId(resourceId) + userIds, err := resourceSharerRepository.FindUserIdsByResourceId(resourceId) if err != nil { return err } @@ -33,7 +32,7 @@ func RSOverwriteSharersEndPoint(c echo.Context) error { return err } - if err := model.OverwriteUserIdsByResourceId(ur.ResourceId, ur.ResourceType, ur.UserIds); err != nil { + if err := resourceSharerRepository.OverwriteUserIdsByResourceId(ur.ResourceId, ur.ResourceType, ur.UserIds); err != nil { return err } @@ -46,7 +45,7 @@ func ResourceRemoveByUserIdAssignEndPoint(c echo.Context) error { return err } - if err := model.DeleteByUserIdAndResourceTypeAndResourceIdIn(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil { + if err := resourceSharerRepository.DeleteByUserIdAndResourceTypeAndResourceIdIn(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil { return err } @@ -59,7 +58,7 @@ func ResourceAddByUserIdAssignEndPoint(c echo.Context) error { return err } - if err := model.AddSharerResources(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil { + if err := resourceSharerRepository.AddSharerResources(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil { return err } diff --git a/pkg/api/routes.go b/server/api/routes.go similarity index 51% rename from pkg/api/routes.go rename to server/api/routes.go index b81396f..0a67e1b 100644 --- a/pkg/api/routes.go +++ b/server/api/routes.go @@ -1,24 +1,69 @@ package api import ( + "fmt" "net/http" - "next-terminal/pkg/constant" + "strings" + "time" + "next-terminal/pkg/global" "next-terminal/pkg/log" - "next-terminal/pkg/model" + "next-terminal/pkg/service" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "github.com/patrickmn/go-cache" + "gorm.io/driver/mysql" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" ) const Token = "X-Auth-Token" -func SetupRoutes() *echo.Echo { +var ( + userRepository *repository.UserRepository + userGroupRepository *repository.UserGroupRepository + resourceSharerRepository *repository.ResourceSharerRepository + assetRepository *repository.AssetRepository + credentialRepository *repository.CredentialRepository + propertyRepository *repository.PropertyRepository + commandRepository *repository.CommandRepository + sessionRepository *repository.SessionRepository + numRepository *repository.NumRepository + accessSecurityRepository *repository.AccessSecurityRepository + jobRepository *repository.JobRepository + jobLogRepository *repository.JobLogRepository + loginLogRepository *repository.LoginLogRepository + + jobService *service.JobService + propertyService *service.PropertyService + userService *service.UserService + sessionService *service.SessionService + mailService *service.MailService + numService *service.NumService +) + +func SetupRoutes(db *gorm.DB) *echo.Echo { + + InitRepository(db) + InitService() + + if err := InitDBData(); err != nil { + log.WithError(err).Error("初始化数据异常") + } + + if err := ReloadData(); err != nil { + return nil + } e := echo.New() e.HideBanner = true - e.Logger = log.GetEchoLogger() - + //e.Logger = log.GetEchoLogger() + e.Use(log.Hook()) e.File("/", "web/build/index.html") e.File("/asciinema.html", "web/build/asciinema.html") e.File("/asciinema-player.js", "web/build/asciinema-player.js") @@ -172,69 +217,138 @@ func SetupRoutes() *echo.Echo { return e } -type H map[string]interface{} +func ReloadData() error { + if err := ReloadAccessSecurity(); err != nil { + return err + } -func Fail(c echo.Context, code int, message string) error { - return c.JSON(200, H{ - "code": code, - "message": message, + if err := ReloadToken(); err != nil { + return err + } + return nil +} + +func InitRepository(db *gorm.DB) { + userRepository = repository.NewUserRepository(db) + userGroupRepository = repository.NewUserGroupRepository(db) + resourceSharerRepository = repository.NewResourceSharerRepository(db) + assetRepository = repository.NewAssetRepository(db) + credentialRepository = repository.NewCredentialRepository(db) + propertyRepository = repository.NewPropertyRepository(db) + commandRepository = repository.NewCommandRepository(db) + sessionRepository = repository.NewSessionRepository(db) + numRepository = repository.NewNumRepository(db) + accessSecurityRepository = repository.NewAccessSecurityRepository(db) + jobRepository = repository.NewJobRepository(db) + jobLogRepository = repository.NewJobLogRepository(db) + loginLogRepository = repository.NewLoginLogRepository(db) +} + +func InitService() { + jobService = service.NewJobService(jobRepository, jobLogRepository, assetRepository, credentialRepository) + propertyService = service.NewPropertyService(propertyRepository) + userService = service.NewUserService(userRepository, loginLogRepository) + sessionService = service.NewSessionService(sessionRepository) + mailService = service.NewMailService(propertyRepository) + numService = service.NewNumService(numRepository) +} + +func InitDBData() (err error) { + if err := propertyService.InitProperties(); err != nil { + return err + } + if err := numService.InitNums(); err != nil { + return err + } + if err := userService.InitUser(); err != nil { + return err + } + if err := jobService.InitJob(); err != nil { + return err + } + if err := userService.FixedUserOnlineState(); err != nil { + return err + } + if err := sessionService.FixSessionState(); err != nil { + return err + } + return nil +} + +func ResetPassword() error { + user, err := userRepository.FindByUsername(global.Config.ResetPassword) + if err != nil { + return err + } + password := "next-terminal" + passwd, err := utils.Encoder.Encode([]byte(password)) + if err != nil { + return err + } + u := &model.User{ + Password: string(passwd), + ID: user.ID, + } + if err := userRepository.Update(u); err != nil { + return err + } + log.Debugf("用户「%v」密码初始化为: %v", user.Username, password) + return nil +} + +func SetupCache() *cache.Cache { + // 配置缓存器 + mCache := cache.New(5*time.Minute, 10*time.Minute) + mCache.OnEvicted(func(key string, value interface{}) { + if strings.HasPrefix(key, Token) { + token := GetTokenFormCacheKey(key) + log.Debugf("用户Token「%v」过期", token) + err := userService.Logout(token) + if err != nil { + log.Errorf("退出登录失败 %v", err) + } + } }) + return mCache } -func FailWithData(c echo.Context, code int, message string, data interface{}) error { - return c.JSON(200, H{ - "code": code, - "message": message, - "data": data, - }) -} +func SetupDB() *gorm.DB { -func Success(c echo.Context, data interface{}) error { - return c.JSON(200, H{ - "code": 1, - "message": "success", - "data": data, - }) -} - -func NotFound(c echo.Context, message string) error { - return c.JSON(200, H{ - "code": -1, - "message": message, - }) -} - -func GetToken(c echo.Context) string { - token := c.Request().Header.Get(Token) - if len(token) > 0 { - return token + var logMode logger.Interface + if global.Config.Debug { + logMode = logger.Default.LogMode(logger.Info) + } else { + logMode = logger.Default.LogMode(logger.Silent) } - return c.QueryParam(Token) -} -func GetCurrentAccount(c echo.Context) (model.User, bool) { - token := GetToken(c) - cacheKey := BuildCacheKeyByToken(token) - get, b := global.Cache.Get(cacheKey) - if b { - return get.(Authorization).User, true + fmt.Printf("当前数据库模式为:%v\n", global.Config.DB) + var err error + var db *gorm.DB + if global.Config.DB == "mysql" { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", + global.Config.Mysql.Username, + global.Config.Mysql.Password, + global.Config.Mysql.Hostname, + global.Config.Mysql.Port, + global.Config.Mysql.Database, + ) + db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logMode, + }) + } else { + db, err = gorm.Open(sqlite.Open(global.Config.Sqlite.File), &gorm.Config{ + Logger: logMode, + }) } - return model.User{}, false -} -func HasPermission(c echo.Context, owner string) bool { - // 检测是否登录 - account, found := GetCurrentAccount(c) - if !found { - return false + if err != nil { + log.WithError(err).Panic("连接数据库异常") } - // 检测是否为管理人员 - if constant.TypeAdmin == account.Type { - return true + + if err := db.AutoMigrate(&model.User{}, &model.Asset{}, &model.AssetAttribute{}, &model.Session{}, &model.Command{}, + &model.Credential{}, &model.Property{}, &model.ResourceSharer{}, &model.UserGroup{}, &model.UserGroupMember{}, + &model.LoginLog{}, &model.Num{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}); err != nil { + log.WithError(err).Panic("初始化数据库表结构异常") } - // 检测是否为所有者 - if owner == account.ID { - return true - } - return false + return db } diff --git a/pkg/api/security.go b/server/api/security.go similarity index 78% rename from pkg/api/security.go rename to server/api/security.go index 23dccfc..1e09269 100644 --- a/pkg/api/security.go +++ b/server/api/security.go @@ -1,12 +1,14 @@ package api import ( - "github.com/labstack/echo/v4" - "next-terminal/pkg/global" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "strconv" "strings" + + "next-terminal/pkg/global" + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) func SecurityCreateEndpoint(c echo.Context) error { @@ -18,7 +20,7 @@ func SecurityCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Source = "管理员添加" - if err := model.CreateNewSecurity(&item); err != nil { + if err := accessSecurityRepository.Create(&item); err != nil { return err } // 更新内存中的安全规则 @@ -29,11 +31,11 @@ func SecurityCreateEndpoint(c echo.Context) error { } func ReloadAccessSecurity() error { - rules, err := model.FindAllAccessSecurities() + rules, err := accessSecurityRepository.FindAllAccessSecurities() if err != nil { return err } - if rules != nil && len(rules) > 0 { + if len(rules) > 0 { var securities []*global.Security for i := 0; i < len(rules); i++ { rule := global.Security{ @@ -56,7 +58,7 @@ func SecurityPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := model.FindPageSecurity(pageIndex, pageSize, ip, rule, order, field) + items, total, err := accessSecurityRepository.Find(pageIndex, pageSize, ip, rule, order, field) if err != nil { return err } @@ -75,7 +77,7 @@ func SecurityUpdateEndpoint(c echo.Context) error { return err } - if err := model.UpdateSecurityById(&item, id); err != nil { + if err := accessSecurityRepository.UpdateById(&item, id); err != nil { return err } // 更新内存中的安全规则 @@ -91,7 +93,7 @@ func SecurityDeleteEndpoint(c echo.Context) error { split := strings.Split(ids, ",") for i := range split { jobId := split[i] - if err := model.DeleteSecurityById(jobId); err != nil { + if err := accessSecurityRepository.DeleteById(jobId); err != nil { return err } } @@ -105,7 +107,7 @@ func SecurityDeleteEndpoint(c echo.Context) error { func SecurityGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := model.FindSecurityById(id) + item, err := accessSecurityRepository.FindById(id) if err != nil { return err } diff --git a/pkg/api/session.go b/server/api/session.go similarity index 86% rename from pkg/api/session.go rename to server/api/session.go index dad4282..9a91d83 100644 --- a/pkg/api/session.go +++ b/server/api/session.go @@ -4,21 +4,23 @@ import ( "bytes" "errors" "fmt" - "github.com/labstack/echo/v4" - "github.com/pkg/sftp" - "github.com/sirupsen/logrus" "io" "io/ioutil" "net/http" - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "os" "path" "strconv" "strings" "sync" + + "next-terminal/pkg/constant" + "next-terminal/pkg/global" + "next-terminal/pkg/log" + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" + "github.com/pkg/sftp" ) func SessionPagingEndpoint(c echo.Context) error { @@ -30,7 +32,7 @@ func SessionPagingEndpoint(c echo.Context) error { assetId := c.QueryParam("assetId") protocol := c.QueryParam("protocol") - items, total, err := model.FindPageSession(pageIndex, pageSize, status, userId, clientIp, assetId, protocol) + items, total, err := sessionRepository.Find(pageIndex, pageSize, status, userId, clientIp, assetId, protocol) if err != nil { return err @@ -65,7 +67,7 @@ func SessionPagingEndpoint(c echo.Context) error { func SessionDeleteEndpoint(c echo.Context) error { sessionIds := c.Param("id") split := strings.Split(sessionIds, ",") - err := model.DeleteSessionByIds(split) + err := sessionRepository.DeleteByIds(split) if err != nil { return err } @@ -81,7 +83,7 @@ func SessionConnectEndpoint(c echo.Context) error { session.Status = constant.Connected session.ConnectedTime = utils.NowJsonTime() - if err := model.UpdateSessionById(&session, sessionId); err != nil { + if err := sessionRepository.UpdateById(&session, sessionId); err != nil { return err } return Success(c, nil) @@ -104,17 +106,17 @@ func CloseSessionById(sessionId string, code int, reason string) { defer mutex.Unlock() observable, _ := global.Store.Get(sessionId) if observable != nil { - logrus.Debugf("会话%v创建者退出,原因:%v", sessionId, reason) + log.Debugf("会话%v创建者退出,原因:%v", sessionId, reason) observable.Subject.Close(code, reason) for i := 0; i < len(observable.Observers); i++ { observable.Observers[i].Close(code, reason) - logrus.Debugf("强制踢出会话%v的观察者", sessionId) + log.Debugf("强制踢出会话%v的观察者", sessionId) } } global.Store.Del(sessionId) - s, err := model.FindSessionById(sessionId) + s, err := sessionRepository.FindById(sessionId) if err != nil { return } @@ -125,7 +127,7 @@ func CloseSessionById(sessionId string, code int, reason string) { if s.Status == constant.Connecting { // 会话还未建立成功,无需保留数据 - _ = model.DeleteSessionById(sessionId) + _ = sessionRepository.DeleteById(sessionId) return } @@ -136,7 +138,7 @@ func CloseSessionById(sessionId string, code int, reason string) { session.Code = code session.Message = reason - _ = model.UpdateSessionById(&session, sessionId) + _ = sessionRepository.UpdateById(&session, sessionId) } func SessionResizeEndpoint(c echo.Context) error { @@ -152,7 +154,7 @@ func SessionResizeEndpoint(c echo.Context) error { intHeight, _ := strconv.Atoi(height) - if err := model.UpdateSessionWindowSizeById(intWidth, intHeight, sessionId); err != nil { + if err := sessionRepository.UpdateWindowSizeById(intWidth, intHeight, sessionId); err != nil { return err } return Success(c, "") @@ -172,7 +174,7 @@ func SessionCreateEndpoint(c echo.Context) error { if constant.TypeUser == user.Type { // 检测是否有访问权限 - assetIds, err := model.FindAssetIdsByUserId(user.ID) + assetIds, err := resourceSharerRepository.FindAssetIdsByUserId(user.ID) if err != nil { return err } @@ -182,7 +184,7 @@ func SessionCreateEndpoint(c echo.Context) error { } } - asset, err := model.FindAssetById(assetId) + asset, err := assetRepository.FindById(assetId) if err != nil { return err } @@ -204,7 +206,7 @@ func SessionCreateEndpoint(c echo.Context) error { } if asset.AccountType == "credential" { - credential, err := model.FindCredentialById(asset.CredentialId) + credential, err := credentialRepository.FindById(asset.CredentialId) if err != nil { return err } @@ -219,7 +221,7 @@ func SessionCreateEndpoint(c echo.Context) error { } } - if err := model.CreateNewSession(session); err != nil { + if err := sessionRepository.Create(session); err != nil { return err } @@ -228,7 +230,7 @@ func SessionCreateEndpoint(c echo.Context) error { func SessionUploadEndpoint(c echo.Context) error { sessionId := c.Param("id") - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { return err } @@ -263,7 +265,7 @@ func SessionUploadEndpoint(c echo.Context) error { n, err := src.Read(buf) if err != nil { if err != io.EOF { - logrus.Warnf("文件上传错误 %v", err) + log.Warnf("文件上传错误 %v", err) } else { break } @@ -278,7 +280,7 @@ func SessionUploadEndpoint(c echo.Context) error { return Fail(c, -1, ":) 您的IP已被记录,请去向管理员自首。") } - drivePath, err := model.GetDrivePath() + drivePath, err := propertyRepository.GetDrivePath() if err != nil { return err } @@ -302,7 +304,7 @@ func SessionUploadEndpoint(c echo.Context) error { func SessionDownloadEndpoint(c echo.Context) error { sessionId := c.Param("id") - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { return err } @@ -335,7 +337,7 @@ func SessionDownloadEndpoint(c echo.Context) error { SafetyRuleTrigger(c) return Fail(c, -1, ":) 您的IP已被记录,请去向管理员自首。") } - drivePath, err := model.GetDrivePath() + drivePath, err := propertyRepository.GetDrivePath() if err != nil { return err } @@ -357,7 +359,7 @@ type File struct { func SessionLsEndpoint(c echo.Context) error { sessionId := c.Param("id") - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { return err } @@ -379,7 +381,7 @@ func SessionLsEndpoint(c echo.Context) error { if tun.Subject.NextTerminal.SftpClient == nil { sftpClient, err := sftp.NewClient(tun.Subject.NextTerminal.SshClient) if err != nil { - logrus.Errorf("创建sftp客户端失败:%v", err.Error()) + log.Errorf("创建sftp客户端失败:%v", err.Error()) return err } tun.Subject.NextTerminal.SftpClient = sftpClient @@ -417,7 +419,7 @@ func SessionLsEndpoint(c echo.Context) error { SafetyRuleTrigger(c) return Fail(c, -1, ":) 您的IP已被记录,请去向管理员自首。") } - drivePath, err := model.GetDrivePath() + drivePath, err := propertyRepository.GetDrivePath() if err != nil { return err } @@ -448,7 +450,7 @@ func SessionLsEndpoint(c echo.Context) error { } func SafetyRuleTrigger(c echo.Context) { - logrus.Warnf("IP %v 尝试进行攻击,请ban掉此IP", c.RealIP()) + log.Warnf("IP %v 尝试进行攻击,请ban掉此IP", c.RealIP()) security := model.AccessSecurity{ ID: utils.UUID(), Source: "安全规则触发", @@ -456,12 +458,12 @@ func SafetyRuleTrigger(c echo.Context) { Rule: constant.AccessRuleReject, } - _ = model.CreateNewSecurity(&security) + _ = accessSecurityRepository.Create(&security) } func SessionMkDirEndpoint(c echo.Context) error { sessionId := c.Param("id") - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { return err } @@ -480,7 +482,7 @@ func SessionMkDirEndpoint(c echo.Context) error { SafetyRuleTrigger(c) return Fail(c, -1, ":) 您的IP已被记录,请去向管理员自首。") } - drivePath, err := model.GetDrivePath() + drivePath, err := propertyRepository.GetDrivePath() if err != nil { return err } @@ -496,7 +498,7 @@ func SessionMkDirEndpoint(c echo.Context) error { func SessionRmEndpoint(c echo.Context) error { sessionId := c.Param("id") - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { return err } @@ -541,7 +543,7 @@ func SessionRmEndpoint(c echo.Context) error { SafetyRuleTrigger(c) return Fail(c, -1, ":) 您的IP已被记录,请去向管理员自首。") } - drivePath, err := model.GetDrivePath() + drivePath, err := propertyRepository.GetDrivePath() if err != nil { return err } @@ -558,7 +560,7 @@ func SessionRmEndpoint(c echo.Context) error { func SessionRenameEndpoint(c echo.Context) error { sessionId := c.Param("id") - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { return err } @@ -582,7 +584,7 @@ func SessionRenameEndpoint(c echo.Context) error { SafetyRuleTrigger(c) return Fail(c, -1, ":) 您的IP已被记录,请去向管理员自首。") } - drivePath, err := model.GetDrivePath() + drivePath, err := propertyRepository.GetDrivePath() if err != nil { return err } @@ -598,7 +600,7 @@ func SessionRenameEndpoint(c echo.Context) error { func SessionRecordingEndpoint(c echo.Context) error { sessionId := c.Param("id") - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { return err } @@ -610,6 +612,6 @@ func SessionRecordingEndpoint(c echo.Context) error { recording = session.Recording + "/recording" } - logrus.Debugf("读取录屏文件:%v,是否存在: %v, 是否为文件: %v", recording, utils.FileExists(recording), utils.IsFile(recording)) + log.Debugf("读取录屏文件:%v,是否存在: %v, 是否为文件: %v", recording, utils.FileExists(recording), utils.IsFile(recording)) return c.File(recording) } diff --git a/pkg/api/ssh.go b/server/api/ssh.go similarity index 84% rename from pkg/api/ssh.go rename to server/api/ssh.go index f7641e1..a76be4b 100644 --- a/pkg/api/ssh.go +++ b/server/api/ssh.go @@ -2,19 +2,21 @@ package api import ( "encoding/json" - "github.com/gorilla/websocket" - "github.com/labstack/echo/v4" - "github.com/sirupsen/logrus" "net/http" - "next-terminal/pkg/constant" - "next-terminal/pkg/global" - "next-terminal/pkg/guacd" - "next-terminal/pkg/model" - "next-terminal/pkg/term" - "next-terminal/pkg/utils" "path" "strconv" "time" + + "next-terminal/pkg/constant" + "next-terminal/pkg/global" + "next-terminal/pkg/guacd" + "next-terminal/pkg/log" + "next-terminal/pkg/term" + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" ) var UpGrader = websocket.Upgrader{ @@ -44,7 +46,7 @@ type WindowSize struct { func SSHEndpoint(c echo.Context) (err error) { ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) if err != nil { - logrus.Errorf("升级为WebSocket协议失败:%v", err.Error()) + log.Errorf("升级为WebSocket协议失败:%v", err.Error()) return err } @@ -52,7 +54,7 @@ func SSHEndpoint(c echo.Context) (err error) { cols, _ := strconv.Atoi(c.QueryParam("cols")) rows, _ := strconv.Atoi(c.QueryParam("rows")) - session, err := model.FindSessionById(sessionId) + session, err := sessionRepository.FindById(sessionId) if err != nil { msg := Message{ Type: Closed, @@ -65,7 +67,7 @@ func SSHEndpoint(c echo.Context) (err error) { user, _ := GetCurrentAccount(c) if constant.TypeUser == user.Type { // 检测是否有访问权限 - assetIds, err := model.FindAssetIdsByUserId(user.ID) + assetIds, err := resourceSharerRepository.FindAssetIdsByUserId(user.ID) if err != nil { return err } @@ -89,7 +91,7 @@ func SSHEndpoint(c echo.Context) (err error) { ) recording := "" - propertyMap := model.FindAllPropertiesMap() + propertyMap := propertyRepository.FindAllMap() if propertyMap[guacd.EnableRecording] == "true" { recording = path.Join(propertyMap[guacd.RecordingPath], sessionId, "recording.cast") } @@ -107,7 +109,7 @@ func SSHEndpoint(c echo.Context) (err error) { observers := append(observable.Observers, tun) observable.Observers = observers global.Store.Set(sessionId, observable) - logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers)) + log.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers)) } return err @@ -116,7 +118,7 @@ func SSHEndpoint(c echo.Context) (err error) { nextTerminal, err := term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording) if err != nil { - logrus.Errorf("创建SSH客户端失败:%v", err.Error()) + log.Errorf("创建SSH客户端失败:%v", err.Error()) msg := Message{ Type: Closed, Content: err.Error(), @@ -142,8 +144,8 @@ func SSHEndpoint(c echo.Context) (err error) { Recording: recording, } // 创建新会话 - logrus.Debugf("创建新会话 %v", sess.ConnectionId) - if err := model.UpdateSessionById(&sess, sessionId); err != nil { + log.Debugf("创建新会话 %v", sess.ConnectionId) + if err := sessionRepository.UpdateById(&sess, sessionId); err != nil { return err } @@ -170,7 +172,7 @@ func SSHEndpoint(c echo.Context) (err error) { var msg Message err = json.Unmarshal(message, &msg) if err != nil { - logrus.Warnf("解析Json失败: %v, 原始字符串:%v", err, string(message)) + log.Warnf("解析Json失败: %v, 原始字符串:%v", err, string(message)) continue } @@ -179,17 +181,17 @@ func SSHEndpoint(c echo.Context) (err error) { var winSize WindowSize err = json.Unmarshal([]byte(msg.Content), &winSize) if err != nil { - logrus.Warnf("解析SSH会话窗口大小失败: %v", err) + log.Warnf("解析SSH会话窗口大小失败: %v", err) continue } if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil { - logrus.Warnf("更改SSH会话窗口大小失败: %v", err) + log.Warnf("更改SSH会话窗口大小失败: %v", err) continue } case Data: _, err = nextTerminal.Write([]byte(msg.Content)) if err != nil { - logrus.Debugf("SSH会话写入失败: %v", err) + log.Debugf("SSH会话写入失败: %v", err) msg := Message{ Type: Closed, Content: "the remote connection is closed.", @@ -245,7 +247,7 @@ func WriteMessage(ws *websocket.Conn, msg Message) error { func WriteByteMessage(ws *websocket.Conn, p []byte) { err := ws.WriteMessage(websocket.TextMessage, p) if err != nil { - logrus.Debugf("write: %v", err) + log.Debugf("write: %v", err) } } diff --git a/pkg/api/tunnel.go b/server/api/tunnel.go similarity index 89% rename from pkg/api/tunnel.go rename to server/api/tunnel.go index ea8788c..a1330cc 100644 --- a/pkg/api/tunnel.go +++ b/server/api/tunnel.go @@ -2,15 +2,17 @@ package api import ( "errors" - "github.com/gorilla/websocket" - "github.com/labstack/echo/v4" - "github.com/sirupsen/logrus" + "path" + "strconv" + "next-terminal/pkg/constant" "next-terminal/pkg/global" "next-terminal/pkg/guacd" - "next-terminal/pkg/model" - "path" - "strconv" + "next-terminal/pkg/log" + "next-terminal/server/model" + + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" ) const ( @@ -25,7 +27,7 @@ func TunEndpoint(c echo.Context) error { ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) if err != nil { - logrus.Errorf("升级为WebSocket协议失败:%v", err.Error()) + log.Errorf("升级为WebSocket协议失败:%v", err.Error()) return err } @@ -40,18 +42,18 @@ func TunEndpoint(c echo.Context) error { configuration := guacd.NewConfiguration() - propertyMap := model.FindAllPropertiesMap() + propertyMap := propertyRepository.FindAllMap() var session model.Session if len(connectionId) > 0 { - session, err = model.FindSessionByConnectionId(connectionId) + session, err = sessionRepository.FindByConnectionId(connectionId) if err != nil { - logrus.Warnf("会话不存在") + log.Warnf("会话不存在") return err } if session.Status != constant.Connected { - logrus.Warnf("会话未在线") + log.Warnf("会话未在线") return errors.New("会话未在线") } configuration.ConnectionID = connectionId @@ -63,7 +65,7 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter("width", width) configuration.SetParameter("height", height) configuration.SetParameter("dpi", dpi) - session, err = model.FindSessionById(sessionId) + session, err = sessionRepository.FindById(sessionId) if err != nil { CloseSessionById(sessionId, NotFoundSession, "会话不存在") return err @@ -98,7 +100,6 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching]) configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching]) configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching]) - break case "ssh": if len(session.PrivateKey) > 0 && session.PrivateKey != "-" { configuration.SetParameter("username", session.Username) @@ -114,11 +115,9 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme]) configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace]) configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType]) - break case "vnc": configuration.SetParameter("username", session.Username) configuration.SetParameter("password", session.Password) - break case "telnet": configuration.SetParameter("username", session.Username) configuration.SetParameter("password", session.Password) @@ -128,7 +127,6 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme]) configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace]) configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType]) - break case "kubernetes": configuration.SetParameter(guacd.FontSize, propertyMap[guacd.FontSize]) @@ -136,13 +134,16 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter(guacd.ColorScheme, propertyMap[guacd.ColorScheme]) configuration.SetParameter(guacd.Backspace, propertyMap[guacd.Backspace]) configuration.SetParameter(guacd.TerminalType, propertyMap[guacd.TerminalType]) + default: + log.WithField("configuration.Protocol", configuration.Protocol).Error("UnSupport Protocol") + return Fail(c, 400, "不支持的协议") } configuration.SetParameter("hostname", session.IP) configuration.SetParameter("port", strconv.Itoa(session.Port)) // 加载资产配置的属性,优先级比全局配置的高,因此最后加载,覆盖掉全局配置 - attributes, _ := model.FindAssetAttributeByAssetId(session.AssetId) + attributes, _ := assetRepository.FindAttrById(session.AssetId) if len(attributes) > 0 { for i := range attributes { attribute := attributes[i] @@ -164,7 +165,7 @@ func TunEndpoint(c echo.Context) error { if connectionId == "" { CloseSessionById(sessionId, NewTunnelError, err.Error()) } - logrus.Printf("建立连接失败: %v", err.Error()) + log.Printf("建立连接失败: %v", err.Error()) return err } @@ -193,8 +194,8 @@ func TunEndpoint(c echo.Context) error { Recording: configuration.GetParameter(guacd.RecordingPath), } // 创建新会话 - logrus.Debugf("创建新会话 %v", sess.ConnectionId) - if err := model.UpdateSessionById(&sess, sessionId); err != nil { + log.Debugf("创建新会话 %v", sess.ConnectionId) + if err := sessionRepository.UpdateById(&sess, sessionId); err != nil { return err } } else { @@ -204,12 +205,12 @@ func TunEndpoint(c echo.Context) error { observers := append(observable.Observers, tun) observable.Observers = observers global.Store.Set(sessionId, observable) - logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers)) + log.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers)) } } go func() { - for true { + for { instruction, err := tunnel.Read() if err != nil { if connectionId == "" { @@ -230,7 +231,7 @@ func TunEndpoint(c echo.Context) error { } }() - for true { + for { _, message, err := ws.ReadMessage() if err != nil { if connectionId == "" { diff --git a/pkg/api/user-group.go b/server/api/user-group.go similarity index 58% rename from pkg/api/user-group.go rename to server/api/user-group.go index 5d9e840..650f0ac 100644 --- a/pkg/api/user-group.go +++ b/server/api/user-group.go @@ -1,12 +1,13 @@ package api import ( - "github.com/labstack/echo/v4" - "next-terminal/pkg/global" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "strconv" "strings" + + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) type UserGroup struct { @@ -27,7 +28,7 @@ func UserGroupCreateEndpoint(c echo.Context) error { Name: item.Name, } - if err := model.CreateNewUserGroup(&userGroup, item.Members); err != nil { + if err := userGroupRepository.Create(&userGroup, item.Members); err != nil { return err } @@ -42,7 +43,7 @@ func UserGroupPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := model.FindPageUserGroup(pageIndex, pageSize, name, order, field) + items, total, err := userGroupRepository.Find(pageIndex, pageSize, name, order, field) if err != nil { return err } @@ -64,7 +65,7 @@ func UserGroupUpdateEndpoint(c echo.Context) error { Name: item.Name, } - if err := model.UpdateUserGroupById(&userGroup, item.Members, id); err != nil { + if err := userGroupRepository.Update(&userGroup, item.Members, id); err != nil { return err } @@ -76,7 +77,9 @@ func UserGroupDeleteEndpoint(c echo.Context) error { split := strings.Split(ids, ",") for i := range split { userId := split[i] - model.DeleteUserGroupById(userId) + if err := userGroupRepository.DeleteById(userId); err != nil { + return err + } } return Success(c, nil) @@ -85,12 +88,12 @@ func UserGroupDeleteEndpoint(c echo.Context) error { func UserGroupGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := model.FindUserGroupById(id) + item, err := userGroupRepository.FindById(id) if err != nil { return err } - members, err := model.FindUserGroupMembersByUserGroupId(id) + members, err := userGroupRepository.FindMembersById(id) if err != nil { return err } @@ -103,32 +106,3 @@ func UserGroupGetEndpoint(c echo.Context) error { return Success(c, userGroup) } - -func UserGroupAddMembersEndpoint(c echo.Context) error { - id := c.Param("id") - - var items []string - if err := c.Bind(&items); err != nil { - return err - } - - if err := model.AddUserGroupMembers(global.DB, items, id); err != nil { - return err - } - return Success(c, "") -} - -func UserGroupDelMembersEndpoint(c echo.Context) (err error) { - id := c.Param("id") - memberIdsStr := c.Param("memberId") - memberIds := strings.Split(memberIdsStr, ",") - for i := range memberIds { - memberId := memberIds[i] - err = global.DB.Where("user_group_id = ? and user_id = ?", id, memberId).Delete(&model.UserGroupMember{}).Error - if err != nil { - return err - } - } - - return Success(c, "") -} diff --git a/pkg/api/user.go b/server/api/user.go similarity index 50% rename from pkg/api/user.go rename to server/api/user.go index 157c2cf..e564dd1 100644 --- a/pkg/api/user.go +++ b/server/api/user.go @@ -1,12 +1,15 @@ package api import ( - "github.com/labstack/echo/v4" - "next-terminal/pkg/global" - "next-terminal/pkg/model" - "next-terminal/pkg/utils" "strconv" "strings" + + "next-terminal/pkg/global" + "next-terminal/pkg/log" + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" ) func UserCreateEndpoint(c echo.Context) error { @@ -26,12 +29,12 @@ func UserCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := model.CreateNewUser(&item); err != nil { + if err := userRepository.Create(&item); err != nil { return err } if item.Mail != "" { - go model.SendMail(item.Mail, "[Next Terminal] 注册通知", "你好,"+item.Nickname+"。管理员为你注册了账号:"+item.Username+" 密码:"+password) + go mailService.SendMail(item.Mail, "[Next Terminal] 注册通知", "你好,"+item.Nickname+"。管理员为你注册了账号:"+item.Username+" 密码:"+password) } return Success(c, item) } @@ -46,7 +49,7 @@ func UserPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := model.FindPageUser(pageIndex, pageSize, username, nickname, mail, order, field) + items, total, err := userRepository.Find(pageIndex, pageSize, username, nickname, mail, order, field) if err != nil { return err } @@ -64,8 +67,11 @@ func UserUpdateEndpoint(c echo.Context) error { if err := c.Bind(&item); err != nil { return err } + item.ID = id - model.UpdateUserById(&item, id) + if err := userRepository.Update(&item); err != nil { + return err + } return Success(c, nil) } @@ -83,18 +89,23 @@ func UserDeleteEndpoint(c echo.Context) error { return Fail(c, -1, "不允许删除自身账户") } // 将用户强制下线 - loginLogs, err := model.FindAliveLoginLogsByUserId(userId) + loginLogs, err := loginLogRepository.FindAliveLoginLogsByUserId(userId) if err != nil { return err } - if loginLogs != nil && len(loginLogs) > 0 { - for j := range loginLogs { - global.Cache.Delete(loginLogs[j].ID) - model.Logout(loginLogs[j].ID) + + for j := range loginLogs { + global.Cache.Delete(loginLogs[j].ID) + if err := userService.Logout(loginLogs[j].ID); err != nil { + log.WithError(err).WithField("id:", loginLogs[j].ID).Error("Cache Deleted Error") + return Fail(c, 500, "强制下线错误") } } + // 删除用户 - model.DeleteUserById(userId) + if err := userRepository.DeleteById(userId); err != nil { + return err + } } return Success(c, nil) @@ -103,7 +114,7 @@ func UserDeleteEndpoint(c echo.Context) error { func UserGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := model.FindUserById(id) + item, err := userRepository.FindById(id) if err != nil { return err } @@ -115,7 +126,7 @@ func UserChangePasswordEndpoint(c echo.Context) error { id := c.Param("id") password := c.QueryParam("password") - user, err := model.FindUserById(id) + user, err := userRepository.FindById(id) if err != nil { return err } @@ -126,11 +137,14 @@ func UserChangePasswordEndpoint(c echo.Context) error { } u := &model.User{ Password: string(passwd), + ID: id, + } + if err := userRepository.Update(u); err != nil { + return err } - model.UpdateUserById(u, id) if user.Mail != "" { - go model.SendMail(user.Mail, "[Next Terminal] 密码修改通知", "你好,"+user.Nickname+"。管理员已将你的密码修改为:"+password) + go mailService.SendMail(user.Mail, "[Next Terminal] 密码修改通知", "你好,"+user.Nickname+"。管理员已将你的密码修改为:"+password) } return Success(c, "") @@ -140,7 +154,44 @@ func UserResetTotpEndpoint(c echo.Context) error { id := c.Param("id") u := &model.User{ TOTPSecret: "-", + ID: id, + } + if err := userRepository.Update(u); err != nil { + return err } - model.UpdateUserById(u, id) return Success(c, "") } + +func ReloadToken() error { + loginLogs, err := loginLogRepository.FindAliveLoginLogs() + if err != nil { + return err + } + + for i := range loginLogs { + loginLog := loginLogs[i] + token := loginLog.ID + user, err := userRepository.FindById(loginLog.UserId) + if err != nil { + log.Debugf("用户「%v」获取失败,忽略", loginLog.UserId) + continue + } + + authorization := Authorization{ + Token: token, + Remember: loginLog.Remember, + User: user, + } + + cacheKey := BuildCacheKeyByToken(token) + + if authorization.Remember { + // 记住登录有效期两周 + global.Cache.Set(cacheKey, authorization, RememberEffectiveTime) + } else { + global.Cache.Set(cacheKey, authorization, NotRememberEffectiveTime) + } + log.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token) + } + return nil +} diff --git a/server/model/access_security.go b/server/model/access_security.go new file mode 100644 index 0000000..c111d12 --- /dev/null +++ b/server/model/access_security.go @@ -0,0 +1,13 @@ +package model + +type AccessSecurity struct { + ID string `json:"id"` + Rule string `json:"rule"` + IP string `json:"ip"` + Source string `json:"source"` + Priority int64 `json:"priority"` // 越小优先级越高 +} + +func (r *AccessSecurity) TableName() string { + return "access_securities" +} diff --git a/server/model/asset.go b/server/model/asset.go new file mode 100644 index 0000000..0a69c2f --- /dev/null +++ b/server/model/asset.go @@ -0,0 +1,53 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type Asset struct { + ID string `gorm:"primary_key " json:"id"` + Name string `json:"name"` + Protocol string `json:"protocol"` + IP string `json:"ip"` + Port int `json:"port"` + AccountType string `json:"accountType"` + Username string `json:"username"` + Password string `json:"password"` + CredentialId string `gorm:"index" json:"credentialId"` + PrivateKey string `json:"privateKey"` + Passphrase string `json:"passphrase"` + Description string `json:"description"` + Active bool `json:"active"` + Created utils.JsonTime `json:"created"` + Tags string `json:"tags"` + Owner string `gorm:"index" json:"owner"` +} + +type AssetForPage struct { + ID string `json:"id"` + Name string `json:"name"` + IP string `json:"ip"` + Protocol string `json:"protocol"` + Port int `json:"port"` + Active bool `json:"active"` + Created utils.JsonTime `json:"created"` + Tags string `json:"tags"` + Owner string `json:"owner"` + OwnerName string `json:"ownerName"` + SharerCount int64 `json:"sharerCount"` +} + +func (r *Asset) TableName() string { + return "assets" +} + +type AssetAttribute struct { + Id string `gorm:"index" json:"id"` + AssetId string `gorm:"index" json:"assetId"` + Name string `gorm:"index" json:"name"` + Value string `json:"value"` +} + +func (r *AssetAttribute) TableName() string { + return "asset_attributes" +} diff --git a/server/model/command.go b/server/model/command.go new file mode 100644 index 0000000..821e48a --- /dev/null +++ b/server/model/command.go @@ -0,0 +1,27 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type Command struct { + ID string `gorm:"primary_key" json:"id"` + Name string `json:"name"` + Content string `json:"content"` + Created utils.JsonTime `json:"created"` + Owner string `gorm:"index" json:"owner"` +} + +type CommandForPage struct { + ID string `gorm:"primary_key" json:"id"` + Name string `json:"name"` + Content string `json:"content"` + Created utils.JsonTime `json:"created"` + Owner string `json:"owner"` + OwnerName string `json:"ownerName"` + SharerCount int64 `json:"sharerCount"` +} + +func (r *Command) TableName() string { + return "commands" +} diff --git a/server/model/credential.go b/server/model/credential.go new file mode 100644 index 0000000..a2fb006 --- /dev/null +++ b/server/model/credential.go @@ -0,0 +1,37 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type Credential struct { + ID string `gorm:"primary_key" json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Username string `json:"username"` + Password string `json:"password"` + PrivateKey string `json:"privateKey"` + Passphrase string `json:"passphrase"` + Created utils.JsonTime `json:"created"` + Owner string `gorm:"index" json:"owner"` +} + +func (r *Credential) TableName() string { + return "credentials" +} + +type CredentialForPage struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Username string `json:"username"` + Created utils.JsonTime `json:"created"` + Owner string `json:"owner"` + OwnerName string `json:"ownerName"` + SharerCount int64 `json:"sharerCount"` +} + +type CredentialSimpleVo struct { + ID string `json:"id"` + Name string `json:"name"` +} diff --git a/server/model/job.go b/server/model/job.go new file mode 100644 index 0000000..35a337f --- /dev/null +++ b/server/model/job.go @@ -0,0 +1,34 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type Job struct { + ID string `gorm:"primary_key" json:"id"` + CronJobId int `json:"cronJobId"` + Name string `json:"name"` + Func string `json:"func"` + Cron string `json:"cron"` + Mode string `json:"mode"` + ResourceIds string `json:"resourceIds"` + Status string `json:"status"` + Metadata string `json:"metadata"` + Created utils.JsonTime `json:"created"` + Updated utils.JsonTime `json:"updated"` +} + +func (r *Job) TableName() string { + return "jobs" +} + +type JobLog struct { + ID string `json:"id"` + Timestamp utils.JsonTime `json:"timestamp"` + JobId string `json:"jobId"` + Message string `json:"message"` +} + +func (r *JobLog) TableName() string { + return "job_logs" +} diff --git a/server/model/login_log.go b/server/model/login_log.go new file mode 100644 index 0000000..bc0319c --- /dev/null +++ b/server/model/login_log.go @@ -0,0 +1,30 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type LoginLog struct { + ID string `gorm:"primary_key" json:"id"` + UserId string `gorm:"index" json:"userId"` + ClientIP string `json:"clientIp"` + ClientUserAgent string `json:"clientUserAgent"` + LoginTime utils.JsonTime `json:"loginTime"` + LogoutTime utils.JsonTime `json:"logoutTime"` + Remember bool `json:"remember"` +} + +type LoginLogForPage struct { + ID string `json:"id"` + UserId string `json:"userId"` + UserName string `json:"userName"` + ClientIP string `json:"clientIp"` + ClientUserAgent string `json:"clientUserAgent"` + LoginTime utils.JsonTime `json:"loginTime"` + LogoutTime utils.JsonTime `json:"logoutTime"` + Remember bool `json:"remember"` +} + +func (r *LoginLog) TableName() string { + return "login_logs" +} diff --git a/server/model/num.go b/server/model/num.go new file mode 100644 index 0000000..fa7bd58 --- /dev/null +++ b/server/model/num.go @@ -0,0 +1,9 @@ +package model + +type Num struct { + I string `gorm:"primary_key" json:"i"` +} + +func (r *Num) TableName() string { + return "nums" +} diff --git a/server/model/property.go b/server/model/property.go new file mode 100644 index 0000000..7f2aae8 --- /dev/null +++ b/server/model/property.go @@ -0,0 +1,10 @@ +package model + +type Property struct { + Name string `gorm:"primary_key" json:"name"` + Value string `json:"value"` +} + +func (r *Property) TableName() string { + return "properties" +} diff --git a/server/model/resource_sharer.go b/server/model/resource_sharer.go new file mode 100644 index 0000000..8a9591b --- /dev/null +++ b/server/model/resource_sharer.go @@ -0,0 +1,13 @@ +package model + +type ResourceSharer struct { + ID string `gorm:"primary_key" json:"id"` + ResourceId string `gorm:"index" json:"resourceId"` + ResourceType string `gorm:"index" json:"resourceType"` + UserId string `gorm:"index" json:"userId"` + UserGroupId string `gorm:"index" json:"userGroupId"` +} + +func (r *ResourceSharer) TableName() string { + return "resource_sharers" +} diff --git a/server/model/session.go b/server/model/session.go new file mode 100644 index 0000000..e39f7a1 --- /dev/null +++ b/server/model/session.go @@ -0,0 +1,56 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type Session struct { + ID string `gorm:"primary_key" json:"id"` + Protocol string `json:"protocol"` + IP string `json:"ip"` + Port int `json:"port"` + ConnectionId string `json:"connectionId"` + AssetId string `gorm:"index" json:"assetId"` + Username string `json:"username"` + Password string `json:"password"` + Creator string `gorm:"index" json:"creator"` + ClientIP string `json:"clientIp"` + Width int `json:"width"` + Height int `json:"height"` + Status string `gorm:"index" json:"status"` + Recording string `json:"recording"` + PrivateKey string `json:"privateKey"` + Passphrase string `json:"passphrase"` + Code int `json:"code"` + Message string `json:"message"` + ConnectedTime utils.JsonTime `json:"connectedTime"` + DisconnectedTime utils.JsonTime `json:"disconnectedTime"` + Mode string `json:"mode"` +} + +func (r *Session) TableName() string { + return "sessions" +} + +type SessionForPage struct { + ID string `json:"id"` + Protocol string `json:"protocol"` + IP string `json:"ip"` + Port int `json:"port"` + Username string `json:"username"` + ConnectionId string `json:"connectionId"` + AssetId string `json:"assetId"` + Creator string `json:"creator"` + ClientIP string `json:"clientIp"` + Width int `json:"width"` + Height int `json:"height"` + Status string `json:"status"` + Recording string `json:"recording"` + ConnectedTime utils.JsonTime `json:"connectedTime"` + DisconnectedTime utils.JsonTime `json:"disconnectedTime"` + AssetName string `json:"assetName"` + CreatorName string `json:"creatorName"` + Code int `json:"code"` + Message string `json:"message"` + Mode string `json:"mode"` +} diff --git a/server/model/user.go b/server/model/user.go new file mode 100644 index 0000000..5326a9b --- /dev/null +++ b/server/model/user.go @@ -0,0 +1,35 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type User struct { + ID string `gorm:"primary_key" json:"id"` + Username string `gorm:"index" json:"username"` + Password string `json:"password"` + Nickname string `json:"nickname"` + TOTPSecret string `json:"-"` + Online bool `json:"online"` + Enabled bool `json:"enabled"` + Created utils.JsonTime `json:"created"` + Type string `json:"type"` + Mail string `json:"mail"` +} + +type UserForPage struct { + ID string `json:"id"` + Username string `json:"username"` + Nickname string `json:"nickname"` + TOTPSecret string `json:"totpSecret"` + Mail string `json:"mail"` + Online bool `json:"online"` + Enabled bool `json:"enabled"` + Created utils.JsonTime `json:"created"` + Type string `json:"type"` + SharerAssetCount int64 `json:"sharerAssetCount"` +} + +func (r *User) TableName() string { + return "users" +} diff --git a/server/model/user_group.go b/server/model/user_group.go new file mode 100644 index 0000000..9fc39b6 --- /dev/null +++ b/server/model/user_group.go @@ -0,0 +1,32 @@ +package model + +import ( + "next-terminal/server/utils" +) + +type UserGroup struct { + ID string `gorm:"primary_key" json:"id"` + Name string `json:"name"` + Created utils.JsonTime `json:"created"` +} + +type UserGroupForPage struct { + ID string `json:"id"` + Name string `json:"name"` + Created utils.JsonTime `json:"created"` + AssetCount int64 `json:"assetCount"` +} + +func (r *UserGroup) TableName() string { + return "user_groups" +} + +type UserGroupMember struct { + ID string `gorm:"primary_key" json:"name"` + UserId string `gorm:"index" json:"userId"` + UserGroupId string `gorm:"index" json:"userGroupId"` +} + +func (r *UserGroupMember) TableName() string { + return "user_group_members" +} diff --git a/server/repository/access_security.go b/server/repository/access_security.go new file mode 100644 index 0000000..c5cb92e --- /dev/null +++ b/server/repository/access_security.go @@ -0,0 +1,81 @@ +package repository + +import ( + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type AccessSecurityRepository struct { + DB *gorm.DB +} + +func NewAccessSecurityRepository(db *gorm.DB) *AccessSecurityRepository { + accessSecurityRepository = &AccessSecurityRepository{DB: db} + return accessSecurityRepository +} + +func (r AccessSecurityRepository) FindAllAccessSecurities() (o []model.AccessSecurity, err error) { + db := r.DB + err = db.Order("priority asc").Find(&o).Error + return +} + +func (r AccessSecurityRepository) Find(pageIndex, pageSize int, ip, rule, order, field string) (o []model.AccessSecurity, total int64, err error) { + t := model.AccessSecurity{} + db := r.DB.Table(t.TableName()) + dbCounter := r.DB.Table(t.TableName()) + + if len(ip) > 0 { + db = db.Where("ip like ?", "%"+ip+"%") + dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") + } + + if len(rule) > 0 { + db = db.Where("rule = ?", rule) + dbCounter = dbCounter.Where("rule = ?", rule) + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "descend" { + order = "desc" + } else { + order = "asc" + } + + if field == "ip" { + field = "ip" + } else if field == "rule" { + field = "rule" + } else { + field = "priority" + } + + err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.AccessSecurity, 0) + } + return +} + +func (r AccessSecurityRepository) Create(o *model.AccessSecurity) error { + return r.DB.Create(o).Error +} + +func (r AccessSecurityRepository) UpdateById(o *model.AccessSecurity, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r AccessSecurityRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(model.AccessSecurity{}).Error +} + +func (r AccessSecurityRepository) FindById(id string) (o *model.AccessSecurity, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} diff --git a/server/repository/asset.go b/server/repository/asset.go new file mode 100644 index 0000000..8c9b684 --- /dev/null +++ b/server/repository/asset.go @@ -0,0 +1,304 @@ +package repository + +import ( + "fmt" + "strings" + + "next-terminal/pkg/constant" + "next-terminal/pkg/global" + "next-terminal/server/model" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" + "gorm.io/gorm" +) + +type AssetRepository struct { + DB *gorm.DB +} + +func NewAssetRepository(db *gorm.DB) *AssetRepository { + assetRepository = &AssetRepository{DB: db} + return assetRepository +} + +func (r AssetRepository) FindAll() (o []model.Asset, err error) { + err = r.DB.Find(&o).Error + return +} + +func (r AssetRepository) FindByIds(assetIds []string) (o []model.Asset, err error) { + err = r.DB.Where("id in ?", assetIds).Find(&o).Error + return +} + +func (r AssetRepository) FindByProtocol(protocol string) (o []model.Asset, err error) { + err = r.DB.Where("protocol = ?", protocol).Find(&o).Error + return +} + +func (r AssetRepository) FindByProtocolAndIds(protocol string, assetIds []string) (o []model.Asset, err error) { + err = r.DB.Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error + return +} + +func (r AssetRepository) FindByProtocolAndUser(protocol string, account model.User) (o []model.Asset, err error) { + db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) + } + + if len(protocol) > 0 { + db = db.Where("assets.protocol = ?", protocol) + } + err = db.Find(&o).Error + return +} + +func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags string, account model.User, owner, sharer, userGroupId, ip, order, field string) (o []model.AssetForPage, total int64, err error) { + db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") + dbCounter := r.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) + dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) + + // 查询用户所在用户组列表 + userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(account.ID) + if err != nil { + return nil, 0, err + } + + if len(userGroupIds) > 0 { + db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) + dbCounter = dbCounter.Or("resource_sharers.user_group_id in ?", userGroupIds) + } + } else { + if len(owner) > 0 { + db = db.Where("assets.owner = ?", owner) + dbCounter = dbCounter.Where("assets.owner = ?", owner) + } + if len(sharer) > 0 { + db = db.Where("resource_sharers.user_id = ?", sharer) + dbCounter = dbCounter.Where("resource_sharers.user_id = ?", sharer) + } + + if len(userGroupId) > 0 { + db = db.Where("resource_sharers.user_group_id = ?", userGroupId) + dbCounter = dbCounter.Where("resource_sharers.user_group_id = ?", userGroupId) + } + } + + if len(name) > 0 { + db = db.Where("assets.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("assets.name like ?", "%"+name+"%") + } + + if len(ip) > 0 { + db = db.Where("assets.ip like ?", "%"+ip+"%") + dbCounter = dbCounter.Where("assets.ip like ?", "%"+ip+"%") + } + + if len(protocol) > 0 { + db = db.Where("assets.protocol = ?", protocol) + dbCounter = dbCounter.Where("assets.protocol = ?", protocol) + } + + if len(tags) > 0 { + tagArr := strings.Split(tags, ",") + for i := range tagArr { + if global.Config.DB == "sqlite" { + db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") + dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%") + } else { + db = db.Where("find_in_set(?, assets.tags)", tagArr[i]) + dbCounter = dbCounter.Where("find_in_set(?, assets.tags)", tagArr[i]) + } + } + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("assets." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + + if o == nil { + o = make([]model.AssetForPage, 0) + } + return +} + +func (r AssetRepository) Create(o *model.Asset) (err error) { + if err = r.DB.Create(o).Error; err != nil { + return err + } + return nil +} + +func (r AssetRepository) FindById(id string) (o model.Asset, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r AssetRepository) UpdateById(o *model.Asset, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r AssetRepository) UpdateActiveById(active bool, id string) error { + sql := "update assets set active = ? where id = ?" + return r.DB.Exec(sql, active, id).Error +} + +func (r AssetRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Asset{}).Error +} + +func (r AssetRepository) Count() (total int64, err error) { + err = r.DB.Find(&model.Asset{}).Count(&total).Error + return +} + +func (r AssetRepository) CountByUserId(userId string) (total int64, err error) { + db := r.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") + + db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId) + + // 查询用户所在用户组列表 + userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + if err != nil { + return 0, err + } + + if len(userGroupIds) > 0 { + db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) + } + err = db.Find(&model.Asset{}).Count(&total).Error + return +} + +func (r AssetRepository) FindTags() (o []string, err error) { + var assets []model.Asset + err = r.DB.Not("tags = ?", "").Find(&assets).Error + if err != nil { + return nil, err + } + + o = make([]string, 0) + + for i := range assets { + if len(assets[i].Tags) == 0 { + continue + } + split := strings.Split(assets[i].Tags, ",") + + o = append(o, split...) + } + + return utils.Distinct(o), nil +} + +func (r AssetRepository) UpdateAttributes(assetId, protocol string, m echo.Map) error { + var data []model.AssetAttribute + var parameterNames []string + switch protocol { + case "ssh": + parameterNames = constant.SSHParameterNames + case "rdp": + parameterNames = constant.RDPParameterNames + case "vnc": + parameterNames = constant.VNCParameterNames + case "telnet": + parameterNames = constant.TelnetParameterNames + case "kubernetes": + parameterNames = constant.KubernetesParameterNames + } + + for i := range parameterNames { + name := parameterNames[i] + if m[name] != nil && m[name] != "" { + data = append(data, genAttribute(assetId, name, m)) + } + } + + return r.DB.Transaction(func(tx *gorm.DB) error { + err := tx.Where("asset_id = ?", assetId).Delete(&model.AssetAttribute{}).Error + if err != nil { + return err + } + return tx.CreateInBatches(&data, len(data)).Error + }) +} + +func genAttribute(assetId, name string, m echo.Map) model.AssetAttribute { + value := fmt.Sprintf("%v", m[name]) + attribute := model.AssetAttribute{ + Id: utils.Sign([]string{assetId, name}), + AssetId: assetId, + Name: name, + Value: value, + } + return attribute +} + +func (r AssetRepository) FindAttrById(assetId string) (o []model.AssetAttribute, err error) { + err = r.DB.Where("asset_id = ?", assetId).Find(&o).Error + if o == nil { + o = make([]model.AssetAttribute, 0) + } + return o, err +} + +func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]interface{}, error) { + asset, err := r.FindById(assetId) + if err != nil { + return nil, err + } + attributes, err := r.FindAttrById(assetId) + if err != nil { + return nil, err + } + + var parameterNames []string + switch asset.Protocol { + case "ssh": + parameterNames = constant.SSHParameterNames + case "rdp": + parameterNames = constant.RDPParameterNames + case "vnc": + parameterNames = constant.VNCParameterNames + case "telnet": + parameterNames = constant.TelnetParameterNames + case "kubernetes": + parameterNames = constant.KubernetesParameterNames + } + propertiesMap := propertyRepository.FindAllMap() + var attributeMap = make(map[string]interface{}) + for name := range propertiesMap { + if utils.Contains(parameterNames, name) { + attributeMap[name] = propertiesMap[name] + } + } + + for i := range attributes { + attributeMap[attributes[i].Name] = attributes[i].Value + } + return attributeMap, nil +} diff --git a/server/repository/command.go b/server/repository/command.go new file mode 100644 index 0000000..c9c483c --- /dev/null +++ b/server/repository/command.go @@ -0,0 +1,82 @@ +package repository + +import ( + "next-terminal/pkg/constant" + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type CommandRepository struct { + DB *gorm.DB +} + +func NewCommandRepository(db *gorm.DB) *CommandRepository { + commandRepository = &CommandRepository{DB: db} + return commandRepository +} + +func (r CommandRepository) Find(pageIndex, pageSize int, name, content, order, field string, account model.User) (o []model.CommandForPage, total int64, err error) { + db := r.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") + dbCounter := r.DB.Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) + dbCounter = dbCounter.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner) + } + + if len(name) > 0 { + db = db.Where("commands.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("commands.name like ?", "%"+name+"%") + } + + if len(content) > 0 { + db = db.Where("commands.content like ?", "%"+content+"%") + dbCounter = dbCounter.Where("commands.content like ?", "%"+content+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("commands." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + if o == nil { + o = make([]model.CommandForPage, 0) + } + return +} + +func (r CommandRepository) Create(o *model.Command) (err error) { + if err = r.DB.Create(o).Error; err != nil { + return err + } + return nil +} + +func (r CommandRepository) FindById(id string) (o model.Command, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r CommandRepository) UpdateById(o *model.Command, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r CommandRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Command{}).Error +} diff --git a/server/repository/credential.go b/server/repository/credential.go new file mode 100644 index 0000000..0a29b2d --- /dev/null +++ b/server/repository/credential.go @@ -0,0 +1,109 @@ +package repository + +import ( + "next-terminal/pkg/constant" + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type CredentialRepository struct { + DB *gorm.DB +} + +func NewCredentialRepository(db *gorm.DB) *CredentialRepository { + credentialRepository = &CredentialRepository{DB: db} + return credentialRepository +} + +func (r CredentialRepository) FindByUser(account model.User) (o []model.CredentialSimpleVo, err error) { + db := r.DB.Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") + if account.Type == constant.TypeUser { + db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", account.ID, account.ID) + } + err = db.Find(&o).Error + return +} + +func (r CredentialRepository) Find(pageIndex, pageSize int, name, order, field string, account model.User) (o []model.CredentialForPage, total int64, err error) { + db := r.DB.Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") + dbCounter := r.DB.Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") + + if constant.TypeUser == account.Type { + owner := account.ID + db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) + dbCounter = dbCounter.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner) + } + + if len(name) > 0 { + db = db.Where("credentials.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("credentials.name like ?", "%"+name+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("credentials." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + if o == nil { + o = make([]model.CredentialForPage, 0) + } + return +} + +func (r CredentialRepository) Create(o *model.Credential) (err error) { + if err = r.DB.Create(o).Error; err != nil { + return err + } + return nil +} + +func (r CredentialRepository) FindById(id string) (o model.Credential, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r CredentialRepository) UpdateById(o *model.Credential, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r CredentialRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Credential{}).Error +} + +func (r CredentialRepository) Count() (total int64, err error) { + err = r.DB.Find(&model.Credential{}).Count(&total).Error + return +} + +func (r CredentialRepository) CountByUserId(userId string) (total int64, err error) { + db := r.DB.Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") + + db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId) + + // 查询用户所在用户组列表 + userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + if err != nil { + return 0, err + } + + if len(userGroupIds) > 0 { + db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) + } + err = db.Find(&model.Credential{}).Count(&total).Error + return +} diff --git a/server/repository/definitions.go b/server/repository/definitions.go new file mode 100644 index 0000000..56686fb --- /dev/null +++ b/server/repository/definitions.go @@ -0,0 +1,17 @@ +package repository + +var ( + userRepository *UserRepository + userGroupRepository *UserGroupRepository + resourceSharerRepository *ResourceSharerRepository + assetRepository *AssetRepository + credentialRepository *CredentialRepository + propertyRepository *PropertyRepository + commandRepository *CommandRepository + sessionRepository *SessionRepository + numRepository *NumRepository + accessSecurityRepository *AccessSecurityRepository + jobRepository *JobRepository + jobLogRepository *JobLogRepository + loginLogRepository *LoginLogRepository +) diff --git a/server/repository/job.go b/server/repository/job.go new file mode 100644 index 0000000..fcfc8f5 --- /dev/null +++ b/server/repository/job.go @@ -0,0 +1,95 @@ +package repository + +import ( + "next-terminal/server/model" + "next-terminal/server/utils" + + "gorm.io/gorm" +) + +type JobRepository struct { + DB *gorm.DB +} + +func NewJobRepository(db *gorm.DB) *JobRepository { + jobRepository = &JobRepository{DB: db} + return jobRepository +} + +func (r JobRepository) Find(pageIndex, pageSize int, name, status, order, field string) (o []model.Job, total int64, err error) { + job := model.Job{} + db := r.DB.Table(job.TableName()) + dbCounter := r.DB.Table(job.TableName()) + + if len(name) > 0 { + db = db.Where("name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("name like ?", "%"+name+"%") + } + + if len(status) > 0 { + db = db.Where("status = ?", status) + dbCounter = dbCounter.Where("status = ?", status) + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else if field == "created" { + field = "created" + } else { + field = "updated" + } + + err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.Job, 0) + } + return +} + +func (r JobRepository) FindByFunc(function string) (o []model.Job, err error) { + db := r.DB + err = db.Where("func = ?", function).Find(&o).Error + return +} + +func (r JobRepository) Create(o *model.Job) (err error) { + return r.DB.Create(o).Error +} + +func (r JobRepository) UpdateById(o *model.Job) (err error) { + return r.DB.Updates(o).Error +} + +func (r JobRepository) UpdateLastUpdatedById(id string) (err error) { + err = r.DB.Updates(model.Job{ID: id, Updated: utils.NowJsonTime()}).Error + return +} + +func (r JobRepository) FindById(id string) (o model.Job, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r JobRepository) DeleteJobById(id string) error { + //job, err := r.FindById(id) + //if err != nil { + // return err + //} + //if job.Status == constant.JobStatusRunning { + // if err := r.ChangeStatusById(id, constant.JobStatusNotRunning); err != nil { + // return err + // } + //} + return r.DB.Where("id = ?", id).Delete(model.Job{}).Error +} diff --git a/server/repository/job_log.go b/server/repository/job_log.go new file mode 100644 index 0000000..17a0dcd --- /dev/null +++ b/server/repository/job_log.go @@ -0,0 +1,29 @@ +package repository + +import ( + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type JobLogRepository struct { + DB *gorm.DB +} + +func NewJobLogRepository(db *gorm.DB) *JobLogRepository { + jobLogRepository = &JobLogRepository{DB: db} + return jobLogRepository +} + +func (r JobLogRepository) Create(o *model.JobLog) error { + return r.DB.Create(o).Error +} + +func (r JobLogRepository) FindByJobId(jobId string) (o []model.JobLog, err error) { + err = r.DB.Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error + return +} + +func (r JobLogRepository) DeleteByJobId(jobId string) error { + return r.DB.Where("job_id = ?", jobId).Delete(model.JobLog{}).Error +} diff --git a/server/repository/login_log.go b/server/repository/login_log.go new file mode 100644 index 0000000..0ff7b58 --- /dev/null +++ b/server/repository/login_log.go @@ -0,0 +1,70 @@ +package repository + +import ( + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type LoginLogRepository struct { + DB *gorm.DB +} + +func NewLoginLogRepository(db *gorm.DB) *LoginLogRepository { + loginLogRepository = &LoginLogRepository{DB: db} + return loginLogRepository +} + +func (r LoginLogRepository) Find(pageIndex, pageSize int, userId, clientIp string) (o []model.LoginLogForPage, total int64, err error) { + + db := r.DB.Table("login_logs").Select("login_logs.id,login_logs.user_id,login_logs.client_ip,login_logs.client_user_agent,login_logs.login_time, login_logs.logout_time, users.nickname as user_name").Joins("left join users on login_logs.user_id = users.id") + dbCounter := r.DB.Table("login_logs").Select("DISTINCT login_logs.id") + + if userId != "" { + db = db.Where("login_logs.user_id = ?", userId) + dbCounter = dbCounter.Where("login_logs.user_id = ?", userId) + } + + if clientIp != "" { + db = db.Where("login_logs.client_ip like ?", "%"+clientIp+"%") + dbCounter = dbCounter.Where("login_logs.client_ip like ?", "%"+clientIp+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + err = db.Order("login_logs.login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error + if o == nil { + o = make([]model.LoginLogForPage, 0) + } + return +} + +func (r LoginLogRepository) FindAliveLoginLogs() (o []model.LoginLog, err error) { + err = r.DB.Where("logout_time is null").Find(&o).Error + return +} + +func (r LoginLogRepository) FindAliveLoginLogsByUserId(userId string) (o []model.LoginLog, err error) { + err = r.DB.Where("logout_time is null and user_id = ?", userId).Find(&o).Error + return +} + +func (r LoginLogRepository) Create(o *model.LoginLog) (err error) { + return r.DB.Create(o).Error +} + +func (r LoginLogRepository) DeleteByIdIn(ids []string) (err error) { + return r.DB.Where("id in ?", ids).Delete(&model.LoginLog{}).Error +} + +func (r LoginLogRepository) FindById(id string) (o model.LoginLog, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r LoginLogRepository) Update(o *model.LoginLog) error { + return r.DB.Updates(o).Error +} diff --git a/server/repository/num.go b/server/repository/num.go new file mode 100644 index 0000000..1fcd174 --- /dev/null +++ b/server/repository/num.go @@ -0,0 +1,26 @@ +package repository + +import ( + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type NumRepository struct { + DB *gorm.DB +} + +func NewNumRepository(db *gorm.DB) *NumRepository { + numRepository = &NumRepository{DB: db} + return numRepository +} + +func (r NumRepository) FindAll() (o []model.Num, err error) { + err = r.DB.Find(&o).Error + return +} + +func (r NumRepository) Create(o *model.Num) (err error) { + err = r.DB.Create(o).Error + return +} diff --git a/server/repository/property.go b/server/repository/property.go new file mode 100644 index 0000000..8e44641 --- /dev/null +++ b/server/repository/property.go @@ -0,0 +1,64 @@ +package repository + +import ( + "next-terminal/pkg/guacd" + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type PropertyRepository struct { + DB *gorm.DB +} + +func NewPropertyRepository(db *gorm.DB) *PropertyRepository { + propertyRepository = &PropertyRepository{DB: db} + return propertyRepository +} + +func (r PropertyRepository) FindAll() (o []model.Property) { + if r.DB.Find(&o).Error != nil { + return nil + } + return +} + +func (r PropertyRepository) Create(o *model.Property) (err error) { + err = r.DB.Create(o).Error + return +} + +func (r PropertyRepository) UpdateByName(o *model.Property, name string) error { + o.Name = name + return r.DB.Updates(o).Error +} + +func (r PropertyRepository) FindByName(name string) (o model.Property, err error) { + err = r.DB.Where("name = ?", name).First(&o).Error + return +} + +func (r PropertyRepository) FindAllMap() map[string]string { + properties := r.FindAll() + propertyMap := make(map[string]string) + for i := range properties { + propertyMap[properties[i].Name] = properties[i].Value + } + return propertyMap +} + +func (r PropertyRepository) GetDrivePath() (string, error) { + property, err := r.FindByName(guacd.DrivePath) + if err != nil { + return "", err + } + return property.Value, nil +} + +func (r PropertyRepository) GetRecordingPath() (string, error) { + property, err := r.FindByName(guacd.RecordingPath) + if err != nil { + return "", err + } + return property.Value, nil +} diff --git a/pkg/model/resource-sharer.go b/server/repository/resource_sharer.go similarity index 50% rename from pkg/model/resource-sharer.go rename to server/repository/resource_sharer.go index 1139b77..0de9628 100644 --- a/pkg/model/resource-sharer.go +++ b/server/repository/resource_sharer.go @@ -1,49 +1,47 @@ -package model +package repository import ( + "next-terminal/server/model" + "next-terminal/server/utils" + "github.com/labstack/echo/v4" + "github.com/pkg/errors" "gorm.io/gorm" - "next-terminal/pkg/global" - "next-terminal/pkg/utils" ) -type ResourceSharer struct { - ID string `gorm:"primary_key" json:"id"` - ResourceId string `gorm:"index" json:"resourceId"` - ResourceType string `gorm:"index" json:"resourceType"` - UserId string `gorm:"index" json:"userId"` - UserGroupId string `gorm:"index" json:"userGroupId"` +type ResourceSharerRepository struct { + DB *gorm.DB } -func (r *ResourceSharer) TableName() string { - return "resource_sharers" +func NewResourceSharerRepository(db *gorm.DB) *ResourceSharerRepository { + resourceSharerRepository = &ResourceSharerRepository{DB: db} + return resourceSharerRepository } -func FindUserIdsByResourceId(resourceId string) (r []string, err error) { - db := global.DB - err = db.Table("resource_sharers").Select("user_id").Where("resource_id = ?", resourceId).Find(&r).Error - if r == nil { - r = make([]string, 0) +func (r *ResourceSharerRepository) FindUserIdsByResourceId(resourceId string) (o []string, err error) { + err = r.DB.Table("resource_sharers").Select("user_id").Where("resource_id = ?", resourceId).Find(&o).Error + if o == nil { + o = make([]string, 0) } return } -func OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) { - db := global.DB.Begin() +func (r *ResourceSharerRepository) OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) { + db := r.DB.Begin() var owner string // 检查资产是否存在 switch resourceType { case "asset": - resource := Asset{} + resource := model.Asset{} err = db.Where("id = ?", resourceId).First(&resource).Error owner = resource.Owner case "command": - resource := Command{} + resource := model.Command{} err = db.Where("id = ?", resourceId).First(&resource).Error owner = resource.Owner case "credential": - resource := Credential{} + resource := model.Credential{} err = db.Where("id = ?", resourceId).First(&resource).Error owner = resource.Owner } @@ -58,7 +56,7 @@ func OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []str } } - db.Where("resource_id = ?", resourceId).Delete(&ResourceSharer{}) + db.Where("resource_id = ?", resourceId).Delete(&model.ResourceSharer{}) for i := range userIds { userId := userIds[i] @@ -66,7 +64,7 @@ func OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []str continue } id := utils.Sign([]string{resourceId, resourceType, userId}) - resource := &ResourceSharer{ + resource := &model.ResourceSharer{ ID: id, ResourceId: resourceId, ResourceType: resourceType, @@ -81,8 +79,8 @@ func OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []str return nil } -func DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceType string, resourceIds []string) error { - db := global.DB +func (r *ResourceSharerRepository) DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceType string, resourceIds []string) error { + db := r.DB if userGroupId != "" { db = db.Where("user_group_id = ?", userGroupId) } @@ -95,19 +93,19 @@ func DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceT db = db.Where("resource_type = ?", resourceType) } - if resourceIds != nil { + if len(resourceIds) > 0 { db = db.Where("resource_id in ?", resourceIds) } - return db.Delete(&ResourceSharer{}).Error + return db.Delete(&model.ResourceSharer{}).Error } -func DeleteResourceSharerByResourceId(resourceId string) error { - return global.DB.Where("resource_id = ?", resourceId).Delete(&ResourceSharer{}).Error +func (r *ResourceSharerRepository) DeleteResourceSharerByResourceId(resourceId string) error { + return r.DB.Where("resource_id = ?", resourceId).Delete(&model.ResourceSharer{}).Error } -func AddSharerResources(userGroupId, userId, resourceType string, resourceIds []string) error { - return global.DB.Transaction(func(tx *gorm.DB) (err error) { +func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, resourceType string, resourceIds []string) error { + return r.DB.Transaction(func(tx *gorm.DB) (err error) { for i := range resourceIds { resourceId := resourceIds[i] @@ -116,16 +114,23 @@ func AddSharerResources(userGroupId, userId, resourceType string, resourceIds [] // 检查资产是否存在 switch resourceType { case "asset": - resource := Asset{} - err = tx.Where("id = ?", resourceId).First(&resource).Error + resource := model.Asset{} + if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { + return errors.Wrap(err, "find asset fail") + } owner = resource.Owner case "command": - resource := Command{} - err = tx.Where("id = ?", resourceId).First(&resource).Error + resource := model.Command{} + if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { + return errors.Wrap(err, "find command fail") + } owner = resource.Owner case "credential": - resource := Credential{} - err = tx.Where("id = ?", resourceId).First(&resource).Error + resource := model.Credential{} + if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil { + return errors.Wrap(err, "find credential fail") + + } owner = resource.Owner } @@ -134,7 +139,7 @@ func AddSharerResources(userGroupId, userId, resourceType string, resourceIds [] } id := utils.Sign([]string{resourceId, resourceType, userId, userGroupId}) - resource := &ResourceSharer{ + resource := &model.ResourceSharer{ ID: id, ResourceId: resourceId, ResourceType: resourceType, @@ -150,23 +155,23 @@ func AddSharerResources(userGroupId, userId, resourceType string, resourceIds [] }) } -func FindAssetIdsByUserId(userId string) (assetIds []string, err error) { +func (r *ResourceSharerRepository) FindAssetIdsByUserId(userId string) (assetIds []string, err error) { // 查询当前用户创建的资产 var ownerAssetIds, sharerAssetIds []string - asset := Asset{} - err = global.DB.Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error + asset := model.Asset{} + err = r.DB.Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error if err != nil { return nil, err } // 查询其他用户授权给该用户的资产 - groupIds, err := FindUserGroupIdsByUserId(userId) + groupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) if err != nil { return nil, err } - db := global.DB.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId) - if groupIds != nil && len(groupIds) > 0 { + db := r.DB.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId) + if len(groupIds) > 0 { db = db.Or("user_group_id in ?", groupIds) } err = db.Find(&sharerAssetIds).Error diff --git a/server/repository/session.go b/server/repository/session.go new file mode 100644 index 0000000..66f082f --- /dev/null +++ b/server/repository/session.go @@ -0,0 +1,169 @@ +package repository + +import ( + "os" + "path" + "time" + + "next-terminal/pkg/constant" + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type SessionRepository struct { + DB *gorm.DB +} + +func NewSessionRepository(db *gorm.DB) *SessionRepository { + sessionRepository = &SessionRepository{DB: db} + return sessionRepository +} + +func (r SessionRepository) Find(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []model.SessionForPage, total int64, err error) { + + db := r.DB + var params []interface{} + + params = append(params, status) + + itemSql := "SELECT s.id,s.mode, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? " + countSql := "select count(*) from sessions as s where s.status = ? " + + if len(userId) > 0 { + itemSql += " and s.creator = ?" + countSql += " and s.creator = ?" + params = append(params, userId) + } + + if len(clientIp) > 0 { + itemSql += " and s.client_ip like ?" + countSql += " and s.client_ip like ?" + params = append(params, "%"+clientIp+"%") + } + + if len(assetId) > 0 { + itemSql += " and s.asset_id = ?" + countSql += " and s.asset_id = ?" + params = append(params, assetId) + } + + if len(protocol) > 0 { + itemSql += " and s.protocol = ?" + countSql += " and s.protocol = ?" + params = append(params, protocol) + } + + params = append(params, (pageIndex-1)*pageSize, pageSize) + itemSql += " order by s.connected_time desc LIMIT ?, ?" + + db.Raw(countSql, params...).Scan(&total) + + err = db.Raw(itemSql, params...).Scan(&results).Error + + if results == nil { + results = make([]model.SessionForPage, 0) + } + return +} + +func (r SessionRepository) FindByStatus(status string) (o []model.Session, err error) { + err = r.DB.Where("status = ?", status).Find(&o).Error + return +} + +func (r SessionRepository) FindByStatusIn(statuses []string) (o []model.Session, err error) { + err = r.DB.Where("status in ?", statuses).Find(&o).Error + return +} + +func (r SessionRepository) FindOutTimeSessions(dayLimit int) (o []model.Session, err error) { + limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour) + err = r.DB.Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error + return +} + +func (r SessionRepository) Create(o *model.Session) (err error) { + err = r.DB.Create(o).Error + return +} + +func (r SessionRepository) FindById(id string) (o model.Session, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r SessionRepository) FindByConnectionId(connectionId string) (o model.Session, err error) { + err = r.DB.Where("connection_id = ?", connectionId).First(&o).Error + return +} + +func (r SessionRepository) UpdateById(o *model.Session, id string) error { + o.ID = id + return r.DB.Updates(o).Error +} + +func (r SessionRepository) UpdateWindowSizeById(width, height int, id string) error { + session := model.Session{} + session.Width = width + session.Height = height + + return r.UpdateById(&session, id) +} + +func (r SessionRepository) DeleteById(id string) error { + return r.DB.Where("id = ?", id).Delete(&model.Session{}).Error +} + +func (r SessionRepository) DeleteByIds(sessionIds []string) error { + drivePath, err := propertyRepository.GetRecordingPath() + if err != nil { + return err + } + for i := range sessionIds { + if err := os.RemoveAll(path.Join(drivePath, sessionIds[i])); err != nil { + return err + } + if err := r.DeleteById(sessionIds[i]); err != nil { + return err + } + } + return nil +} + +func (r SessionRepository) DeleteByStatus(status string) error { + return r.DB.Where("status = ?", status).Delete(&model.Session{}).Error +} + +func (r SessionRepository) CountOnlineSession() (total int64, err error) { + err = r.DB.Where("status = ?", constant.Connected).Find(&model.Session{}).Count(&total).Error + return +} + +type D struct { + Day string `json:"day"` + Count int `json:"count"` + Protocol string `json:"protocol"` +} + +func (r SessionRepository) CountSessionByDay(day int) (results []D, err error) { + + today := time.Now().Format("20060102") + sql := "select t1.`day`, count(t2.id) as count\nfrom (\n SELECT @date := DATE_ADD(@date, INTERVAL - 1 DAY) day\n FROM (SELECT @date := DATE_ADD('" + today + "', INTERVAL + 1 DAY) FROM nums) as t0\n LIMIT ?\n )\n as t1\n left join\n (\n select DATE(s.connected_time) as day, s.id\n from sessions as s\n WHERE protocol = ? and DATE(connected_time) <= '" + today + "'\n AND DATE(connected_time) > DATE_SUB('" + today + "', INTERVAL ? DAY)\n ) as t2 on t1.day = t2.day\ngroup by t1.day" + + protocols := []string{"rdp", "ssh", "vnc", "telnet"} + + for i := range protocols { + var result []D + err = r.DB.Raw(sql, day, protocols[i], day).Scan(&result).Error + if err != nil { + return nil, err + } + for j := range result { + result[j].Protocol = protocols[i] + } + results = append(results, result...) + } + + return +} diff --git a/server/repository/user.go b/server/repository/user.go new file mode 100644 index 0000000..4ba8a33 --- /dev/null +++ b/server/repository/user.go @@ -0,0 +1,129 @@ +package repository + +import ( + "next-terminal/server/model" + + "gorm.io/gorm" +) + +type UserRepository struct { + DB *gorm.DB +} + +func NewUserRepository(db *gorm.DB) *UserRepository { + userRepository = &UserRepository{DB: db} + return userRepository +} + +func (r UserRepository) FindAll() (o []model.User) { + if r.DB.Find(&o).Error != nil { + return nil + } + return +} + +func (r UserRepository) Find(pageIndex, pageSize int, username, nickname, mail, order, field string) (o []model.UserForPage, total int64, err error) { + db := r.DB.Table("users").Select("users.id,users.username,users.nickname,users.mail,users.online,users.enabled,users.created,users.type, count(resource_sharers.user_id) as sharer_asset_count, users.totp_secret").Joins("left join resource_sharers on users.id = resource_sharers.user_id and resource_sharers.resource_type = 'asset'").Group("users.id") + dbCounter := r.DB.Table("users") + if len(username) > 0 { + db = db.Where("users.username like ?", "%"+username+"%") + dbCounter = dbCounter.Where("username like ?", "%"+username+"%") + } + + if len(nickname) > 0 { + db = db.Where("users.nickname like ?", "%"+nickname+"%") + dbCounter = dbCounter.Where("nickname like ?", "%"+nickname+"%") + } + + if len(mail) > 0 { + db = db.Where("users.mail like ?", "%"+mail+"%") + dbCounter = dbCounter.Where("mail like ?", "%"+mail+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "username" { + field = "username" + } else if field == "nickname" { + field = "nickname" + } else { + field = "created" + } + + err = db.Order("users." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.UserForPage, 0) + } + + for i := 0; i < len(o); i++ { + if o[i].TOTPSecret == "" || o[i].TOTPSecret == "-" { + o[i].TOTPSecret = "0" + } else { + o[i].TOTPSecret = "1" + } + } + return +} + +func (r UserRepository) FindById(id string) (o model.User, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r UserRepository) FindByUsername(username string) (o model.User, err error) { + err = r.DB.Where("username = ?", username).First(&o).Error + return +} + +func (r UserRepository) FindOnlineUsers() (o []model.User, err error) { + err = r.DB.Where("online = ?", true).Find(&o).Error + return +} + +func (r UserRepository) Create(o *model.User) error { + return r.DB.Create(o).Error +} + +func (r UserRepository) Update(o *model.User) error { + return r.DB.Updates(o).Error +} + +func (r UserRepository) UpdateOnline(id string, online bool) error { + sql := "update users set online = ? where id = ?" + return r.DB.Exec(sql, online, id).Error +} + +func (r UserRepository) DeleteById(id string) error { + return r.DB.Transaction(func(tx *gorm.DB) (err error) { + // 删除用户 + err = tx.Where("id = ?", id).Delete(&model.User{}).Error + if err != nil { + return err + } + // 删除用户组中的用户关系 + err = tx.Where("user_id = ?", id).Delete(&model.UserGroupMember{}).Error + if err != nil { + return err + } + // 删除用户分享到的资产 + err = tx.Where("user_id = ?", id).Delete(&model.ResourceSharer{}).Error + if err != nil { + return err + } + return nil + }) +} + +func (r UserRepository) CountOnlineUser() (total int64, err error) { + err = r.DB.Where("online = ?", true).Find(&model.User{}).Count(&total).Error + return +} diff --git a/server/repository/user_group.go b/server/repository/user_group.go new file mode 100644 index 0000000..ade6c42 --- /dev/null +++ b/server/repository/user_group.go @@ -0,0 +1,143 @@ +package repository + +import ( + "next-terminal/server/model" + "next-terminal/server/utils" + + "gorm.io/gorm" +) + +type UserGroupRepository struct { + DB *gorm.DB +} + +func NewUserGroupRepository(db *gorm.DB) *UserGroupRepository { + userGroupRepository = &UserGroupRepository{DB: db} + return userGroupRepository +} + +func (r UserGroupRepository) FindAll() (o []model.UserGroup) { + if r.DB.Find(&o).Error != nil { + return nil + } + return +} + +func (r UserGroupRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.UserGroupForPage, total int64, err error) { + db := r.DB.Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id") + dbCounter := r.DB.Table("user_groups") + if len(name) > 0 { + db = db.Where("user_groups.name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("name like ?", "%"+name+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "ascend" { + order = "asc" + } else { + order = "desc" + } + + if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order("user_groups." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.UserGroupForPage, 0) + } + return +} + +func (r UserGroupRepository) FindById(id string) (o model.UserGroup, err error) { + err = r.DB.Where("id = ?", id).First(&o).Error + return +} + +func (r UserGroupRepository) FindUserGroupIdsByUserId(userId string) (o []string, err error) { + // 先查询用户所在的用户 + err = r.DB.Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error + return +} + +func (r UserGroupRepository) FindMembersById(userGroupId string) (o []string, err error) { + err = r.DB.Table("user_group_members").Select("user_id").Where("user_group_id = ?", userGroupId).Find(&o).Error + return +} + +func (r UserGroupRepository) Create(o *model.UserGroup, members []string) (err error) { + return r.DB.Transaction(func(tx *gorm.DB) error { + err = tx.Create(o).Error + if err != nil { + return err + } + + if members != nil { + userGroupId := o.ID + err = AddUserGroupMembers(tx, members, userGroupId) + if err != nil { + return err + } + } + return err + }) +} + +func (r UserGroupRepository) Update(o *model.UserGroup, members []string, id string) error { + return r.DB.Transaction(func(tx *gorm.DB) error { + o.ID = id + err := tx.Updates(o).Error + if err != nil { + return err + } + + err = tx.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{}).Error + if err != nil { + return err + } + if members != nil { + userGroupId := o.ID + err = AddUserGroupMembers(tx, members, userGroupId) + if err != nil { + return err + } + } + return err + }) +} + +func (r UserGroupRepository) DeleteById(id string) (err error) { + err = r.DB.Where("id = ?", id).Delete(&model.UserGroup{}).Error + if err != nil { + return err + } + return r.DB.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{}).Error +} + +func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error { + userRepository := NewUserRepository(tx) + for i := range userIds { + userId := userIds[i] + _, err := userRepository.FindById(userId) + if err != nil { + return err + } + + userGroupMember := model.UserGroupMember{ + ID: utils.Sign([]string{userGroupId, userId}), + UserId: userId, + UserGroupId: userGroupId, + } + err = tx.Create(&userGroupMember).Error + if err != nil { + return err + } + } + return nil +} diff --git a/server/utils/util_test.go b/server/utils/util_test.go new file mode 100644 index 0000000..2f9d0f3 --- /dev/null +++ b/server/utils/util_test.go @@ -0,0 +1,35 @@ +package utils_test + +import ( + "net" + "testing" + + "next-terminal/server/utils" + + "github.com/stretchr/testify/assert" +) + +func TestTcping(t *testing.T) { + localhost4 := "127.0.0.1" + localhost6 := "::1" + conn, err := net.Listen("tcp", ":9999") + assert.NoError(t, err) + ip4resfalse := utils.Tcping(localhost4, 22) + assert.Equal(t, false, ip4resfalse) + + ip4res := utils.Tcping(localhost4, 9999) + assert.Equal(t, true, ip4res) + + ip6res := utils.Tcping(localhost6, 9999) + assert.Equal(t, true, ip6res) + + ip4resWithBracket := utils.Tcping("["+localhost4+"]", 9999) + assert.Equal(t, true, ip4resWithBracket) + + ip6resWithBracket := utils.Tcping("["+localhost6+"]", 9999) + assert.Equal(t, true, ip6resWithBracket) + + defer func() { + _ = conn.Close() + }() +} diff --git a/pkg/utils/utils.go b/server/utils/utils.go similarity index 86% rename from pkg/utils/utils.go rename to server/utils/utils.go index 6dad7bb..d586471 100644 --- a/pkg/utils/utils.go +++ b/server/utils/utils.go @@ -18,6 +18,7 @@ import ( "time" "github.com/gofrs/uuid" + "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" ) @@ -81,13 +82,24 @@ func UUID() string { } func Tcping(ip string, port int) bool { - var conn net.Conn - var err error - - if conn, err = net.DialTimeout("tcp", ip+":"+strconv.Itoa(port), 2*time.Second); err != nil { + var ( + conn net.Conn + err error + address string + ) + strPort := strconv.Itoa(port) + if strings.HasPrefix(ip, "[") && strings.HasSuffix(ip, "]") { + // 如果用户有填写中括号就不再拼接 + address = fmt.Sprintf("%s:%s", ip, strPort) + } else { + address = fmt.Sprintf("[%s]:%s", ip, strPort) + } + if conn, err = net.DialTimeout("tcp", address, 2*time.Second); err != nil { return false } - defer conn.Close() + defer func() { + _ = conn.Close() + }() return true } @@ -103,10 +115,7 @@ func ImageToBase64Encode(img image.Image) (string, error) { func FileExists(path string) bool { _, err := os.Stat(path) //os.Stat获取文件信息 if err != nil { - if os.IsExist(err) { - return true - } - return false + return os.IsExist(err) } return true } @@ -209,3 +218,9 @@ func StringToInt(in string) (out int) { out, _ = strconv.Atoi(in) return } + +func Check(f func() error) { + if err := f(); err != nil { + logrus.Error("Received error:", err) + } +} diff --git a/web/package.json b/web/package.json index b5046c7..ae3f55c 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "next-terminal", - "version": "0.3.3", + "version": "0.3.4", "private": true, "dependencies": { "@ant-design/icons": "^4.3.0", diff --git a/web/src/App.css b/web/src/App.css index 9443a7f..59ae085 100644 --- a/web/src/App.css +++ b/web/src/App.css @@ -33,7 +33,7 @@ } .layout-header { - height: 48px; + height: 60px; align-items: center; padding: 0 16px 0 0; background: #fff; @@ -45,23 +45,19 @@ padding: 0 12px; cursor: pointer; transition: all .3s; - line-height: 48px; - height: 48px; + line-height: 60px; + height: 60px; } .layout-header-right-item { margin: 0 6px; display: inline; - height: 48px; -} - -.layout-header-right-item:hover { - background-color: #eeeeee; + height: 60px; } .nickname { - line-height: 48px; - height: 48px; + line-height: 60px; + height: 60px; width: 125px; text-align: left; padding: 0 5px; diff --git a/web/src/App.js b/web/src/App.js index 0868860..8a7f90b 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -1,7 +1,7 @@ import React, {Component} from 'react'; import 'antd/dist/antd.css'; import './App.css'; -import {Divider, Layout, Menu} from "antd"; +import {Col, Divider, Dropdown, Layout, Menu, Popconfirm, Row, Tooltip} from "antd"; import {Link, Route, Switch} from "react-router-dom"; import Dashboard from "./components/dashboard/Dashboard"; import Asset from "./components/asset/Asset"; @@ -21,15 +21,21 @@ import { DashboardOutlined, DesktopOutlined, DisconnectOutlined, + DownOutlined, + GithubOutlined, IdcardOutlined, LinkOutlined, LoginOutlined, + LogoutOutlined, + QuestionCircleOutlined, SafetyCertificateOutlined, SettingOutlined, SolutionOutlined, TeamOutlined, UserOutlined, - UserSwitchOutlined + UserSwitchOutlined, + MenuUnfoldOutlined, + MenuFoldOutlined, } from '@ant-design/icons'; import Info from "./components/user/Info"; import request from "./common/request"; @@ -37,18 +43,18 @@ import {message} from "antd/es"; import Setting from "./components/setting/Setting"; import BatchCommand from "./components/command/BatchCommand"; import {isEmpty, NT_PACKAGE} from "./utils/utils"; -import {isAdmin} from "./service/permission"; +import {getCurrentUser, isAdmin} from "./service/permission"; import UserGroup from "./components/user/UserGroup"; import LoginLog from "./components/devops/LoginLog"; import Term from "./components/access/Term"; import Job from "./components/devops/Job"; import {Header} from "antd/es/layout/layout"; -import LayoutHeader from "./components/user/LayoutHeader"; import Security from "./components/devops/Security"; const {Footer, Sider} = Layout; const {SubMenu} = Menu; +const headerHeight = 60; class App extends Component { @@ -113,8 +119,54 @@ class App extends Component { sessionStorage.setItem('openKeys', JSON.stringify(openKeys)); } + confirm = async (e) => { + let result = await request.post('/logout'); + if (result['code'] !== 1) { + message.error(result['message']); + } else { + message.success('退出登录成功,即将跳转至登录页面。'); + window.location.reload(); + } + } + render() { + const menu = ( +
+ ); return ( @@ -126,7 +178,7 @@ class App extends Component {