diff --git a/.gitignore b/.gitignore index dfe54c0..4d40a2a 100644 --- a/.gitignore +++ b/.gitignore @@ -21,13 +21,6 @@ web/build *.swp -# playground -playground/data -playground/drive -playground/recording - -/log - - # next terminal -/recording \ No newline at end of file +/data/ +/logs/ diff --git a/Dockerfile b/Dockerfile index b58fa66..5ac1871 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,7 +21,6 @@ LABEL MAINTAINER="helloworld1024@foxmail.com" ENV TZ Asia/Shanghai ENV DB sqlite -ENV CONTAINER "true" ENV SQLITE_FILE './data/sqlite/next-terminal.db' ENV SERVER_PORT 8088 ENV SERVER_ADDR 0.0.0.0:$SERVER_PORT diff --git a/go.mod b/go.mod index 33c9970..0b42bfb 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 golang.org/x/text v0.3.6 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gorm.io/driver/mysql v1.0.3 @@ -29,12 +30,14 @@ require ( ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect github.com/fsnotify/fsnotify v1.4.7 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect @@ -56,7 +59,6 @@ require ( github.com/subosito/gotenv v1.2.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.1 // indirect - golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 // indirect golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 // indirect gopkg.in/ini.v1 v1.51.0 // indirect diff --git a/go.sum b/go.sum index c86851b..7cf2708 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqCl cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28= +github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= @@ -44,6 +46,8 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfc github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisbrodbeck/machineid v1.0.1 h1:geKr9qtkB876mXguW2X6TU4ZynleN6ezuMSRhl4D7AQ= +github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbjJCrnectwCyxcUSI= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= @@ -53,8 +57,12 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA= github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= +github.com/go-asn1-ber/asn1-ber v1.5.1 h1:pDbRAunXzIUXfx4CB2QJFv5IuPiuoW+sWvr/Us009o8= +github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-ldap/ldap/v3 v3.4.1 h1:fU/0xli6HY02ocbMuozHAYsaHLcnkLjvho2r5a34BUU= +github.com/go-ldap/ldap/v3 v3.4.1/go.mod h1:iYS1MdmrmceOJ1QOTnRXrIs7i3kloqtmGQjRvjKpyMg= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= @@ -250,6 +258,7 @@ golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= diff --git a/main.go b/main.go index 26e1e87..53a97e2 100644 --- a/main.go +++ b/main.go @@ -1,62 +1,14 @@ package main import ( - "encoding/json" - "fmt" - - "next-terminal/server/api" - "next-terminal/server/config" - "next-terminal/server/constant" - "next-terminal/server/repository" - "next-terminal/server/task" + "next-terminal/server/app" "github.com/labstack/gommon/log" ) func main() { - err := Run() + err := app.Run() if err != nil { log.Fatal(err) } } - -func Run() error { - - fmt.Printf(constant.Banner, constant.Version) - - if config.GlobalCfg.Debug { - jsonBytes, err := json.MarshalIndent(config.GlobalCfg, "", " ") - if err != nil { - return err - } - fmt.Printf("当前配置为: %v\n", string(jsonBytes)) - } - - db := api.SetupDB() - e := api.SetupRoutes(db) - - if config.GlobalCfg.ResetPassword != "" { - return api.ResetPassword(config.GlobalCfg.ResetPassword) - } - if config.GlobalCfg.ResetTotp != "" { - return api.ResetTotp(config.GlobalCfg.ResetTotp) - } - - if config.GlobalCfg.NewEncryptionKey != "" { - return api.ChangeEncryptionKey(config.GlobalCfg.EncryptionKey, config.GlobalCfg.NewEncryptionKey) - } - - sessionRepo := repository.NewSessionRepository(db) - propertyRepo := repository.NewPropertyRepository(db) - loginLogRepo := repository.NewLoginLogRepository(db) - jobLogRepo := repository.NewJobLogRepository(db) - ticker := task.NewTicker(sessionRepo, propertyRepo, loginLogRepo, jobLogRepo) - ticker.SetupTicker() - - if config.GlobalCfg.Server.Cert != "" && config.GlobalCfg.Server.Key != "" { - return e.StartTLS(config.GlobalCfg.Server.Addr, config.GlobalCfg.Server.Cert, config.GlobalCfg.Server.Key) - } else { - return e.Start(config.GlobalCfg.Server.Addr) - } - -} diff --git a/playground/docker-compose.yml b/playground/docker-compose.yml index 392f982..1a881bd 100644 --- a/playground/docker-compose.yml +++ b/playground/docker-compose.yml @@ -1,49 +1,30 @@ version: '3.3' services: + guacd: + image: dushixiang/guacd:latest + volumes: + - ../data:/usr/local/next-terminal/data + ports: + - "4822:4822" + restart: + always mysql: image: mysql:8.0 - container_name: mysql environment: MYSQL_DATABASE: next-terminal MYSQL_USER: next-terminal MYSQL_PASSWORD: next-terminal MYSQL_ROOT_PASSWORD: next-terminal volumes: - - ./data/mysql_data:/var/lib/mysql + - ../data/mysql:/var/lib/mysql ports: - "3306:3306" restart: - always - networks: - next-terminal: - ipv4_address: 172.77.77.2 - -# next-terminal: -# container_name: next-terminal -# image: "dushixiang/next-terminal:latest" -# environment: -# DB: "mysql" -# MYSQL_HOSTNAME: "mysql" -# MYSQL_PORT: 3306 -# MYSQL_USERNAME: "next-terminal" -# MYSQL_PASSWORD: "next-terminal" -# MYSQL_DATABASE: "next-terminal" -# ports: -# - "8088:8088" -# volumes: -# - ./drive:/usr/local/next-terminal/drive -# - ./recording:/usr/local/next-terminal/recording -# depends_on: -# - mysql -# networks: -# next-terminal: -# ipv4_address: 172.77.77.3 -# restart: -# always + always networks: next-terminal: ipam: driver: default config: - - subnet: "172.77.77.0/24" + - subnet: 172.77.77.0/24 diff --git a/server/api/access_gateway.go b/server/api/access_gateway.go index 81d1b81..e783f2c 100644 --- a/server/api/access_gateway.go +++ b/server/api/access_gateway.go @@ -1,16 +1,21 @@ package api import ( + "context" "strconv" "strings" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func AccessGatewayCreateEndpoint(c echo.Context) error { +type AccessGatewayApi struct{} + +func (api AccessGatewayApi) AccessGatewayCreateEndpoint(c echo.Context) error { var item model.AccessGateway if err := c.Bind(&item); err != nil { return err @@ -19,16 +24,16 @@ func AccessGatewayCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := accessGatewayRepository.Create(&item); err != nil { + if err := repository.GatewayRepository.Create(context.TODO(), &item); err != nil { return err } // 连接网关 - accessGatewayService.ReConnect(&item) + service.GatewayService.ReConnect(&item) return Success(c, "") } -func AccessGatewayAllEndpoint(c echo.Context) error { - gateways, err := accessGatewayRepository.FindAll() +func (api AccessGatewayApi) AccessGatewayAllEndpoint(c echo.Context) error { + gateways, err := repository.GatewayRepository.FindAll(context.TODO()) if err != nil { return err } @@ -39,7 +44,7 @@ func AccessGatewayAllEndpoint(c echo.Context) error { return Success(c, simpleGateways) } -func AccessGatewayPagingEndpoint(c echo.Context) error { +func (api AccessGatewayApi) AccessGatewayPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) ip := c.QueryParam("ip") @@ -48,12 +53,12 @@ func AccessGatewayPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := accessGatewayRepository.Find(pageIndex, pageSize, ip, name, order, field) + items, total, err := repository.GatewayRepository.Find(context.TODO(), pageIndex, pageSize, ip, name, order, field) if err != nil { return err } for i := 0; i < len(items); i++ { - g, err := accessGatewayService.GetGatewayById(items[i].ID) + g, err := service.GatewayService.GetGatewayById(items[i].ID) if err != nil { return err } @@ -61,13 +66,13 @@ func AccessGatewayPagingEndpoint(c echo.Context) error { items[i].Message = g.Message } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func AccessGatewayUpdateEndpoint(c echo.Context) error { +func (api AccessGatewayApi) AccessGatewayUpdateEndpoint(c echo.Context) error { id := c.Param("id") var item model.AccessGateway @@ -75,30 +80,30 @@ func AccessGatewayUpdateEndpoint(c echo.Context) error { return err } - if err := accessGatewayRepository.UpdateById(&item, id); err != nil { + if err := repository.GatewayRepository.UpdateById(context.TODO(), &item, id); err != nil { return err } - accessGatewayService.ReConnect(&item) + service.GatewayService.ReConnect(&item) return Success(c, nil) } -func AccessGatewayDeleteEndpoint(c echo.Context) error { +func (api AccessGatewayApi) AccessGatewayDeleteEndpoint(c echo.Context) error { ids := c.Param("id") split := strings.Split(ids, ",") for i := range split { id := split[i] - if err := accessGatewayRepository.DeleteById(id); err != nil { + if err := repository.GatewayRepository.DeleteById(context.TODO(), id); err != nil { return err } - accessGatewayService.DisconnectById(id) + service.GatewayService.DisconnectById(id) } return Success(c, nil) } -func AccessGatewayGetEndpoint(c echo.Context) error { +func (api AccessGatewayApi) AccessGatewayGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := accessGatewayRepository.FindById(id) + item, err := repository.GatewayRepository.FindById(context.TODO(), id) if err != nil { return err } @@ -106,13 +111,13 @@ func AccessGatewayGetEndpoint(c echo.Context) error { return Success(c, item) } -func AccessGatewayReconnectEndpoint(c echo.Context) error { +func (api AccessGatewayApi) AccessGatewayReconnectEndpoint(c echo.Context) error { id := c.Param("id") - item, err := accessGatewayRepository.FindById(id) + item, err := repository.GatewayRepository.FindById(context.TODO(), id) if err != nil { return err } - accessGatewayService.ReConnect(&item) + service.GatewayService.ReConnect(&item) return Success(c, "") } diff --git a/server/api/account.go b/server/api/account.go index b4ce7b7..b90a7b4 100644 --- a/server/api/account.go +++ b/server/api/account.go @@ -1,58 +1,34 @@ package api import ( - "next-terminal/server/constant" + "context" "path" "strconv" - "strings" - "time" "next-terminal/server/config" + "next-terminal/server/constant" + "next-terminal/server/dto" "next-terminal/server/global/cache" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/totp" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -const ( - RememberEffectiveTime = time.Hour * time.Duration(24*14) - NotRememberEffectiveTime = time.Hour * time.Duration(2) -) +type AccountApi struct{} -type LoginAccount struct { - Username string `json:"username"` - Password string `json:"password"` - Remember bool `json:"remember"` - TOTP string `json:"totp"` -} - -type ConfirmTOTP struct { - Secret string `json:"secret"` - TOTP string `json:"totp"` -} - -type ChangePassword struct { - NewPassword string `json:"newPassword"` - OldPassword string `json:"oldPassword"` -} - -type Authorization struct { - Token string - Remember bool - User model.User -} - -func LoginEndpoint(c echo.Context) error { - var loginAccount LoginAccount +func (api AccountApi) LoginEndpoint(c echo.Context) error { + var loginAccount dto.LoginAccount if err := c.Bind(&loginAccount); err != nil { return err } // 存储登录失败次数信息 loginFailCountKey := c.RealIP() + loginAccount.Username - v, ok := cache.GlobalCache.Get(loginFailCountKey) + v, ok := cache.LoginFailedKeyManager.Get(loginFailCountKey) if !ok { v = 1 } @@ -61,12 +37,12 @@ func LoginEndpoint(c echo.Context) error { return Fail(c, -1, "登录失败次数过多,请等待5分钟后再试") } - user, err := userRepository.FindByUsername(loginAccount.Username) + user, err := repository.UserRepository.FindByUsername(context.TODO(), loginAccount.Username) if err != nil { count++ - cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5)) + cache.LoginFailedKeyManager.Set(loginFailCountKey, count, cache.LoginLockExpiration) // 保存登录日志 - if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { + if err := service.UserService.SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { return err } return FailWithData(c, -1, "您输入的账号或密码不正确", count) @@ -78,9 +54,9 @@ func LoginEndpoint(c echo.Context) error { if err := utils.Encoder.Match([]byte(user.Password), []byte(loginAccount.Password)); err != nil { count++ - cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5)) + cache.LoginFailedKeyManager.Set(loginFailCountKey, count, cache.LoginLockExpiration) // 保存登录日志 - if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { + if err := service.UserService.SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { return err } return FailWithData(c, -1, "您输入的账号或密码不正确", count) @@ -90,73 +66,49 @@ func LoginEndpoint(c echo.Context) error { return Fail(c, 0, "") } - token, err := LoginSuccess(loginAccount, user) + token, err := api.LoginSuccess(loginAccount, user) if err != nil { return err } // 保存登录日志 - if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, true, loginAccount.Remember, token, ""); err != nil { + if err := service.UserService.SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, true, loginAccount.Remember, token, ""); err != nil { return err } return Success(c, token) } -func SaveLoginLog(clientIP, clientUserAgent string, username string, success, remember bool, id, reason string) error { - loginLog := model.LoginLog{ - Username: username, - ClientIP: clientIP, - ClientUserAgent: clientUserAgent, - LoginTime: utils.NowJsonTime(), - Reason: reason, - Remember: remember, - } - if success { - loginLog.State = "1" - loginLog.ID = id - } else { - loginLog.State = "0" - loginLog.ID = utils.UUID() - } +func (api AccountApi) LoginSuccess(loginAccount dto.LoginAccount, user model.User) (string, error) { + token := utils.LongUUID() - if err := loginLogRepository.Create(&loginLog); err != nil { - return err - } - return nil -} - -func LoginSuccess(loginAccount LoginAccount, user model.User) (token string, err error) { - token = strings.Join([]string{utils.UUID(), utils.UUID(), utils.UUID(), utils.UUID()}, "") - - authorization := Authorization{ + authorization := dto.Authorization{ Token: token, + Type: constant.LoginToken, Remember: loginAccount.Remember, - User: user, + User: &user, } - cacheKey := userService.BuildCacheKeyByToken(token) - if authorization.Remember { // 记住登录有效期两周 - cache.GlobalCache.Set(cacheKey, authorization, RememberEffectiveTime) + cache.TokenManager.Set(token, authorization, cache.RememberMeExpiration) } else { - cache.GlobalCache.Set(cacheKey, authorization, NotRememberEffectiveTime) + cache.TokenManager.Set(token, authorization, cache.NotRememberExpiration) } // 修改登录状态 - err = userRepository.Update(&model.User{Online: true, ID: user.ID}) + err := repository.UserRepository.Update(context.TODO(), &model.User{Online: true, ID: user.ID}) return token, err } -func loginWithTotpEndpoint(c echo.Context) error { - var loginAccount LoginAccount +func (api AccountApi) LoginWithTotpEndpoint(c echo.Context) error { + var loginAccount dto.LoginAccount if err := c.Bind(&loginAccount); err != nil { return err } // 存储登录失败次数信息 loginFailCountKey := c.RealIP() + loginAccount.Username - v, ok := cache.GlobalCache.Get(loginFailCountKey) + v, ok := cache.LoginFailedKeyManager.Get(loginFailCountKey) if !ok { v = 1 } @@ -165,12 +117,12 @@ func loginWithTotpEndpoint(c echo.Context) error { return Fail(c, -1, "登录失败次数过多,请等待5分钟后再试") } - user, err := userRepository.FindByUsername(loginAccount.Username) + user, err := repository.UserRepository.FindByUsername(context.TODO(), loginAccount.Username) if err != nil { count++ - cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5)) + cache.LoginFailedKeyManager.Set(loginFailCountKey, count, cache.LoginLockExpiration) // 保存登录日志 - if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { + if err := service.UserService.SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { return err } return FailWithData(c, -1, "您输入的账号或密码不正确", count) @@ -182,9 +134,9 @@ func loginWithTotpEndpoint(c echo.Context) error { if err := utils.Encoder.Match([]byte(user.Password), []byte(loginAccount.Password)); err != nil { count++ - cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5)) + cache.LoginFailedKeyManager.Set(loginFailCountKey, count, cache.LoginLockExpiration) // 保存登录日志 - if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { + if err := service.UserService.SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "账号或密码不正确"); err != nil { return err } return FailWithData(c, -1, "您输入的账号或密码不正确", count) @@ -192,42 +144,42 @@ func loginWithTotpEndpoint(c echo.Context) error { if !totp.Validate(loginAccount.TOTP, user.TOTPSecret) { count++ - cache.GlobalCache.Set(loginFailCountKey, count, time.Minute*time.Duration(5)) + cache.LoginFailedKeyManager.Set(loginFailCountKey, count, cache.LoginLockExpiration) // 保存登录日志 - if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "双因素认证授权码不正确"); err != nil { + if err := service.UserService.SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, false, loginAccount.Remember, "", "双因素认证授权码不正确"); err != nil { return err } return FailWithData(c, -1, "您输入双因素认证授权码不正确", count) } - token, err := LoginSuccess(loginAccount, user) + token, err := api.LoginSuccess(loginAccount, user) if err != nil { return err } // 保存登录日志 - if err := SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, true, loginAccount.Remember, token, ""); err != nil { + if err := service.UserService.SaveLoginLog(c.RealIP(), c.Request().UserAgent(), loginAccount.Username, true, loginAccount.Remember, token, ""); err != nil { return err } return Success(c, token) } -func LogoutEndpoint(c echo.Context) error { +func (api AccountApi) LogoutEndpoint(c echo.Context) error { token := GetToken(c) - err := userService.LogoutByToken(token) + err := service.UserService.LogoutByToken(token) if err != nil { return err } return Success(c, nil) } -func ConfirmTOTPEndpoint(c echo.Context) error { +func (api AccountApi) ConfirmTOTPEndpoint(c echo.Context) error { if config.GlobalCfg.Demo { return Fail(c, 0, "演示模式禁止开启两步验证") } account, _ := GetCurrentAccount(c) - var confirmTOTP ConfirmTOTP + var confirmTOTP dto.ConfirmTOTP if err := c.Bind(&confirmTOTP); err != nil { return err } @@ -241,14 +193,14 @@ func ConfirmTOTPEndpoint(c echo.Context) error { ID: account.ID, } - if err := userRepository.Update(u); err != nil { + if err := repository.UserRepository.Update(context.TODO(), u); err != nil { return err } return Success(c, nil) } -func ReloadTOTPEndpoint(c echo.Context) error { +func (api AccountApi) ReloadTOTPEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) key, err := totp.NewTOTP(totp.GenerateOpts{ @@ -275,25 +227,25 @@ func ReloadTOTPEndpoint(c echo.Context) error { }) } -func ResetTOTPEndpoint(c echo.Context) error { +func (api AccountApi) ResetTOTPEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) u := &model.User{ TOTPSecret: "-", ID: account.ID, } - if err := userRepository.Update(u); err != nil { + if err := repository.UserRepository.Update(context.TODO(), u); err != nil { return err } return Success(c, "") } -func ChangePasswordEndpoint(c echo.Context) error { +func (api AccountApi) ChangePasswordEndpoint(c echo.Context) error { if config.GlobalCfg.Demo { return Fail(c, 0, "演示模式禁止修改密码") } account, _ := GetCurrentAccount(c) - var changePassword ChangePassword + var changePassword dto.ChangePassword if err := c.Bind(&changePassword); err != nil { return err } @@ -311,11 +263,11 @@ func ChangePasswordEndpoint(c echo.Context) error { ID: account.ID, } - if err := userRepository.Update(u); err != nil { + if err := repository.UserRepository.Update(context.TODO(), u); err != nil { return err } - return LogoutEndpoint(c) + return api.LogoutEndpoint(c) } type AccountInfo struct { @@ -326,10 +278,10 @@ type AccountInfo struct { EnableTotp bool `json:"enableTotp"` } -func InfoEndpoint(c echo.Context) error { +func (api AccountApi) InfoEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) - user, err := userRepository.FindById(account.ID) + user, err := repository.UserRepository.FindById(context.TODO(), account.ID) if err != nil { return err } @@ -344,7 +296,7 @@ func InfoEndpoint(c echo.Context) error { return Success(c, info) } -func AccountAssetEndpoint(c echo.Context) error { +func (api AccountApi) AccountAssetEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -359,26 +311,26 @@ func AccountAssetEndpoint(c echo.Context) error { field := c.QueryParam("field") account, _ := GetCurrentAccount(c) - items, total, err := assetRepository.Find(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) + items, total, err := repository.AssetRepository.Find(context.TODO(), pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func AccountStorageEndpoint(c echo.Context) error { +func (api AccountApi) AccountStorageEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) storageId := account.ID - storage, err := storageRepository.FindById(storageId) + storage, err := repository.StorageRepository.FindById(context.TODO(), storageId) if err != nil { return err } structMap := utils.StructToMap(storage) - drivePath := storageService.GetBaseDrivePath() + drivePath := service.StorageService.GetBaseDrivePath() dirSize, err := utils.DirSize(path.Join(drivePath, storageId)) if err != nil { structMap["usedSize"] = -1 @@ -388,3 +340,20 @@ func AccountStorageEndpoint(c echo.Context) error { return Success(c, structMap) } + +func (api AccountApi) AccessTokenGetEndpoint(c echo.Context) error { + account, _ := GetCurrentAccount(c) + accessToken, err := repository.AccessTokenRepository.FindByUserId(context.TODO(), account.ID) + if err != nil { + return err + } + return Success(c, accessToken) +} + +func (api AccountApi) AccessTokenGenEndpoint(c echo.Context) error { + account, _ := GetCurrentAccount(c) + if err := service.AccessTokenService.GenAccessToken(account.ID); err != nil { + return err + } + return Success(c, nil) +} diff --git a/server/api/api.go b/server/api/api.go index 9668334..5d80160 100644 --- a/server/api/api.go +++ b/server/api/api.go @@ -2,23 +2,24 @@ package api import ( "next-terminal/server/constant" + "next-terminal/server/dto" "next-terminal/server/global/cache" "next-terminal/server/model" "github.com/labstack/echo/v4" ) -type H map[string]interface{} +type Map map[string]interface{} func Fail(c echo.Context, code int, message string) error { - return c.JSON(200, H{ + return c.JSON(200, Map{ "code": code, "message": message, }) } func FailWithData(c echo.Context, code int, message string, data interface{}) error { - return c.JSON(200, H{ + return c.JSON(200, Map{ "code": code, "message": message, "data": data, @@ -26,20 +27,13 @@ func FailWithData(c echo.Context, code int, message string, data interface{}) er } func Success(c echo.Context, data interface{}) error { - return c.JSON(200, H{ + return c.JSON(200, Map{ "code": 1, "message": "success", "data": data, }) } -func NotFound(c echo.Context, message string) error { - return c.JSON(200, H{ - "code": -1, - "message": message, - }) -} - func GetToken(c echo.Context) string { token := c.Request().Header.Get(constant.Token) if len(token) > 0 { @@ -48,14 +42,13 @@ func GetToken(c echo.Context) string { return c.QueryParam(constant.Token) } -func GetCurrentAccount(c echo.Context) (model.User, bool) { +func GetCurrentAccount(c echo.Context) (*model.User, bool) { token := GetToken(c) - cacheKey := userService.BuildCacheKeyByToken(token) - get, b := cache.GlobalCache.Get(cacheKey) + get, b := cache.TokenManager.Get(token) if b { - return get.(Authorization).User, true + return get.(dto.Authorization).User, true } - return model.User{}, false + return nil, false } func HasPermission(c echo.Context, owner string) bool { diff --git a/server/api/asset.go b/server/api/asset.go index 4542c75..726bf23 100644 --- a/server/api/asset.go +++ b/server/api/asset.go @@ -2,58 +2,40 @@ package api import ( "bufio" + "context" "encoding/csv" - "encoding/json" "errors" "strconv" "strings" - "next-terminal/server/config" "next-terminal/server/constant" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func AssetCreateEndpoint(c echo.Context) error { +type AssetApi struct{} + +func (assetApi AssetApi) AssetCreateEndpoint(c echo.Context) error { m := echo.Map{} if err := c.Bind(&m); err != nil { return err } - data, _ := json.Marshal(m) - var item model.Asset - if err := json.Unmarshal(data, &item); err != nil { - return err - } - account, _ := GetCurrentAccount(c) - item.Owner = account.ID - item.ID = utils.UUID() - item.Created = utils.NowJsonTime() - item.Active = true + m["owner"] = account.ID - if err := assetRepository.Create(&item); err != nil { + if _, err := service.AssetService.Create(m); err != nil { return err } - if err := assetRepository.UpdateAttributes(item.ID, item.Protocol, m); err != nil { - return err - } - - go func() { - active, _ := assetService.CheckStatus(item.AccessGatewayId, item.IP, item.Port) - - if item.Active != active { - _ = assetRepository.UpdateActiveById(active, item.ID) - } - }() - - return Success(c, item) + return Success(c, nil) } -func AssetImportEndpoint(c echo.Context) error { +func (assetApi AssetApi) AssetImportEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) file, err := c.FormFile("file") @@ -66,7 +48,9 @@ func AssetImportEndpoint(c echo.Context) error { return err } - defer src.Close() + defer func() { + _ = src.Close() + }() reader := csv.NewReader(bufio.NewReader(src)) records, err := reader.ReadAll() if err != nil { @@ -107,7 +91,7 @@ func AssetImportEndpoint(c echo.Context) error { asset.Tags = tags } - err := assetRepository.Create(&asset) + err := repository.AssetRepository.Create(context.TODO(), &asset) if err != nil { errorCount++ m[strconv.Itoa(i)] = err.Error() @@ -124,7 +108,7 @@ func AssetImportEndpoint(c echo.Context) error { }) } -func AssetPagingEndpoint(c echo.Context) error { +func (assetApi AssetApi) AssetPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -140,26 +124,26 @@ func AssetPagingEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) - items, total, err := assetRepository.Find(pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) + items, total, err := repository.AssetRepository.Find(context.TODO(), pageIndex, pageSize, name, protocol, tags, account, owner, sharer, userGroupId, ip, order, field) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func AssetAllEndpoint(c echo.Context) error { +func (assetApi AssetApi) AssetAllEndpoint(c echo.Context) error { protocol := c.QueryParam("protocol") - items, _ := assetRepository.FindByProtocol(protocol) + items, _ := repository.AssetRepository.FindByProtocol(context.TODO(), protocol) return Success(c, items) } -func AssetUpdateEndpoint(c echo.Context) error { +func (assetApi AssetApi) AssetUpdateEndpoint(c echo.Context) error { id := c.Param("id") - if err := PreCheckAssetPermission(c, id); err != nil { + if err := assetApi.PreCheckAssetPermission(c, id); err != nil { return err } @@ -167,67 +151,20 @@ func AssetUpdateEndpoint(c echo.Context) error { if err := c.Bind(&m); err != nil { return err } - - data, _ := json.Marshal(m) - var item model.Asset - if err := json.Unmarshal(data, &item); err != nil { + if err := service.AssetService.UpdateById(id, m); err != nil { return err } - - switch item.AccountType { - case "credential": - item.Username = "-" - item.Password = "-" - item.PrivateKey = "-" - item.Passphrase = "-" - case "private-key": - item.Password = "-" - item.CredentialId = "-" - if len(item.Username) == 0 { - item.Username = "-" - } - if len(item.Passphrase) == 0 { - item.Passphrase = "-" - } - case "custom": - item.PrivateKey = "-" - item.Passphrase = "-" - item.CredentialId = "-" - } - - if len(item.Tags) == 0 { - item.Tags = "-" - } - - if item.Description == "" { - item.Description = "-" - } - - if err := assetRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil { - return err - } - if err := assetRepository.UpdateById(&item, id); err != nil { - return err - } - if err := assetRepository.UpdateAttributes(id, item.Protocol, m); err != nil { - return err - } - return Success(c, nil) } -func AssetDeleteEndpoint(c echo.Context) error { +func (assetApi AssetApi) AssetDeleteEndpoint(c echo.Context) error { id := c.Param("id") split := strings.Split(id, ",") for i := range split { - if err := PreCheckAssetPermission(c, split[i]); err != nil { + if err := assetApi.PreCheckAssetPermission(c, split[i]); err != nil { return err } - if err := assetRepository.DeleteById(split[i]); err != nil { - return err - } - // 删除资产与用户的关系 - if err := resourceSharerRepository.DeleteResourceSharerByResourceId(split[i]); err != nil { + if err := service.AssetService.DeleteById(split[i]); err != nil { return err } } @@ -235,17 +172,17 @@ func AssetDeleteEndpoint(c echo.Context) error { return Success(c, nil) } -func AssetGetEndpoint(c echo.Context) (err error) { +func (assetApi AssetApi) AssetGetEndpoint(c echo.Context) (err error) { id := c.Param("id") - if err := PreCheckAssetPermission(c, id); err != nil { + if err := assetApi.PreCheckAssetPermission(c, id); err != nil { return err } var item model.Asset - if item, err = assetRepository.FindByIdAndDecrypt(id); err != nil { + if item, err = service.AssetService.FindByIdAndDecrypt(context.TODO(), id); err != nil { return err } - attributeMap, err := assetRepository.FindAssetAttrMapByAssetId(id) + attributeMap, err := repository.AssetRepository.FindAssetAttrMapByAssetId(context.TODO(), id) if err != nil { return err } @@ -257,18 +194,18 @@ func AssetGetEndpoint(c echo.Context) (err error) { return Success(c, itemMap) } -func AssetTcpingEndpoint(c echo.Context) (err error) { +func (assetApi AssetApi) AssetTcpingEndpoint(c echo.Context) (err error) { id := c.Param("id") var item model.Asset - if item, err = assetRepository.FindById(id); err != nil { + if item, err = repository.AssetRepository.FindById(context.TODO(), id); err != nil { return err } - active, err := assetService.CheckStatus(item.AccessGatewayId, item.IP, item.Port) + active, err := service.AssetService.CheckStatus(item.AccessGatewayId, item.IP, item.Port) if item.Active != active { - if err := assetRepository.UpdateActiveById(active, item.ID); err != nil { + if err := repository.AssetRepository.UpdateActiveById(context.TODO(), active, item.ID); err != nil { return err } } @@ -278,36 +215,36 @@ func AssetTcpingEndpoint(c echo.Context) (err error) { message = err.Error() } - return Success(c, H{ + return Success(c, Map{ "active": active, "message": message, }) } -func AssetTagsEndpoint(c echo.Context) (err error) { +func (assetApi AssetApi) AssetTagsEndpoint(c echo.Context) (err error) { var items []string - if items, err = assetRepository.FindTags(); err != nil { + if items, err = repository.AssetRepository.FindTags(context.TODO()); err != nil { return err } return Success(c, items) } -func AssetChangeOwnerEndpoint(c echo.Context) (err error) { +func (assetApi AssetApi) AssetChangeOwnerEndpoint(c echo.Context) (err error) { id := c.Param("id") - if err := PreCheckAssetPermission(c, id); err != nil { + if err := assetApi.PreCheckAssetPermission(c, id); err != nil { return err } owner := c.QueryParam("owner") - if err := assetRepository.UpdateById(&model.Asset{Owner: owner}, id); err != nil { + if err := repository.AssetRepository.UpdateById(context.TODO(), &model.Asset{Owner: owner}, id); err != nil { return err } return Success(c, "") } -func PreCheckAssetPermission(c echo.Context, id string) error { - item, err := assetRepository.FindById(id) +func (assetApi AssetApi) PreCheckAssetPermission(c echo.Context, id string) error { + item, err := repository.AssetRepository.FindById(context.TODO(), id) if err != nil { return err } diff --git a/server/api/backup.go b/server/api/backup.go index e2c85d6..429f143 100644 --- a/server/api/backup.go +++ b/server/api/backup.go @@ -5,133 +5,21 @@ import ( "encoding/json" "fmt" "net/http" - "strings" "time" - "next-terminal/server/config" - "next-terminal/server/constant" - "next-terminal/server/global/security" - "next-terminal/server/model" - "next-terminal/server/utils" + "next-terminal/server/dto" + "next-terminal/server/service" "github.com/labstack/echo/v4" ) -type Backup struct { - Users []model.User `json:"users"` - UserGroups []model.UserGroup `json:"user_groups"` +type BackupApi struct{} - Storages []model.Storage `json:"storages"` - Strategies []model.Strategy `json:"strategies"` - AccessSecurities []model.AccessSecurity `json:"access_securities"` - AccessGateways []model.AccessGateway `json:"access_gateways"` - Commands []model.Command `json:"commands"` - Credentials []model.Credential `json:"credentials"` - Assets []map[string]interface{} `json:"assets"` - ResourceSharers []model.ResourceSharer `json:"resource_sharers"` - Jobs []model.Job `json:"jobs"` -} - -func BackupExportEndpoint(c echo.Context) error { - users, err := userRepository.FindAll() +func (api BackupApi) BackupExportEndpoint(c echo.Context) error { + err, backup := service.BackupService.Export() if err != nil { return err } - for i := range users { - users[i].Password = "" - } - userGroups, err := userGroupRepository.FindAll() - if err != nil { - return err - } - if len(userGroups) > 0 { - for i := range userGroups { - members, err := userGroupRepository.FindMembersById(userGroups[i].ID) - if err != nil { - return err - } - userGroups[i].Members = members - } - } - - storages, err := storageRepository.FindAll() - if err != nil { - return err - } - - strategies, err := strategyRepository.FindAll() - if err != nil { - return err - } - jobs, err := jobRepository.FindAll() - if err != nil { - return err - } - accessSecurities, err := accessSecurityRepository.FindAll() - if err != nil { - return err - } - accessGateways, err := accessGatewayRepository.FindAll() - if err != nil { - return err - } - commands, err := commandRepository.FindAll() - if err != nil { - return err - } - credentials, err := credentialRepository.FindAll() - if err != nil { - return err - } - if len(credentials) > 0 { - for i := range credentials { - if err := credentialRepository.Decrypt(&credentials[i], config.GlobalCfg.EncryptionPassword); err != nil { - return err - } - } - } - assets, err := assetRepository.FindAll() - if err != nil { - return err - } - var assetMaps = make([]map[string]interface{}, 0) - if len(assets) > 0 { - for i := range assets { - asset := assets[i] - if err := assetRepository.Decrypt(&asset, config.GlobalCfg.EncryptionPassword); err != nil { - return err - } - attributeMap, err := assetRepository.FindAssetAttrMapByAssetId(asset.ID) - if err != nil { - return err - } - itemMap := utils.StructToMap(asset) - for key := range attributeMap { - itemMap[key] = attributeMap[key] - } - itemMap["created"] = asset.Created.Format("2006-01-02 15:04:05") - assetMaps = append(assetMaps, itemMap) - } - } - - resourceSharers, err := resourceSharerRepository.FindAll() - if err != nil { - return err - } - - backup := Backup{ - Users: users, - UserGroups: userGroups, - Storages: storages, - Strategies: strategies, - Jobs: jobs, - AccessSecurities: accessSecurities, - AccessGateways: accessGateways, - Commands: commands, - Credentials: credentials, - Assets: assetMaps, - ResourceSharers: resourceSharers, - } jsonBytes, err := json.Marshal(backup) if err != nil { @@ -141,200 +29,13 @@ func BackupExportEndpoint(c echo.Context) error { return c.Stream(http.StatusOK, echo.MIMEOctetStream, bytes.NewReader(jsonBytes)) } -func BackupImportEndpoint(c echo.Context) error { - var backup Backup +func (api BackupApi) BackupImportEndpoint(c echo.Context) error { + var backup dto.Backup if err := c.Bind(&backup); err != nil { return err } - - var userIdMapping = make(map[string]string, 0) - if len(backup.Users) > 0 { - for _, item := range backup.Users { - if userRepository.ExistByUsername(item.Username) { - continue - } - oldId := item.ID - newId := utils.UUID() - item.ID = newId - item.Password = utils.GenPassword() - if err := userRepository.Create(&item); err != nil { - return err - } - userIdMapping[oldId] = newId - } + if err := service.BackupService.Import(&backup); err != nil { + return err } - - var userGroupIdMapping = make(map[string]string, 0) - if len(backup.UserGroups) > 0 { - for _, item := range backup.UserGroups { - oldId := item.ID - newId := utils.UUID() - item.ID = newId - - var members = make([]string, 0) - if len(item.Members) > 0 { - for _, member := range item.Members { - members = append(members, userIdMapping[member]) - } - } - - if err := userGroupRepository.Create(&item, members); err != nil { - return err - } - userGroupIdMapping[oldId] = newId - } - } - - if len(backup.Storages) > 0 { - for _, item := range backup.Storages { - item.ID = utils.UUID() - item.Owner = userIdMapping[item.Owner] - if err := storageRepository.Create(&item); err != nil { - return err - } - } - } - - var strategyIdMapping = make(map[string]string, 0) - if len(backup.Strategies) > 0 { - for _, item := range backup.Strategies { - oldId := item.ID - newId := utils.UUID() - item.ID = newId - if err := strategyRepository.Create(&item); err != nil { - return err - } - strategyIdMapping[oldId] = newId - } - } - - if len(backup.AccessSecurities) > 0 { - for _, item := range backup.AccessSecurities { - item.ID = utils.UUID() - if err := accessSecurityRepository.Create(&item); err != nil { - return err - } - // 更新内存中的安全规则 - rule := &security.Security{ - ID: item.ID, - IP: item.IP, - Rule: item.Rule, - Priority: item.Priority, - } - security.GlobalSecurityManager.Add <- rule - } - } - - var accessGatewayIdMapping = make(map[string]string, 0) - if len(backup.AccessGateways) > 0 { - for _, item := range backup.AccessGateways { - oldId := item.ID - newId := utils.UUID() - item.ID = newId - if err := accessGatewayRepository.Create(&item); err != nil { - return err - } - accessGatewayIdMapping[oldId] = newId - } - } - - if len(backup.Commands) > 0 { - for _, item := range backup.Commands { - item.ID = utils.UUID() - if err := commandRepository.Create(&item); err != nil { - return err - } - } - } - - var credentialIdMapping = make(map[string]string, 0) - if len(backup.Credentials) > 0 { - for _, item := range backup.Credentials { - oldId := item.ID - newId := utils.UUID() - item.ID = newId - if err := credentialRepository.Create(&item); err != nil { - return err - } - credentialIdMapping[oldId] = newId - } - } - - var assetIdMapping = make(map[string]string, 0) - if len(backup.Assets) > 0 { - for _, m := range backup.Assets { - data, err := json.Marshal(m) - if err != nil { - return err - } - var item model.Asset - if err := json.Unmarshal(data, &item); err != nil { - return err - } - - if item.CredentialId != "" && item.CredentialId != "-" { - item.CredentialId = credentialIdMapping[item.CredentialId] - } - if item.AccessGatewayId != "" && item.AccessGatewayId != "-" { - item.AccessGatewayId = accessGatewayIdMapping[item.AccessGatewayId] - } - - oldId := item.ID - newId := utils.UUID() - item.ID = newId - if err := assetRepository.Create(&item); err != nil { - return err - } - - if err := assetRepository.UpdateAttributes(item.ID, item.Protocol, m); err != nil { - return err - } - - go func() { - active, _ := assetService.CheckStatus(item.AccessGatewayId, item.IP, item.Port) - - if item.Active != active { - _ = assetRepository.UpdateActiveById(active, item.ID) - } - }() - - assetIdMapping[oldId] = newId - } - } - - if len(backup.ResourceSharers) > 0 { - for _, item := range backup.ResourceSharers { - - userGroupId := userGroupIdMapping[item.UserGroupId] - userId := userIdMapping[item.UserId] - strategyId := strategyIdMapping[item.StrategyId] - resourceId := assetIdMapping[item.ResourceId] - - if err := resourceSharerRepository.AddSharerResources(userGroupId, userId, strategyId, item.ResourceType, []string{resourceId}); err != nil { - return err - } - } - } - - if len(backup.Jobs) > 0 { - for _, item := range backup.Jobs { - if item.Func == constant.FuncCheckAssetStatusJob { - continue - } - - resourceIds := strings.Split(item.ResourceIds, ",") - if len(resourceIds) > 0 { - var newResourceIds = make([]string, 0) - for _, resourceId := range resourceIds { - newResourceIds = append(newResourceIds, assetIdMapping[resourceId]) - } - item.ResourceIds = strings.Join(newResourceIds, ",") - } - if err := jobService.Create(&item); err != nil { - return err - } - } - } - return Success(c, "") } diff --git a/server/api/command.go b/server/api/command.go index 6450f28..1576feb 100644 --- a/server/api/command.go +++ b/server/api/command.go @@ -1,17 +1,21 @@ package api import ( + "context" "errors" "strconv" "strings" "next-terminal/server/model" + "next-terminal/server/repository" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func CommandCreateEndpoint(c echo.Context) error { +type CommandApi struct{} + +func (api CommandApi) CommandCreateEndpoint(c echo.Context) error { var item model.Command if err := c.Bind(&item); err != nil { return err @@ -22,20 +26,23 @@ func CommandCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := commandRepository.Create(&item); err != nil { + if err := repository.CommandRepository.Create(context.TODO(), &item); err != nil { return err } return Success(c, item) } -func CommandAllEndpoint(c echo.Context) error { +func (api CommandApi) CommandAllEndpoint(c echo.Context) error { account, _ := GetCurrentAccount(c) - items, _ := commandRepository.FindByUser(account) + items, err := repository.CommandRepository.FindByUser(context.TODO(), account) + if err != nil { + return err + } return Success(c, items) } -func CommandPagingEndpoint(c echo.Context) error { +func (api CommandApi) CommandPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -45,20 +52,20 @@ func CommandPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := commandRepository.Find(pageIndex, pageSize, name, content, order, field, account) + items, total, err := repository.CommandRepository.Find(context.TODO(), pageIndex, pageSize, name, content, order, field, account) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func CommandUpdateEndpoint(c echo.Context) error { +func (api CommandApi) CommandUpdateEndpoint(c echo.Context) error { id := c.Param("id") - if err := PreCheckCommandPermission(c, id); err != nil { + if err := api.PreCheckCommandPermission(c, id); err != nil { return err } @@ -67,61 +74,57 @@ func CommandUpdateEndpoint(c echo.Context) error { return err } - if err := commandRepository.UpdateById(&item, id); err != nil { + if err := repository.CommandRepository.UpdateById(context.TODO(), &item, id); err != nil { return err } return Success(c, nil) } -func CommandDeleteEndpoint(c echo.Context) error { +func (api CommandApi) CommandDeleteEndpoint(c echo.Context) error { id := c.Param("id") split := strings.Split(id, ",") for i := range split { - if err := PreCheckCommandPermission(c, split[i]); err != nil { + if err := api.PreCheckCommandPermission(c, split[i]); err != nil { return err } - if err := commandRepository.DeleteById(split[i]); err != nil { - return err - } - // 删除资产与用户的关系 - if err := resourceSharerRepository.DeleteResourceSharerByResourceId(split[i]); err != nil { + if err := repository.CommandRepository.DeleteById(context.TODO(), split[i]); err != nil { return err } } return Success(c, nil) } -func CommandGetEndpoint(c echo.Context) (err error) { +func (api CommandApi) CommandGetEndpoint(c echo.Context) (err error) { id := c.Param("id") - if err := PreCheckCommandPermission(c, id); err != nil { + if err := api.PreCheckCommandPermission(c, id); err != nil { return err } var item model.Command - if item, err = commandRepository.FindById(id); err != nil { + if item, err = repository.CommandRepository.FindById(context.TODO(), id); err != nil { return err } return Success(c, item) } -func CommandChangeOwnerEndpoint(c echo.Context) (err error) { +func (api CommandApi) CommandChangeOwnerEndpoint(c echo.Context) (err error) { id := c.Param("id") - if err := PreCheckCommandPermission(c, id); err != nil { + if err := api.PreCheckCommandPermission(c, id); err != nil { return err } owner := c.QueryParam("owner") - if err := commandRepository.UpdateById(&model.Command{Owner: owner}, id); err != nil { + if err := repository.CommandRepository.UpdateById(context.TODO(), &model.Command{Owner: owner}, id); err != nil { return err } return Success(c, "") } -func PreCheckCommandPermission(c echo.Context, id string) error { - item, err := commandRepository.FindById(id) +func (api CommandApi) PreCheckCommandPermission(c echo.Context, id string) error { + item, err := repository.CommandRepository.FindById(context.TODO(), id) if err != nil { return err } diff --git a/server/api/credential.go b/server/api/credential.go index bc5f816..b3709c2 100644 --- a/server/api/credential.go +++ b/server/api/credential.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/base64" "errors" "strconv" @@ -9,17 +10,23 @@ import ( "next-terminal/server/config" "next-terminal/server/constant" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func CredentialAllEndpoint(c echo.Context) error { - account, _ := GetCurrentAccount(c) - items, _ := credentialRepository.FindByUser(account) +type CredentialApi struct{} + +func (api CredentialApi) CredentialAllEndpoint(c echo.Context) error { + items, err := repository.CredentialRepository.FindByUser(context.TODO()) + if err != nil { + return err + } return Success(c, items) } -func CredentialCreateEndpoint(c echo.Context) error { +func (api CredentialApi) CredentialCreateEndpoint(c echo.Context) error { var item model.Credential if err := c.Bind(&item); err != nil { return err @@ -56,14 +63,15 @@ func CredentialCreateEndpoint(c echo.Context) error { } item.Encrypted = true - if err := credentialRepository.Create(&item); err != nil { + + if err := service.CredentialService.Create(&item); err != nil { return err } return Success(c, item) } -func CredentialPagingEndpoint(c echo.Context) error { +func (api CredentialApi) CredentialPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -72,21 +80,21 @@ func CredentialPagingEndpoint(c echo.Context) error { field := c.QueryParam("field") account, _ := GetCurrentAccount(c) - items, total, err := credentialRepository.Find(pageIndex, pageSize, name, order, field, account) + items, total, err := repository.CredentialRepository.Find(context.TODO(), pageIndex, pageSize, name, order, field, account) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func CredentialUpdateEndpoint(c echo.Context) error { +func (api CredentialApi) CredentialUpdateEndpoint(c echo.Context) error { id := c.Param("id") - if err := PreCheckCredentialPermission(c, id); err != nil { + if err := api.PreCheckCredentialPermission(c, id); err != nil { return err } @@ -142,25 +150,21 @@ func CredentialUpdateEndpoint(c echo.Context) error { } item.Encrypted = true - if err := credentialRepository.UpdateById(&item, id); err != nil { + if err := repository.CredentialRepository.UpdateById(context.TODO(), &item, id); err != nil { return err } return Success(c, nil) } -func CredentialDeleteEndpoint(c echo.Context) error { +func (api CredentialApi) CredentialDeleteEndpoint(c echo.Context) error { id := c.Param("id") split := strings.Split(id, ",") for i := range split { - if err := PreCheckCredentialPermission(c, split[i]); err != nil { + if err := api.PreCheckCredentialPermission(c, split[i]); err != nil { return err } - if err := credentialRepository.DeleteById(split[i]); err != nil { - return err - } - // 删除资产与用户的关系 - if err := resourceSharerRepository.DeleteResourceSharerByResourceId(split[i]); err != nil { + if err := repository.CredentialRepository.DeleteById(context.TODO(), split[i]); err != nil { return err } } @@ -168,13 +172,13 @@ func CredentialDeleteEndpoint(c echo.Context) error { return Success(c, nil) } -func CredentialGetEndpoint(c echo.Context) error { +func (api CredentialApi) CredentialGetEndpoint(c echo.Context) error { id := c.Param("id") - if err := PreCheckCredentialPermission(c, id); err != nil { + if err := api.PreCheckCredentialPermission(c, id); err != nil { return err } - item, err := credentialRepository.FindByIdAndDecrypt(id) + item, err := service.CredentialService.FindByIdAndDecrypt(context.TODO(), id) if err != nil { return err } @@ -186,22 +190,22 @@ func CredentialGetEndpoint(c echo.Context) error { return Success(c, item) } -func CredentialChangeOwnerEndpoint(c echo.Context) error { +func (api CredentialApi) CredentialChangeOwnerEndpoint(c echo.Context) error { id := c.Param("id") - if err := PreCheckCredentialPermission(c, id); err != nil { + if err := api.PreCheckCredentialPermission(c, id); err != nil { return err } owner := c.QueryParam("owner") - if err := credentialRepository.UpdateById(&model.Credential{Owner: owner}, id); err != nil { + if err := repository.CredentialRepository.UpdateById(context.TODO(), &model.Credential{Owner: owner}, id); err != nil { return err } return Success(c, "") } -func PreCheckCredentialPermission(c echo.Context, id string) error { - item, err := credentialRepository.FindById(id) +func (api CredentialApi) PreCheckCredentialPermission(c echo.Context, id string) error { + item, err := repository.CredentialRepository.FindById(context.TODO(), id) if err != nil { return err } diff --git a/server/api/tunnel.go b/server/api/guacamole.go similarity index 72% rename from server/api/tunnel.go rename to server/api/guacamole.go index 69fbd59..042f3fb 100644 --- a/server/api/tunnel.go +++ b/server/api/guacamole.go @@ -2,11 +2,10 @@ package api import ( "context" - "encoding/base64" "errors" + "net/http" "path" "strconv" - "time" "next-terminal/server/config" "next-terminal/server/constant" @@ -14,6 +13,8 @@ import ( "next-terminal/server/guacd" "next-terminal/server/log" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/utils" "github.com/gorilla/websocket" @@ -31,18 +32,27 @@ const ( AssetNotActive int = 805 ) -func TunEndpoint(c echo.Context) error { +var UpGrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + Subprotocols: []string{"guacamole"}, +} +type GuacamoleApi struct { +} + +func (api GuacamoleApi) Guacamole(c echo.Context) error { ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) if err != nil { log.Errorf("升级为WebSocket协议失败:%v", err.Error()) return err } - + ctx := context.TODO() width := c.QueryParam("width") height := c.QueryParam("height") dpi := c.QueryParam("dpi") - sessionId := c.QueryParam("sessionId") + sessionId := c.Param("id") connectionId := c.QueryParam("connectionId") intWidth, _ := strconv.Atoi(width) @@ -50,12 +60,12 @@ func TunEndpoint(c echo.Context) error { configuration := guacd.NewConfiguration() - propertyMap := propertyRepository.FindAllMap() + propertyMap := repository.PropertyRepository.FindAllMap(ctx) var s model.Session if len(connectionId) > 0 { - s, err = sessionRepository.FindByConnectionId(connectionId) + s, err = repository.SessionRepository.FindByConnectionId(ctx, connectionId) if err != nil { return err } @@ -71,28 +81,28 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter("width", width) configuration.SetParameter("height", height) configuration.SetParameter("dpi", dpi) - s, err = sessionRepository.FindByIdAndDecrypt(sessionId) + s, err = service.SessionService.FindByIdAndDecrypt(ctx, sessionId) if err != nil { return err } - setConfig(propertyMap, s, configuration) + api.setConfig(propertyMap, s, configuration) var ( ip = s.IP port = s.Port ) if s.AccessGatewayId != "" && s.AccessGatewayId != "-" { - g, err := accessGatewayService.GetGatewayAndReconnectById(s.AccessGatewayId) + g, err := service.GatewayService.GetGatewayAndReconnectById(s.AccessGatewayId) if err != nil { - disconnect(ws, AccessGatewayUnAvailable, "获取接入网关失败:"+err.Error()) + utils.Disconnect(ws, AccessGatewayUnAvailable, "获取接入网关失败:"+err.Error()) return nil } if !g.Connected { - disconnect(ws, AccessGatewayUnAvailable, "接入网关不可用:"+g.Message) + utils.Disconnect(ws, AccessGatewayUnAvailable, "接入网关不可用:"+g.Message) return nil } exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port) if err != nil { - disconnect(ws, AccessGatewayCreateError, "创建SSH隧道失败:"+err.Error()) + utils.Disconnect(ws, AccessGatewayCreateError, "创建SSH隧道失败:"+err.Error()) return nil } defer g.CloseSshTunnel(s.ID) @@ -101,7 +111,7 @@ func TunEndpoint(c echo.Context) error { } active, err := utils.Tcping(ip, port) if !active { - disconnect(ws, AssetNotActive, "目标资产不在线: "+err.Error()) + utils.Disconnect(ws, AssetNotActive, "目标资产不在线: "+err.Error()) return nil } @@ -109,12 +119,12 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter("port", strconv.Itoa(port)) // 加载资产配置的属性,优先级比全局配置的高,因此最后加载,覆盖掉全局配置 - attributes, err := assetRepository.FindAssetAttrMapByAssetId(s.AssetId) + attributes, err := repository.AssetRepository.FindAssetAttrMapByAssetId(ctx, s.AssetId) if err != nil { return err } if len(attributes) > 0 { - setAssetConfig(attributes, s, configuration) + api.setAssetConfig(attributes, s, configuration) } } for name := range configuration.Parameters { @@ -130,7 +140,7 @@ func TunEndpoint(c echo.Context) error { guacdTunnel, err := guacd.NewTunnel(addr, configuration) if err != nil { if connectionId == "" { - disconnect(ws, NewTunnelError, err.Error()) + utils.Disconnect(ws, NewTunnelError, err.Error()) } log.Printf("[%v:%v] 建立连接失败: %v", sessionId, connectionId, err.Error()) return err @@ -144,7 +154,7 @@ func TunEndpoint(c echo.Context) error { GuacdTunnel: guacdTunnel, } - if len(s.ConnectionId) == 0 { + if connectionId == "" { if configuration.Protocol == constant.SSH { nextTerminal, err := CreateNextTerminalBySession(s) if err == nil { @@ -168,14 +178,14 @@ func TunEndpoint(c echo.Context) error { } // 创建新会话 log.Debugf("[%v:%v] 创建新会话: %v", sessionId, connectionId, sess.ConnectionId) - if err := sessionRepository.UpdateById(&sess, sessionId); err != nil { + if err := repository.SessionRepository.UpdateById(ctx, &sess, sessionId); err != nil { return err } } else { // 要监控会话 forObsSession := session.GlobalSessionManager.GetById(sessionId) if forObsSession == nil { - disconnect(ws, NotFoundSession, "获取会话失败") + utils.Disconnect(ws, NotFoundSession, "获取会话失败") return nil } nextSession.ID = utils.UUID() @@ -183,56 +193,8 @@ func TunEndpoint(c echo.Context) error { log.Debugf("[%v:%v] 观察者[%v]加入会话[%v]", sessionId, connectionId, nextSession.ID, s.ConnectionId) } - ctx, cancel := context.WithCancel(context.Background()) - tick := time.NewTicker(time.Millisecond * time.Duration(60)) - defer tick.Stop() - var buf []byte - dataChan := make(chan []byte) - - go func() { - GuacdLoop: - for { - select { - case <-ctx.Done(): - log.Debugf("[%v:%v] WebSocket 已关闭,即将关闭 Guacd 连接...", sessionId, connectionId) - break GuacdLoop - default: - instruction, err := guacdTunnel.Read() - if err != nil { - log.Debugf("[%v:%v] Guacd 读取失败,即将退出循环...", sessionId, connectionId) - disconnect(ws, TunnelClosed, "远程连接已关闭") - break GuacdLoop - } - if len(instruction) == 0 { - continue - } - dataChan <- instruction - } - } - log.Debugf("[%v:%v] Guacd 连接已关闭,退出 Guacd 循环。", sessionId, connectionId) - }() - - go func() { - tickLoop: - for { - select { - case <-ctx.Done(): - break tickLoop - case <-tick.C: - if len(buf) > 0 { - err = ws.WriteMessage(websocket.TextMessage, buf) - if err != nil { - log.Debugf("[%v:%v] WebSocket写入失败,即将关闭Guacd连接...", sessionId, connectionId) - break tickLoop - } - buf = []byte{} - } - case data := <-dataChan: - buf = append(buf, data...) - } - } - log.Debugf("[%v:%v] Guacd连接已关闭,退出定时器循环。", sessionId, connectionId) - }() + guacamoleHandler := NewGuacamoleHandler(ws, guacdTunnel) + guacamoleHandler.Start() for { _, message, err := ws.ReadMessage() @@ -250,20 +212,20 @@ func TunEndpoint(c echo.Context) error { log.Debugf("[%v:%v] 观察者[%v]退出会话", sessionId, connectionId, observerId) } } else { - CloseSessionById(sessionId, Normal, "用户正常退出") + service.SessionService.CloseSessionById(sessionId, Normal, "用户正常退出") } - cancel() - break + guacamoleHandler.Stop() + return nil } _, err = guacdTunnel.WriteAndFlush(message) if err != nil { - CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") + service.SessionService.CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") + return nil } } - return nil } -func setAssetConfig(attributes map[string]string, s model.Session, configuration *guacd.Configuration) { +func (api GuacamoleApi) setAssetConfig(attributes map[string]string, s model.Session, configuration *guacd.Configuration) { for key, value := range attributes { if guacd.DrivePath == key { // 忽略该参数 @@ -275,7 +237,7 @@ func setAssetConfig(attributes map[string]string, s model.Session, configuration // 默认空间ID和用户ID相同 storageId = s.Creator } - realPath := path.Join(storageService.GetBaseDrivePath(), storageId) + realPath := path.Join(service.StorageService.GetBaseDrivePath(), storageId) configuration.SetParameter(guacd.EnableDrive, "true") configuration.SetParameter(guacd.DriveName, "Next Terminal Filesystem") configuration.SetParameter(guacd.DrivePath, realPath) @@ -286,7 +248,7 @@ func setAssetConfig(attributes map[string]string, s model.Session, configuration } } -func setConfig(propertyMap map[string]string, s model.Session, configuration *guacd.Configuration) { +func (api GuacamoleApi) setConfig(propertyMap map[string]string, s model.Session, configuration *guacd.Configuration) { if propertyMap[guacd.EnableRecording] == "true" { configuration.SetParameter(guacd.RecordingPath, path.Join(config.GlobalCfg.Guacd.Recording, s.ID)) configuration.SetParameter(guacd.CreateRecordingPath, "true") @@ -312,7 +274,8 @@ func setConfig(propertyMap map[string]string, s model.Session, configuration *gu configuration.SetParameter(guacd.EnableMenuAnimations, propertyMap[guacd.EnableMenuAnimations]) configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching]) configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching]) - configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching]) + configuration.SetParameter(guacd.ColorDepth, propertyMap[guacd.ColorDepth]) + configuration.SetParameter(guacd.ForceLossless, propertyMap[guacd.ForceLossless]) case "ssh": if len(s.PrivateKey) > 0 && s.PrivateKey != "-" { configuration.SetParameter("username", s.Username) @@ -350,12 +313,3 @@ func setConfig(propertyMap map[string]string, s model.Session, configuration *gu } } - -func disconnect(ws *websocket.Conn, code int, reason string) { - // guacd 无法处理中文字符,所以进行了base64编码。 - encodeReason := base64.StdEncoding.EncodeToString([]byte(reason)) - err := guacd.NewInstruction("error", encodeReason, strconv.Itoa(code)) - _ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) - disconnect := guacd.NewInstruction("disconnect") - _ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) -} diff --git a/server/api/guacamole_handler.go b/server/api/guacamole_handler.go new file mode 100644 index 0000000..f81223a --- /dev/null +++ b/server/api/guacamole_handler.go @@ -0,0 +1,84 @@ +package api + +import ( + "context" + "time" + + "next-terminal/server/guacd" + "next-terminal/server/log" + "next-terminal/server/utils" + + "github.com/gorilla/websocket" +) + +type GuacamoleHandler struct { + ws *websocket.Conn + tunnel *guacd.Tunnel + ctx context.Context + cancel context.CancelFunc + dataChan chan []byte + tick *time.Ticker +} + +func NewGuacamoleHandler(ws *websocket.Conn, tunnel *guacd.Tunnel) *GuacamoleHandler { + ctx, cancel := context.WithCancel(context.Background()) + tick := time.NewTicker(time.Millisecond * time.Duration(60)) + return &GuacamoleHandler{ + ws: ws, + tunnel: tunnel, + ctx: ctx, + cancel: cancel, + dataChan: make(chan []byte), + tick: tick, + } +} + +func (r GuacamoleHandler) Start() { + go r.readFormTunnel() + go r.writeToWebsocket() +} + +func (r GuacamoleHandler) Stop() { + r.tick.Stop() + r.cancel() +} + +func (r GuacamoleHandler) readFormTunnel() { + for { + select { + case <-r.ctx.Done(): + return + default: + instruction, err := r.tunnel.Read() + if err != nil { + utils.Disconnect(r.ws, TunnelClosed, "远程连接已关闭") + return + } + if len(instruction) == 0 { + continue + } + r.dataChan <- instruction + } + } +} + +func (r GuacamoleHandler) writeToWebsocket() { + var buf []byte + for { + select { + case <-r.ctx.Done(): + return + case <-r.tick.C: + if len(buf) > 0 { + err := r.ws.WriteMessage(websocket.TextMessage, buf) + if err != nil { + log.Debugf("WebSocket写入失败,即将关闭Guacd连接...") + return + } + buf = []byte{} + } + case data := <-r.dataChan: + buf = append(buf, data...) + } + } +} diff --git a/server/api/job.go b/server/api/job.go index 1d2849e..fd2791d 100644 --- a/server/api/job.go +++ b/server/api/job.go @@ -1,16 +1,22 @@ package api import ( + "context" + "strconv" "strings" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func JobCreateEndpoint(c echo.Context) error { +type JobApi struct{} + +func (api JobApi) JobCreateEndpoint(c echo.Context) error { var item model.Job if err := c.Bind(&item); err != nil { return err @@ -19,13 +25,13 @@ func JobCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := jobService.Create(&item); err != nil { + if err := service.JobService.Create(&item); err != nil { return err } return Success(c, "") } -func JobPagingEndpoint(c echo.Context) error { +func (api JobApi) JobPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -34,18 +40,18 @@ func JobPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := jobRepository.Find(pageIndex, pageSize, name, status, order, field) + items, total, err := repository.JobRepository.Find(context.TODO(), pageIndex, pageSize, name, status, order, field) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func JobUpdateEndpoint(c echo.Context) error { +func (api JobApi) JobUpdateEndpoint(c echo.Context) error { id := c.Param("id") var item model.Job @@ -53,37 +59,37 @@ func JobUpdateEndpoint(c echo.Context) error { return err } item.ID = id - if err := jobService.UpdateById(&item); err != nil { + if err := service.JobService.UpdateById(&item); err != nil { return err } return Success(c, nil) } -func JobChangeStatusEndpoint(c echo.Context) error { +func (api JobApi) JobChangeStatusEndpoint(c echo.Context) error { id := c.Param("id") status := c.QueryParam("status") - if err := jobService.ChangeStatusById(id, status); err != nil { + if err := service.JobService.ChangeStatusById(id, status); err != nil { return err } return Success(c, "") } -func JobExecEndpoint(c echo.Context) error { +func (api JobApi) JobExecEndpoint(c echo.Context) error { id := c.Param("id") - if err := jobService.ExecJobById(id); err != nil { + if err := service.JobService.ExecJobById(id); err != nil { return err } return Success(c, "") } -func JobDeleteEndpoint(c echo.Context) error { +func (api JobApi) JobDeleteEndpoint(c echo.Context) error { ids := c.Param("id") split := strings.Split(ids, ",") for i := range split { jobId := split[i] - if err := jobService.DeleteJobById(jobId); err != nil { + if err := service.JobService.DeleteJobById(jobId); err != nil { return err } } @@ -91,10 +97,10 @@ func JobDeleteEndpoint(c echo.Context) error { return Success(c, nil) } -func JobGetEndpoint(c echo.Context) error { +func (api JobApi) JobGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := jobRepository.FindById(id) + item, err := repository.JobRepository.FindById(context.TODO(), id) if err != nil { return err } @@ -102,10 +108,10 @@ func JobGetEndpoint(c echo.Context) error { return Success(c, item) } -func JobGetLogsEndpoint(c echo.Context) error { +func (api JobApi) JobGetLogsEndpoint(c echo.Context) error { id := c.Param("id") - items, err := jobLogRepository.FindByJobId(id) + items, err := repository.JobLogRepository.FindByJobId(context.TODO(), id) if err != nil { return err } @@ -113,9 +119,9 @@ func JobGetLogsEndpoint(c echo.Context) error { return Success(c, items) } -func JobDeleteLogsEndpoint(c echo.Context) error { +func (api JobApi) JobDeleteLogsEndpoint(c echo.Context) error { id := c.Param("id") - if err := jobLogRepository.DeleteByJobId(id); err != nil { + if err := repository.JobLogRepository.DeleteByJobId(context.TODO(), id); err != nil { return err } return Success(c, "") diff --git a/server/api/login-log.go b/server/api/login-log.go index cd235c5..eb4207f 100644 --- a/server/api/login-log.go +++ b/server/api/login-log.go @@ -1,43 +1,49 @@ package api import ( + "context" "strconv" "strings" + "next-terminal/server/repository" + "next-terminal/server/service" + "github.com/labstack/echo/v4" ) -func LoginLogPagingEndpoint(c echo.Context) error { +type LoginLogApi struct{} + +func (api LoginLogApi) LoginLogPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) username := c.QueryParam("username") clientIp := c.QueryParam("clientIp") state := c.QueryParam("state") - items, total, err := loginLogRepository.Find(pageIndex, pageSize, username, clientIp, state) + items, total, err := repository.LoginLogRepository.Find(context.TODO(), pageIndex, pageSize, username, clientIp, state) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func LoginLogDeleteEndpoint(c echo.Context) error { +func (api LoginLogApi) LoginLogDeleteEndpoint(c echo.Context) error { ids := c.Param("id") tokens := strings.Split(ids, ",") - if err := userService.DeleteLoginLogs(tokens); err != nil { + if err := service.UserService.DeleteLoginLogs(tokens); err != nil { return err } return Success(c, nil) } -func LoginLogClearEndpoint(c echo.Context) error { - loginLogs, err := loginLogRepository.FindAllLoginLogs() +func (api LoginLogApi) LoginLogClearEndpoint(c echo.Context) error { + loginLogs, err := repository.LoginLogRepository.FindAllLoginLogs(context.TODO()) if err != nil { return err } @@ -46,7 +52,7 @@ func LoginLogClearEndpoint(c echo.Context) error { tokens = append(tokens, loginLogs[i].ID) } - if err := userService.DeleteLoginLogs(tokens); err != nil { + if err := service.UserService.DeleteLoginLogs(tokens); err != nil { return err } return Success(c, nil) diff --git a/server/api/overview.go b/server/api/overview.go index 752456e..0fbfd86 100644 --- a/server/api/overview.go +++ b/server/api/overview.go @@ -1,39 +1,30 @@ package api import ( + "context" + "next-terminal/server/constant" + "next-terminal/server/dto" + "next-terminal/server/repository" "github.com/labstack/echo/v4" ) -type Counter struct { - User int64 `json:"user"` - Asset int64 `json:"asset"` - Credential int64 `json:"credential"` - OnlineSession int64 `json:"onlineSession"` -} - -func OverviewCounterEndPoint(c echo.Context) error { - account, _ := GetCurrentAccount(c) +type OverviewApi struct{} +func (api OverviewApi) OverviewCounterEndPoint(c echo.Context) error { var ( countUser int64 countOnlineSession int64 credential int64 asset int64 ) - if constant.TypeUser == account.Type { - countUser, _ = userRepository.CountOnlineUser() - countOnlineSession, _ = sessionRepository.CountOnlineSession() - credential, _ = credentialRepository.CountByUserId(account.ID) - asset, _ = assetRepository.CountByUserId(account.ID) - } else { - countUser, _ = userRepository.CountOnlineUser() - countOnlineSession, _ = sessionRepository.CountOnlineSession() - credential, _ = credentialRepository.Count() - asset, _ = assetRepository.Count() - } - counter := Counter{ + countUser, _ = repository.UserRepository.CountOnlineUser(context.TODO()) + countOnlineSession, _ = repository.SessionRepository.CountOnlineSession(context.TODO()) + credential, _ = repository.CredentialRepository.Count(context.TODO()) + asset, _ = repository.AssetRepository.Count(context.TODO()) + + counter := dto.Counter{ User: countUser, OnlineSession: countOnlineSession, Credential: credential, @@ -43,8 +34,7 @@ func OverviewCounterEndPoint(c echo.Context) error { return Success(c, counter) } -func OverviewAssetEndPoint(c echo.Context) error { - account, _ := GetCurrentAccount(c) +func (api OverviewApi) OverviewAssetEndPoint(c echo.Context) error { var ( ssh int64 rdp int64 @@ -52,19 +42,13 @@ func OverviewAssetEndPoint(c echo.Context) error { telnet int64 kubernetes int64 ) - if constant.TypeUser == account.Type { - ssh, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.SSH) - rdp, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.RDP) - vnc, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.VNC) - telnet, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.Telnet) - kubernetes, _ = assetRepository.CountByUserIdAndProtocol(account.ID, constant.K8s) - } else { - ssh, _ = assetRepository.CountByProtocol(constant.SSH) - rdp, _ = assetRepository.CountByProtocol(constant.RDP) - vnc, _ = assetRepository.CountByProtocol(constant.VNC) - telnet, _ = assetRepository.CountByProtocol(constant.Telnet) - kubernetes, _ = assetRepository.CountByProtocol(constant.K8s) - } + + ssh, _ = repository.AssetRepository.CountByProtocol(context.TODO(), constant.SSH) + rdp, _ = repository.AssetRepository.CountByProtocol(context.TODO(), constant.RDP) + vnc, _ = repository.AssetRepository.CountByProtocol(context.TODO(), constant.VNC) + telnet, _ = repository.AssetRepository.CountByProtocol(context.TODO(), constant.Telnet) + kubernetes, _ = repository.AssetRepository.CountByProtocol(context.TODO(), constant.K8s) + m := echo.Map{ "ssh": ssh, "rdp": rdp, @@ -75,9 +59,8 @@ func OverviewAssetEndPoint(c echo.Context) error { return Success(c, m) } -func OverviewAccessEndPoint(c echo.Context) error { - account, _ := GetCurrentAccount(c) - access, err := sessionRepository.OverviewAccess(account) +func (api OverviewApi) OverviewAccessEndPoint(c echo.Context) error { + access, err := repository.SessionRepository.OverviewAccess(context.TODO()) if err != nil { return err } diff --git a/server/api/property.go b/server/api/property.go index f5e6310..bb2a71e 100644 --- a/server/api/property.go +++ b/server/api/property.go @@ -1,47 +1,29 @@ package api import ( - "errors" - "fmt" + "context" - "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "github.com/labstack/echo/v4" - "gorm.io/gorm" ) -func PropertyGetEndpoint(c echo.Context) error { - properties := propertyRepository.FindAllMap() +type PropertyApi struct{} + +func (api PropertyApi) PropertyGetEndpoint(c echo.Context) error { + properties := repository.PropertyRepository.FindAllMap(context.TODO()) return Success(c, properties) } -func PropertyUpdateEndpoint(c echo.Context) error { +func (api PropertyApi) PropertyUpdateEndpoint(c echo.Context) error { var item map[string]interface{} if err := c.Bind(&item); err != nil { return err } - for key := range item { - value := fmt.Sprintf("%v", item[key]) - if value == "" { - value = "-" - } - - property := model.Property{ - Name: key, - Value: value, - } - - _, err := propertyRepository.FindByName(key) - if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { - if err := propertyRepository.Create(&property); err != nil { - return err - } - } else { - if err := propertyRepository.UpdateByName(&property, key); err != nil { - return err - } - } + if err := service.PropertyService.Update(item); err != nil { + return err } return Success(c, nil) } diff --git a/server/api/resource-sharer.go b/server/api/resource-sharer.go index 3ecaa91..fa01985 100644 --- a/server/api/resource-sharer.go +++ b/server/api/resource-sharer.go @@ -1,55 +1,48 @@ package api import ( + "context" + + "next-terminal/server/dto" + "next-terminal/server/repository" + "github.com/labstack/echo/v4" ) -type RU struct { - UserGroupId string `json:"userGroupId"` - UserId string `json:"userId"` - StrategyId string `json:"strategyId"` - ResourceType string `json:"resourceType"` - ResourceIds []string `json:"resourceIds"` -} +type ResourceSharerApi struct{} -type UR struct { - ResourceId string `json:"resourceId"` - ResourceType string `json:"resourceType"` - UserIds []string `json:"userIds"` -} - -func RSGetSharersEndPoint(c echo.Context) error { +func (api ResourceSharerApi) RSGetSharersEndPoint(c echo.Context) error { resourceId := c.QueryParam("resourceId") resourceType := c.QueryParam("resourceType") userId := c.QueryParam("userId") userGroupId := c.QueryParam("userGroupId") - userIds, err := resourceSharerRepository.Find(resourceId, resourceType, userId, userGroupId) + userIds, err := repository.ResourceSharerRepository.Find(context.TODO(), resourceId, resourceType, userId, userGroupId) if err != nil { return err } return Success(c, userIds) } -func ResourceRemoveByUserIdAssignEndPoint(c echo.Context) error { - var ru RU +func (api ResourceSharerApi) ResourceRemoveByUserIdAssignEndPoint(c echo.Context) error { + var ru dto.RU if err := c.Bind(&ru); err != nil { return err } - if err := resourceSharerRepository.DeleteByUserIdAndResourceTypeAndResourceIdIn(ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil { + if err := repository.ResourceSharerRepository.DeleteByUserIdAndResourceTypeAndResourceIdIn(context.TODO(), ru.UserGroupId, ru.UserId, ru.ResourceType, ru.ResourceIds); err != nil { return err } return Success(c, "") } -func ResourceAddByUserIdAssignEndPoint(c echo.Context) error { - var ru RU +func (api ResourceSharerApi) ResourceAddByUserIdAssignEndPoint(c echo.Context) error { + var ru dto.RU if err := c.Bind(&ru); err != nil { return err } - if err := resourceSharerRepository.AddSharerResources(ru.UserGroupId, ru.UserId, ru.StrategyId, ru.ResourceType, ru.ResourceIds); err != nil { + if err := repository.ResourceSharerRepository.AddSharerResources(ru.UserGroupId, ru.UserId, ru.StrategyId, ru.ResourceType, ru.ResourceIds); err != nil { return err } diff --git a/server/api/routes.go b/server/api/routes.go deleted file mode 100644 index f6ecfdd..0000000 --- a/server/api/routes.go +++ /dev/null @@ -1,484 +0,0 @@ -package api - -import ( - "crypto/md5" - "fmt" - "net/http" - "os" - - "next-terminal/server/config" - "next-terminal/server/global/cache" - "next-terminal/server/log" - "next-terminal/server/model" - "next-terminal/server/repository" - "next-terminal/server/service" - "next-terminal/server/utils" - - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "gorm.io/driver/mysql" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" -) - -var ( - userRepository *repository.UserRepository - userGroupRepository *repository.UserGroupRepository - resourceSharerRepository *repository.ResourceSharerRepository - assetRepository *repository.AssetRepository - credentialRepository *repository.CredentialRepository - propertyRepository *repository.PropertyRepository - commandRepository *repository.CommandRepository - sessionRepository *repository.SessionRepository - accessSecurityRepository *repository.AccessSecurityRepository - accessGatewayRepository *repository.AccessGatewayRepository - jobRepository *repository.JobRepository - jobLogRepository *repository.JobLogRepository - loginLogRepository *repository.LoginLogRepository - storageRepository *repository.StorageRepository - strategyRepository *repository.StrategyRepository - - jobService *service.JobService - propertyService *service.PropertyService - userService *service.UserService - sessionService *service.SessionService - mailService *service.MailService - assetService *service.AssetService - credentialService *service.CredentialService - storageService *service.StorageService - accessGatewayService *service.AccessGatewayService -) - -func SetupRoutes(db *gorm.DB) *echo.Echo { - - InitRepository(db) - InitService() - - cache.GlobalCache.OnEvicted(userService.OnEvicted) - - if err := InitDBData(); err != nil { - log.Errorf("初始化数据异常: %v", err.Error()) - os.Exit(0) - } - - if err := ReloadData(); err != nil { - return nil - } - - e := echo.New() - e.HideBanner = true - //e.Logger = log.GetEchoLogger() - //e.Use(log.Hook()) - e.File("/", "web/build/index.html") - e.File("/asciinema.html", "web/build/asciinema.html") - e.File("/", "web/build/index.html") - e.File("/favicon.ico", "web/build/favicon.ico") - e.Static("/static", "web/build/static") - - e.Use(middleware.Recover()) - e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - Skipper: middleware.DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - })) - e.Use(ErrorHandler) - e.Use(TcpWall) - e.Use(Auth) - - e.POST("/login", LoginEndpoint) - e.POST("/loginWithTotp", loginWithTotpEndpoint) - - e.GET("/tunnel", TunEndpoint) - e.GET("/ssh", SSHEndpoint) - e.GET("/ssh-monitor", SshMonitor) - e.POST("/logout", LogoutEndpoint) - e.POST("/change-password", ChangePasswordEndpoint) - e.GET("/reload-totp", ReloadTOTPEndpoint) - e.POST("/reset-totp", ResetTOTPEndpoint) - e.POST("/confirm-totp", ConfirmTOTPEndpoint) - e.GET("/info", InfoEndpoint) - - account := e.Group("/account") - { - account.GET("/assets", AccountAssetEndpoint) - account.GET("/storage", AccountStorageEndpoint) - } - - users := e.Group("/users", Admin) - { - users.POST("", UserCreateEndpoint) - users.GET("/paging", UserPagingEndpoint) - users.PUT("/:id", UserUpdateEndpoint) - users.PATCH("/:id/status", UserUpdateStatusEndpoint) - users.DELETE("/:id", UserDeleteEndpoint) - users.GET("/:id", UserGetEndpoint) - users.POST("/:id/change-password", UserChangePasswordEndpoint) - users.POST("/:id/reset-totp", UserResetTotpEndpoint) - } - - userGroups := e.Group("/user-groups", Admin) - { - userGroups.POST("", UserGroupCreateEndpoint) - userGroups.GET("/paging", UserGroupPagingEndpoint) - userGroups.PUT("/:id", UserGroupUpdateEndpoint) - userGroups.DELETE("/:id", UserGroupDeleteEndpoint) - userGroups.GET("/:id", UserGroupGetEndpoint) - } - - assets := e.Group("/assets", Admin) - { - assets.GET("", AssetAllEndpoint) - assets.POST("", AssetCreateEndpoint) - assets.POST("/import", AssetImportEndpoint) - assets.GET("/paging", AssetPagingEndpoint) - assets.POST("/:id/tcping", AssetTcpingEndpoint) - assets.PUT("/:id", AssetUpdateEndpoint) - assets.GET("/:id", AssetGetEndpoint) - assets.DELETE("/:id", AssetDeleteEndpoint) - assets.POST("/:id/change-owner", AssetChangeOwnerEndpoint) - } - - e.GET("/tags", AssetTagsEndpoint) - - commands := e.Group("/commands") - { - commands.GET("", CommandAllEndpoint) - commands.GET("/paging", CommandPagingEndpoint) - commands.POST("", CommandCreateEndpoint) - commands.PUT("/:id", CommandUpdateEndpoint) - commands.DELETE("/:id", CommandDeleteEndpoint) - commands.GET("/:id", CommandGetEndpoint) - commands.POST("/:id/change-owner", CommandChangeOwnerEndpoint, Admin) - } - - credentials := e.Group("/credentials", Admin) - { - credentials.GET("", CredentialAllEndpoint) - credentials.GET("/paging", CredentialPagingEndpoint) - credentials.POST("", CredentialCreateEndpoint) - credentials.PUT("/:id", CredentialUpdateEndpoint) - credentials.DELETE("/:id", CredentialDeleteEndpoint) - credentials.GET("/:id", CredentialGetEndpoint) - credentials.POST("/:id/change-owner", CredentialChangeOwnerEndpoint) - } - - sessions := e.Group("/sessions") - { - sessions.GET("/paging", Admin(SessionPagingEndpoint)) - sessions.POST("/:id/disconnect", Admin(SessionDisconnectEndpoint)) - sessions.DELETE("/:id", Admin(SessionDeleteEndpoint)) - sessions.GET("/:id/recording", Admin(SessionRecordingEndpoint)) - sessions.GET("/:id", Admin(SessionGetEndpoint)) - sessions.POST("/:id/reviewed", Admin(SessionReviewedEndpoint)) - sessions.POST("/:id/unreviewed", Admin(SessionUnViewedEndpoint)) - sessions.POST("/clear", Admin(SessionClearEndpoint)) - sessions.POST("/reviewed", Admin(SessionReviewedAllEndpoint)) - - sessions.POST("", SessionCreateEndpoint) - sessions.POST("/:id/connect", SessionConnectEndpoint) - sessions.POST("/:id/resize", SessionResizeEndpoint) - sessions.GET("/:id/stats", SessionStatsEndpoint) - - sessions.POST("/:id/ls", SessionLsEndpoint) - sessions.GET("/:id/download", SessionDownloadEndpoint) - sessions.POST("/:id/upload", SessionUploadEndpoint) - sessions.POST("/:id/edit", SessionEditEndpoint) - sessions.POST("/:id/mkdir", SessionMkDirEndpoint) - sessions.POST("/:id/rm", SessionRmEndpoint) - sessions.POST("/:id/rename", SessionRenameEndpoint) - } - - resourceSharers := e.Group("/resource-sharers", Admin) - { - resourceSharers.GET("", RSGetSharersEndPoint) - resourceSharers.POST("/remove-resources", ResourceRemoveByUserIdAssignEndPoint) - resourceSharers.POST("/add-resources", ResourceAddByUserIdAssignEndPoint) - } - - loginLogs := e.Group("login-logs", Admin) - { - loginLogs.GET("/paging", LoginLogPagingEndpoint) - loginLogs.DELETE("/:id", LoginLogDeleteEndpoint) - loginLogs.POST("/clear", LoginLogClearEndpoint) - } - - e.GET("/properties", Admin(PropertyGetEndpoint)) - e.PUT("/properties", Admin(PropertyUpdateEndpoint)) - - overview := e.Group("overview", Admin) - { - overview.GET("/counter", OverviewCounterEndPoint) - overview.GET("/asset", OverviewAssetEndPoint) - overview.GET("/access", OverviewAccessEndPoint) - } - - jobs := e.Group("/jobs", Admin) - { - jobs.POST("", JobCreateEndpoint) - jobs.GET("/paging", JobPagingEndpoint) - jobs.PUT("/:id", JobUpdateEndpoint) - jobs.POST("/:id/change-status", JobChangeStatusEndpoint) - jobs.POST("/:id/exec", JobExecEndpoint) - jobs.DELETE("/:id", JobDeleteEndpoint) - jobs.GET("/:id", JobGetEndpoint) - jobs.GET("/:id/logs", JobGetLogsEndpoint) - jobs.DELETE("/:id/logs", JobDeleteLogsEndpoint) - } - - securities := e.Group("/securities", Admin) - { - securities.POST("", SecurityCreateEndpoint) - securities.GET("/paging", SecurityPagingEndpoint) - securities.PUT("/:id", SecurityUpdateEndpoint) - securities.DELETE("/:id", SecurityDeleteEndpoint) - securities.GET("/:id", SecurityGetEndpoint) - } - - storages := e.Group("/storages") - { - storages.GET("/paging", StoragePagingEndpoint, Admin) - storages.POST("", StorageCreateEndpoint, Admin) - storages.DELETE("/:id", StorageDeleteEndpoint, Admin) - storages.PUT("/:id", StorageUpdateEndpoint, Admin) - storages.GET("/shares", StorageSharesEndpoint, Admin) - storages.GET("/:id", StorageGetEndpoint, Admin) - - storages.POST("/:storageId/ls", StorageLsEndpoint) - storages.GET("/:storageId/download", StorageDownloadEndpoint) - storages.POST("/:storageId/upload", StorageUploadEndpoint) - storages.POST("/:storageId/mkdir", StorageMkDirEndpoint) - storages.POST("/:storageId/rm", StorageRmEndpoint) - storages.POST("/:storageId/rename", StorageRenameEndpoint) - storages.POST("/:storageId/edit", StorageEditEndpoint) - } - - strategies := e.Group("/strategies", Admin) - { - strategies.GET("", StrategyAllEndpoint) - strategies.GET("/paging", StrategyPagingEndpoint) - strategies.POST("", StrategyCreateEndpoint) - strategies.DELETE("/:id", StrategyDeleteEndpoint) - strategies.PUT("/:id", StrategyUpdateEndpoint) - } - - accessGateways := e.Group("/access-gateways", Admin) - { - accessGateways.GET("", AccessGatewayAllEndpoint) - accessGateways.POST("", AccessGatewayCreateEndpoint) - accessGateways.GET("/paging", AccessGatewayPagingEndpoint) - accessGateways.PUT("/:id", AccessGatewayUpdateEndpoint) - accessGateways.DELETE("/:id", AccessGatewayDeleteEndpoint) - accessGateways.GET("/:id", AccessGatewayGetEndpoint) - accessGateways.POST("/:id/reconnect", AccessGatewayReconnectEndpoint) - } - - backup := e.Group("/backup", Admin) - { - backup.GET("/export", BackupExportEndpoint) - backup.POST("/import", BackupImportEndpoint) - } - - return e -} - -func ReloadData() error { - if err := ReloadAccessSecurity(); err != nil { - return err - } - - if err := ReloadToken(); err != nil { - return err - } - return nil -} - -func InitRepository(db *gorm.DB) { - userRepository = repository.NewUserRepository(db) - userGroupRepository = repository.NewUserGroupRepository(db) - resourceSharerRepository = repository.NewResourceSharerRepository(db) - assetRepository = repository.NewAssetRepository(db) - credentialRepository = repository.NewCredentialRepository(db) - propertyRepository = repository.NewPropertyRepository(db) - commandRepository = repository.NewCommandRepository(db) - sessionRepository = repository.NewSessionRepository(db) - accessSecurityRepository = repository.NewAccessSecurityRepository(db) - accessGatewayRepository = repository.NewAccessGatewayRepository(db) - jobRepository = repository.NewJobRepository(db) - jobLogRepository = repository.NewJobLogRepository(db) - loginLogRepository = repository.NewLoginLogRepository(db) - storageRepository = repository.NewStorageRepository(db) - strategyRepository = repository.NewStrategyRepository(db) -} - -func InitService() { - propertyService = service.NewPropertyService(propertyRepository) - userService = service.NewUserService(userRepository, loginLogRepository) - sessionService = service.NewSessionService(sessionRepository) - mailService = service.NewMailService(propertyRepository) - assetService = service.NewAssetService(assetRepository) - jobService = service.NewJobService(jobRepository, jobLogRepository, assetRepository, credentialRepository, assetService) - credentialService = service.NewCredentialService(credentialRepository) - storageService = service.NewStorageService(storageRepository, userRepository, propertyRepository) - accessGatewayService = service.NewAccessGatewayService(accessGatewayRepository) -} - -func InitDBData() (err error) { - if err := propertyService.DeleteDeprecatedProperty(); err != nil { - return err - } - if err := accessGatewayService.ReConnectAll(); err != nil { - return err - } - if err := propertyService.InitProperties(); err != nil { - return err - } - if err := userService.InitUser(); err != nil { - return err - } - if err := jobService.InitJob(); err != nil { - return err - } - if err := userService.FixUserOnlineState(); err != nil { - return err - } - if err := sessionService.FixSessionState(); err != nil { - return err - } - if err := sessionService.EmptyPassword(); err != nil { - return err - } - if err := credentialService.Encrypt(); err != nil { - return err - } - if err := assetService.Encrypt(); err != nil { - return err - } - if err := storageService.InitStorages(); err != nil { - return err - } - - return nil -} - -func ResetPassword(username string) error { - user, err := userRepository.FindByUsername(username) - if err != nil { - return err - } - password := "next-terminal" - passwd, err := utils.Encoder.Encode([]byte(password)) - if err != nil { - return err - } - u := &model.User{ - Password: string(passwd), - ID: user.ID, - } - if err := userRepository.Update(u); err != nil { - return err - } - log.Debugf("用户「%v」密码初始化为: %v", user.Username, password) - return nil -} - -func ResetTotp(username string) error { - user, err := userRepository.FindByUsername(username) - if err != nil { - return err - } - u := &model.User{ - TOTPSecret: "-", - ID: user.ID, - } - if err := userRepository.Update(u); err != nil { - return err - } - log.Debugf("用户「%v」已重置TOTP", user.Username) - return nil -} - -func ChangeEncryptionKey(oldEncryptionKey, newEncryptionKey string) error { - - oldPassword := []byte(fmt.Sprintf("%x", md5.Sum([]byte(oldEncryptionKey)))) - newPassword := []byte(fmt.Sprintf("%x", md5.Sum([]byte(newEncryptionKey)))) - - credentials, err := credentialRepository.FindAll() - if err != nil { - return err - } - for i := range credentials { - credential := credentials[i] - if err := credentialRepository.Decrypt(&credential, oldPassword); err != nil { - return err - } - if err := credentialRepository.Encrypt(&credential, newPassword); err != nil { - return err - } - if err := credentialRepository.UpdateById(&credential, credential.ID); err != nil { - return err - } - } - assets, err := assetRepository.FindAll() - if err != nil { - return err - } - for i := range assets { - asset := assets[i] - if err := assetRepository.Decrypt(&asset, oldPassword); err != nil { - return err - } - if err := assetRepository.Encrypt(&asset, newPassword); err != nil { - return err - } - if err := assetRepository.UpdateById(&asset, asset.ID); err != nil { - return err - } - } - log.Infof("encryption key has being changed.") - return nil -} - -func SetupDB() *gorm.DB { - - var logMode logger.Interface - if config.GlobalCfg.Debug { - logMode = logger.Default.LogMode(logger.Info) - } else { - logMode = logger.Default.LogMode(logger.Silent) - } - - fmt.Printf("当前数据库模式为:%v\n", config.GlobalCfg.DB) - var err error - var db *gorm.DB - if config.GlobalCfg.DB == "mysql" { - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=60s", - config.GlobalCfg.Mysql.Username, - config.GlobalCfg.Mysql.Password, - config.GlobalCfg.Mysql.Hostname, - config.GlobalCfg.Mysql.Port, - config.GlobalCfg.Mysql.Database, - ) - db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ - Logger: logMode, - }) - } else { - db, err = gorm.Open(sqlite.Open(config.GlobalCfg.Sqlite.File), &gorm.Config{ - Logger: logMode, - }) - } - - if err != nil { - log.Errorf("连接数据库异常: %v", err.Error()) - os.Exit(0) - } - - if err := db.AutoMigrate(&model.User{}, &model.Asset{}, &model.AssetAttribute{}, &model.Session{}, &model.Command{}, - &model.Credential{}, &model.Property{}, &model.ResourceSharer{}, &model.UserGroup{}, &model.UserGroupMember{}, - &model.LoginLog{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}, &model.AccessGateway{}, - &model.Storage{}, &model.Strategy{}); err != nil { - log.Errorf("初始化数据库表结构异常: %v", err.Error()) - os.Exit(0) - } - return db -} diff --git a/server/api/security.go b/server/api/security.go index 7378323..abed4b9 100644 --- a/server/api/security.go +++ b/server/api/security.go @@ -1,17 +1,22 @@ package api import ( + "context" + "strconv" "strings" "next-terminal/server/global/security" "next-terminal/server/model" + "next-terminal/server/repository" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func SecurityCreateEndpoint(c echo.Context) error { +type SecurityApi struct{} + +func (api SecurityApi) SecurityCreateEndpoint(c echo.Context) error { var item model.AccessSecurity if err := c.Bind(&item); err != nil { return err @@ -20,7 +25,7 @@ func SecurityCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Source = "管理员添加" - if err := accessSecurityRepository.Create(&item); err != nil { + if err := repository.SecurityRepository.Create(context.TODO(), &item); err != nil { return err } // 更新内存中的安全规则 @@ -35,29 +40,7 @@ func SecurityCreateEndpoint(c echo.Context) error { return Success(c, "") } -func ReloadAccessSecurity() error { - rules, err := accessSecurityRepository.FindAll() - if err != nil { - return err - } - if len(rules) > 0 { - // 先清空 - security.GlobalSecurityManager.Clear() - // 再添加到全局的安全管理器中 - for i := 0; i < len(rules); i++ { - rule := &security.Security{ - ID: rules[i].ID, - IP: rules[i].IP, - Rule: rules[i].Rule, - Priority: rules[i].Priority, - } - security.GlobalSecurityManager.Add <- rule - } - } - return nil -} - -func SecurityPagingEndpoint(c echo.Context) error { +func (api SecurityApi) SecurityPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) ip := c.QueryParam("ip") @@ -66,18 +49,18 @@ func SecurityPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := accessSecurityRepository.Find(pageIndex, pageSize, ip, rule, order, field) + items, total, err := repository.SecurityRepository.Find(context.TODO(), pageIndex, pageSize, ip, rule, order, field) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func SecurityUpdateEndpoint(c echo.Context) error { +func (api SecurityApi) SecurityUpdateEndpoint(c echo.Context) error { id := c.Param("id") var item model.AccessSecurity @@ -85,7 +68,7 @@ func SecurityUpdateEndpoint(c echo.Context) error { return err } - if err := accessSecurityRepository.UpdateById(&item, id); err != nil { + if err := repository.SecurityRepository.UpdateById(context.TODO(), &item, id); err != nil { return err } // 更新内存中的安全规则 @@ -101,13 +84,13 @@ func SecurityUpdateEndpoint(c echo.Context) error { return Success(c, nil) } -func SecurityDeleteEndpoint(c echo.Context) error { +func (api SecurityApi) SecurityDeleteEndpoint(c echo.Context) error { ids := c.Param("id") split := strings.Split(ids, ",") for i := range split { id := split[i] - if err := accessSecurityRepository.DeleteById(id); err != nil { + if err := repository.SecurityRepository.DeleteById(context.TODO(), id); err != nil { return err } // 更新内存中的安全规则 @@ -117,10 +100,10 @@ func SecurityDeleteEndpoint(c echo.Context) error { return Success(c, nil) } -func SecurityGetEndpoint(c echo.Context) error { +func (api SecurityApi) SecurityGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := accessSecurityRepository.FindById(id) + item, err := repository.SecurityRepository.FindById(context.TODO(), id) if err != nil { return err } diff --git a/server/api/session.go b/server/api/session.go index 651790b..631e7cd 100644 --- a/server/api/session.go +++ b/server/api/session.go @@ -3,6 +3,7 @@ package api import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -11,23 +12,22 @@ import ( "path" "strconv" "strings" - "sync" "next-terminal/server/constant" "next-terminal/server/global/session" - "next-terminal/server/guacd" "next-terminal/server/log" "next-terminal/server/model" + "next-terminal/server/repository" "next-terminal/server/service" "next-terminal/server/utils" - "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/pkg/sftp" - "gorm.io/gorm" ) -func SessionPagingEndpoint(c echo.Context) error { +type SessionApi struct{} + +func (api SessionApi) SessionPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) status := c.QueryParam("status") @@ -37,7 +37,7 @@ func SessionPagingEndpoint(c echo.Context) error { protocol := c.QueryParam("protocol") reviewed := c.QueryParam("reviewed") - items, total, err := sessionRepository.Find(pageIndex, pageSize, status, userId, clientIp, assetId, protocol, reviewed) + items, total, err := repository.SessionRepository.Find(context.TODO(), pageIndex, pageSize, status, userId, clientIp, assetId, protocol, reviewed) if err != nil { return err @@ -63,15 +63,15 @@ func SessionPagingEndpoint(c echo.Context) error { } } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func SessionDeleteEndpoint(c echo.Context) error { +func (api SessionApi) SessionDeleteEndpoint(c echo.Context) error { sessionIds := strings.Split(c.Param("id"), ",") - err := sessionRepository.DeleteByIds(sessionIds) + err := repository.SessionRepository.DeleteByIds(context.TODO(), sessionIds) if err != nil { return err } @@ -79,38 +79,38 @@ func SessionDeleteEndpoint(c echo.Context) error { return Success(c, nil) } -func SessionClearEndpoint(c echo.Context) error { - err := sessionService.ClearOfflineSession() +func (api SessionApi) SessionClearEndpoint(c echo.Context) error { + err := service.SessionService.ClearOfflineSession() if err != nil { return err } return Success(c, nil) } -func SessionReviewedEndpoint(c echo.Context) error { +func (api SessionApi) SessionReviewedEndpoint(c echo.Context) error { sessionIds := strings.Split(c.Param("id"), ",") - if err := sessionRepository.UpdateReadByIds(true, sessionIds); err != nil { + if err := repository.SessionRepository.UpdateReadByIds(context.TODO(), true, sessionIds); err != nil { return err } return Success(c, nil) } -func SessionUnViewedEndpoint(c echo.Context) error { +func (api SessionApi) SessionUnViewedEndpoint(c echo.Context) error { sessionIds := strings.Split(c.Param("id"), ",") - if err := sessionRepository.UpdateReadByIds(false, sessionIds); err != nil { + if err := repository.SessionRepository.UpdateReadByIds(context.TODO(), false, sessionIds); err != nil { return err } return Success(c, nil) } -func SessionReviewedAllEndpoint(c echo.Context) error { - if err := sessionService.ReviewedAll(); err != nil { +func (api SessionApi) SessionReviewedAllEndpoint(c echo.Context) error { + if err := service.SessionService.ReviewedAll(); err != nil { return err } return Success(c, nil) } -func SessionConnectEndpoint(c echo.Context) error { +func (api SessionApi) SessionConnectEndpoint(c echo.Context) error { sessionId := c.Param("id") s := model.Session{} @@ -118,112 +118,37 @@ func SessionConnectEndpoint(c echo.Context) error { s.Status = constant.Connected s.ConnectedTime = utils.NowJsonTime() - if err := sessionRepository.UpdateById(&s, sessionId); err != nil { + if err := repository.SessionRepository.UpdateById(context.TODO(), &s, sessionId); err != nil { return err } - o, err := sessionRepository.FindById(sessionId) + o, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } - asset, err := assetRepository.FindById(o.AssetId) + asset, err := repository.AssetRepository.FindById(context.TODO(), o.AssetId) if err != nil { return err } if !asset.Active { asset.Active = true - _ = assetRepository.UpdateById(&asset, asset.ID) + _ = repository.AssetRepository.UpdateById(context.TODO(), &asset, asset.ID) } return Success(c, nil) } -func SessionDisconnectEndpoint(c echo.Context) error { +func (api SessionApi) SessionDisconnectEndpoint(c echo.Context) error { sessionIds := c.Param("id") split := strings.Split(sessionIds, ",") for i := range split { - CloseSessionById(split[i], ForcedDisconnect, "管理员强制关闭了此会话") + service.SessionService.CloseSessionById(split[i], ForcedDisconnect, "管理员强制关闭了此会话") } return Success(c, nil) } -var mutex sync.Mutex - -func CloseSessionById(sessionId string, code int, reason string) { - mutex.Lock() - defer mutex.Unlock() - nextSession := session.GlobalSessionManager.GetById(sessionId) - if nextSession != nil { - log.Debugf("[%v] 会话关闭,原因:%v", sessionId, reason) - WriteCloseMessage(nextSession.WebSocket, nextSession.Mode, code, reason) - - if nextSession.Observer != nil { - obs := nextSession.Observer.All() - for _, ob := range obs { - WriteCloseMessage(ob.WebSocket, ob.Mode, code, reason) - log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID) - } - } - } - session.GlobalSessionManager.Del <- sessionId - - DisDBSess(sessionId, code, reason) -} - -func WriteCloseMessage(ws *websocket.Conn, mode string, code int, reason string) { - switch mode { - case constant.Guacd: - if ws != nil { - err := guacd.NewInstruction("error", "", strconv.Itoa(code)) - _ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) - disconnect := guacd.NewInstruction("disconnect") - _ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) - } - case constant.Naive: - if ws != nil { - msg := `0` + reason - _ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) - } - case constant.Terminal: - // 这里是关闭观察者的ssh会话 - if ws != nil { - msg := `0` + reason - _ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) - } - } -} - -func DisDBSess(sessionId string, code int, reason string) { - s, err := sessionRepository.FindById(sessionId) - if err != nil { - return - } - - if s.Status == constant.Disconnected { - return - } - - if s.Status == constant.Connecting { - // 会话还未建立成功,无需保留数据 - _ = sessionRepository.DeleteById(sessionId) - return - } - - ss := model.Session{} - ss.ID = sessionId - ss.Status = constant.Disconnected - ss.DisconnectedTime = utils.NowJsonTime() - ss.Code = code - ss.Message = reason - ss.Password = "-" - ss.PrivateKey = "-" - ss.Passphrase = "-" - - _ = sessionRepository.UpdateById(&ss, sessionId) -} - -func SessionResizeEndpoint(c echo.Context) error { +func (api SessionApi) SessionResizeEndpoint(c echo.Context) error { width := c.QueryParam("width") height := c.QueryParam("height") sessionId := c.Param("id") @@ -235,13 +160,13 @@ func SessionResizeEndpoint(c echo.Context) error { intWidth, _ := strconv.Atoi(width) intHeight, _ := strconv.Atoi(height) - if err := sessionRepository.UpdateWindowSizeById(intWidth, intHeight, sessionId); err != nil { + if err := repository.SessionRepository.UpdateWindowSizeById(context.TODO(), intWidth, intHeight, sessionId); err != nil { return err } return Success(c, "") } -func SessionCreateEndpoint(c echo.Context) error { +func (api SessionApi) SessionCreateEndpoint(c echo.Context) error { assetId := c.QueryParam("assetId") mode := c.QueryParam("mode") @@ -253,106 +178,11 @@ func SessionCreateEndpoint(c echo.Context) error { user, _ := GetCurrentAccount(c) - asset, err := assetRepository.FindById(assetId) + s, err := service.SessionService.Create(c.RealIP(), assetId, mode, user) if err != nil { return err } - var ( - upload = "1" - download = "1" - _delete = "1" - rename = "1" - edit = "1" - fileSystem = "1" - ) - if asset.Owner != user.ID && constant.TypeUser == user.Type { - // 普通用户访问非自己创建的资产需要校验权限 - resourceSharers, err := resourceSharerRepository.FindByResourceIdAndUserId(assetId, user.ID) - if err != nil { - return err - } - if len(resourceSharers) == 0 { - return errors.New("您没有权限访问此资产") - } - strategyId := resourceSharers[0].StrategyId - if strategyId != "" { - strategy, err := strategyRepository.FindById(strategyId) - if err != nil { - if !errors.Is(gorm.ErrRecordNotFound, err) { - return err - } - } else { - upload = strategy.Upload - download = strategy.Download - _delete = strategy.Delete - rename = strategy.Rename - edit = strategy.Edit - } - } - } - - var storageId = "" - if constant.RDP == asset.Protocol { - attr, err := assetRepository.FindAssetAttrMapByAssetId(assetId) - if err != nil { - return err - } - if "true" == attr[guacd.EnableDrive] { - fileSystem = "1" - storageId = attr[guacd.DrivePath] - if storageId == "" { - storageId = user.ID - } - } else { - fileSystem = "0" - } - } - - s := &model.Session{ - ID: utils.UUID(), - AssetId: asset.ID, - Username: asset.Username, - Password: asset.Password, - PrivateKey: asset.PrivateKey, - Passphrase: asset.Passphrase, - Protocol: asset.Protocol, - IP: asset.IP, - Port: asset.Port, - Status: constant.NoConnect, - Creator: user.ID, - ClientIP: c.RealIP(), - Mode: mode, - Upload: upload, - Download: download, - Delete: _delete, - Rename: rename, - Edit: edit, - StorageId: storageId, - AccessGatewayId: asset.AccessGatewayId, - Reviewed: false, - } - - if asset.AccountType == "credential" { - credential, err := credentialRepository.FindById(asset.CredentialId) - if err != nil { - return err - } - - if credential.Type == constant.Custom { - s.Username = credential.Username - s.Password = credential.Password - } else { - s.Username = credential.Username - s.PrivateKey = credential.PrivateKey - s.Passphrase = credential.Passphrase - } - } - - if err := sessionRepository.Create(s); err != nil { - return err - } - return Success(c, echo.Map{ "id": s.ID, "upload": s.Upload, @@ -361,13 +191,15 @@ func SessionCreateEndpoint(c echo.Context) error { "rename": s.Rename, "edit": s.Edit, "storageId": s.StorageId, - "fileSystem": fileSystem, + "fileSystem": s.FileSystem, + "copy": s.Copy, + "paste": s.Paste, }) } -func SessionUploadEndpoint(c echo.Context) error { +func (api SessionApi) SessionUploadEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } @@ -414,15 +246,18 @@ func SessionUploadEndpoint(c echo.Context) error { } return Success(c, nil) } else if "rdp" == s.Protocol { - return StorageUpload(c, file, s.StorageId) + if err := service.StorageService.StorageUpload(c, file, s.StorageId); err != nil { + return err + } + return Success(c, nil) } return err } -func SessionEditEndpoint(c echo.Context) error { +func (api SessionApi) SessionEditEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } @@ -453,14 +288,17 @@ func SessionEditEndpoint(c echo.Context) error { } return Success(c, nil) } else if "rdp" == s.Protocol { - return StorageEdit(c, file, fileContent, s.StorageId) + if err := service.StorageService.StorageEdit(file, fileContent, s.StorageId); err != nil { + return err + } + return Success(c, nil) } return err } -func SessionDownloadEndpoint(c echo.Context) error { +func (api SessionApi) SessionDownloadEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } @@ -492,15 +330,15 @@ func SessionDownloadEndpoint(c echo.Context) error { return c.Stream(http.StatusOK, echo.MIMEOctetStream, bytes.NewReader(buff.Bytes())) } else if "rdp" == s.Protocol { storageId := s.StorageId - return StorageDownload(c, remoteFile, storageId) + return service.StorageService.StorageDownload(c, remoteFile, storageId) } return err } -func SessionLsEndpoint(c echo.Context) error { +func (api SessionApi) SessionLsEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindByIdAndDecrypt(sessionId) + s, err := service.SessionService.FindByIdAndDecrypt(context.TODO(), sessionId) if err != nil { return err } @@ -550,15 +388,19 @@ func SessionLsEndpoint(c echo.Context) error { return Success(c, files) } else if "rdp" == s.Protocol { storageId := s.StorageId - return StorageLs(c, remoteDir, storageId) + err, files := service.StorageService.StorageLs(remoteDir, storageId) + if err != nil { + return err + } + return Success(c, files) } return errors.New("当前协议不支持此操作") } -func SessionMkDirEndpoint(c echo.Context) error { +func (api SessionApi) SessionMkDirEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } @@ -576,14 +418,18 @@ func SessionMkDirEndpoint(c echo.Context) error { } return Success(c, nil) } else if "rdp" == s.Protocol { - return StorageMkDir(c, remoteDir, s.StorageId) + storageId := s.StorageId + if err := service.StorageService.StorageMkDir(remoteDir, storageId); err != nil { + return err + } + return Success(c, nil) } return errors.New("当前协议不支持此操作") } -func SessionRmEndpoint(c echo.Context) error { +func (api SessionApi) SessionRmEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } @@ -628,15 +474,19 @@ func SessionRmEndpoint(c echo.Context) error { return Success(c, nil) } else if "rdp" == s.Protocol { - return StorageRm(c, file, s.StorageId) + storageId := s.StorageId + if err := service.StorageService.StorageRm(file, storageId); err != nil { + return err + } + return Success(c, nil) } return errors.New("当前协议不支持此操作") } -func SessionRenameEndpoint(c echo.Context) error { +func (api SessionApi) SessionRenameEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } @@ -659,14 +509,18 @@ func SessionRenameEndpoint(c echo.Context) error { return Success(c, nil) } else if "rdp" == s.Protocol { - return StorageRename(c, oldName, newName, s.StorageId) + storageId := s.StorageId + if err := service.StorageService.StorageRename(oldName, newName, storageId); err != nil { + return err + } + return Success(c, nil) } return errors.New("当前协议不支持此操作") } -func SessionRecordingEndpoint(c echo.Context) error { +func (api SessionApi) SessionRecordingEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } @@ -677,24 +531,24 @@ func SessionRecordingEndpoint(c echo.Context) error { } else { recording = s.Recording + "/recording" } - _ = sessionRepository.UpdateReadByIds(true, []string{sessionId}) + _ = repository.SessionRepository.UpdateReadByIds(context.TODO(), true, []string{sessionId}) log.Debugf("读取录屏文件:%v,是否存在: %v, 是否为文件: %v", recording, utils.FileExists(recording), utils.IsFile(recording)) return c.File(recording) } -func SessionGetEndpoint(c echo.Context) error { +func (api SessionApi) SessionGetEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindById(sessionId) + s, err := repository.SessionRepository.FindById(context.TODO(), sessionId) if err != nil { return err } return Success(c, s) } -func SessionStatsEndpoint(c echo.Context) error { +func (api SessionApi) SessionStatsEndpoint(c echo.Context) error { sessionId := c.Param("id") - s, err := sessionRepository.FindByIdAndDecrypt(sessionId) + s, err := service.SessionService.FindByIdAndDecrypt(context.TODO(), sessionId) if err != nil { return err } diff --git a/server/api/ssh.go b/server/api/ssh.go deleted file mode 100644 index e4c375d..0000000 --- a/server/api/ssh.go +++ /dev/null @@ -1,435 +0,0 @@ -package api - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "net/http" - "path" - "strconv" - "time" - "unicode/utf8" - - "next-terminal/server/config" - "next-terminal/server/constant" - "next-terminal/server/global/session" - "next-terminal/server/guacd" - "next-terminal/server/log" - "next-terminal/server/model" - "next-terminal/server/term" - "next-terminal/server/utils" - - "github.com/gorilla/websocket" - "github.com/labstack/echo/v4" -) - -var UpGrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, - Subprotocols: []string{"guacamole"}, -} - -const ( - Closed = 0 - Connected = 1 - Data = 2 - Resize = 3 - Ping = 4 -) - -type Message struct { - Type int `json:"type"` - Content string `json:"content"` -} - -func (r Message) ToString() string { - if r.Content != "" { - return strconv.Itoa(r.Type) + r.Content - } else { - return strconv.Itoa(r.Type) - } -} - -func NewMessage(_type int, content string) Message { - return Message{Content: content, Type: _type} -} - -func ParseMessage(value string) (message Message, err error) { - if value == "" { - return - } - - _type, err := strconv.Atoi(value[:1]) - if err != nil { - return - } - var content = value[1:] - message = NewMessage(_type, content) - return -} - -type WindowSize struct { - Cols int `json:"cols"` - Rows int `json:"rows"` -} - -func SSHEndpoint(c echo.Context) (err error) { - ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) - if err != nil { - log.Errorf("升级为WebSocket协议失败:%v", err.Error()) - return err - } - - defer ws.Close() - - sessionId := c.QueryParam("sessionId") - cols, _ := strconv.Atoi(c.QueryParam("cols")) - rows, _ := strconv.Atoi(c.QueryParam("rows")) - - s, err := sessionRepository.FindByIdAndDecrypt(sessionId) - if err != nil { - return WriteMessage(ws, NewMessage(Closed, "获取会话失败")) - } - - if err := permissionCheck(c, s.AssetId); err != nil { - return WriteMessage(ws, NewMessage(Closed, err.Error())) - } - - var ( - username = s.Username - password = s.Password - privateKey = s.PrivateKey - passphrase = s.Passphrase - ip = s.IP - port = s.Port - ) - - if s.AccessGatewayId != "" && s.AccessGatewayId != "-" { - g, err := accessGatewayService.GetGatewayAndReconnectById(s.AccessGatewayId) - if err != nil { - return WriteMessage(ws, NewMessage(Closed, "获取接入网关失败:"+err.Error())) - } - if !g.Connected { - return WriteMessage(ws, NewMessage(Closed, "接入网关不可用:"+g.Message)) - } - exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port) - if err != nil { - return WriteMessage(ws, NewMessage(Closed, "创建隧道失败:"+err.Error())) - } - defer g.CloseSshTunnel(s.ID) - ip = exposedIP - port = exposedPort - } - - recording := "" - var isRecording = false - property, err := propertyRepository.FindByName(guacd.EnableRecording) - if err == nil && property.Value == "true" { - isRecording = true - } - - if isRecording { - recording = path.Join(config.GlobalCfg.Guacd.Recording, sessionId, "recording.cast") - } - - attributes, err := assetRepository.FindAssetAttrMapByAssetId(s.AssetId) - if err != nil { - return WriteMessage(ws, NewMessage(Closed, "获取资产属性失败:"+err.Error())) - } - - var xterm = "xterm-256color" - var nextTerminal *term.NextTerminal - if "true" == attributes[constant.SocksProxyEnable] { - nextTerminal, err = term.NewNextTerminalUseSocks(ip, port, username, password, privateKey, passphrase, rows, cols, recording, xterm, true, attributes[constant.SocksProxyHost], attributes[constant.SocksProxyPort], attributes[constant.SocksProxyUsername], attributes[constant.SocksProxyPassword]) - } else { - nextTerminal, err = term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording, xterm, true) - } - - if err != nil { - return WriteMessage(ws, NewMessage(Closed, "创建SSH客户端失败:"+err.Error())) - } - - if err := nextTerminal.RequestPty(xterm, rows, cols); err != nil { - return err - } - - if err := nextTerminal.Shell(); err != nil { - return err - } - - sess := model.Session{ - ConnectionId: sessionId, - Width: cols, - Height: rows, - Status: constant.Connecting, - Recording: recording, - } - if sess.Recording == "" { - // 未录屏时无需审计 - sess.Reviewed = true - } - // 创建新会话 - log.Debugf("创建新会话 %v", sess.ConnectionId) - if err := sessionRepository.UpdateById(&sess, sessionId); err != nil { - return err - } - - if err := WriteMessage(ws, NewMessage(Connected, "")); err != nil { - return err - } - - nextSession := &session.Session{ - ID: s.ID, - Protocol: s.Protocol, - Mode: s.Mode, - WebSocket: ws, - GuacdTunnel: nil, - NextTerminal: nextTerminal, - Observer: session.NewObserver(s.ID), - } - go nextSession.Observer.Run() - session.GlobalSessionManager.Add <- nextSession - - ctx, cancel := context.WithCancel(context.Background()) - tick := time.NewTicker(time.Millisecond * time.Duration(60)) - defer tick.Stop() - - var buf []byte - dataChan := make(chan rune) - - go func() { - SshLoop: - for { - select { - case <-ctx.Done(): - log.Debugf("WebSocket已关闭,即将关闭SSH连接...") - break SshLoop - default: - r, size, err := nextTerminal.StdoutReader.ReadRune() - if err != nil { - log.Debugf("SSH 读取失败,即将退出循环...") - _ = WriteMessage(ws, NewMessage(Closed, "")) - break SshLoop - } - if size > 0 { - dataChan <- r - } - } - } - log.Debugf("SSH 连接已关闭,退出循环。") - }() - - go func() { - tickLoop: - for { - select { - case <-ctx.Done(): - break tickLoop - case <-tick.C: - if len(buf) > 0 { - s := string(buf) - // 录屏 - if isRecording { - _ = nextTerminal.Recorder.WriteData(s) - } - // 监控 - if len(nextSession.Observer.All()) > 0 { - obs := nextSession.Observer.All() - for _, ob := range obs { - _ = WriteMessage(ob.WebSocket, NewMessage(Data, s)) - } - } - if err := WriteMessage(ws, NewMessage(Data, s)); err != nil { - log.Debugf("WebSocket写入失败,即将退出循环...") - cancel() - } - buf = []byte{} - } - case data := <-dataChan: - if data != utf8.RuneError { - p := make([]byte, utf8.RuneLen(data)) - utf8.EncodeRune(p, data) - buf = append(buf, p...) - } else { - buf = append(buf, []byte("@")...) - } - } - } - log.Debugf("SSH 连接已关闭,退出定时器循环。") - }() - - //var enterKeys []rune - //enterIndex := 0 - for { - _, message, err := ws.ReadMessage() - if err != nil { - // web socket会话关闭后主动关闭ssh会话 - log.Debugf("WebSocket已关闭") - CloseSessionById(sessionId, Normal, "用户正常退出") - cancel() - break - } - - msg, err := ParseMessage(string(message)) - if err != nil { - log.Warnf("消息解码失败: %v, 原始字符串:%v", err, string(message)) - continue - } - - switch msg.Type { - case Resize: - decodeString, err := base64.StdEncoding.DecodeString(msg.Content) - if err != nil { - log.Warnf("Base64解码失败: %v,原始字符串:%v", err, msg.Content) - continue - } - var winSize WindowSize - err = json.Unmarshal(decodeString, &winSize) - if err != nil { - log.Warnf("解析SSH会话窗口大小失败: %v,原始字符串:%v", err, msg.Content) - continue - } - if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil { - log.Warnf("更改SSH会话窗口大小失败: %v", err) - } - _ = sessionRepository.UpdateWindowSizeById(winSize.Rows, winSize.Cols, sessionId) - case Data: - input := []byte(msg.Content) - //hexInput := hex.EncodeToString(input) - //switch hexInput { - //case "0d": // 回车 - // DealCommand(enterKeys) - // // 清空输入的字符 - // enterKeys = enterKeys[:0] - // enterIndex = 0 - //case "7f": // backspace - // enterIndex-- - // if enterIndex < 0 { - // enterIndex = 0 - // } - // temp := enterKeys[:enterIndex] - // if len(enterKeys) > enterIndex { - // enterKeys = append(temp, enterKeys[enterIndex+1:]...) - // } else { - // enterKeys = temp - // } - //case "1b5b337e": // del - // temp := enterKeys[:enterIndex] - // if len(enterKeys) > enterIndex { - // enterKeys = append(temp, enterKeys[enterIndex+1:]...) - // } else { - // enterKeys = temp - // } - // enterIndex-- - // if enterIndex < 0 { - // enterIndex = 0 - // } - //case "1b5b41": - //case "1b5b42": - // break - //case "1b5b43": // -> - // enterIndex++ - // if enterIndex > len(enterKeys) { - // enterIndex = len(enterKeys) - // } - //case "1b5b44": // <- - // enterIndex-- - // if enterIndex < 0 { - // enterIndex = 0 - // } - //default: - // enterKeys = utils.InsertSlice(enterIndex, []rune(msg.Content), enterKeys) - // enterIndex++ - //} - _, err := nextTerminal.Write(input) - if err != nil { - CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") - } - case Ping: - _, _, err := nextTerminal.SshClient.Conn.SendRequest("helloworld1024@foxmail.com", true, nil) - if err != nil { - CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") - } else { - _ = WriteMessage(ws, NewMessage(Ping, "")) - } - - } - } - return err -} - -func permissionCheck(c echo.Context, assetId string) error { - user, _ := GetCurrentAccount(c) - if constant.TypeUser == user.Type { - // 检测是否有访问权限 - assetIds, err := resourceSharerRepository.FindAssetIdsByUserId(user.ID) - if err != nil { - return err - } - - if !utils.Contains(assetIds, assetId) { - return errors.New("您没有权限访问此资产") - } - } - return nil -} - -func WriteMessage(ws *websocket.Conn, msg Message) error { - message := []byte(msg.ToString()) - return ws.WriteMessage(websocket.TextMessage, message) -} - -func CreateNextTerminalBySession(session model.Session) (*term.NextTerminal, error) { - var ( - username = session.Username - password = session.Password - privateKey = session.PrivateKey - passphrase = session.Passphrase - ip = session.IP - port = session.Port - ) - return term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, 10, 10, "", "", false) -} - -func SshMonitor(c echo.Context) error { - ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) - if err != nil { - log.Errorf("升级为WebSocket协议失败:%v", err.Error()) - return err - } - - defer ws.Close() - - sessionId := c.QueryParam("sessionId") - s, err := sessionRepository.FindById(sessionId) - if err != nil { - return WriteMessage(ws, NewMessage(Closed, "获取会话失败")) - } - - nextSession := session.GlobalSessionManager.GetById(sessionId) - if nextSession == nil { - return WriteMessage(ws, NewMessage(Closed, "会话已离线")) - } - - obId := utils.UUID() - obSession := &session.Session{ - ID: obId, - Protocol: s.Protocol, - Mode: s.Mode, - WebSocket: ws, - } - nextSession.Observer.Add <- obSession - log.Debugf("会话 %v 观察者 %v 进入", sessionId, obId) - - for { - _, _, err := ws.ReadMessage() - if err != nil { - log.Debugf("会话 %v 观察者 %v 退出", sessionId, obId) - nextSession.Observer.Del <- obId - break - } - } - return nil -} diff --git a/server/api/storage.go b/server/api/storage.go index 071724f..6bb5d52 100644 --- a/server/api/storage.go +++ b/server/api/storage.go @@ -1,10 +1,8 @@ package api import ( - "bufio" + "context" "errors" - "io" - "mime/multipart" "os" "path" "strconv" @@ -12,12 +10,16 @@ import ( "next-terminal/server/constant" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func StoragePagingEndpoint(c echo.Context) error { +type StorageApi struct{} + +func (api StorageApi) StoragePagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -25,12 +27,12 @@ func StoragePagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := storageRepository.Find(pageIndex, pageSize, name, order, field) + items, total, err := repository.StorageRepository.Find(context.TODO(), pageIndex, pageSize, name, order, field) if err != nil { return err } - drivePath := storageService.GetBaseDrivePath() + drivePath := service.StorageService.GetBaseDrivePath() for i := range items { item := items[i] @@ -42,13 +44,13 @@ func StoragePagingEndpoint(c echo.Context) error { } } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func StorageCreateEndpoint(c echo.Context) error { +func (api StorageApi) StorageCreateEndpoint(c echo.Context) error { var item model.Storage if err := c.Bind(&item); err != nil { return err @@ -60,24 +62,24 @@ func StorageCreateEndpoint(c echo.Context) error { item.Created = utils.NowJsonTime() item.Owner = account.ID // 创建对应的目录文件夹 - drivePath := storageService.GetBaseDrivePath() + drivePath := service.StorageService.GetBaseDrivePath() if err := os.MkdirAll(path.Join(drivePath, item.ID), os.ModePerm); err != nil { return err } - if err := storageRepository.Create(&item); err != nil { + if err := repository.StorageRepository.Create(context.TODO(), &item); err != nil { return err } return Success(c, "") } -func StorageUpdateEndpoint(c echo.Context) error { +func (api StorageApi) StorageUpdateEndpoint(c echo.Context) error { id := c.Param("id") var item model.Storage if err := c.Bind(&item); err != nil { return err } - drivePath := storageService.GetBaseDrivePath() + drivePath := service.StorageService.GetBaseDrivePath() dirSize, err := utils.DirSize(path.Join(drivePath, item.ID)) if err != nil { return err @@ -87,7 +89,7 @@ func StorageUpdateEndpoint(c echo.Context) error { return errors.New("空间大小不能小于已使用大小") } - storage, err := storageRepository.FindById(id) + storage, err := repository.StorageRepository.FindById(context.TODO(), id) if err != nil { return err } @@ -95,20 +97,20 @@ func StorageUpdateEndpoint(c echo.Context) error { storage.LimitSize = item.LimitSize storage.IsShare = item.IsShare - if err := storageRepository.UpdateById(&storage, id); err != nil { + if err := repository.StorageRepository.UpdateById(context.TODO(), &storage, id); err != nil { return err } return Success(c, "") } -func StorageGetEndpoint(c echo.Context) error { +func (api StorageApi) StorageGetEndpoint(c echo.Context) error { storageId := c.Param("id") - storage, err := storageRepository.FindById(storageId) + storage, err := repository.StorageRepository.FindById(context.TODO(), storageId) if err != nil { return err } structMap := utils.StructToMap(storage) - drivePath := storageService.GetBaseDrivePath() + drivePath := service.StorageService.GetBaseDrivePath() dirSize, err := utils.DirSize(path.Join(drivePath, storageId)) if err != nil { structMap["usedSize"] = -1 @@ -119,28 +121,28 @@ func StorageGetEndpoint(c echo.Context) error { return Success(c, structMap) } -func StorageSharesEndpoint(c echo.Context) error { - storages, err := storageRepository.FindShares() +func (api StorageApi) StorageSharesEndpoint(c echo.Context) error { + storages, err := repository.StorageRepository.FindShares(context.TODO()) if err != nil { return err } return Success(c, storages) } -func StorageDeleteEndpoint(c echo.Context) error { +func (api StorageApi) StorageDeleteEndpoint(c echo.Context) error { ids := c.Param("id") split := strings.Split(ids, ",") for i := range split { id := split[i] - if err := storageService.DeleteStorageById(id, false); err != nil { + if err := service.StorageService.DeleteStorageById(id, false); err != nil { return err } } return Success(c, nil) } -func PermissionCheck(c echo.Context, id string) error { - storage, err := storageRepository.FindById(id) +func (api StorageApi) PermissionCheck(c echo.Context, id string) error { + storage, err := repository.StorageRepository.FindById(context.TODO(), id) if err != nil { return err } @@ -153,49 +155,31 @@ func PermissionCheck(c echo.Context, id string) error { return nil } -func StorageLsEndpoint(c echo.Context) error { +func (api StorageApi) StorageLsEndpoint(c echo.Context) error { storageId := c.Param("storageId") - if err := PermissionCheck(c, storageId); err != nil { + if err := api.PermissionCheck(c, storageId); err != nil { return err } remoteDir := c.FormValue("dir") - return StorageLs(c, remoteDir, storageId) -} - -func StorageLs(c echo.Context, remoteDir, storageId string) error { - drivePath := storageService.GetBaseDrivePath() - if strings.Contains(remoteDir, "../") { - return Fail(c, -1, "非法请求 :(") - } - files, err := storageService.Ls(path.Join(drivePath, storageId), remoteDir) + err, files := service.StorageService.StorageLs(remoteDir, storageId) if err != nil { return err } return Success(c, files) } -func StorageDownloadEndpoint(c echo.Context) error { +func (api StorageApi) StorageDownloadEndpoint(c echo.Context) error { storageId := c.Param("storageId") - if err := PermissionCheck(c, storageId); err != nil { + if err := api.PermissionCheck(c, storageId); err != nil { return err } remoteFile := c.QueryParam("file") - return StorageDownload(c, remoteFile, storageId) + return service.StorageService.StorageDownload(c, remoteFile, storageId) } -func StorageDownload(c echo.Context, remoteFile, storageId string) error { - drivePath := storageService.GetBaseDrivePath() - if strings.Contains(remoteFile, "../") { - return Fail(c, -1, "非法请求 :(") - } - // 获取带后缀的文件名称 - filenameWithSuffix := path.Base(remoteFile) - return c.Attachment(path.Join(path.Join(drivePath, storageId), remoteFile), filenameWithSuffix) -} - -func StorageUploadEndpoint(c echo.Context) error { +func (api StorageApi) StorageUploadEndpoint(c echo.Context) error { storageId := c.Param("storageId") - if err := PermissionCheck(c, storageId); err != nil { + if err := api.PermissionCheck(c, storageId); err != nil { return err } file, err := c.FormFile("file") @@ -203,150 +187,58 @@ func StorageUploadEndpoint(c echo.Context) error { return err } - return StorageUpload(c, file, storageId) -} - -func StorageUpload(c echo.Context, file *multipart.FileHeader, storageId string) error { - drivePath := storageService.GetBaseDrivePath() - storage, _ := storageRepository.FindById(storageId) - if storage.LimitSize > 0 { - dirSize, err := utils.DirSize(path.Join(drivePath, storageId)) - if err != nil { - return err - } - if dirSize+file.Size > storage.LimitSize { - return errors.New("可用空间不足") - } - } - - filename := file.Filename - src, err := file.Open() - if err != nil { - return err - } - - remoteDir := c.QueryParam("dir") - remoteFile := path.Join(remoteDir, filename) - - if strings.Contains(remoteDir, "../") { - return Fail(c, -1, "非法请求 :(") - } - if strings.Contains(remoteFile, "../") { - return Fail(c, -1, "非法请求 :(") - } - - // 判断文件夹不存在时自动创建 - dir := path.Join(path.Join(drivePath, storageId), remoteDir) - if !utils.FileExists(dir) { - if err := os.MkdirAll(dir, os.ModePerm); err != nil { - return err - } - } - // Destination - dst, err := os.Create(path.Join(path.Join(drivePath, storageId), remoteFile)) - if err != nil { - return err - } - defer dst.Close() - - // Copy - if _, err = io.Copy(dst, src); err != nil { + if err := service.StorageService.StorageUpload(c, file, storageId); err != nil { return err } return Success(c, nil) } -func StorageMkDirEndpoint(c echo.Context) error { +func (api StorageApi) StorageMkDirEndpoint(c echo.Context) error { storageId := c.Param("storageId") - if err := PermissionCheck(c, storageId); err != nil { + if err := api.PermissionCheck(c, storageId); err != nil { return err } remoteDir := c.QueryParam("dir") - return StorageMkDir(c, remoteDir, storageId) -} - -func StorageMkDir(c echo.Context, remoteDir, storageId string) error { - drivePath := storageService.GetBaseDrivePath() - if strings.Contains(remoteDir, "../") { - return Fail(c, -1, ":) 非法请求") - } - if err := os.MkdirAll(path.Join(path.Join(drivePath, storageId), remoteDir), os.ModePerm); err != nil { + if err := service.StorageService.StorageMkDir(remoteDir, storageId); err != nil { return err } return Success(c, nil) } -func StorageRmEndpoint(c echo.Context) error { +func (api StorageApi) StorageRmEndpoint(c echo.Context) error { storageId := c.Param("storageId") - if err := PermissionCheck(c, storageId); err != nil { + if err := api.PermissionCheck(c, storageId); err != nil { return err } // 文件夹或者文件 file := c.FormValue("file") - return StorageRm(c, file, storageId) -} - -func StorageRm(c echo.Context, file, storageId string) error { - drivePath := storageService.GetBaseDrivePath() - if strings.Contains(file, "../") { - return Fail(c, -1, ":) 非法请求") - } - if err := os.RemoveAll(path.Join(path.Join(drivePath, storageId), file)); err != nil { + if err := service.StorageService.StorageRm(file, storageId); err != nil { return err } return Success(c, nil) } -func StorageRenameEndpoint(c echo.Context) error { +func (api StorageApi) StorageRenameEndpoint(c echo.Context) error { storageId := c.Param("storageId") - if err := PermissionCheck(c, storageId); err != nil { + if err := api.PermissionCheck(c, storageId); err != nil { return err } oldName := c.QueryParam("oldName") newName := c.QueryParam("newName") - return StorageRename(c, oldName, newName, storageId) -} - -func StorageRename(c echo.Context, oldName, newName, storageId string) error { - drivePath := storageService.GetBaseDrivePath() - if strings.Contains(oldName, "../") { - return Fail(c, -1, ":) 非法请求") - } - if strings.Contains(newName, "../") { - return Fail(c, -1, ":) 非法请求") - } - if err := os.Rename(path.Join(path.Join(drivePath, storageId), oldName), path.Join(path.Join(drivePath, storageId), newName)); err != nil { + if err := service.StorageService.StorageRename(oldName, newName, storageId); err != nil { return err } return Success(c, nil) } -func StorageEditEndpoint(c echo.Context) error { +func (api StorageApi) StorageEditEndpoint(c echo.Context) error { storageId := c.Param("storageId") - if err := PermissionCheck(c, storageId); err != nil { + if err := api.PermissionCheck(c, storageId); err != nil { return err } file := c.FormValue("file") fileContent := c.FormValue("fileContent") - return StorageEdit(c, file, fileContent, storageId) -} - -func StorageEdit(c echo.Context, file string, fileContent string, storageId string) error { - drivePath := storageService.GetBaseDrivePath() - if strings.Contains(file, "../") { - return Fail(c, -1, ":) 非法请求") - } - realFilePath := path.Join(path.Join(drivePath, storageId), file) - dstFile, err := os.OpenFile(realFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) - if err != nil { - return err - } - defer dstFile.Close() - write := bufio.NewWriter(dstFile) - if _, err := write.WriteString(fileContent); err != nil { - return err - } - if err := write.Flush(); err != nil { + if err := service.StorageService.StorageEdit(file, fileContent, storageId); err != nil { return err } return Success(c, nil) diff --git a/server/api/strategy.go b/server/api/strategy.go index 9600a59..a837caa 100644 --- a/server/api/strategy.go +++ b/server/api/strategy.go @@ -1,24 +1,29 @@ package api import ( + "context" + "strconv" "strings" "next-terminal/server/model" + "next-terminal/server/repository" "next-terminal/server/utils" "github.com/labstack/echo/v4" ) -func StrategyAllEndpoint(c echo.Context) error { - items, err := strategyRepository.FindAll() +type StrategyApi struct{} + +func (api StrategyApi) StrategyAllEndpoint(c echo.Context) error { + items, err := repository.StrategyRepository.FindAll(context.TODO()) if err != nil { return err } return Success(c, items) } -func StrategyPagingEndpoint(c echo.Context) error { +func (api StrategyApi) StrategyPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -26,18 +31,18 @@ func StrategyPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := strategyRepository.Find(pageIndex, pageSize, name, order, field) + items, total, err := repository.StrategyRepository.Find(context.TODO(), pageIndex, pageSize, name, order, field) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func StrategyCreateEndpoint(c echo.Context) error { +func (api StrategyApi) StrategyCreateEndpoint(c echo.Context) error { var item model.Strategy if err := c.Bind(&item); err != nil { return err @@ -45,32 +50,32 @@ func StrategyCreateEndpoint(c echo.Context) error { item.ID = utils.UUID() item.Created = utils.NowJsonTime() - if err := strategyRepository.Create(&item); err != nil { + if err := repository.StrategyRepository.Create(context.TODO(), &item); err != nil { return err } return Success(c, "") } -func StrategyDeleteEndpoint(c echo.Context) error { +func (api StrategyApi) StrategyDeleteEndpoint(c echo.Context) error { ids := c.Param("id") split := strings.Split(ids, ",") for i := range split { id := split[i] - if err := strategyRepository.DeleteById(id); err != nil { + if err := repository.StrategyRepository.DeleteById(context.TODO(), id); err != nil { return err } } return Success(c, nil) } -func StrategyUpdateEndpoint(c echo.Context) error { +func (api StrategyApi) StrategyUpdateEndpoint(c echo.Context) error { id := c.Param("id") var item model.Strategy if err := c.Bind(&item); err != nil { return err } - if err := strategyRepository.UpdateById(&item, id); err != nil { + if err := repository.StrategyRepository.UpdateById(context.TODO(), &item, id); err != nil { return err } return Success(c, "") diff --git a/server/api/term.go b/server/api/term.go new file mode 100644 index 0000000..0a28bbc --- /dev/null +++ b/server/api/term.go @@ -0,0 +1,288 @@ +package api + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "path" + "strconv" + + "next-terminal/server/config" + "next-terminal/server/constant" + "next-terminal/server/dto" + "next-terminal/server/global/session" + "next-terminal/server/guacd" + "next-terminal/server/log" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" + "next-terminal/server/term" + "next-terminal/server/utils" + + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" +) + +const ( + Closed = 0 + Connected = 1 + Data = 2 + Resize = 3 + Ping = 4 +) + +type WebTerminalApi struct { +} + +func (api WebTerminalApi) SshEndpoint(c echo.Context) error { + ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) + if err != nil { + log.Errorf("升级为WebSocket协议失败:%v", err.Error()) + return err + } + + defer func() { + _ = ws.Close() + }() + ctx := context.TODO() + + sessionId := c.QueryParam("sessionId") + cols, _ := strconv.Atoi(c.QueryParam("cols")) + rows, _ := strconv.Atoi(c.QueryParam("rows")) + + s, err := service.SessionService.FindByIdAndDecrypt(ctx, sessionId) + if err != nil { + return WriteMessage(ws, dto.NewMessage(Closed, "获取会话失败")) + } + + if err := api.permissionCheck(c, s.AssetId); err != nil { + return WriteMessage(ws, dto.NewMessage(Closed, err.Error())) + } + + var ( + username = s.Username + password = s.Password + privateKey = s.PrivateKey + passphrase = s.Passphrase + ip = s.IP + port = s.Port + ) + + if s.AccessGatewayId != "" && s.AccessGatewayId != "-" { + g, err := service.GatewayService.GetGatewayAndReconnectById(s.AccessGatewayId) + if err != nil { + return WriteMessage(ws, dto.NewMessage(Closed, "获取接入网关失败:"+err.Error())) + } + if !g.Connected { + return WriteMessage(ws, dto.NewMessage(Closed, "接入网关不可用:"+g.Message)) + } + exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port) + if err != nil { + return WriteMessage(ws, dto.NewMessage(Closed, "创建隧道失败:"+err.Error())) + } + defer g.CloseSshTunnel(s.ID) + ip = exposedIP + port = exposedPort + } + + recording := "" + var isRecording = false + property, err := repository.PropertyRepository.FindByName(ctx, guacd.EnableRecording) + if err == nil && property.Value == "true" { + isRecording = true + } + + if isRecording { + recording = path.Join(config.GlobalCfg.Guacd.Recording, sessionId, "recording.cast") + } + + attributes, err := repository.AssetRepository.FindAssetAttrMapByAssetId(ctx, s.AssetId) + if err != nil { + return WriteMessage(ws, dto.NewMessage(Closed, "获取资产属性失败:"+err.Error())) + } + + var xterm = "xterm-256color" + var nextTerminal *term.NextTerminal + if "true" == attributes[constant.SocksProxyEnable] { + nextTerminal, err = term.NewNextTerminalUseSocks(ip, port, username, password, privateKey, passphrase, rows, cols, recording, xterm, true, attributes[constant.SocksProxyHost], attributes[constant.SocksProxyPort], attributes[constant.SocksProxyUsername], attributes[constant.SocksProxyPassword]) + } else { + nextTerminal, err = term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, rows, cols, recording, xterm, true) + } + + if err != nil { + return WriteMessage(ws, dto.NewMessage(Closed, "创建SSH客户端失败:"+err.Error())) + } + + if err := nextTerminal.RequestPty(xterm, rows, cols); err != nil { + return err + } + + if err := nextTerminal.Shell(); err != nil { + return err + } + + sess := model.Session{ + ConnectionId: sessionId, + Width: cols, + Height: rows, + Status: constant.Connecting, + Recording: recording, + } + if sess.Recording == "" { + // 未录屏时无需审计 + sess.Reviewed = true + } + // 创建新会话 + log.Debugf("创建新会话 %v", sess.ConnectionId) + if err := repository.SessionRepository.UpdateById(ctx, &sess, sessionId); err != nil { + return err + } + + if err := WriteMessage(ws, dto.NewMessage(Connected, "")); err != nil { + return err + } + + nextSession := &session.Session{ + ID: s.ID, + Protocol: s.Protocol, + Mode: s.Mode, + WebSocket: ws, + GuacdTunnel: nil, + NextTerminal: nextTerminal, + Observer: session.NewObserver(s.ID), + } + go nextSession.Observer.Run() + session.GlobalSessionManager.Add <- nextSession + + termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal) + termHandler.Start() + + for { + _, message, err := ws.ReadMessage() + if err != nil { + // web socket会话关闭后主动关闭ssh会话 + log.Debugf("WebSocket已关闭") + service.SessionService.CloseSessionById(sessionId, Normal, "用户正常退出") + termHandler.Stop() + break + } + + msg, err := dto.ParseMessage(string(message)) + if err != nil { + log.Warnf("消息解码失败: %v, 原始字符串:%v", err, string(message)) + continue + } + + switch msg.Type { + case Resize: + decodeString, err := base64.StdEncoding.DecodeString(msg.Content) + if err != nil { + log.Warnf("Base64解码失败: %v,原始字符串:%v", err, msg.Content) + continue + } + var winSize dto.WindowSize + err = json.Unmarshal(decodeString, &winSize) + if err != nil { + log.Warnf("解析SSH会话窗口大小失败: %v,原始字符串:%v", err, msg.Content) + continue + } + if err := nextTerminal.WindowChange(winSize.Rows, winSize.Cols); err != nil { + log.Warnf("更改SSH会话窗口大小失败: %v", err) + } + _ = repository.SessionRepository.UpdateWindowSizeById(ctx, winSize.Rows, winSize.Cols, sessionId) + case Data: + input := []byte(msg.Content) + _, err := nextTerminal.Write(input) + if err != nil { + service.SessionService.CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") + } + case Ping: + _, _, err := nextTerminal.SshClient.Conn.SendRequest("helloworld1024@foxmail.com", true, nil) + if err != nil { + service.SessionService.CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") + } else { + _ = WriteMessage(ws, dto.NewMessage(Ping, "")) + } + + } + } + return err +} + +func (api WebTerminalApi) SshMonitorEndpoint(c echo.Context) error { + ws, err := UpGrader.Upgrade(c.Response().Writer, c.Request(), nil) + if err != nil { + log.Errorf("升级为WebSocket协议失败:%v", err.Error()) + return err + } + + defer func() { + _ = ws.Close() + }() + ctx := context.TODO() + + sessionId := c.QueryParam("sessionId") + s, err := repository.SessionRepository.FindById(ctx, sessionId) + if err != nil { + return WriteMessage(ws, dto.NewMessage(Closed, "获取会话失败")) + } + + nextSession := session.GlobalSessionManager.GetById(sessionId) + if nextSession == nil { + return WriteMessage(ws, dto.NewMessage(Closed, "会话已离线")) + } + + obId := utils.UUID() + obSession := &session.Session{ + ID: obId, + Protocol: s.Protocol, + Mode: s.Mode, + WebSocket: ws, + } + nextSession.Observer.Add <- obSession + log.Debugf("会话 %v 观察者 %v 进入", sessionId, obId) + + for { + _, _, err := ws.ReadMessage() + if err != nil { + log.Debugf("会话 %v 观察者 %v 退出", sessionId, obId) + nextSession.Observer.Del <- obId + break + } + } + return nil +} + +func (api WebTerminalApi) permissionCheck(c echo.Context, assetId string) error { + user, _ := GetCurrentAccount(c) + if constant.TypeUser == user.Type { + // 检测是否有访问权限 + assetIds, err := repository.ResourceSharerRepository.FindAssetIdsByUserId(context.TODO(), user.ID) + if err != nil { + return err + } + + if !utils.Contains(assetIds, assetId) { + return errors.New("您没有权限访问此资产") + } + } + return nil +} + +func WriteMessage(ws *websocket.Conn, msg dto.Message) error { + message := []byte(msg.ToString()) + return ws.WriteMessage(websocket.TextMessage, message) +} + +func CreateNextTerminalBySession(session model.Session) (*term.NextTerminal, error) { + var ( + username = session.Username + password = session.Password + privateKey = session.PrivateKey + passphrase = session.Passphrase + ip = session.IP + port = session.Port + ) + return term.NewNextTerminal(ip, port, username, password, privateKey, passphrase, 10, 10, "", "", false) +} diff --git a/server/api/term_handler.go b/server/api/term_handler.go new file mode 100644 index 0000000..8a0c99a --- /dev/null +++ b/server/api/term_handler.go @@ -0,0 +1,104 @@ +package api + +import ( + "context" + "time" + "unicode/utf8" + + "next-terminal/server/dto" + "next-terminal/server/global/session" + "next-terminal/server/term" + + "github.com/gorilla/websocket" +) + +type TermHandler struct { + sessionId string + isRecording bool + ws *websocket.Conn + nextTerminal *term.NextTerminal + ctx context.Context + cancel context.CancelFunc + dataChan chan rune + tick *time.Ticker +} + +func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, nextTerminal *term.NextTerminal) *TermHandler { + ctx, cancel := context.WithCancel(context.Background()) + tick := time.NewTicker(time.Millisecond * time.Duration(60)) + return &TermHandler{ + sessionId: sessionId, + isRecording: isRecording, + ws: ws, + nextTerminal: nextTerminal, + ctx: ctx, + cancel: cancel, + dataChan: make(chan rune), + tick: tick, + } +} + +func (r TermHandler) Start() { + go r.readFormTunnel() + go r.writeToWebsocket() +} + +func (r TermHandler) Stop() { + r.tick.Stop() + r.cancel() +} + +func (r TermHandler) readFormTunnel() { + for { + select { + case <-r.ctx.Done(): + return + default: + rn, size, err := r.nextTerminal.StdoutReader.ReadRune() + if err != nil { + return + } + if size > 0 { + r.dataChan <- rn + } + } + } +} + +func (r TermHandler) writeToWebsocket() { + var buf []byte + for { + select { + case <-r.ctx.Done(): + return + case <-r.tick.C: + if len(buf) > 0 { + s := string(buf) + if err := WriteMessage(r.ws, dto.NewMessage(Data, s)); err != nil { + return + } + // 录屏 + if r.isRecording { + _ = r.nextTerminal.Recorder.WriteData(s) + } + nextSession := session.GlobalSessionManager.GetById(r.sessionId) + // 监控 + if nextSession != nil && len(nextSession.Observer.All()) > 0 { + obs := nextSession.Observer.All() + for _, ob := range obs { + _ = WriteMessage(ob.WebSocket, dto.NewMessage(Data, s)) + } + } + buf = []byte{} + } + case data := <-r.dataChan: + if data != utf8.RuneError { + p := make([]byte, utf8.RuneLen(data)) + utf8.EncodeRune(p, data) + buf = append(buf, p...) + } else { + buf = append(buf, []byte("@")...) + } + } + } +} diff --git a/server/api/user-group.go b/server/api/user-group.go index 650f0ac..dabe649 100644 --- a/server/api/user-group.go +++ b/server/api/user-group.go @@ -1,41 +1,33 @@ package api import ( + "context" "strconv" "strings" - "next-terminal/server/model" - "next-terminal/server/utils" + "next-terminal/server/dto" + "next-terminal/server/repository" + "next-terminal/server/service" "github.com/labstack/echo/v4" ) -type UserGroup struct { - Id string `json:"id"` - Name string `json:"name"` - Members []string `json:"members"` -} +type UserGroupApi struct{} -func UserGroupCreateEndpoint(c echo.Context) error { - var item UserGroup +func (userGroupApi UserGroupApi) UserGroupCreateEndpoint(c echo.Context) error { + var item dto.UserGroup if err := c.Bind(&item); err != nil { return err } - userGroup := model.UserGroup{ - ID: utils.UUID(), - Created: utils.NowJsonTime(), - Name: item.Name, - } - - if err := userGroupRepository.Create(&userGroup, item.Members); err != nil { + if _, err := service.UserGroupService.Create(item.Name, item.Members); err != nil { return err } return Success(c, item) } -func UserGroupPagingEndpoint(c echo.Context) error { +func (userGroupApi UserGroupApi) UserGroupPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) name := c.QueryParam("name") @@ -43,41 +35,38 @@ func UserGroupPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - items, total, err := userGroupRepository.Find(pageIndex, pageSize, name, order, field) + items, total, err := repository.UserGroupRepository.Find(context.TODO(), pageIndex, pageSize, name, order, field) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func UserGroupUpdateEndpoint(c echo.Context) error { +func (userGroupApi UserGroupApi) UserGroupUpdateEndpoint(c echo.Context) error { id := c.Param("id") - var item UserGroup + var item dto.UserGroup if err := c.Bind(&item); err != nil { return err } - userGroup := model.UserGroup{ - Name: item.Name, - } - if err := userGroupRepository.Update(&userGroup, item.Members, id); err != nil { + if err := service.UserGroupService.Update(id, item.Name, item.Members); err != nil { return err } return Success(c, nil) } -func UserGroupDeleteEndpoint(c echo.Context) error { +func (userGroupApi UserGroupApi) UserGroupDeleteEndpoint(c echo.Context) error { ids := c.Param("id") split := strings.Split(ids, ",") for i := range split { userId := split[i] - if err := userGroupRepository.DeleteById(userId); err != nil { + if err := service.UserGroupService.DeleteById(userId); err != nil { return err } } @@ -85,20 +74,20 @@ func UserGroupDeleteEndpoint(c echo.Context) error { return Success(c, nil) } -func UserGroupGetEndpoint(c echo.Context) error { +func (userGroupApi UserGroupApi) UserGroupGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := userGroupRepository.FindById(id) + item, err := repository.UserGroupRepository.FindById(context.TODO(), id) if err != nil { return err } - members, err := userGroupRepository.FindMembersById(id) + members, err := repository.UserGroupMemberRepository.FindUserIdsByUserGroupId(context.TODO(), id) if err != nil { return err } - userGroup := UserGroup{ + userGroup := dto.UserGroup{ Id: item.ID, Name: item.Name, Members: members, diff --git a/server/api/user.go b/server/api/user.go index e3eac21..80370e0 100644 --- a/server/api/user.go +++ b/server/api/user.go @@ -1,56 +1,35 @@ package api import ( - "errors" + "context" + "strconv" "strings" - "next-terminal/server/constant" - "next-terminal/server/global/cache" - "next-terminal/server/log" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/utils" "github.com/labstack/echo/v4" - "gorm.io/gorm" ) -func UserCreateEndpoint(c echo.Context) (err error) { +type UserApi struct{} + +func (userApi UserApi) UserCreateEndpoint(c echo.Context) (err error) { var item model.User if err := c.Bind(&item); err != nil { return err } - if userRepository.ExistByUsername(item.Username) { - return Fail(c, -1, "username is already in use") - } - password := item.Password - - var pass []byte - if pass, err = utils.Encoder.Encode([]byte(password)); err != nil { - return err - } - item.Password = string(pass) - - item.ID = utils.UUID() - item.Created = utils.NowJsonTime() - item.Status = constant.StatusEnabled - - if err := userRepository.Create(&item); err != nil { - return err - } - err = storageService.CreateStorageByUser(&item) - if err != nil { + if err := service.UserService.CreateUser(item); err != nil { return err } - if item.Mail != "" { - go mailService.SendMail(item.Mail, "[Next Terminal] 注册通知", "你好,"+item.Nickname+"。管理员为你注册了账号:"+item.Username+" 密码:"+password) - } return Success(c, item) } -func UserPagingEndpoint(c echo.Context) error { +func (userApi UserApi) UserPagingEndpoint(c echo.Context) error { pageIndex, _ := strconv.Atoi(c.QueryParam("pageIndex")) pageSize, _ := strconv.Atoi(c.QueryParam("pageSize")) username := c.QueryParam("username") @@ -60,19 +39,18 @@ func UserPagingEndpoint(c echo.Context) error { order := c.QueryParam("order") field := c.QueryParam("field") - account, _ := GetCurrentAccount(c) - items, total, err := userRepository.Find(pageIndex, pageSize, username, nickname, mail, order, field, account) + items, total, err := repository.UserRepository.Find(context.TODO(), pageIndex, pageSize, username, nickname, mail, order, field) if err != nil { return err } - return Success(c, H{ + return Success(c, Map{ "total": total, "items": items, }) } -func UserUpdateEndpoint(c echo.Context) error { +func (userApi UserApi) UserUpdateEndpoint(c echo.Context) error { id := c.Param("id") account, _ := GetCurrentAccount(c) @@ -86,14 +64,14 @@ func UserUpdateEndpoint(c echo.Context) error { } item.ID = id - if err := userRepository.Update(&item); err != nil { + if err := repository.UserRepository.Update(context.TODO(), &item); err != nil { return err } return Success(c, nil) } -func UserUpdateStatusEndpoint(c echo.Context) error { +func (userApi UserApi) UserUpdateStatusEndpoint(c echo.Context) error { id := c.Param("id") status := c.QueryParam("status") account, _ := GetCurrentAccount(c) @@ -101,14 +79,14 @@ func UserUpdateStatusEndpoint(c echo.Context) error { return Fail(c, -1, "不能操作自身账户") } - if err := userService.UpdateStatusById(id, status); err != nil { + if err := service.UserService.UpdateStatusById(id, status); err != nil { return err } return Success(c, nil) } -func UserDeleteEndpoint(c echo.Context) error { +func (userApi UserApi) UserDeleteEndpoint(c echo.Context) error { ids := c.Param("id") account, found := GetCurrentAccount(c) if !found { @@ -120,16 +98,7 @@ func UserDeleteEndpoint(c echo.Context) error { if account.ID == userId { return Fail(c, -1, "不允许删除自身账户") } - // 下线该用户 - if err := userService.LogoutById(userId); err != nil { - return err - } - // 删除用户 - if err := userRepository.DeleteById(userId); err != nil { - return err - } - // 删除用户的默认磁盘空间 - if err := storageService.DeleteStorageById(userId, true); err != nil { + if err := service.UserService.DeleteUserById(userId); err != nil { return err } } @@ -137,10 +106,10 @@ func UserDeleteEndpoint(c echo.Context) error { return Success(c, nil) } -func UserGetEndpoint(c echo.Context) error { +func (userApi UserApi) UserGetEndpoint(c echo.Context) error { id := c.Param("id") - item, err := userRepository.FindById(id) + item, err := repository.UserRepository.FindById(context.TODO(), id) if err != nil { return err } @@ -148,14 +117,14 @@ func UserGetEndpoint(c echo.Context) error { return Success(c, item) } -func UserChangePasswordEndpoint(c echo.Context) error { +func (userApi UserApi) UserChangePasswordEndpoint(c echo.Context) error { id := c.Param("id") password := c.FormValue("password") if password == "" { return Fail(c, -1, "请输入密码") } - user, err := userRepository.FindById(id) + user, err := repository.UserRepository.FindById(context.TODO(), id) if err != nil { return err } @@ -168,61 +137,25 @@ func UserChangePasswordEndpoint(c echo.Context) error { Password: string(passwd), ID: id, } - if err := userRepository.Update(u); err != nil { + if err := repository.UserRepository.Update(context.TODO(), u); err != nil { return err } if user.Mail != "" { - go mailService.SendMail(user.Mail, "[Next Terminal] 密码修改通知", "你好,"+user.Nickname+"。管理员已将你的密码修改为:"+password) + go service.MailService.SendMail(user.Mail, "[Next Terminal] 密码修改通知", "你好,"+user.Nickname+"。管理员已将你的密码修改为:"+password) } return Success(c, "") } -func UserResetTotpEndpoint(c echo.Context) error { +func (userApi UserApi) UserResetTotpEndpoint(c echo.Context) error { id := c.Param("id") u := &model.User{ TOTPSecret: "-", ID: id, } - if err := userRepository.Update(u); err != nil { + if err := repository.UserRepository.Update(context.TODO(), u); err != nil { return err } return Success(c, "") } - -func ReloadToken() error { - loginLogs, err := loginLogRepository.FindAliveLoginLogs() - if err != nil { - return err - } - - for i := range loginLogs { - loginLog := loginLogs[i] - token := loginLog.ID - user, err := userRepository.FindByUsername(loginLog.Username) - if err != nil { - if errors.Is(gorm.ErrRecordNotFound, err) { - _ = loginLogRepository.DeleteById(token) - } - continue - } - - authorization := Authorization{ - Token: token, - Remember: loginLog.Remember, - User: user, - } - - cacheKey := userService.BuildCacheKeyByToken(token) - - if authorization.Remember { - // 记住登录有效期两周 - cache.GlobalCache.Set(cacheKey, authorization, RememberEffectiveTime) - } else { - cache.GlobalCache.Set(cacheKey, authorization, NotRememberEffectiveTime) - } - log.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token) - } - return nil -} diff --git a/server/app/app.go b/server/app/app.go new file mode 100644 index 0000000..ca5a555 --- /dev/null +++ b/server/app/app.go @@ -0,0 +1,118 @@ +package app + +import ( + "encoding/json" + "fmt" + + "next-terminal/server/cli" + "next-terminal/server/config" + "next-terminal/server/constant" + "next-terminal/server/service" + + "github.com/labstack/echo/v4" +) + +var app *App + +type App struct { + Server *echo.Echo +} + +func newApp() *App { + return &App{} +} + +func init() { + setupCache() + app = newApp() + if err := app.InitDBData(); err != nil { + panic(err) + } + if err := app.ReloadData(); err != nil { + panic(err) + } + app.Server = setupRoutes() +} + +func (app App) InitDBData() (err error) { + if err := service.PropertyService.DeleteDeprecatedProperty(); err != nil { + return err + } + if err := service.GatewayService.ReConnectAll(); err != nil { + return err + } + if err := service.PropertyService.InitProperties(); err != nil { + return err + } + if err := service.UserService.InitUser(); err != nil { + return err + } + if err := service.JobService.InitJob(); err != nil { + return err + } + if err := service.UserService.FixUserOnlineState(); err != nil { + return err + } + if err := service.SessionService.FixSessionState(); err != nil { + return err + } + if err := service.SessionService.EmptyPassword(); err != nil { + return err + } + if err := service.CredentialService.EncryptAll(); err != nil { + return err + } + if err := service.AssetService.EncryptAll(); err != nil { + return err + } + if err := service.StorageService.InitStorages(); err != nil { + return err + } + + return nil +} + +func (app App) ReloadData() error { + if err := service.SecurityService.ReloadAccessSecurity(); err != nil { + return err + } + if err := service.UserService.ReloadToken(); err != nil { + return err + } + if err := service.AccessTokenService.Reload(); err != nil { + return err + } + return nil +} + +func Run() error { + + fmt.Printf(constant.Banner, constant.Version) + + if config.GlobalCfg.Debug { + jsonBytes, err := json.MarshalIndent(config.GlobalCfg, "", " ") + if err != nil { + return err + } + fmt.Printf("当前配置为: %v\n", string(jsonBytes)) + } + + _cli := cli.NewCli() + + if config.GlobalCfg.ResetPassword != "" { + return _cli.ResetPassword(config.GlobalCfg.ResetPassword) + } + if config.GlobalCfg.ResetTotp != "" { + return _cli.ResetTotp(config.GlobalCfg.ResetTotp) + } + + if config.GlobalCfg.NewEncryptionKey != "" { + return _cli.ChangeEncryptionKey(config.GlobalCfg.EncryptionKey, config.GlobalCfg.NewEncryptionKey) + } + + if config.GlobalCfg.Server.Cert != "" && config.GlobalCfg.Server.Key != "" { + return app.Server.StartTLS(config.GlobalCfg.Server.Addr, config.GlobalCfg.Server.Cert, config.GlobalCfg.Server.Key) + } else { + return app.Server.Start(config.GlobalCfg.Server.Addr) + } +} diff --git a/server/app/cache.go b/server/app/cache.go new file mode 100644 index 0000000..e43b7e2 --- /dev/null +++ b/server/app/cache.go @@ -0,0 +1,10 @@ +package app + +import ( + "next-terminal/server/global/cache" + "next-terminal/server/service" +) + +func setupCache() { + cache.TokenManager.OnEvicted(service.UserService.OnEvicted) +} diff --git a/server/api/middleware.go b/server/app/middleware.go similarity index 55% rename from server/api/middleware.go rename to server/app/middleware.go index a80b168..2548784 100644 --- a/server/api/middleware.go +++ b/server/app/middleware.go @@ -1,12 +1,13 @@ -package api +package app import ( "fmt" "net" "strings" - "time" + "next-terminal/server/api" "next-terminal/server/constant" + "next-terminal/server/dto" "next-terminal/server/global/cache" "next-terminal/server/global/security" "next-terminal/server/utils" @@ -21,10 +22,10 @@ func ErrorHandler(next echo.HandlerFunc) echo.HandlerFunc { if he, ok := err.(*echo.HTTPError); ok { message := fmt.Sprintf("%v", he.Message) - return Fail(c, he.Code, message) + return api.Fail(c, he.Code, message) } - return Fail(c, 0, err.Error()) + return api.Fail(c, 0, err.Error()) } return nil } @@ -74,7 +75,7 @@ func TcpWall(next echo.HandlerFunc) echo.HandlerFunc { } if s.Rule == constant.AccessRuleReject { if c.Request().Header.Get("X-Requested-With") != "" || c.Request().Header.Get(constant.Token) != "" { - return Fail(c, 0, "您的访问请求被拒绝 :(") + return api.Fail(c, 0, "您的访问请求被拒绝 :(") } else { return c.HTML(666, "您的访问请求被拒绝 :(") } @@ -85,9 +86,9 @@ func TcpWall(next echo.HandlerFunc) echo.HandlerFunc { } } -func Auth(next echo.HandlerFunc) echo.HandlerFunc { +var anonymousUrls = []string{"/login", "/static", "/favicon.ico", "/logo.svg", "/asciinema"} - anonymousUrls := []string{"/login", "/static", "/favicon.ico", "/logo.svg", "/asciinema"} +func Auth(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -102,21 +103,46 @@ func Auth(next echo.HandlerFunc) echo.HandlerFunc { } } - token := GetToken(c) + token := api.GetToken(c) if token == "" { - return Fail(c, 401, "您的登录信息已失效,请重新登录后再试。") - } - cacheKey := userService.BuildCacheKeyByToken(token) - authorization, found := cache.GlobalCache.Get(cacheKey) - if !found { - return Fail(c, 401, "您的登录信息已失效,请重新登录后再试。") + return api.Fail(c, 401, "您的登录信息已失效,请重新登录后再试。") } - if authorization.(Authorization).Remember { - // 记住登录有效期两周 - cache.GlobalCache.Set(cacheKey, authorization, time.Hour*time.Duration(24*14)) - } else { - cache.GlobalCache.Set(cacheKey, authorization, time.Hour*time.Duration(2)) + v, found := cache.TokenManager.Get(token) + if !found { + return api.Fail(c, 401, "您的登录信息已失效,请重新登录后再试。") + } + + authorization := v.(dto.Authorization) + + if strings.EqualFold(constant.LoginToken, authorization.Type) { + if authorization.Remember { + // 记住登录有效期两周 + cache.TokenManager.Set(token, authorization, cache.RememberMeExpiration) + } else { + cache.TokenManager.Set(token, authorization, cache.NotRememberExpiration) + } + } else if strings.EqualFold(constant.ShareSession, authorization.Type) { + id := c.Param("id") + uri = strings.Split(uri, "?")[0] + allowUrls := []string{ + "/share-sessions/" + id, + "/sessions", + "/sessions/" + id + "/tunnel", + "/sessions/" + id + "/connect", + "/sessions/" + id + "/resize", + "/sessions/" + id + "/stats", + "/sessions/" + id + "/ls", + "/sessions/" + id + "/download", + "/sessions/" + id + "/upload", + "/sessions/" + id + "/edit", + "/sessions/" + id + "/mkdir", + "/sessions/" + id + "/rm", + "/sessions/" + id + "/rename", + } + if !utils.Contains(allowUrls, uri) { + return api.Fail(c, 401, "您的登录信息已失效,请重新登录后再试。") + } } return next(c) @@ -126,13 +152,13 @@ func Auth(next echo.HandlerFunc) echo.HandlerFunc { func Admin(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - account, found := GetCurrentAccount(c) + account, found := api.GetCurrentAccount(c) if !found { - return Fail(c, 401, "您的登录信息已失效,请重新登录后再试。") + return api.Fail(c, 401, "您的登录信息已失效,请重新登录后再试。") } if account.Type != constant.TypeAdmin { - return Fail(c, 403, "permission denied") + return api.Fail(c, 403, "permission denied") } return next(c) diff --git a/server/app/server.go b/server/app/server.go new file mode 100644 index 0000000..9902c98 --- /dev/null +++ b/server/app/server.go @@ -0,0 +1,254 @@ +package app + +import ( + "net/http" + + "next-terminal/server/api" + + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +func setupRoutes() *echo.Echo { + + e := echo.New() + e.HideBanner = true + //e.Logger = log.GetEchoLogger() + //e.Use(log.Hook()) + e.File("/", "web/build/index.html") + e.File("/asciinema.html", "web/build/asciinema.html") + e.File("/", "web/build/index.html") + e.File("/favicon.ico", "web/build/favicon.ico") + e.File("/logo.png", "web/build/logo.png") + e.Static("/static", "web/build/static") + + e.Use(middleware.Recover()) + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ + Skipper: middleware.DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, + })) + e.Use(ErrorHandler) + e.Use(TcpWall) + e.Use(Auth) + + accountApi := new(api.AccountApi) + guacamoleApi := new(api.GuacamoleApi) + webTerminalApi := new(api.WebTerminalApi) + UserApi := new(api.UserApi) + UserGroupApi := new(api.UserGroupApi) + AssetApi := new(api.AssetApi) + CommandApi := new(api.CommandApi) + CredentialApi := new(api.CredentialApi) + SessionApi := new(api.SessionApi) + ResourceSharerApi := new(api.ResourceSharerApi) + LoginLogApi := new(api.LoginLogApi) + PropertyApi := new(api.PropertyApi) + OverviewApi := new(api.OverviewApi) + JobApi := new(api.JobApi) + SecurityApi := new(api.SecurityApi) + StorageApi := new(api.StorageApi) + StrategyApi := new(api.StrategyApi) + AccessGatewayApi := new(api.AccessGatewayApi) + BackupApi := new(api.BackupApi) + + e.POST("/login", accountApi.LoginEndpoint) + e.POST("/loginWithTotp", accountApi.LoginWithTotpEndpoint) + + e.GET("/ssh", webTerminalApi.SshEndpoint) + e.GET("/ssh-monitor", webTerminalApi.SshMonitorEndpoint) + + account := e.Group("/account") + { + account.GET("/info", accountApi.InfoEndpoint) + account.GET("/assets", accountApi.AccountAssetEndpoint) + account.GET("/storage", accountApi.AccountStorageEndpoint) + account.POST("/logout", accountApi.LogoutEndpoint) + account.POST("/change-password", accountApi.ChangePasswordEndpoint) + account.GET("/reload-totp", accountApi.ReloadTOTPEndpoint) + account.POST("/reset-totp", accountApi.ResetTOTPEndpoint) + account.POST("/confirm-totp", accountApi.ConfirmTOTPEndpoint) + account.GET("/access-token", accountApi.AccessTokenGetEndpoint) + account.POST("/access-token", accountApi.AccessTokenGenEndpoint) + } + + users := e.Group("/users", Admin) + { + users.POST("", UserApi.UserCreateEndpoint) + users.GET("/paging", UserApi.UserPagingEndpoint) + users.PUT("/:id", UserApi.UserUpdateEndpoint) + users.PATCH("/:id/status", UserApi.UserUpdateStatusEndpoint) + users.DELETE("/:id", UserApi.UserDeleteEndpoint) + users.GET("/:id", UserApi.UserGetEndpoint) + users.POST("/:id/change-password", UserApi.UserChangePasswordEndpoint) + users.POST("/:id/reset-totp", UserApi.UserResetTotpEndpoint) + } + + userGroups := e.Group("/user-groups", Admin) + { + userGroups.POST("", UserGroupApi.UserGroupCreateEndpoint) + userGroups.GET("/paging", UserGroupApi.UserGroupPagingEndpoint) + userGroups.PUT("/:id", UserGroupApi.UserGroupUpdateEndpoint) + userGroups.DELETE("/:id", UserGroupApi.UserGroupDeleteEndpoint) + userGroups.GET("/:id", UserGroupApi.UserGroupGetEndpoint) + } + + assets := e.Group("/assets", Admin) + { + assets.GET("", AssetApi.AssetAllEndpoint) + assets.POST("", AssetApi.AssetCreateEndpoint) + assets.POST("/import", AssetApi.AssetImportEndpoint) + assets.GET("/paging", AssetApi.AssetPagingEndpoint) + assets.POST("/:id/tcping", AssetApi.AssetTcpingEndpoint) + assets.PUT("/:id", AssetApi.AssetUpdateEndpoint) + assets.GET("/:id", AssetApi.AssetGetEndpoint) + assets.DELETE("/:id", AssetApi.AssetDeleteEndpoint) + assets.POST("/:id/change-owner", AssetApi.AssetChangeOwnerEndpoint) + } + + e.GET("/tags", AssetApi.AssetTagsEndpoint) + + commands := e.Group("/commands") + { + commands.GET("", CommandApi.CommandAllEndpoint) + commands.GET("/paging", CommandApi.CommandPagingEndpoint) + commands.POST("", CommandApi.CommandCreateEndpoint) + commands.PUT("/:id", CommandApi.CommandUpdateEndpoint) + commands.DELETE("/:id", CommandApi.CommandDeleteEndpoint) + commands.GET("/:id", CommandApi.CommandGetEndpoint) + commands.POST("/:id/change-owner", CommandApi.CommandChangeOwnerEndpoint, Admin) + } + + credentials := e.Group("/credentials", Admin) + { + credentials.GET("", CredentialApi.CredentialAllEndpoint) + credentials.GET("/paging", CredentialApi.CredentialPagingEndpoint) + credentials.POST("", CredentialApi.CredentialCreateEndpoint) + credentials.PUT("/:id", CredentialApi.CredentialUpdateEndpoint) + credentials.DELETE("/:id", CredentialApi.CredentialDeleteEndpoint) + credentials.GET("/:id", CredentialApi.CredentialGetEndpoint) + credentials.POST("/:id/change-owner", CredentialApi.CredentialChangeOwnerEndpoint) + } + + sessions := e.Group("/sessions") + { + sessions.GET("/paging", Admin(SessionApi.SessionPagingEndpoint)) + sessions.POST("/:id/disconnect", Admin(SessionApi.SessionDisconnectEndpoint)) + sessions.DELETE("/:id", Admin(SessionApi.SessionDeleteEndpoint)) + sessions.GET("/:id/recording", Admin(SessionApi.SessionRecordingEndpoint)) + sessions.GET("/:id", Admin(SessionApi.SessionGetEndpoint)) + sessions.POST("/:id/reviewed", Admin(SessionApi.SessionReviewedEndpoint)) + sessions.POST("/:id/unreviewed", Admin(SessionApi.SessionUnViewedEndpoint)) + sessions.POST("/clear", Admin(SessionApi.SessionClearEndpoint)) + sessions.POST("/reviewed", Admin(SessionApi.SessionReviewedAllEndpoint)) + + sessions.POST("", SessionApi.SessionCreateEndpoint) + sessions.POST("/:id/connect", SessionApi.SessionConnectEndpoint) + sessions.GET("/:id/tunnel", guacamoleApi.Guacamole) + sessions.POST("/:id/resize", SessionApi.SessionResizeEndpoint) + sessions.GET("/:id/stats", SessionApi.SessionStatsEndpoint) + + sessions.POST("/:id/ls", SessionApi.SessionLsEndpoint) + sessions.GET("/:id/download", SessionApi.SessionDownloadEndpoint) + sessions.POST("/:id/upload", SessionApi.SessionUploadEndpoint) + sessions.POST("/:id/edit", SessionApi.SessionEditEndpoint) + sessions.POST("/:id/mkdir", SessionApi.SessionMkDirEndpoint) + sessions.POST("/:id/rm", SessionApi.SessionRmEndpoint) + sessions.POST("/:id/rename", SessionApi.SessionRenameEndpoint) + } + + resourceSharers := e.Group("/resource-sharers", Admin) + { + resourceSharers.GET("", ResourceSharerApi.RSGetSharersEndPoint) + resourceSharers.POST("/remove-resources", ResourceSharerApi.ResourceRemoveByUserIdAssignEndPoint) + resourceSharers.POST("/add-resources", ResourceSharerApi.ResourceAddByUserIdAssignEndPoint) + } + + loginLogs := e.Group("login-logs", Admin) + { + loginLogs.GET("/paging", LoginLogApi.LoginLogPagingEndpoint) + loginLogs.DELETE("/:id", LoginLogApi.LoginLogDeleteEndpoint) + loginLogs.POST("/clear", LoginLogApi.LoginLogClearEndpoint) + } + + properties := e.Group("properties", Admin) + { + properties.GET("", PropertyApi.PropertyGetEndpoint) + properties.PUT("", PropertyApi.PropertyUpdateEndpoint) + } + + overview := e.Group("overview", Admin) + { + overview.GET("/counter", OverviewApi.OverviewCounterEndPoint) + overview.GET("/asset", OverviewApi.OverviewAssetEndPoint) + overview.GET("/access", OverviewApi.OverviewAccessEndPoint) + } + + jobs := e.Group("/jobs", Admin) + { + jobs.POST("", JobApi.JobCreateEndpoint) + jobs.GET("/paging", JobApi.JobPagingEndpoint) + jobs.PUT("/:id", JobApi.JobUpdateEndpoint) + jobs.POST("/:id/change-status", JobApi.JobChangeStatusEndpoint) + jobs.POST("/:id/exec", JobApi.JobExecEndpoint) + jobs.DELETE("/:id", JobApi.JobDeleteEndpoint) + jobs.GET("/:id", JobApi.JobGetEndpoint) + jobs.GET("/:id/logs", JobApi.JobGetLogsEndpoint) + jobs.DELETE("/:id/logs", JobApi.JobDeleteLogsEndpoint) + } + + securities := e.Group("/securities", Admin) + { + securities.POST("", SecurityApi.SecurityCreateEndpoint) + securities.GET("/paging", SecurityApi.SecurityPagingEndpoint) + securities.PUT("/:id", SecurityApi.SecurityUpdateEndpoint) + securities.DELETE("/:id", SecurityApi.SecurityDeleteEndpoint) + securities.GET("/:id", SecurityApi.SecurityGetEndpoint) + } + + storages := e.Group("/storages") + { + storages.GET("/paging", StorageApi.StoragePagingEndpoint, Admin) + storages.POST("", StorageApi.StorageCreateEndpoint, Admin) + storages.DELETE("/:id", StorageApi.StorageDeleteEndpoint, Admin) + storages.PUT("/:id", StorageApi.StorageUpdateEndpoint, Admin) + storages.GET("/shares", StorageApi.StorageSharesEndpoint, Admin) + storages.GET("/:id", StorageApi.StorageGetEndpoint, Admin) + + storages.POST("/:storageId/ls", StorageApi.StorageLsEndpoint) + storages.GET("/:storageId/download", StorageApi.StorageDownloadEndpoint) + storages.POST("/:storageId/upload", StorageApi.StorageUploadEndpoint) + storages.POST("/:storageId/mkdir", StorageApi.StorageMkDirEndpoint) + storages.POST("/:storageId/rm", StorageApi.StorageRmEndpoint) + storages.POST("/:storageId/rename", StorageApi.StorageRenameEndpoint) + storages.POST("/:storageId/edit", StorageApi.StorageEditEndpoint) + } + + strategies := e.Group("/strategies", Admin) + { + strategies.GET("", StrategyApi.StrategyAllEndpoint) + strategies.GET("/paging", StrategyApi.StrategyPagingEndpoint) + strategies.POST("", StrategyApi.StrategyCreateEndpoint) + strategies.DELETE("/:id", StrategyApi.StrategyDeleteEndpoint) + strategies.PUT("/:id", StrategyApi.StrategyUpdateEndpoint) + } + + accessGateways := e.Group("/access-gateways", Admin) + { + accessGateways.GET("", AccessGatewayApi.AccessGatewayAllEndpoint) + accessGateways.POST("", AccessGatewayApi.AccessGatewayCreateEndpoint) + accessGateways.GET("/paging", AccessGatewayApi.AccessGatewayPagingEndpoint) + accessGateways.PUT("/:id", AccessGatewayApi.AccessGatewayUpdateEndpoint) + accessGateways.DELETE("/:id", AccessGatewayApi.AccessGatewayDeleteEndpoint) + accessGateways.GET("/:id", AccessGatewayApi.AccessGatewayGetEndpoint) + accessGateways.POST("/:id/reconnect", AccessGatewayApi.AccessGatewayReconnectEndpoint) + } + + backup := e.Group("/backup", Admin) + { + backup.GET("/export", BackupApi.BackupExportEndpoint) + backup.POST("/import", BackupApi.BackupImportEndpoint) + } + + return e +} diff --git a/server/cli/cli.go b/server/cli/cli.go new file mode 100644 index 0000000..619764a --- /dev/null +++ b/server/cli/cli.go @@ -0,0 +1,104 @@ +package cli + +import ( + "context" + "crypto/md5" + "fmt" + "next-terminal/server/env" + "next-terminal/server/service" + + "next-terminal/server/log" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" + + "gorm.io/gorm" +) + +type Cli struct { +} + +func NewCli() *Cli { + return &Cli{} +} + +func (cli Cli) ResetPassword(username string) error { + user, err := repository.UserRepository.FindByUsername(context.TODO(), username) + if err != nil { + return err + } + password := "next-terminal" + passwd, err := utils.Encoder.Encode([]byte(password)) + if err != nil { + return err + } + u := &model.User{ + Password: string(passwd), + ID: user.ID, + } + if err := repository.UserRepository.Update(context.TODO(), u); err != nil { + return err + } + log.Debugf("用户「%v」密码初始化为: %v", user.Username, password) + return nil +} + +func (cli Cli) ResetTotp(username string) error { + user, err := repository.UserRepository.FindByUsername(context.TODO(), username) + if err != nil { + return err + } + u := &model.User{ + TOTPSecret: "-", + ID: user.ID, + } + if err := repository.UserRepository.Update(context.TODO(), u); err != nil { + return err + } + log.Debugf("用户「%v」已重置TOTP", user.Username) + return nil +} + +func (cli Cli) ChangeEncryptionKey(oldEncryptionKey, newEncryptionKey string) error { + + oldPassword := []byte(fmt.Sprintf("%x", md5.Sum([]byte(oldEncryptionKey)))) + newPassword := []byte(fmt.Sprintf("%x", md5.Sum([]byte(newEncryptionKey)))) + + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := context.WithValue(context.TODO(), "db", tx) + credentials, err := repository.CredentialRepository.FindAll(c) + if err != nil { + return err + } + for i := range credentials { + credential := credentials[i] + if err := service.CredentialService.Decrypt(&credential, oldPassword); err != nil { + return err + } + if err := service.CredentialService.Encrypt(&credential, newPassword); err != nil { + return err + } + if err := repository.CredentialRepository.UpdateById(c, &credential, credential.ID); err != nil { + return err + } + } + assets, err := repository.AssetRepository.FindAll(c) + if err != nil { + return err + } + for i := range assets { + asset := assets[i] + if err := service.AssetService.Decrypt(&asset, oldPassword); err != nil { + return err + } + if err := service.AssetService.Encrypt(&asset, newPassword); err != nil { + return err + } + if err := repository.AssetRepository.UpdateById(c, &asset, asset.ID); err != nil { + return err + } + } + log.Infof("encryption key has being changed.") + return nil + }) +} diff --git a/server/constant/const.go b/server/constant/const.go index 179f90f..00281e3 100644 --- a/server/constant/const.go +++ b/server/constant/const.go @@ -5,21 +5,23 @@ import ( ) const ( - Version = "v1.2.2" + Version = "v1.2.3" Banner = ` - _______ __ ___________ .__ .__ - \ \ ____ ___ ____/ |_ \__ ___/__________ _____ |__| ____ _____ | | - / | \_/ __ \\ \/ /\ __\ | |_/ __ \_ __ \/ \| |/ \\__ \ | | -/ | \ ___/ > < | | | |\ ___/| | \/ Y Y \ | | \/ __ \| |__ -\____|__ /\___ >__/\_ \ |__| |____| \___ >__| |__|_| /__|___| (____ /____/ - \/ \/ \/ \/ \/ \/ \/ %s - -` + _______ __ ___________ .__ .__ + \ \ ____ ___ ____/ |_ \__ ___/__________ _____ |__| ____ _____ | | + / | \_/ __ \\ \/ /\ __\ | |_/ __ \_ __ \/ \| |/ \\__ \ | | + / | \ ___/ > < | | | |\ ___/| | \/ Y Y \ | | \/ __ \| |__ + \____|__ /\___ >__/\_ \ |__| |____| \___ >__| |__|_| /__|___| (____ /____/ + \/ \/ \/ \/ \/ \/ \/ %s + + ` ) const Token = "X-Auth-Token" const ( + DB = "db" + SSH = "ssh" RDP = "rdp" VNC = "vnc" @@ -57,6 +59,8 @@ const ( TypeUser = "user" // 普通用户 TypeAdmin = "admin" // 管理员 + SourceLdap = "ldap" // 从LDAP同步的用户 + StatusEnabled = "enabled" StatusDisabled = "disabled" @@ -65,10 +69,16 @@ const ( SocksProxyPort = "socks-proxy-port" SocksProxyUsername = "socks-proxy-username" SocksProxyPassword = "socks-proxy-password" + + LoginToken = "login-token" + AccessToken = "access-token" + ShareSession = "share-session" + + Anonymous = "anonymous" ) var SSHParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, SshMode, SocksProxyEnable, SocksProxyHost, SocksProxyPort, SocksProxyUsername, SocksProxyPassword} -var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs, guacd.EnableDrive, guacd.DrivePath} +var RDPParameterNames = []string{guacd.Domain, guacd.RemoteApp, guacd.RemoteAppDir, guacd.RemoteAppArgs, guacd.EnableDrive, guacd.DrivePath, guacd.ColorDepth, guacd.ForceLossless} var VNCParameterNames = []string{guacd.ColorDepth, guacd.Cursor, guacd.SwapRedBlue, guacd.DestHost, guacd.DestPort} var TelnetParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.UsernameRegex, guacd.PasswordRegex, guacd.LoginSuccessRegex, guacd.LoginFailureRegex} var KubernetesParameterNames = []string{guacd.FontName, guacd.FontSize, guacd.ColorScheme, guacd.Backspace, guacd.TerminalType, guacd.Namespace, guacd.Pod, guacd.Container, guacd.UesSSL, guacd.ClientCert, guacd.ClientKey, guacd.CaCert, guacd.IgnoreCert} diff --git a/server/constant/errors.go b/server/constant/errors.go new file mode 100644 index 0000000..0f9bf84 --- /dev/null +++ b/server/constant/errors.go @@ -0,0 +1,7 @@ +package constant + +import "errors" + +var ( + ErrNameAlreadyUsed = errors.New("name already used") +) diff --git a/server/dto/auth.go b/server/dto/auth.go new file mode 100644 index 0000000..503717f --- /dev/null +++ b/server/dto/auth.go @@ -0,0 +1,27 @@ +package dto + +import "next-terminal/server/model" + +type Authorization struct { + Token string + Remember bool + Type string // LoginToken: 登录令牌, AccessToken: 授权令牌, ShareSession: 会话分享, AccessSession: 只允许访问特定的会话 + User *model.User +} + +type LoginAccount struct { + Username string `json:"username"` + Password string `json:"password"` + Remember bool `json:"remember"` + TOTP string `json:"totp"` +} + +type ConfirmTOTP struct { + Secret string `json:"secret"` + TOTP string `json:"totp"` +} + +type ChangePassword struct { + NewPassword string `json:"newPassword"` + OldPassword string `json:"oldPassword"` +} diff --git a/server/dto/dashboard.go b/server/dto/dashboard.go new file mode 100644 index 0000000..c3c7371 --- /dev/null +++ b/server/dto/dashboard.go @@ -0,0 +1,8 @@ +package dto + +type Counter struct { + User int64 `json:"user"` + Asset int64 `json:"asset"` + Credential int64 `json:"credential"` + OnlineSession int64 `json:"onlineSession"` +} diff --git a/server/dto/identity.go b/server/dto/identity.go new file mode 100644 index 0000000..9304d3d --- /dev/null +++ b/server/dto/identity.go @@ -0,0 +1,7 @@ +package dto + +type UserGroup struct { + Id string `json:"id"` + Name string `json:"name"` + Members []string `json:"members"` +} diff --git a/server/dto/resource.go b/server/dto/resource.go new file mode 100644 index 0000000..8bfffb2 --- /dev/null +++ b/server/dto/resource.go @@ -0,0 +1,32 @@ +package dto + +import "next-terminal/server/model" + +type RU struct { + UserGroupId string `json:"userGroupId"` + UserId string `json:"userId"` + StrategyId string `json:"strategyId"` + ResourceType string `json:"resourceType"` + ResourceIds []string `json:"resourceIds"` +} + +type UR struct { + ResourceId string `json:"resourceId"` + ResourceType string `json:"resourceType"` + UserIds []string `json:"userIds"` +} + +type Backup struct { + Users []model.User `json:"users"` + UserGroups []model.UserGroup `json:"user_groups"` + + Storages []model.Storage `json:"storages"` + Strategies []model.Strategy `json:"strategies"` + AccessSecurities []model.AccessSecurity `json:"access_securities"` + AccessGateways []model.AccessGateway `json:"access_gateways"` + Commands []model.Command `json:"commands"` + Credentials []model.Credential `json:"credentials"` + Assets []map[string]interface{} `json:"assets"` + ResourceSharers []model.ResourceSharer `json:"resource_sharers"` + Jobs []model.Job `json:"jobs"` +} diff --git a/server/dto/session.go b/server/dto/session.go new file mode 100644 index 0000000..b91d526 --- /dev/null +++ b/server/dto/session.go @@ -0,0 +1,11 @@ +package dto + +type ExternalSession struct { + AssetId string `json:"assetId"` + FileSystem string `json:"fileSystem"` + Upload string `json:"upload"` + Download string `json:"download"` + Delete string `json:"delete"` + Rename string `json:"rename"` + Edit string `json:"edit"` +} diff --git a/server/dto/ssh.go b/server/dto/ssh.go new file mode 100644 index 0000000..16f8364 --- /dev/null +++ b/server/dto/ssh.go @@ -0,0 +1,39 @@ +package dto + +import "strconv" + +type Message struct { + Type int `json:"type"` + Content string `json:"content"` +} + +func (r Message) ToString() string { + if r.Content != "" { + return strconv.Itoa(r.Type) + r.Content + } else { + return strconv.Itoa(r.Type) + } +} + +func NewMessage(_type int, content string) Message { + return Message{Content: content, Type: _type} +} + +func ParseMessage(value string) (message Message, err error) { + if value == "" { + return + } + + _type, err := strconv.Atoi(value[:1]) + if err != nil { + return + } + var content = value[1:] + message = NewMessage(_type, content) + return +} + +type WindowSize struct { + Cols int `json:"cols"` + Rows int `json:"rows"` +} diff --git a/server/env/db.go b/server/env/db.go new file mode 100644 index 0000000..25784d4 --- /dev/null +++ b/server/env/db.go @@ -0,0 +1,55 @@ +package env + +import ( + "fmt" + + "next-terminal/server/config" + "next-terminal/server/model" + + "gorm.io/driver/mysql" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func setupDB() *gorm.DB { + + var logMode logger.Interface + if config.GlobalCfg.Debug { + logMode = logger.Default.LogMode(logger.Info) + } else { + logMode = logger.Default.LogMode(logger.Silent) + } + + fmt.Printf("当前数据库模式为:%v\n", config.GlobalCfg.DB) + var err error + var db *gorm.DB + if config.GlobalCfg.DB == "mysql" { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=60s", + config.GlobalCfg.Mysql.Username, + config.GlobalCfg.Mysql.Password, + config.GlobalCfg.Mysql.Hostname, + config.GlobalCfg.Mysql.Port, + config.GlobalCfg.Mysql.Database, + ) + db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logMode, + }) + } else { + db, err = gorm.Open(sqlite.Open(config.GlobalCfg.Sqlite.File), &gorm.Config{ + Logger: logMode, + }) + } + + if err != nil { + panic(fmt.Errorf("连接数据库异常: %v", err.Error())) + } + + if err := db.AutoMigrate(&model.User{}, &model.Asset{}, &model.AssetAttribute{}, &model.Session{}, &model.Command{}, + &model.Credential{}, &model.Property{}, &model.ResourceSharer{}, &model.UserGroup{}, &model.UserGroupMember{}, + &model.LoginLog{}, &model.Job{}, &model.JobLog{}, &model.AccessSecurity{}, &model.AccessGateway{}, + &model.Storage{}, &model.Strategy{}, &model.AccessToken{}); err != nil { + panic(fmt.Errorf("初始化数据库表结构异常: %v", err.Error())) + } + return db +} diff --git a/server/env/env.go b/server/env/env.go new file mode 100644 index 0000000..e736325 --- /dev/null +++ b/server/env/env.go @@ -0,0 +1,19 @@ +package env + +import "gorm.io/gorm" + +var env *Env + +type Env struct { + db *gorm.DB +} + +func init() { + env = &Env{ + db: setupDB(), + } +} + +func GetDB() *gorm.DB { + return env.db +} diff --git a/server/global/cache/cache.go b/server/global/cache/cache.go index 657a6fb..ef34119 100644 --- a/server/global/cache/cache.go +++ b/server/global/cache/cache.go @@ -6,8 +6,17 @@ import ( "github.com/patrickmn/go-cache" ) -var GlobalCache *cache.Cache +const ( + NoExpiration = -1 + RememberMeExpiration = time.Hour * time.Duration(24*14) + NotRememberExpiration = time.Hour * time.Duration(2) + LoginLockExpiration = time.Minute * time.Duration(5) +) + +var TokenManager *cache.Cache +var LoginFailedKeyManager *cache.Cache func init() { - GlobalCache = cache.New(5*time.Minute, 10*time.Minute) + TokenManager = cache.New(5*time.Minute, 10*time.Minute) + LoginFailedKeyManager = cache.New(5*time.Minute, 10*time.Minute) } diff --git a/server/global/gateway/gateway.go b/server/global/gateway/gateway.go index a4be8ed..e63ae71 100644 --- a/server/global/gateway/gateway.go +++ b/server/global/gateway/gateway.go @@ -7,7 +7,6 @@ import ( "net" "os" - "next-terminal/server/config" "next-terminal/server/utils" "golang.org/x/crypto/ssh" @@ -17,7 +16,6 @@ import ( type Gateway struct { ID string // 接入网关ID Connected bool // 是否已连接 - LocalHost string // 隧道映射到本地的IP地址 SshClient *ssh.Client Message string // 失败原因 @@ -28,10 +26,9 @@ type Gateway struct { exit chan bool } -func NewGateway(id, localhost string, connected bool, message string, client *ssh.Client) *Gateway { +func NewGateway(id string, connected bool, message string, client *ssh.Client) *Gateway { return &Gateway{ ID: id, - LocalHost: localhost, Connected: connected, Message: message, SshClient: client, @@ -80,26 +77,13 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo if err != nil { return "", 0, err } - localHost := g.LocalHost - if localHost == "" { - if config.GlobalCfg.Container { - localIp, err := utils.GetLocalIp() - if err != nil { - hostname, err := os.Hostname() - if err != nil { - return "", 0, err - } else { - localHost = hostname - } - } else { - localHost = localIp - } - } else { - localHost = "localhost" - } + + hostname, err := os.Hostname() + if err != nil { + return "", 0, err } - localAddr := fmt.Sprintf("%s:%d", localHost, localPort) + localAddr := fmt.Sprintf("%s:%d", hostname, localPort) listener, err := net.Listen("tcp", localAddr) if err != nil { return "", 0, err @@ -108,7 +92,7 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo ctx, cancel := context.WithCancel(context.Background()) tunnel := &Tunnel{ ID: id, - LocalHost: g.LocalHost, + LocalHost: hostname, LocalPort: localPort, Gateway: g, RemoteHost: ip, diff --git a/server/guacd/guacd.go b/server/guacd/guacd.go index c56ceb6..416193b 100644 --- a/server/guacd/guacd.go +++ b/server/guacd/guacd.go @@ -10,8 +10,6 @@ import ( ) const ( - Host = "host" - Port = "port" EnableRecording = "enable-recording" RecordingPath = "recording-path" CreateRecordingPath = "create-recording-path" @@ -33,7 +31,9 @@ const ( EnableMenuAnimations = "enable-menu-animations" DisableBitmapCaching = "disable-bitmap-caching" DisableOffscreenCaching = "disable-offscreen-caching" - DisableGlyphCaching = "disable-glyph-caching" + // DisableGlyphCaching Deprecated + DisableGlyphCaching = "disable-glyph-caching" + ForceLossless = "force-lossless" Domain = "domain" RemoteApp = "remote-app" @@ -62,7 +62,7 @@ const ( ) const Delimiter = ';' -const Version = "VERSION_1_3_0" +const Version = "VERSION_1_4_0" type Configuration struct { ConnectionID string diff --git a/server/model/access_gateway.go b/server/model/access_gateway.go index b27c683..d760e38 100644 --- a/server/model/access_gateway.go +++ b/server/model/access_gateway.go @@ -8,7 +8,6 @@ type AccessGateway struct { Name string `gorm:"type:varchar(500)" json:"name"` IP string `gorm:"type:varchar(500)" json:"ip"` Port int `gorm:"type:int(5)" json:"port"` - Localhost string `gorm:"type:varchar(200)" json:"localhost"` // 隧道映射到本地的地址 AccountType string `gorm:"type:varchar(50)" json:"accountType"` Username string `gorm:"type:varchar(200)" json:"username"` Password string `gorm:"type:varchar(500)" json:"password"` diff --git a/server/model/access_token.go b/server/model/access_token.go new file mode 100644 index 0000000..722099d --- /dev/null +++ b/server/model/access_token.go @@ -0,0 +1,14 @@ +package model + +import "next-terminal/server/utils" + +type AccessToken struct { + ID string `gorm:"primary_key,type:varchar(36)" json:"id"` + UserId string `gorm:"index,type:varchar(200)" json:"userId"` + Token string `gorm:"index,type:varchar(128)" json:"token"` + Created utils.JsonTime `json:"created"` +} + +func (r *AccessToken) TableName() string { + return "access_token" +} diff --git a/server/model/login_log.go b/server/model/login_log.go index 11b50f0..ae5f61c 100644 --- a/server/model/login_log.go +++ b/server/model/login_log.go @@ -5,7 +5,7 @@ import ( ) type LoginLog struct { - ID string `gorm:"primary_key,type:varchar(36)" json:"id"` + ID string `gorm:"primary_key,type:varchar(128)" json:"id"` Username string `gorm:"index,type:varchar(200)" json:"username"` ClientIP string `gorm:"type:varchar(200)" json:"clientIp"` ClientUserAgent string `gorm:"type:varchar(500)" json:"clientUserAgent"` diff --git a/server/model/session.go b/server/model/session.go index be8349c..d06e493 100644 --- a/server/model/session.go +++ b/server/model/session.go @@ -26,12 +26,15 @@ type Session struct { ConnectedTime utils.JsonTime `json:"connectedTime"` DisconnectedTime utils.JsonTime `json:"disconnectedTime"` Mode string `gorm:"type:varchar(10)" json:"mode"` - Upload string `gorm:"type:varchar(1)" json:"upload"` // 1 = true, 0 = false + FileSystem string `gorm:"type:varchar(1)" json:"fileSystem"` // 1 = true, 0 = false + Upload string `gorm:"type:varchar(1)" json:"upload"` Download string `gorm:"type:varchar(1)" json:"download"` Delete string `gorm:"type:varchar(1)" json:"delete"` Rename string `gorm:"type:varchar(1)" json:"rename"` Edit string `gorm:"type:varchar(1)" json:"edit"` CreateDir string `gorm:"type:varchar(1)" json:"createDir"` + Copy string `gorm:"type:varchar(1)" json:"copy"` + Paste string `gorm:"type:varchar(1)" json:"paste"` StorageId string `gorm:"type:varchar(36)" json:"storageId"` AccessGatewayId string `gorm:"type:varchar(36)" json:"accessGatewayId"` Reviewed bool `gorm:"type:tinyint(1)" json:"reviewed"` diff --git a/server/model/strategy.go b/server/model/strategy.go index b7cb416..92f1c5e 100644 --- a/server/model/strategy.go +++ b/server/model/strategy.go @@ -11,6 +11,8 @@ type Strategy struct { Rename string `gorm:"type:varchar(1)" json:"rename"` Edit string `gorm:"type:varchar(1)" json:"edit"` CreateDir string `gorm:"type:varchar(1)" json:"createDir"` + Copy string `gorm:"type:varchar(1)" json:"copy"` + Paste string `gorm:"type:varchar(1)" json:"paste"` Created utils.JsonTime `json:"created"` } diff --git a/server/model/user.go b/server/model/user.go index ae1588f..649d03f 100644 --- a/server/model/user.go +++ b/server/model/user.go @@ -15,6 +15,7 @@ type User struct { Created utils.JsonTime `json:"created"` Type string `gorm:"type:varchar(20)" json:"type"` Mail string `gorm:"type:varchar(500)" json:"mail"` + Source string `gorm:"type:varchar(20)" json:"source"` } type UserForPage struct { @@ -27,6 +28,7 @@ type UserForPage struct { Status string `json:"status"` Created utils.JsonTime `json:"created"` Type string `json:"type"` + Source string `json:"source"` SharerAssetCount int64 `json:"sharerAssetCount"` } diff --git a/server/repository/access_gateway.go b/server/repository/access_gateway.go deleted file mode 100644 index 57c3871..0000000 --- a/server/repository/access_gateway.go +++ /dev/null @@ -1,80 +0,0 @@ -package repository - -import ( - "next-terminal/server/model" - - "gorm.io/gorm" -) - -type AccessGatewayRepository struct { - DB *gorm.DB -} - -func NewAccessGatewayRepository(db *gorm.DB) *AccessGatewayRepository { - accessGatewayRepository = &AccessGatewayRepository{DB: db} - return accessGatewayRepository -} - -func (r AccessGatewayRepository) Find(pageIndex, pageSize int, ip, name, order, field string) (o []model.AccessGatewayForPage, total int64, err error) { - t := model.AccessGateway{} - db := r.DB.Table(t.TableName()) - dbCounter := r.DB.Table(t.TableName()) - - if len(ip) > 0 { - db = db.Where("ip like ?", "%"+ip+"%") - dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") - } - - if len(name) > 0 { - db = db.Where("name like ?", "%"+name+"%") - dbCounter = dbCounter.Where("name like ?", "%"+name+"%") - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "descend" { - order = "desc" - } else { - order = "asc" - } - - if field == "ip" { - field = "ip" - } else if field == "name" { - field = "name" - } else { - field = "created" - } - - err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]model.AccessGatewayForPage, 0) - } - return -} - -func (r AccessGatewayRepository) Create(o *model.AccessGateway) error { - return r.DB.Create(o).Error -} - -func (r AccessGatewayRepository) UpdateById(o *model.AccessGateway, id string) error { - o.ID = id - return r.DB.Updates(o).Error -} - -func (r AccessGatewayRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(model.AccessGateway{}).Error -} - -func (r AccessGatewayRepository) FindById(id string) (o model.AccessGateway, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error - return -} - -func (r AccessGatewayRepository) FindAll() (o []model.AccessGateway, err error) { - err = r.DB.Find(&o).Error - return -} diff --git a/server/repository/access_security.go b/server/repository/access_security.go deleted file mode 100644 index 1e40bec..0000000 --- a/server/repository/access_security.go +++ /dev/null @@ -1,80 +0,0 @@ -package repository - -import ( - "next-terminal/server/model" - - "gorm.io/gorm" -) - -type AccessSecurityRepository struct { - DB *gorm.DB -} - -func NewAccessSecurityRepository(db *gorm.DB) *AccessSecurityRepository { - accessSecurityRepository = &AccessSecurityRepository{DB: db} - return accessSecurityRepository -} - -func (r AccessSecurityRepository) FindAll() (o []model.AccessSecurity, err error) { - err = r.DB.Order("priority asc").Find(&o).Error - return -} - -func (r AccessSecurityRepository) Find(pageIndex, pageSize int, ip, rule, order, field string) (o []model.AccessSecurity, total int64, err error) { - t := model.AccessSecurity{} - db := r.DB.Table(t.TableName()) - dbCounter := r.DB.Table(t.TableName()) - - if len(ip) > 0 { - db = db.Where("ip like ?", "%"+ip+"%") - dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") - } - - if len(rule) > 0 { - db = db.Where("rule = ?", rule) - dbCounter = dbCounter.Where("rule = ?", rule) - } - - err = dbCounter.Count(&total).Error - if err != nil { - return nil, 0, err - } - - if order == "descend" { - order = "desc" - } else { - order = "asc" - } - - if field == "ip" { - field = "ip" - } else if field == "rule" { - field = "rule" - } else { - field = "priority" - } - - err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error - if o == nil { - o = make([]model.AccessSecurity, 0) - } - return -} - -func (r AccessSecurityRepository) Create(o *model.AccessSecurity) error { - return r.DB.Create(o).Error -} - -func (r AccessSecurityRepository) UpdateById(o *model.AccessSecurity, id string) error { - o.ID = id - return r.DB.Updates(o).Error -} - -func (r AccessSecurityRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(model.AccessSecurity{}).Error -} - -func (r AccessSecurityRepository) FindById(id string) (o *model.AccessSecurity, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error - return -} diff --git a/server/repository/access_token.go b/server/repository/access_token.go new file mode 100644 index 0000000..90d0e88 --- /dev/null +++ b/server/repository/access_token.go @@ -0,0 +1,29 @@ +package repository + +import ( + "context" + + "next-terminal/server/model" +) + +type accessTokenRepository struct { + baseRepository +} + +func (repo accessTokenRepository) FindByUserId(ctx context.Context, userId string) (o model.AccessToken, err error) { + err = repo.GetDB(ctx).Where("user_id = ?", userId).First(&o).Error + return +} + +func (repo accessTokenRepository) DeleteByUserId(ctx context.Context, userId string) error { + return repo.GetDB(ctx).Where("user_id = ?", userId).Delete(&model.AccessToken{}).Error +} + +func (repo accessTokenRepository) Create(ctx context.Context, o *model.AccessToken) error { + return repo.GetDB(ctx).Create(o).Error +} + +func (repo accessTokenRepository) FindAll(ctx context.Context) (o []model.AccessToken, err error) { + err = repo.GetDB(ctx).Find(&o).Error + return +} diff --git a/server/repository/asset.go b/server/repository/asset.go index 0ae018f..c77d64b 100644 --- a/server/repository/asset.go +++ b/server/repository/asset.go @@ -1,7 +1,7 @@ package repository import ( - "encoding/base64" + "context" "fmt" "strings" @@ -11,47 +11,41 @@ import ( "next-terminal/server/utils" "github.com/labstack/echo/v4" - "gorm.io/gorm" ) -type AssetRepository struct { - DB *gorm.DB +type assetRepository struct { + baseRepository } -func NewAssetRepository(db *gorm.DB) *AssetRepository { - assetRepository = &AssetRepository{DB: db} - return assetRepository -} - -func (r AssetRepository) FindAll() (o []model.Asset, err error) { - err = r.DB.Find(&o).Error +func (r assetRepository) FindAll(c context.Context) (o []model.Asset, err error) { + err = r.GetDB(c).Find(&o).Error return } -func (r AssetRepository) FindByIds(assetIds []string) (o []model.Asset, err error) { - err = r.DB.Where("id in ?", assetIds).Find(&o).Error +func (r assetRepository) FindByIds(c context.Context, assetIds []string) (o []model.Asset, err error) { + err = r.GetDB(c).Where("id in ?", assetIds).Find(&o).Error return } -func (r AssetRepository) FindByProtocol(protocol string) (o []model.Asset, err error) { - err = r.DB.Where("protocol = ?", protocol).Find(&o).Error +func (r assetRepository) FindByProtocol(c context.Context, protocol string) (o []model.Asset, err error) { + err = r.GetDB(c).Where("protocol = ?", protocol).Find(&o).Error return } -func (r AssetRepository) FindByProtocolAndIds(protocol string, assetIds []string) (o []model.Asset, err error) { - err = r.DB.Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error +func (r assetRepository) FindByProtocolAndIds(c context.Context, protocol string, assetIds []string) (o []model.Asset, err error) { + err = r.GetDB(c).Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error return } -func (r AssetRepository) FindByProtocolAndUser(protocol string, account model.User) (o []model.Asset, err error) { - db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") +func (r assetRepository) FindByProtocolAndUser(c context.Context, protocol string, account model.User) (o []model.Asset, err error) { + db := r.GetDB(c).Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") if constant.TypeUser == account.Type { owner := account.ID db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) // 查询用户所在用户组列表 - userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(account.ID) + userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, account.ID) if err != nil { return nil, err } @@ -68,9 +62,9 @@ func (r AssetRepository) FindByProtocolAndUser(protocol string, account model.Us return } -func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags string, account model.User, owner, sharer, userGroupId, ip, order, field string) (o []model.AssetForPage, total int64, err error) { - db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") - dbCounter := r.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") +func (r assetRepository) Find(c context.Context, pageIndex, pageSize int, name, protocol, tags string, account *model.User, owner, sharer, userGroupId, ip, order, field string) (o []model.AssetForPage, total int64, err error) { + db := r.GetDB(c).Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") + dbCounter := r.GetDB(c).Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id") if constant.TypeUser == account.Type { owner := account.ID @@ -78,7 +72,7 @@ func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags stri dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner) // 查询用户所在用户组列表 - userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(account.ID) + userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, account.ID) if err != nil { return nil, 0, err } @@ -155,7 +149,7 @@ func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags stri } else { for i := 0; i < len(o); i++ { if o[i].Protocol == "ssh" { - attributes, err := r.FindAttrById(o[i].ID) + attributes, err := r.FindAttrById(c, o[i].ID) if err != nil { continue } @@ -172,134 +166,50 @@ func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags stri return } -func (r AssetRepository) Encrypt(item *model.Asset, password []byte) error { - if item.Password != "" && item.Password != "-" { - encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), password) - if err != nil { - return err - } - item.Password = base64.StdEncoding.EncodeToString(encryptedCBC) - } - if item.PrivateKey != "" && item.PrivateKey != "-" { - encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), password) - if err != nil { - return err - } - item.PrivateKey = base64.StdEncoding.EncodeToString(encryptedCBC) - } - if item.Passphrase != "" && item.Passphrase != "-" { - encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), password) - if err != nil { - return err - } - item.Passphrase = base64.StdEncoding.EncodeToString(encryptedCBC) - } - item.Encrypted = true - return nil +func (r assetRepository) Create(c context.Context, o *model.Asset) (err error) { + return r.GetDB(c).Create(o).Error } -func (r AssetRepository) Create(o *model.Asset) (err error) { - if err := r.Encrypt(o, config.GlobalCfg.EncryptionPassword); err != nil { - return err - } - if err = r.DB.Create(o).Error; err != nil { - return err - } - return nil -} - -func (r AssetRepository) FindById(id string) (o model.Asset, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r assetRepository) FindById(c context.Context, id string) (o model.Asset, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r AssetRepository) Decrypt(item *model.Asset, password []byte) error { - if item.Encrypted { - if item.Password != "" && item.Password != "-" { - origData, err := base64.StdEncoding.DecodeString(item.Password) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, password) - if err != nil { - return err - } - item.Password = string(decryptedCBC) - } - if item.PrivateKey != "" && item.PrivateKey != "-" { - origData, err := base64.StdEncoding.DecodeString(item.PrivateKey) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, password) - if err != nil { - return err - } - item.PrivateKey = string(decryptedCBC) - } - if item.Passphrase != "" && item.Passphrase != "-" { - origData, err := base64.StdEncoding.DecodeString(item.Passphrase) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, password) - if err != nil { - return err - } - item.Passphrase = string(decryptedCBC) - } - } - return nil -} - -func (r AssetRepository) FindByIdAndDecrypt(id string) (o model.Asset, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error - if err == nil { - err = r.Decrypt(&o, config.GlobalCfg.EncryptionPassword) - } - return -} - -func (r AssetRepository) UpdateById(o *model.Asset, id string) error { +func (r assetRepository) UpdateById(c context.Context, o *model.Asset, id string) error { o.ID = id - return r.DB.Updates(o).Error + return r.GetDB(c).Updates(o).Error } -func (r AssetRepository) UpdateActiveById(active bool, id string) error { +func (r assetRepository) UpdateActiveById(c context.Context, active bool, id string) error { sql := "update assets set active = ? where id = ?" - return r.DB.Exec(sql, active, id).Error + return r.GetDB(c).Exec(sql, active, id).Error } -func (r AssetRepository) DeleteById(id string) (err error) { - return r.DB.Transaction(func(tx *gorm.DB) error { - err = tx.Where("id = ?", id).Delete(&model.Asset{}).Error - if err != nil { - return err - } - // 删除资产属性 - err = tx.Where("asset_id = ?", id).Delete(&model.AssetAttribute{}).Error - return err - }) - +func (r assetRepository) DeleteById(c context.Context, assetId string) (err error) { + return r.GetDB(c).Where("id = ?", assetId).Delete(&model.Asset{}).Error } -func (r AssetRepository) Count() (total int64, err error) { - err = r.DB.Find(&model.Asset{}).Count(&total).Error +func (r assetRepository) DeleteAttrByAssetId(c context.Context, assetId string) error { + return r.GetDB(c).Where("asset_id = ?", assetId).Delete(&model.AssetAttribute{}).Error +} + +func (r assetRepository) Count(c context.Context) (total int64, err error) { + err = r.GetDB(c).Find(&model.Asset{}).Count(&total).Error return } -func (r AssetRepository) CountByProtocol(protocol string) (total int64, err error) { - err = r.DB.Find(&model.Asset{}).Where("protocol = ?", protocol).Count(&total).Error +func (r assetRepository) CountByProtocol(c context.Context, protocol string) (total int64, err error) { + err = r.GetDB(c).Find(&model.Asset{}).Where("protocol = ?", protocol).Count(&total).Error return } -func (r AssetRepository) CountByUserId(userId string) (total int64, err error) { - db := r.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") +func (r assetRepository) CountByUserId(c context.Context, userId string) (total int64, err error) { + db := r.GetDB(c).Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId) // 查询用户所在用户组列表 - userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, userId) if err != nil { return 0, err } @@ -311,13 +221,13 @@ func (r AssetRepository) CountByUserId(userId string) (total int64, err error) { return } -func (r AssetRepository) CountByUserIdAndProtocol(userId, protocol string) (total int64, err error) { - db := r.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") +func (r assetRepository) CountByUserIdAndProtocol(c context.Context, userId, protocol string) (total int64, err error) { + db := r.GetDB(c).Joins("left join resource_sharers on assets.id = resource_sharers.resource_id") db = db.Where("( assets.owner = ? or resource_sharers.user_id = ? ) and assets.protocol = ?", userId, userId, protocol) // 查询用户所在用户组列表 - userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, userId) if err != nil { return 0, err } @@ -329,9 +239,9 @@ func (r AssetRepository) CountByUserIdAndProtocol(userId, protocol string) (tota return } -func (r AssetRepository) FindTags() (o []string, err error) { +func (r assetRepository) FindTags(c context.Context) (o []string, err error) { var assets []model.Asset - err = r.DB.Not("tags = '' or tags = '-' ").Find(&assets).Error + err = r.GetDB(c).Not("tags = '' or tags = '-' ").Find(&assets).Error if err != nil { return nil, err } @@ -350,7 +260,7 @@ func (r AssetRepository) FindTags() (o []string, err error) { return utils.Distinct(o), nil } -func (r AssetRepository) UpdateAttributes(assetId, protocol string, m echo.Map) error { +func (r assetRepository) UpdateAttributes(c context.Context, assetId, protocol string, m echo.Map) error { var data []model.AssetAttribute var parameterNames []string switch protocol { @@ -373,13 +283,11 @@ func (r AssetRepository) UpdateAttributes(assetId, protocol string, m echo.Map) } } - return r.DB.Transaction(func(tx *gorm.DB) error { - err := tx.Where("asset_id = ?", assetId).Delete(&model.AssetAttribute{}).Error - if err != nil { - return err - } - return tx.CreateInBatches(&data, len(data)).Error - }) + err := r.GetDB(c).Where("asset_id = ?", assetId).Delete(&model.AssetAttribute{}).Error + if err != nil { + return err + } + return r.GetDB(c).CreateInBatches(&data, len(data)).Error } func genAttribute(assetId, name string, m echo.Map) model.AssetAttribute { @@ -393,20 +301,20 @@ func genAttribute(assetId, name string, m echo.Map) model.AssetAttribute { return attribute } -func (r AssetRepository) FindAttrById(assetId string) (o []model.AssetAttribute, err error) { - err = r.DB.Where("asset_id = ?", assetId).Find(&o).Error +func (r assetRepository) FindAttrById(c context.Context, assetId string) (o []model.AssetAttribute, err error) { + err = r.GetDB(c).Where("asset_id = ?", assetId).Find(&o).Error if o == nil { o = make([]model.AssetAttribute, 0) } return o, err } -func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]string, error) { - asset, err := r.FindById(assetId) +func (r assetRepository) FindAssetAttrMapByAssetId(c context.Context, assetId string) (map[string]string, error) { + asset, err := r.FindById(c, assetId) if err != nil { return nil, err } - attributes, err := r.FindAttrById(assetId) + attributes, err := r.FindAttrById(c, assetId) if err != nil { return nil, err } @@ -424,7 +332,7 @@ func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]s case "kubernetes": parameterNames = constant.KubernetesParameterNames } - propertiesMap := propertyRepository.FindAllMap() + propertiesMap := PropertyRepository.FindAllMap(c) var attributeMap = make(map[string]string) for name := range propertiesMap { if utils.Contains(parameterNames, name) { diff --git a/server/repository/base.go b/server/repository/base.go new file mode 100644 index 0000000..3251b3c --- /dev/null +++ b/server/repository/base.go @@ -0,0 +1,26 @@ +package repository + +import ( + "context" + + "next-terminal/server/constant" + "next-terminal/server/env" + + "gorm.io/gorm" +) + +type baseRepository struct { +} + +func (b *baseRepository) GetDB(c context.Context) *gorm.DB { + db := c.Value(constant.DB) + if db == nil { + return env.GetDB() + } + switch val := db.(type) { + case gorm.DB: + return &val + default: + return env.GetDB() + } +} diff --git a/server/repository/command.go b/server/repository/command.go index 7982c52..7b16a0e 100644 --- a/server/repository/command.go +++ b/server/repository/command.go @@ -1,24 +1,19 @@ package repository import ( + "context" + "next-terminal/server/constant" "next-terminal/server/model" - - "gorm.io/gorm" ) -type CommandRepository struct { - DB *gorm.DB +type commandRepository struct { + baseRepository } -func NewCommandRepository(db *gorm.DB) *CommandRepository { - commandRepository = &CommandRepository{DB: db} - return commandRepository -} - -func (r CommandRepository) Find(pageIndex, pageSize int, name, content, order, field string, account model.User) (o []model.CommandForPage, total int64, err error) { - db := r.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") - dbCounter := r.DB.Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") +func (r commandRepository) Find(c context.Context, pageIndex, pageSize int, name, content, order, field string, account *model.User) (o []model.CommandForPage, total int64, err error) { + db := r.GetDB(c).Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") + dbCounter := r.GetDB(c).Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") if constant.TypeUser == account.Type { owner := account.ID @@ -60,29 +55,29 @@ func (r CommandRepository) Find(pageIndex, pageSize int, name, content, order, f return } -func (r CommandRepository) Create(o *model.Command) (err error) { - if err = r.DB.Create(o).Error; err != nil { +func (r commandRepository) Create(c context.Context, o *model.Command) (err error) { + if err = r.GetDB(c).Create(o).Error; err != nil { return err } return nil } -func (r CommandRepository) FindById(id string) (o model.Command, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r commandRepository) FindById(c context.Context, id string) (o model.Command, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r CommandRepository) UpdateById(o *model.Command, id string) error { +func (r commandRepository) UpdateById(c context.Context, o *model.Command, id string) error { o.ID = id - return r.DB.Updates(o).Error + return r.GetDB(c).Updates(o).Error } -func (r CommandRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(&model.Command{}).Error +func (r commandRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(&model.Command{}).Error } -func (r CommandRepository) FindByUser(account model.User) (o []model.CommandForPage, err error) { - db := r.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") +func (r commandRepository) FindByUser(c context.Context, account *model.User) (o []model.CommandForPage, err error) { + db := r.GetDB(c).Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id") if constant.TypeUser == account.Type { owner := account.ID @@ -95,7 +90,7 @@ func (r CommandRepository) FindByUser(account model.User) (o []model.CommandForP return } -func (r CommandRepository) FindAll() (o []model.Command, err error) { - err = r.DB.Find(&o).Error +func (r commandRepository) FindAll(c context.Context) (o []model.Command, err error) { + err = r.GetDB(c).Find(&o).Error return } diff --git a/server/repository/credential.go b/server/repository/credential.go index 8e3478b..935359a 100644 --- a/server/repository/credential.go +++ b/server/repository/credential.go @@ -1,37 +1,25 @@ package repository import ( - "encoding/base64" + "context" - "next-terminal/server/config" "next-terminal/server/constant" "next-terminal/server/model" - "next-terminal/server/utils" - - "gorm.io/gorm" ) -type CredentialRepository struct { - DB *gorm.DB +type credentialRepository struct { + baseRepository } -func NewCredentialRepository(db *gorm.DB) *CredentialRepository { - credentialRepository = &CredentialRepository{DB: db} - return credentialRepository -} - -func (r CredentialRepository) FindByUser(account model.User) (o []model.CredentialSimpleVo, err error) { - db := r.DB.Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") - if account.Type == constant.TypeUser { - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", account.ID, account.ID) - } +func (r credentialRepository) FindByUser(c context.Context) (o []model.CredentialSimpleVo, err error) { + db := r.GetDB(c).Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") err = db.Find(&o).Error return } -func (r CredentialRepository) Find(pageIndex, pageSize int, name, order, field string, account model.User) (o []model.CredentialForPage, total int64, err error) { - db := r.DB.Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") - dbCounter := r.DB.Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") +func (r credentialRepository) Find(c context.Context, pageIndex, pageSize int, name, order, field string, account *model.User) (o []model.CredentialForPage, total int64, err error) { + db := r.GetDB(c).Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") + dbCounter := r.GetDB(c).Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id") if constant.TypeUser == account.Type { owner := account.ID @@ -68,127 +56,48 @@ func (r CredentialRepository) Find(pageIndex, pageSize int, name, order, field s return } -func (r CredentialRepository) Create(o *model.Credential) (err error) { - if err := r.Encrypt(o, config.GlobalCfg.EncryptionPassword); err != nil { - return err - } - if err = r.DB.Create(o).Error; err != nil { - return err - } - return nil +func (r credentialRepository) Create(c context.Context, o *model.Credential) (err error) { + return r.GetDB(c).Create(o).Error } -func (r CredentialRepository) FindById(id string) (o model.Credential, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r credentialRepository) FindById(c context.Context, id string) (o model.Credential, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r CredentialRepository) Encrypt(item *model.Credential, password []byte) error { - if item.Password != "-" { - encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), password) - if err != nil { - return err - } - item.Password = base64.StdEncoding.EncodeToString(encryptedCBC) - } - if item.PrivateKey != "-" { - encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), password) - if err != nil { - return err - } - item.PrivateKey = base64.StdEncoding.EncodeToString(encryptedCBC) - } - if item.Passphrase != "-" { - encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), password) - if err != nil { - return err - } - item.Passphrase = base64.StdEncoding.EncodeToString(encryptedCBC) - } - item.Encrypted = true - return nil -} - -func (r CredentialRepository) Decrypt(item *model.Credential, password []byte) error { - if item.Encrypted { - if item.Password != "" && item.Password != "-" { - origData, err := base64.StdEncoding.DecodeString(item.Password) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, password) - if err != nil { - return err - } - item.Password = string(decryptedCBC) - } - if item.PrivateKey != "" && item.PrivateKey != "-" { - origData, err := base64.StdEncoding.DecodeString(item.PrivateKey) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, password) - if err != nil { - return err - } - item.PrivateKey = string(decryptedCBC) - } - if item.Passphrase != "" && item.Passphrase != "-" { - origData, err := base64.StdEncoding.DecodeString(item.Passphrase) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, password) - if err != nil { - return err - } - item.Passphrase = string(decryptedCBC) - } - } - return nil -} - -func (r CredentialRepository) FindByIdAndDecrypt(id string) (o model.Credential, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error - if err == nil { - err = r.Decrypt(&o, config.GlobalCfg.EncryptionPassword) - } - return -} - -func (r CredentialRepository) UpdateById(o *model.Credential, id string) error { +func (r credentialRepository) UpdateById(c context.Context, o *model.Credential, id string) error { o.ID = id - return r.DB.Updates(o).Error + return r.GetDB(c).Updates(o).Error } -func (r CredentialRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(&model.Credential{}).Error +func (r credentialRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(&model.Credential{}).Error } -func (r CredentialRepository) Count() (total int64, err error) { - err = r.DB.Find(&model.Credential{}).Count(&total).Error +func (r credentialRepository) Count(c context.Context) (total int64, err error) { + err = r.GetDB(c).Find(&model.Credential{}).Count(&total).Error return } -func (r CredentialRepository) CountByUserId(userId string) (total int64, err error) { - db := r.DB.Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") +//func (r credentialRepository) CountByUserId(c context.Context, userId string) (total int64, err error) { +// db := r.GetDB(c).Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id") +// +// db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId) +// +// // 查询用户所在用户组列表 +// userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(c, userId) +// if err != nil { +// return 0, err +// } +// +// if len(userGroupIds) > 0 { +// db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) +// } +// err = db.Find(&model.Credential{}).Count(&total).Error +// return +//} - db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId) - - // 查询用户所在用户组列表 - userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) - if err != nil { - return 0, err - } - - if len(userGroupIds) > 0 { - db = db.Or("resource_sharers.user_group_id in ?", userGroupIds) - } - err = db.Find(&model.Credential{}).Count(&total).Error - return -} - -func (r CredentialRepository) FindAll() (o []model.Credential, err error) { - err = r.DB.Find(&o).Error +func (r credentialRepository) FindAll(c context.Context) (o []model.Credential, err error) { + err = r.GetDB(c).Find(&o).Error return } diff --git a/server/repository/definitions.go b/server/repository/definitions.go deleted file mode 100644 index ecfb503..0000000 --- a/server/repository/definitions.go +++ /dev/null @@ -1,22 +0,0 @@ -package repository - -/** - * 定义了相关模型的持久化层,方便相互之间调用 - */ -var ( - userRepository *UserRepository - userGroupRepository *UserGroupRepository - resourceSharerRepository *ResourceSharerRepository - assetRepository *AssetRepository - credentialRepository *CredentialRepository - propertyRepository *PropertyRepository - commandRepository *CommandRepository - sessionRepository *SessionRepository - accessSecurityRepository *AccessSecurityRepository - accessGatewayRepository *AccessGatewayRepository - jobRepository *JobRepository - jobLogRepository *JobLogRepository - loginLogRepository *LoginLogRepository - storageRepository *StorageRepository - strategyRepository *StrategyRepository -) diff --git a/server/repository/gateway.go b/server/repository/gateway.go new file mode 100644 index 0000000..d19a619 --- /dev/null +++ b/server/repository/gateway.go @@ -0,0 +1,75 @@ +package repository + +import ( + "context" + + "next-terminal/server/model" +) + +type gatewayRepository struct { + baseRepository +} + +func (r gatewayRepository) Find(c context.Context, pageIndex, pageSize int, ip, name, order, field string) (o []model.AccessGatewayForPage, total int64, err error) { + t := model.AccessGateway{} + db := r.GetDB(c).Table(t.TableName()) + dbCounter := r.GetDB(c).Table(t.TableName()) + + if len(ip) > 0 { + db = db.Where("ip like ?", "%"+ip+"%") + dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") + } + + if len(name) > 0 { + db = db.Where("name like ?", "%"+name+"%") + dbCounter = dbCounter.Where("name like ?", "%"+name+"%") + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "descend" { + order = "desc" + } else { + order = "asc" + } + + if field == "ip" { + field = "ip" + } else if field == "name" { + field = "name" + } else { + field = "created" + } + + err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.AccessGatewayForPage, 0) + } + return +} + +func (r gatewayRepository) Create(c context.Context, o *model.AccessGateway) error { + return r.GetDB(c).Create(o).Error +} + +func (r gatewayRepository) UpdateById(c context.Context, o *model.AccessGateway, id string) error { + o.ID = id + return r.GetDB(c).Updates(o).Error +} + +func (r gatewayRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(model.AccessGateway{}).Error +} + +func (r gatewayRepository) FindById(c context.Context, id string) (o model.AccessGateway, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error + return +} + +func (r gatewayRepository) FindAll(c context.Context) (o []model.AccessGateway, err error) { + err = r.GetDB(c).Find(&o).Error + return +} diff --git a/server/repository/job.go b/server/repository/job.go index 3e84d18..372b71f 100644 --- a/server/repository/job.go +++ b/server/repository/job.go @@ -1,25 +1,20 @@ package repository import ( + "context" + "next-terminal/server/model" "next-terminal/server/utils" - - "gorm.io/gorm" ) -type JobRepository struct { - DB *gorm.DB +type jobRepository struct { + baseRepository } -func NewJobRepository(db *gorm.DB) *JobRepository { - jobRepository = &JobRepository{DB: db} - return jobRepository -} - -func (r JobRepository) Find(pageIndex, pageSize int, name, status, order, field string) (o []model.Job, total int64, err error) { +func (r jobRepository) Find(c context.Context, pageIndex, pageSize int, name, status, order, field string) (o []model.Job, total int64, err error) { job := model.Job{} - db := r.DB.Table(job.TableName()) - dbCounter := r.DB.Table(job.TableName()) + db := r.GetDB(c).Table(job.TableName()) + dbCounter := r.GetDB(c).Table(job.TableName()) if len(name) > 0 { db = db.Where("name like ?", "%"+name+"%") @@ -57,45 +52,36 @@ func (r JobRepository) Find(pageIndex, pageSize int, name, status, order, field return } -func (r JobRepository) FindByFunc(function string) (o []model.Job, err error) { - db := r.DB +func (r jobRepository) FindByFunc(c context.Context, function string) (o []model.Job, err error) { + db := r.GetDB(c) err = db.Where("func = ?", function).Find(&o).Error return } -func (r JobRepository) FindAll() (o []model.Job, err error) { - db := r.DB +func (r jobRepository) FindAll(c context.Context) (o []model.Job, err error) { + db := r.GetDB(c) err = db.Find(&o).Error return } -func (r JobRepository) Create(o *model.Job) (err error) { - return r.DB.Create(o).Error +func (r jobRepository) Create(c context.Context, o *model.Job) (err error) { + return r.GetDB(c).Create(o).Error } -func (r JobRepository) UpdateById(o *model.Job) (err error) { - return r.DB.Updates(o).Error +func (r jobRepository) UpdateById(c context.Context, o *model.Job) (err error) { + return r.GetDB(c).Updates(o).Error } -func (r JobRepository) UpdateLastUpdatedById(id string) (err error) { - err = r.DB.Updates(model.Job{ID: id, Updated: utils.NowJsonTime()}).Error +func (r jobRepository) UpdateLastUpdatedById(c context.Context, id string) (err error) { + err = r.GetDB(c).Updates(model.Job{ID: id, Updated: utils.NowJsonTime()}).Error return } -func (r JobRepository) FindById(id string) (o model.Job, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r jobRepository) FindById(c context.Context, id string) (o model.Job, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r JobRepository) DeleteJobById(id string) error { - //job, err := r.FindById(id) - //if err != nil { - // return err - //} - //if job.Status == constant.JobStatusRunning { - // if err := r.ChangeStatusById(id, constant.JobStatusNotRunning); err != nil { - // return err - // } - //} - return r.DB.Where("id = ?", id).Delete(model.Job{}).Error +func (r jobRepository) DeleteJobById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(model.Job{}).Error } diff --git a/server/repository/job_log.go b/server/repository/job_log.go index cd5e477..2cee1b7 100644 --- a/server/repository/job_log.go +++ b/server/repository/job_log.go @@ -1,45 +1,39 @@ package repository import ( + "context" "time" "next-terminal/server/model" - - "gorm.io/gorm" ) -type JobLogRepository struct { - DB *gorm.DB +type jobLogRepository struct { + baseRepository } -func NewJobLogRepository(db *gorm.DB) *JobLogRepository { - jobLogRepository = &JobLogRepository{DB: db} - return jobLogRepository +func (r jobLogRepository) Create(c context.Context, o *model.JobLog) error { + return r.GetDB(c).Create(o).Error } -func (r JobLogRepository) Create(o *model.JobLog) error { - return r.DB.Create(o).Error -} - -func (r JobLogRepository) FindByJobId(jobId string) (o []model.JobLog, err error) { - err = r.DB.Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error +func (r jobLogRepository) FindByJobId(c context.Context, jobId string) (o []model.JobLog, err error) { + err = r.GetDB(c).Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error return } -func (r JobLogRepository) FindOutTimeLog(dayLimit int) (o []model.JobLog, err error) { +func (r jobLogRepository) FindOutTimeLog(c context.Context, dayLimit int) (o []model.JobLog, err error) { limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour) - err = r.DB.Where("timestamp < ?", limitTime).Find(&o).Error + err = r.GetDB(c).Where("timestamp < ?", limitTime).Find(&o).Error return } -func (r JobLogRepository) DeleteByJobId(jobId string) error { - return r.DB.Where("job_id = ?", jobId).Delete(model.JobLog{}).Error +func (r jobLogRepository) DeleteByJobId(c context.Context, jobId string) error { + return r.GetDB(c).Where("job_id = ?", jobId).Delete(model.JobLog{}).Error } -func (r JobLogRepository) DeleteByIdIn(ids []string) error { - return r.DB.Where("id in ?", ids).Delete(&model.JobLog{}).Error +func (r jobLogRepository) DeleteByIdIn(c context.Context, ids []string) error { + return r.GetDB(c).Where("id in ?", ids).Delete(&model.JobLog{}).Error } -func (r JobLogRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(&model.JobLog{}).Error +func (r jobLogRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(&model.JobLog{}).Error } diff --git a/server/repository/login_log.go b/server/repository/login_log.go index d1b89b2..fe6e0c5 100644 --- a/server/repository/login_log.go +++ b/server/repository/login_log.go @@ -1,26 +1,20 @@ package repository import ( + "context" "time" "next-terminal/server/model" - - "gorm.io/gorm" ) -type LoginLogRepository struct { - DB *gorm.DB +type loginLogRepository struct { + baseRepository } -func NewLoginLogRepository(db *gorm.DB) *LoginLogRepository { - loginLogRepository = &LoginLogRepository{DB: db} - return loginLogRepository -} - -func (r LoginLogRepository) Find(pageIndex, pageSize int, username, clientIp, state string) (o []model.LoginLog, total int64, err error) { +func (r loginLogRepository) Find(c context.Context, pageIndex, pageSize int, username, clientIp, state string) (o []model.LoginLog, total int64, err error) { m := model.LoginLog{} - db := r.DB.Table(m.TableName()) - dbCounter := r.DB.Table(m.TableName()) + db := r.GetDB(c).Table(m.TableName()) + dbCounter := r.GetDB(c).Table(m.TableName()) if username != "" { db = db.Where("username like ?", "%"+username+"%") @@ -49,44 +43,44 @@ func (r LoginLogRepository) Find(pageIndex, pageSize int, username, clientIp, st return } -func (r LoginLogRepository) FindAliveLoginLogs() (o []model.LoginLog, err error) { - err = r.DB.Where("state = '1' and logout_time is null").Find(&o).Error +func (r loginLogRepository) FindAliveLoginLogs(c context.Context) (o []model.LoginLog, err error) { + err = r.GetDB(c).Where("state = '1' and logout_time is null").Find(&o).Error return } -func (r LoginLogRepository) FindAllLoginLogs() (o []model.LoginLog, err error) { - err = r.DB.Find(&o).Error +func (r loginLogRepository) FindAllLoginLogs(c context.Context) (o []model.LoginLog, err error) { + err = r.GetDB(c).Find(&o).Error return } -func (r LoginLogRepository) FindAliveLoginLogsByUsername(username string) (o []model.LoginLog, err error) { - err = r.DB.Where("state = '1' and logout_time is null and username = ?", username).Find(&o).Error +func (r loginLogRepository) FindAliveLoginLogsByUsername(c context.Context, username string) (o []model.LoginLog, err error) { + err = r.GetDB(c).Where("state = '1' and logout_time is null and username = ?", username).Find(&o).Error return } -func (r LoginLogRepository) FindOutTimeLog(dayLimit int) (o []model.LoginLog, err error) { +func (r loginLogRepository) FindOutTimeLog(c context.Context, dayLimit int) (o []model.LoginLog, err error) { limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour) - err = r.DB.Where("(state = '0' and login_time < ?) or (state = '1' and logout_time < ?) or (state is null and logout_time < ?)", limitTime, limitTime, limitTime).Find(&o).Error + err = r.GetDB(c).Where("(state = '0' and login_time < ?) or (state = '1' and logout_time < ?) or (state is null and logout_time < ?)", limitTime, limitTime, limitTime).Find(&o).Error return } -func (r LoginLogRepository) Create(o *model.LoginLog) (err error) { - return r.DB.Create(o).Error +func (r loginLogRepository) Create(c context.Context, o *model.LoginLog) (err error) { + return r.GetDB(c).Create(o).Error } -func (r LoginLogRepository) DeleteByIdIn(ids []string) (err error) { - return r.DB.Where("id in ?", ids).Delete(&model.LoginLog{}).Error +func (r loginLogRepository) DeleteByIdIn(c context.Context, ids []string) (err error) { + return r.GetDB(c).Where("id in ?", ids).Delete(&model.LoginLog{}).Error } -func (r LoginLogRepository) DeleteById(id string) (err error) { - return r.DB.Where("id = ?", id).Delete(&model.LoginLog{}).Error +func (r loginLogRepository) DeleteById(c context.Context, id string) (err error) { + return r.GetDB(c).Where("id = ?", id).Delete(&model.LoginLog{}).Error } -func (r LoginLogRepository) FindById(id string) (o model.LoginLog, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r loginLogRepository) FindById(c context.Context, id string) (o model.LoginLog, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r LoginLogRepository) Update(o *model.LoginLog) error { - return r.DB.Updates(o).Error +func (r loginLogRepository) Update(c context.Context, o *model.LoginLog) error { + return r.GetDB(c).Updates(o).Error } diff --git a/server/repository/property.go b/server/repository/property.go index c4df192..b94cd87 100644 --- a/server/repository/property.go +++ b/server/repository/property.go @@ -1,48 +1,43 @@ package repository import ( - "next-terminal/server/model" + "context" - "gorm.io/gorm" + "next-terminal/server/model" ) -type PropertyRepository struct { - DB *gorm.DB +type propertyRepository struct { + baseRepository } -func NewPropertyRepository(db *gorm.DB) *PropertyRepository { - propertyRepository = &PropertyRepository{DB: db} - return propertyRepository -} - -func (r PropertyRepository) FindAll() (o []model.Property) { - if r.DB.Find(&o).Error != nil { +func (r propertyRepository) FindAll(c context.Context) (o []model.Property) { + if r.GetDB(c).Find(&o).Error != nil { return nil } return } -func (r PropertyRepository) Create(o *model.Property) (err error) { - err = r.DB.Create(o).Error +func (r propertyRepository) Create(c context.Context, o *model.Property) (err error) { + err = r.GetDB(c).Create(o).Error return } -func (r PropertyRepository) UpdateByName(o *model.Property, name string) error { +func (r propertyRepository) UpdateByName(c context.Context, o *model.Property, name string) error { o.Name = name - return r.DB.Updates(o).Error + return r.GetDB(c).Updates(o).Error } -func (r PropertyRepository) DeleteByName(name string) error { - return r.DB.Where("name = ?", name).Delete(model.Property{}).Error +func (r propertyRepository) DeleteByName(c context.Context, name string) error { + return r.GetDB(c).Where("name = ?", name).Delete(model.Property{}).Error } -func (r PropertyRepository) FindByName(name string) (o model.Property, err error) { - err = r.DB.Where("name = ?", name).First(&o).Error +func (r propertyRepository) FindByName(c context.Context, name string) (o model.Property, err error) { + err = r.GetDB(c).Where("name = ?", name).First(&o).Error return } -func (r PropertyRepository) FindAllMap() map[string]string { - properties := r.FindAll() +func (r propertyRepository) FindAllMap(c context.Context) map[string]string { + properties := r.FindAll(c) propertyMap := make(map[string]string) for i := range properties { propertyMap[properties[i].Name] = properties[i].Value diff --git a/server/repository/resource_sharer.go b/server/repository/resource_sharer.go index c6617aa..e368b0d 100644 --- a/server/repository/resource_sharer.go +++ b/server/repository/resource_sharer.go @@ -1,6 +1,8 @@ package repository import ( + "context" + "next-terminal/server/model" "next-terminal/server/utils" @@ -9,17 +11,12 @@ import ( "gorm.io/gorm" ) -type ResourceSharerRepository struct { - DB *gorm.DB +type resourceSharerRepository struct { + baseRepository } -func NewResourceSharerRepository(db *gorm.DB) *ResourceSharerRepository { - resourceSharerRepository = &ResourceSharerRepository{DB: db} - return resourceSharerRepository -} - -func (r *ResourceSharerRepository) OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) { - db := r.DB.Begin() +func (r *resourceSharerRepository) OverwriteUserIdsByResourceId(c context.Context, resourceId, resourceType string, userIds []string) (err error) { + db := r.GetDB(c).Begin() var owner string // 检查资产是否存在 @@ -71,8 +68,8 @@ func (r *ResourceSharerRepository) OverwriteUserIdsByResourceId(resourceId, reso return nil } -func (r *ResourceSharerRepository) DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceType string, resourceIds []string) error { - db := r.DB +func (r *resourceSharerRepository) DeleteByUserIdAndResourceTypeAndResourceIdIn(c context.Context, userGroupId, userId, resourceType string, resourceIds []string) error { + db := r.GetDB(c) if userGroupId != "" { db = db.Where("user_group_id = ?", userGroupId) } @@ -92,12 +89,20 @@ func (r *ResourceSharerRepository) DeleteByUserIdAndResourceTypeAndResourceIdIn( return db.Delete(&model.ResourceSharer{}).Error } -func (r *ResourceSharerRepository) DeleteResourceSharerByResourceId(resourceId string) error { - return r.DB.Where("resource_id = ?", resourceId).Delete(&model.ResourceSharer{}).Error +func (r *resourceSharerRepository) DeleteByResourceId(c context.Context, resourceId string) error { + return r.GetDB(c).Where("resource_id = ?", resourceId).Delete(&model.ResourceSharer{}).Error } -func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, strategyId, resourceType string, resourceIds []string) error { - return r.DB.Transaction(func(tx *gorm.DB) (err error) { +func (r *resourceSharerRepository) DeleteByUserId(c context.Context, userId string) error { + return r.GetDB(c).Where("user_id = ?", userId).Delete(&model.ResourceSharer{}).Error +} + +func (r *resourceSharerRepository) DeleteByUserGroupId(c context.Context, userGroupId string) error { + return r.GetDB(c).Where("user_group_id = ?", userGroupId).Delete(&model.ResourceSharer{}).Error +} + +func (r *resourceSharerRepository) AddSharerResources(userGroupId, userId, strategyId, resourceType string, resourceIds []string) error { + return r.GetDB(context.TODO()).Transaction(func(tx *gorm.DB) (err error) { for i := range resourceIds { resourceId := resourceIds[i] @@ -149,22 +154,22 @@ func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, strat }) } -func (r *ResourceSharerRepository) FindAssetIdsByUserId(userId string) (assetIds []string, err error) { +func (r *resourceSharerRepository) FindAssetIdsByUserId(c context.Context, userId string) (assetIds []string, err error) { // 查询当前用户创建的资产 var ownerAssetIds, sharerAssetIds []string asset := model.Asset{} - err = r.DB.Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error + err = r.GetDB(c).Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error if err != nil { return nil, err } // 查询其他用户授权给该用户的资产 - groupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + groupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, userId) if err != nil { return nil, err } - db := r.DB.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId) + db := r.GetDB(c).Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId) if len(groupIds) > 0 { db = db.Or("user_group_id in ?", groupIds) } @@ -187,13 +192,13 @@ func (r *ResourceSharerRepository) FindAssetIdsByUserId(userId string) (assetIds return } -func (r *ResourceSharerRepository) FindByResourceIdAndUserId(assetId, userId string) (resourceSharers []model.ResourceSharer, err error) { +func (r *resourceSharerRepository) FindByResourceIdAndUserId(c context.Context, assetId, userId string) (resourceSharers []model.ResourceSharer, err error) { // 查询其他用户授权给该用户的资产 - groupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId) + groupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, userId) if err != nil { return } - db := r.DB.Where("( resource_id = ? and user_id = ? )", assetId, userId) + db := r.GetDB(c).Where("( resource_id = ? and user_id = ? )", assetId, userId) if len(groupIds) > 0 { db = db.Or("user_group_id in ?", groupIds) } @@ -201,8 +206,8 @@ func (r *ResourceSharerRepository) FindByResourceIdAndUserId(assetId, userId str return } -func (r *ResourceSharerRepository) Find(resourceId, resourceType, userId, userGroupId string) (resourceSharers []model.ResourceSharer, err error) { - db := r.DB +func (r *resourceSharerRepository) Find(c context.Context, resourceId, resourceType, userId, userGroupId string) (resourceSharers []model.ResourceSharer, err error) { + db := r.GetDB(c) if resourceId != "" { db = db.Where("resource_id = ?") } @@ -219,7 +224,7 @@ func (r *ResourceSharerRepository) Find(resourceId, resourceType, userId, userGr return } -func (r *ResourceSharerRepository) FindAll() (o []model.ResourceSharer, err error) { - err = r.DB.Find(&o).Error +func (r *resourceSharerRepository) FindAll(c context.Context) (o []model.ResourceSharer, err error) { + err = r.GetDB(c).Find(&o).Error return } diff --git a/server/repository/security.go b/server/repository/security.go new file mode 100644 index 0000000..2a89a60 --- /dev/null +++ b/server/repository/security.go @@ -0,0 +1,75 @@ +package repository + +import ( + "context" + + "next-terminal/server/model" +) + +type securityRepository struct { + baseRepository +} + +func (r securityRepository) FindAll(c context.Context) (o []model.AccessSecurity, err error) { + err = r.GetDB(c).Order("priority asc").Find(&o).Error + return +} + +func (r securityRepository) Find(c context.Context, pageIndex, pageSize int, ip, rule, order, field string) (o []model.AccessSecurity, total int64, err error) { + t := model.AccessSecurity{} + db := r.GetDB(c).Table(t.TableName()) + dbCounter := r.GetDB(c).Table(t.TableName()) + + if len(ip) > 0 { + db = db.Where("ip like ?", "%"+ip+"%") + dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%") + } + + if len(rule) > 0 { + db = db.Where("rule = ?", rule) + dbCounter = dbCounter.Where("rule = ?", rule) + } + + err = dbCounter.Count(&total).Error + if err != nil { + return nil, 0, err + } + + if order == "descend" { + order = "desc" + } else { + order = "asc" + } + + if field == "ip" { + field = "ip" + } else if field == "rule" { + field = "rule" + } else { + field = "priority" + } + + err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error + if o == nil { + o = make([]model.AccessSecurity, 0) + } + return +} + +func (r securityRepository) Create(c context.Context, o *model.AccessSecurity) error { + return r.GetDB(c).Create(o).Error +} + +func (r securityRepository) UpdateById(c context.Context, o *model.AccessSecurity, id string) error { + o.ID = id + return r.GetDB(c).Updates(o).Error +} + +func (r securityRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(model.AccessSecurity{}).Error +} + +func (r securityRepository) FindById(c context.Context, id string) (o *model.AccessSecurity, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error + return +} diff --git a/server/repository/session.go b/server/repository/session.go index 230a158..7052706 100644 --- a/server/repository/session.go +++ b/server/repository/session.go @@ -1,7 +1,7 @@ package repository import ( - "encoding/base64" + "context" "os" "path" "time" @@ -9,23 +9,15 @@ import ( "next-terminal/server/config" "next-terminal/server/constant" "next-terminal/server/model" - "next-terminal/server/utils" - - "gorm.io/gorm" ) -type SessionRepository struct { - DB *gorm.DB +type sessionRepository struct { + baseRepository } -func NewSessionRepository(db *gorm.DB) *SessionRepository { - sessionRepository = &SessionRepository{DB: db} - return sessionRepository -} +func (r sessionRepository) Find(c context.Context, pageIndex, pageSize int, status, userId, clientIp, assetId, protocol, reviewed string) (results []model.SessionForPage, total int64, err error) { -func (r SessionRepository) Find(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol, reviewed string) (results []model.SessionForPage, total int64, err error) { - - db := r.DB + db := r.GetDB(c) var params []interface{} params = append(params, status) @@ -77,152 +69,102 @@ func (r SessionRepository) Find(pageIndex, pageSize int, status, userId, clientI return } -func (r SessionRepository) FindByStatus(status string) (o []model.Session, err error) { - err = r.DB.Where("status = ?", status).Find(&o).Error +func (r sessionRepository) FindByStatus(c context.Context, status string) (o []model.Session, err error) { + err = r.GetDB(c).Where("status = ?", status).Find(&o).Error return } -func (r SessionRepository) FindByStatusIn(statuses []string) (o []model.Session, err error) { - err = r.DB.Where("status in ?", statuses).Find(&o).Error +func (r sessionRepository) FindByStatusIn(c context.Context, statuses []string) (o []model.Session, err error) { + err = r.GetDB(c).Where("status in ?", statuses).Find(&o).Error return } -func (r SessionRepository) FindOutTimeSessions(dayLimit int) (o []model.Session, err error) { +func (r sessionRepository) FindOutTimeSessions(c context.Context, dayLimit int) (o []model.Session, err error) { limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour) - err = r.DB.Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error + err = r.GetDB(c).Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error return } -func (r SessionRepository) Create(o *model.Session) (err error) { - err = r.DB.Create(o).Error +func (r sessionRepository) Create(c context.Context, o *model.Session) (err error) { + err = r.GetDB(c).Create(o).Error return } -func (r SessionRepository) FindById(id string) (o model.Session, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r sessionRepository) FindById(c context.Context, id string) (o model.Session, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r SessionRepository) FindByIdAndDecrypt(id string) (o model.Session, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error - if err == nil { - err = r.Decrypt(&o) - } +func (r sessionRepository) FindByConnectionId(c context.Context, connectionId string) (o model.Session, err error) { + err = r.GetDB(c).Where("connection_id = ?", connectionId).First(&o).Error return } -func (r SessionRepository) Decrypt(item *model.Session) error { - if item.Password != "" && item.Password != "-" { - origData, err := base64.StdEncoding.DecodeString(item.Password) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword) - if err != nil { - return err - } - item.Password = string(decryptedCBC) - } - if item.PrivateKey != "" && item.PrivateKey != "-" { - origData, err := base64.StdEncoding.DecodeString(item.PrivateKey) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword) - if err != nil { - return err - } - item.PrivateKey = string(decryptedCBC) - } - if item.Passphrase != "" && item.Passphrase != "-" { - origData, err := base64.StdEncoding.DecodeString(item.Passphrase) - if err != nil { - return err - } - decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword) - if err != nil { - return err - } - item.Passphrase = string(decryptedCBC) - } - return nil -} - -func (r SessionRepository) FindByConnectionId(connectionId string) (o model.Session, err error) { - err = r.DB.Where("connection_id = ?", connectionId).First(&o).Error - return -} - -func (r SessionRepository) UpdateById(o *model.Session, id string) error { +func (r sessionRepository) UpdateById(c context.Context, o *model.Session, id string) error { o.ID = id - return r.DB.Updates(o).Error + return r.GetDB(c).Updates(o).Error } -func (r SessionRepository) UpdateWindowSizeById(width, height int, id string) error { +func (r sessionRepository) UpdateWindowSizeById(c context.Context, width, height int, id string) error { session := model.Session{} session.Width = width session.Height = height - return r.UpdateById(&session, id) + return r.UpdateById(c, &session, id) } -func (r SessionRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(&model.Session{}).Error +func (r sessionRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(&model.Session{}).Error } -func (r SessionRepository) DeleteByIds(sessionIds []string) error { +func (r sessionRepository) DeleteByIds(c context.Context, sessionIds []string) error { recordingPath := config.GlobalCfg.Guacd.Recording for i := range sessionIds { if err := os.RemoveAll(path.Join(recordingPath, sessionIds[i])); err != nil { return err } - if err := r.DeleteById(sessionIds[i]); err != nil { + if err := r.DeleteById(c, sessionIds[i]); err != nil { return err } } return nil } -func (r SessionRepository) DeleteByStatus(status string) error { - return r.DB.Where("status = ?", status).Delete(&model.Session{}).Error +func (r sessionRepository) DeleteByStatus(c context.Context, status string) error { + return r.GetDB(c).Where("status = ?", status).Delete(&model.Session{}).Error } -func (r SessionRepository) CountOnlineSession() (total int64, err error) { - err = r.DB.Where("status = ?", constant.Connected).Find(&model.Session{}).Count(&total).Error +func (r sessionRepository) CountOnlineSession(c context.Context) (total int64, err error) { + err = r.GetDB(c).Where("status = ?", constant.Connected).Find(&model.Session{}).Count(&total).Error return } -func (r SessionRepository) EmptyPassword() error { +func (r sessionRepository) EmptyPassword(c context.Context) error { sql := "update sessions set password = '-',private_key = '-', passphrase = '-' where 1=1" - return r.DB.Exec(sql).Error + return r.GetDB(c).Exec(sql).Error } -func (r SessionRepository) CountByStatus(status string) (total int64, err error) { - err = r.DB.Find(&model.Session{}).Where("status = ?", status).Count(&total).Error +func (r sessionRepository) CountByStatus(c context.Context, status string) (total int64, err error) { + err = r.GetDB(c).Find(&model.Session{}).Where("status = ?", status).Count(&total).Error return } -func (r SessionRepository) OverviewAccess(account model.User) (o []model.SessionForAccess, err error) { - db := r.DB - if constant.TypeUser == account.Type { - sql := "SELECT s.asset_id, s.ip, s.port, s.protocol, s.username, count(s.asset_id) AS access_count FROM sessions AS s where s.creator = ? GROUP BY s.asset_id, s.ip, s.port, s.protocol, s.username ORDER BY access_count DESC limit 10" - err = db.Raw(sql, []string{account.ID}).Scan(&o).Error - } else { - sql := "SELECT s.asset_id, s.ip, s.port, s.protocol, s.username, count(s.asset_id) AS access_count FROM sessions AS s GROUP BY s.asset_id, s.ip, s.port, s.protocol, s.username ORDER BY access_count DESC limit 10" - err = db.Raw(sql).Scan(&o).Error - } +func (r sessionRepository) OverviewAccess(c context.Context) (o []model.SessionForAccess, err error) { + db := r.GetDB(c) + sql := "SELECT s.asset_id, s.ip, s.port, s.protocol, s.username, count(s.asset_id) AS access_count FROM sessions AS s GROUP BY s.asset_id, s.ip, s.port, s.protocol, s.username ORDER BY access_count DESC limit 10" + err = db.Raw(sql).Scan(&o).Error if o == nil { o = make([]model.SessionForAccess, 0) } return } -func (r SessionRepository) UpdateReadByIds(reviewed bool, ids []string) error { +func (r sessionRepository) UpdateReadByIds(c context.Context, reviewed bool, ids []string) error { sql := "update sessions set reviewed = ? where id in ?" - return r.DB.Exec(sql, reviewed, ids).Error + return r.GetDB(c).Exec(sql, reviewed, ids).Error } -func (r SessionRepository) FindAllUnReviewed() (o []model.Session, err error) { - err = r.DB.Where("reviewed = false or reviewed is null").Find(&o).Error +func (r sessionRepository) FindAllUnReviewed(c context.Context) (o []model.Session, err error) { + err = r.GetDB(c).Where("reviewed = false or reviewed is null").Find(&o).Error return } diff --git a/server/repository/storage.go b/server/repository/storage.go index ba7df9d..deddb94 100644 --- a/server/repository/storage.go +++ b/server/repository/storage.go @@ -1,24 +1,19 @@ package repository import ( - "next-terminal/server/model" + "context" - "gorm.io/gorm" + "next-terminal/server/model" ) -type StorageRepository struct { - DB *gorm.DB +type storageRepository struct { + baseRepository } -func NewStorageRepository(db *gorm.DB) *StorageRepository { - storageRepository = &StorageRepository{DB: db} - return storageRepository -} - -func (r StorageRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.StorageForPage, total int64, err error) { +func (r storageRepository) Find(c context.Context, pageIndex, pageSize int, name, order, field string) (o []model.StorageForPage, total int64, err error) { m := model.Storage{} - db := r.DB.Table(m.TableName()).Select("storages.id,storages.name,storages.is_share,storages.limit_size,storages.is_default,storages.owner,storages.created, users.nickname as owner_name").Joins("left join users on storages.owner = users.id") - dbCounter := r.DB.Table(m.TableName()) + db := r.GetDB(c).Table(m.TableName()).Select("storages.id,storages.name,storages.is_share,storages.limit_size,storages.is_default,storages.owner,storages.created, users.nickname as owner_name").Joins("left join users on storages.owner = users.id") + dbCounter := r.GetDB(c).Table(m.TableName()) if len(name) > 0 { db = db.Where("name like ?", "%"+name+"%") @@ -49,37 +44,37 @@ func (r StorageRepository) Find(pageIndex, pageSize int, name, order, field stri return } -func (r StorageRepository) FindShares() (o []model.Storage, err error) { +func (r storageRepository) FindShares(c context.Context) (o []model.Storage, err error) { m := model.Storage{} - db := r.DB.Table(m.TableName()).Where("is_share = 1") + db := r.GetDB(c).Table(m.TableName()).Where("is_share = 1") err = db.Find(&o).Error return } -func (r StorageRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(model.Storage{}).Error +func (r storageRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(model.Storage{}).Error } -func (r StorageRepository) Create(m *model.Storage) error { - return r.DB.Create(m).Error +func (r storageRepository) Create(c context.Context, m *model.Storage) error { + return r.GetDB(c).Create(m).Error } -func (r StorageRepository) UpdateById(o *model.Storage, id string) error { +func (r storageRepository) UpdateById(c context.Context, o *model.Storage, id string) error { o.ID = id - return r.DB.Updates(o).Error + return r.GetDB(c).Updates(o).Error } -func (r StorageRepository) FindByOwnerIdAndDefault(owner string, isDefault bool) (m model.Storage, err error) { - err = r.DB.Where("owner = ? and is_default = ?", owner, isDefault).First(&m).Error +func (r storageRepository) FindByOwnerIdAndDefault(c context.Context, owner string, isDefault bool) (m model.Storage, err error) { + err = r.GetDB(c).Where("owner = ? and is_default = ?", owner, isDefault).First(&m).Error return } -func (r StorageRepository) FindById(id string) (m model.Storage, err error) { - err = r.DB.Where("id = ?", id).First(&m).Error +func (r storageRepository) FindById(c context.Context, id string) (m model.Storage, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&m).Error return } -func (r StorageRepository) FindAll() (o []model.Storage, err error) { - err = r.DB.Find(&o).Error +func (r storageRepository) FindAll(c context.Context) (o []model.Storage, err error) { + err = r.GetDB(c).Find(&o).Error return } diff --git a/server/repository/strategy.go b/server/repository/strategy.go index 58df818..77c6a6a 100644 --- a/server/repository/strategy.go +++ b/server/repository/strategy.go @@ -1,29 +1,24 @@ package repository import ( - "next-terminal/server/model" + "context" - "gorm.io/gorm" + "next-terminal/server/model" ) -type StrategyRepository struct { - DB *gorm.DB +type strategyRepository struct { + baseRepository } -func NewStrategyRepository(db *gorm.DB) *StrategyRepository { - strategyRepository = &StrategyRepository{DB: db} - return strategyRepository -} - -func (r StrategyRepository) FindAll() (o []model.Strategy, err error) { - err = r.DB.Order("name desc").Find(&o).Error +func (r strategyRepository) FindAll(c context.Context) (o []model.Strategy, err error) { + err = r.GetDB(c).Order("name desc").Find(&o).Error return } -func (r StrategyRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.Strategy, total int64, err error) { +func (r strategyRepository) Find(c context.Context, pageIndex, pageSize int, name, order, field string) (o []model.Strategy, total int64, err error) { m := model.Strategy{} - db := r.DB.Table(m.TableName()) - dbCounter := r.DB.Table(m.TableName()) + db := r.GetDB(c).Table(m.TableName()) + dbCounter := r.GetDB(c).Table(m.TableName()) if len(name) > 0 { db = db.Where("name like ?", "%"+name+"%") @@ -54,20 +49,20 @@ func (r StrategyRepository) Find(pageIndex, pageSize int, name, order, field str return } -func (r StrategyRepository) DeleteById(id string) error { - return r.DB.Where("id = ?", id).Delete(model.Strategy{}).Error +func (r strategyRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(model.Strategy{}).Error } -func (r StrategyRepository) Create(m *model.Strategy) error { - return r.DB.Create(m).Error +func (r strategyRepository) Create(c context.Context, m *model.Strategy) error { + return r.GetDB(c).Create(m).Error } -func (r StrategyRepository) UpdateById(o *model.Strategy, id string) error { +func (r strategyRepository) UpdateById(c context.Context, o *model.Strategy, id string) error { o.ID = id - return r.DB.Updates(o).Error + return r.GetDB(c).Updates(o).Error } -func (r StrategyRepository) FindById(id string) (m model.Strategy, err error) { - err = r.DB.Where("id = ?", id).First(&m).Error +func (r strategyRepository) FindById(c context.Context, id string) (m model.Strategy, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&m).Error return } diff --git a/server/repository/user.go b/server/repository/user.go index 22a50e7..5b9cbb0 100644 --- a/server/repository/user.go +++ b/server/repository/user.go @@ -1,35 +1,23 @@ package repository import ( - "next-terminal/server/constant" - "next-terminal/server/model" + "context" - "gorm.io/gorm" + "next-terminal/server/model" ) -type UserRepository struct { - DB *gorm.DB +type userRepository struct { + baseRepository } -func NewUserRepository(db *gorm.DB) *UserRepository { - userRepository = &UserRepository{DB: db} - return userRepository -} - -func (r UserRepository) FindAll() (o []model.User, err error) { - err = r.DB.Find(&o).Error +func (r userRepository) FindAll(c context.Context) (o []model.User, err error) { + err = r.GetDB(c).Find(&o).Error return } -func (r UserRepository) Find(pageIndex, pageSize int, username, nickname, mail, order, field string, account model.User) (o []model.UserForPage, total int64, err error) { - db := r.DB.Table("users").Select("users.id,users.username,users.nickname,users.mail,users.online,users.created,users.type,users.status, count(resource_sharers.user_id) as sharer_asset_count, users.totp_secret").Joins("left join resource_sharers on users.id = resource_sharers.user_id and resource_sharers.resource_type = 'asset'").Group("users.id") - dbCounter := r.DB.Table("users") - - if constant.TypeUser == account.Type { - // 普通用户只能查看到普通用户 - db = db.Where("users.type = ?", constant.TypeUser) - dbCounter = dbCounter.Where("type = ?", constant.TypeUser) - } +func (r userRepository) Find(c context.Context, pageIndex, pageSize int, username, nickname, mail, order, field string) (o []model.UserForPage, total int64, err error) { + db := r.GetDB(c).Table("users").Select("users.id,users.username,users.nickname,users.mail,users.online,users.created,users.type,users.status,users.source, count(resource_sharers.user_id) as sharer_asset_count, users.totp_secret").Joins("left join resource_sharers on users.id = resource_sharers.user_id and resource_sharers.resource_type = 'asset'").Group("users.id") + dbCounter := r.GetDB(c).Table("users") if len(username) > 0 { db = db.Where("users.username like ?", "%"+username+"%") @@ -80,19 +68,19 @@ func (r UserRepository) Find(pageIndex, pageSize int, username, nickname, mail, return } -func (r UserRepository) FindById(id string) (o model.User, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r userRepository) FindById(c context.Context, id string) (o model.User, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r UserRepository) FindByUsername(username string) (o model.User, err error) { - err = r.DB.Where("username = ?", username).First(&o).Error +func (r userRepository) FindByUsername(c context.Context, username string) (o model.User, err error) { + err = r.GetDB(c).Where("username = ?", username).First(&o).Error return } -func (r UserRepository) ExistByUsername(username string) (exist bool) { +func (r userRepository) ExistByUsername(c context.Context, username string) (exist bool) { count := int64(0) - err := r.DB.Table("users").Where("username = ?", username).Count(&count).Error + err := r.GetDB(c).Table("users").Where("username = ?", username).Count(&count).Error if err != nil { return false } @@ -100,51 +88,38 @@ func (r UserRepository) ExistByUsername(username string) (exist bool) { return count > 0 } -func (r UserRepository) FindOnlineUsers() (o []model.User, err error) { - err = r.DB.Where("online = ?", true).Find(&o).Error +func (r userRepository) FindOnlineUsers(c context.Context) (o []model.User, err error) { + err = r.GetDB(c).Where("online = ?", true).Find(&o).Error return } -func (r UserRepository) Create(o *model.User) error { - return r.DB.Create(o).Error +func (r userRepository) Create(c context.Context, o *model.User) error { + return r.GetDB(c).Create(o).Error } -func (r UserRepository) Update(o *model.User) error { - return r.DB.Updates(o).Error +func (r userRepository) Update(c context.Context, o *model.User) error { + return r.GetDB(c).Updates(o).Error } -func (r UserRepository) UpdateOnlineByUsername(username string, online bool) error { +func (r userRepository) UpdateOnlineByUsername(c context.Context, username string, online bool) error { sql := "update users set online = ? where username = ?" - return r.DB.Exec(sql, online, username).Error + return r.GetDB(c).Exec(sql, online, username).Error } -func (r UserRepository) DeleteById(id string) error { - return r.DB.Transaction(func(tx *gorm.DB) (err error) { - // 删除用户 - err = tx.Where("id = ?", id).Delete(&model.User{}).Error - if err != nil { - return err - } - // 删除用户组中的用户关系 - err = tx.Where("user_id = ?", id).Delete(&model.UserGroupMember{}).Error - if err != nil { - return err - } - // 删除用户分享到的资产 - err = tx.Where("user_id = ?", id).Delete(&model.ResourceSharer{}).Error - if err != nil { - return err - } - return nil - }) +func (r userRepository) DeleteById(c context.Context, id string) error { + return r.GetDB(c).Where("id = ?", id).Delete(&model.User{}).Error } -func (r UserRepository) CountOnlineUser() (total int64, err error) { - err = r.DB.Where("online = ?", true).Find(&model.User{}).Count(&total).Error +func (r userRepository) DeleteBySource(c context.Context, source string) error { + return r.GetDB(c).Where("source = ?", source).Delete(&model.User{}).Error +} + +func (r userRepository) CountOnlineUser(c context.Context) (total int64, err error) { + err = r.GetDB(c).Where("online = ?", true).Find(&model.User{}).Count(&total).Error return } -func (r UserRepository) Count() (total int64, err error) { - err = r.DB.Find(&model.User{}).Count(&total).Error +func (r userRepository) Count(c context.Context) (total int64, err error) { + err = r.GetDB(c).Find(&model.User{}).Count(&total).Error return } diff --git a/server/repository/user_group.go b/server/repository/user_group.go index 450ccea..54f3675 100644 --- a/server/repository/user_group.go +++ b/server/repository/user_group.go @@ -1,29 +1,23 @@ package repository import ( - "next-terminal/server/model" - "next-terminal/server/utils" + "context" - "gorm.io/gorm" + "next-terminal/server/model" ) -type UserGroupRepository struct { - DB *gorm.DB +type userGroupRepository struct { + baseRepository } -func NewUserGroupRepository(db *gorm.DB) *UserGroupRepository { - userGroupRepository = &UserGroupRepository{DB: db} - return userGroupRepository -} - -func (r UserGroupRepository) FindAll() (o []model.UserGroup, err error) { - err = r.DB.Find(&o).Error +func (r userGroupRepository) FindAll(c context.Context) (o []model.UserGroup, err error) { + err = r.GetDB(c).Find(&o).Error return } -func (r UserGroupRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.UserGroupForPage, total int64, err error) { - db := r.DB.Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id") - dbCounter := r.DB.Table("user_groups") +func (r userGroupRepository) Find(c context.Context, pageIndex, pageSize int, name, order, field string) (o []model.UserGroupForPage, total int64, err error) { + db := r.GetDB(c).Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id") + dbCounter := r.GetDB(c).Table("user_groups") if len(name) > 0 { db = db.Where("user_groups.name like ?", "%"+name+"%") dbCounter = dbCounter.Where("name like ?", "%"+name+"%") @@ -53,94 +47,29 @@ func (r UserGroupRepository) Find(pageIndex, pageSize int, name, order, field st return } -func (r UserGroupRepository) FindById(id string) (o model.UserGroup, err error) { - err = r.DB.Where("id = ?", id).First(&o).Error +func (r userGroupRepository) FindById(c context.Context, id string) (o model.UserGroup, err error) { + err = r.GetDB(c).Where("id = ?", id).First(&o).Error return } -func (r UserGroupRepository) FindUserGroupIdsByUserId(userId string) (o []string, err error) { - // 先查询用户所在的用户 - err = r.DB.Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error +func (r userGroupRepository) FindByName(c context.Context, name string) (o model.UserGroup, err error) { + err = r.GetDB(c).Where("name = ?", name).First(&o).Error return } -func (r UserGroupRepository) FindMembersById(userGroupId string) (o []string, err error) { - err = r.DB.Table("user_group_members").Select("user_id").Where("user_group_id = ?", userGroupId).Find(&o).Error - return -} - -func (r UserGroupRepository) Create(o *model.UserGroup, members []string) (err error) { - return r.DB.Transaction(func(tx *gorm.DB) error { - err = tx.Create(o).Error - if err != nil { - return err - } - - if members != nil { - userGroupId := o.ID - err = AddUserGroupMembers(tx, members, userGroupId) - if err != nil { - return err - } - } - return err - }) -} - -func (r UserGroupRepository) Update(o *model.UserGroup, members []string, id string) error { - return r.DB.Transaction(func(tx *gorm.DB) error { - o.ID = id - err := tx.Updates(o).Error - if err != nil { - return err - } - - err = tx.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{}).Error - if err != nil { - return err - } - if members != nil { - userGroupId := o.ID - err = AddUserGroupMembers(tx, members, userGroupId) - if err != nil { - return err - } - } - return err - }) -} - -func (r UserGroupRepository) DeleteById(id string) (err error) { - err = r.DB.Where("id = ?", id).Delete(&model.UserGroup{}).Error - if err != nil { - return err - } - return r.DB.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{}).Error -} - -func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error { - userRepository := NewUserRepository(tx) - for i := range userIds { - userId := userIds[i] - _, err := userRepository.FindById(userId) - if err != nil { - return err - } - - userGroupMember := model.UserGroupMember{ - ID: utils.Sign([]string{userGroupId, userId}), - UserId: userId, - UserGroupId: userGroupId, - } - err = tx.Create(&userGroupMember).Error - if err != nil { - return err - } - } - return nil -} - -func (r UserGroupRepository) FindAllUserGroupMembers() (o []model.UserGroupMember, err error) { - err = r.DB.Find(&o).Error +func (r userGroupRepository) Create(c context.Context, o *model.UserGroup) (err error) { + return r.GetDB(c).Create(o).Error +} + +func (r userGroupRepository) Update(c context.Context, o *model.UserGroup) error { + return r.GetDB(c).Updates(o).Error +} + +func (r userGroupRepository) DeleteById(c context.Context, id string) (err error) { + return r.GetDB(c).Where("id = ?", id).Delete(&model.UserGroup{}).Error +} + +func (r userGroupRepository) FindAllUserGroupMembers() (c context.Context, o []model.UserGroupMember, err error) { + err = r.GetDB(c).Find(&o).Error return } diff --git a/server/repository/user_group_member.go b/server/repository/user_group_member.go new file mode 100644 index 0000000..b24708e --- /dev/null +++ b/server/repository/user_group_member.go @@ -0,0 +1,34 @@ +package repository + +import ( + "context" + + "next-terminal/server/model" +) + +type userGroupMemberRepository struct { + baseRepository +} + +func (r userGroupMemberRepository) FindUserIdsByUserGroupId(c context.Context, userGroupId string) (o []string, err error) { + err = r.GetDB(c).Table("user_group_members").Select("user_id").Where("user_group_id = ?", userGroupId).Find(&o).Error + return +} + +func (r userGroupMemberRepository) FindUserGroupIdsByUserId(c context.Context, userId string) (o []string, err error) { + // 先查询用户所在的用户 + err = r.GetDB(c).Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error + return +} + +func (r userGroupMemberRepository) Create(c context.Context, o *model.UserGroupMember) error { + return r.GetDB(c).Create(o).Error +} + +func (r userGroupMemberRepository) DeleteByUserId(c context.Context, userId string) error { + return r.GetDB(c).Where("user_id = ?", userId).Delete(&model.UserGroupMember{}).Error +} + +func (r userGroupMemberRepository) DeleteByUserGroupId(c context.Context, userGroupId string) error { + return r.GetDB(c).Where("user_group_id = ?", userGroupId).Delete(&model.UserGroupMember{}).Error +} diff --git a/server/repository/var.go b/server/repository/var.go new file mode 100644 index 0000000..ac00080 --- /dev/null +++ b/server/repository/var.go @@ -0,0 +1,21 @@ +package repository + +var ( + PropertyRepository = new(propertyRepository) + UserRepository = new(userRepository) + UserGroupRepository = new(userGroupRepository) + UserGroupMemberRepository = new(userGroupMemberRepository) + ResourceSharerRepository = new(resourceSharerRepository) + AssetRepository = new(assetRepository) + CredentialRepository = new(credentialRepository) + CommandRepository = new(commandRepository) + SessionRepository = new(sessionRepository) + SecurityRepository = new(securityRepository) + GatewayRepository = new(gatewayRepository) + JobRepository = new(jobRepository) + JobLogRepository = new(jobLogRepository) + LoginLogRepository = new(loginLogRepository) + StorageRepository = new(storageRepository) + StrategyRepository = new(strategyRepository) + AccessTokenRepository = new(accessTokenRepository) +) diff --git a/server/service/access_gateway.go b/server/service/access_gateway.go deleted file mode 100644 index 705c60a..0000000 --- a/server/service/access_gateway.go +++ /dev/null @@ -1,75 +0,0 @@ -package service - -import ( - "next-terminal/server/global/gateway" - "next-terminal/server/log" - "next-terminal/server/model" - "next-terminal/server/repository" - "next-terminal/server/term" -) - -type AccessGatewayService struct { - accessGatewayRepository *repository.AccessGatewayRepository -} - -func NewAccessGatewayService(accessGatewayRepository *repository.AccessGatewayRepository) *AccessGatewayService { - accessGatewayService = &AccessGatewayService{accessGatewayRepository: accessGatewayRepository} - return accessGatewayService -} - -func (r AccessGatewayService) GetGatewayAndReconnectById(accessGatewayId string) (g *gateway.Gateway, err error) { - g = gateway.GlobalGatewayManager.GetById(accessGatewayId) - if g == nil || !g.Connected { - accessGateway, err := r.accessGatewayRepository.FindById(accessGatewayId) - if err != nil { - return nil, err - } - g = r.ReConnect(&accessGateway) - } - return g, nil -} - -func (r AccessGatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gateway, err error) { - g = gateway.GlobalGatewayManager.GetById(accessGatewayId) - if g == nil { - accessGateway, err := r.accessGatewayRepository.FindById(accessGatewayId) - if err != nil { - return nil, err - } - g = r.ReConnect(&accessGateway) - } - return g, nil -} - -func (r AccessGatewayService) ReConnectAll() error { - gateways, err := r.accessGatewayRepository.FindAll() - if err != nil { - return err - } - if len(gateways) > 0 { - for i := range gateways { - r.ReConnect(&gateways[i]) - } - } - - return nil -} - -func (r AccessGatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway { - log.Debugf("重建接入网关「%v」中...", m.Name) - r.DisconnectById(m.ID) - sshClient, err := term.NewSshClient(m.IP, m.Port, m.Username, m.Password, m.PrivateKey, m.Passphrase) - var g *gateway.Gateway - if err != nil { - g = gateway.NewGateway(m.ID, m.Localhost, false, err.Error(), nil) - } else { - g = gateway.NewGateway(m.ID, m.Localhost, true, "", sshClient) - } - gateway.GlobalGatewayManager.Add <- g - log.Debugf("重建接入网关「%v」完成", m.Name) - return g -} - -func (r AccessGatewayService) DisconnectById(accessGatewayId string) { - gateway.GlobalGatewayManager.Del <- accessGatewayId -} diff --git a/server/service/access_token.go b/server/service/access_token.go new file mode 100644 index 0000000..c10c3df --- /dev/null +++ b/server/service/access_token.go @@ -0,0 +1,86 @@ +package service + +import ( + "context" + "errors" + + "next-terminal/server/constant" + "next-terminal/server/dto" + "next-terminal/server/env" + "next-terminal/server/global/cache" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" + + "gorm.io/gorm" +) + +type accessTokenService struct { + baseService +} + +func (service accessTokenService) FindByUserId(userId string) (model.AccessToken, error) { + return repository.AccessTokenRepository.FindByUserId(context.TODO(), userId) +} + +func (service accessTokenService) GenAccessToken(userId string) error { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + ctx := service.Context(tx) + + user, err := repository.UserRepository.FindById(ctx, userId) + if err != nil { + return err + } + oldAccessToken, err := repository.AccessTokenRepository.FindByUserId(ctx, userId) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if oldAccessToken.Token != "" { + cache.TokenManager.Delete(oldAccessToken.Token) + } + if err := repository.AccessTokenRepository.DeleteByUserId(ctx, userId); err != nil { + return err + } + + token := "forever-" + utils.UUID() + accessToken := &model.AccessToken{ + ID: utils.UUID(), + UserId: userId, + Token: token, + Created: utils.NowJsonTime(), + } + + authorization := dto.Authorization{ + Token: token, + Remember: false, + Type: constant.AccessToken, + User: &user, + } + + cache.TokenManager.Set(token, authorization, cache.NoExpiration) + + return repository.AccessTokenRepository.Create(ctx, accessToken) + }) +} + +func (service accessTokenService) Reload() error { + accessTokens, err := repository.AccessTokenRepository.FindAll(context.TODO()) + if err != nil { + return err + } + for _, accessToken := range accessTokens { + user, err := repository.UserRepository.FindById(context.TODO(), accessToken.UserId) + if err != nil { + return err + } + authorization := dto.Authorization{ + Token: accessToken.Token, + Remember: false, + Type: constant.AccessToken, + User: &user, + } + + cache.TokenManager.Set(accessToken.Token, authorization, cache.NoExpiration) + } + return nil +} diff --git a/server/service/asset.go b/server/service/asset.go index a25d4a2..c58b564 100644 --- a/server/service/asset.go +++ b/server/service/asset.go @@ -1,21 +1,26 @@ package service import ( + "context" + "encoding/base64" + "encoding/json" + "next-terminal/server/config" + "next-terminal/server/env" + "next-terminal/server/model" "next-terminal/server/repository" "next-terminal/server/utils" + + "github.com/labstack/echo/v4" + "gorm.io/gorm" ) -type AssetService struct { - assetRepository *repository.AssetRepository +type assetService struct { + baseService } -func NewAssetService(assetRepository *repository.AssetRepository) *AssetService { - return &AssetService{assetRepository: assetRepository} -} - -func (r AssetService) Encrypt() error { - items, err := r.assetRepository.FindAll() +func (s assetService) EncryptAll() error { + items, err := repository.AssetRepository.FindAll(context.TODO()) if err != nil { return err } @@ -24,19 +29,95 @@ func (r AssetService) Encrypt() error { if item.Encrypted { continue } - if err := r.assetRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil { + if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil { return err } - if err := r.assetRepository.UpdateById(&item, item.ID); err != nil { + if err := repository.AssetRepository.UpdateById(context.TODO(), &item, item.ID); err != nil { return err } } return nil } -func (r AssetService) CheckStatus(accessGatewayId string, ip string, port int) (active bool, err error) { +func (s assetService) Decrypt(item *model.Asset, password []byte) error { + if item.Encrypted { + if item.Password != "" && item.Password != "-" { + origData, err := base64.StdEncoding.DecodeString(item.Password) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, password) + if err != nil { + return err + } + item.Password = string(decryptedCBC) + } + if item.PrivateKey != "" && item.PrivateKey != "-" { + origData, err := base64.StdEncoding.DecodeString(item.PrivateKey) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, password) + if err != nil { + return err + } + item.PrivateKey = string(decryptedCBC) + } + if item.Passphrase != "" && item.Passphrase != "-" { + origData, err := base64.StdEncoding.DecodeString(item.Passphrase) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, password) + if err != nil { + return err + } + item.Passphrase = string(decryptedCBC) + } + } + return nil +} + +func (s assetService) Encrypt(item *model.Asset, password []byte) error { + if item.Password != "" && item.Password != "-" { + encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), password) + if err != nil { + return err + } + item.Password = base64.StdEncoding.EncodeToString(encryptedCBC) + } + if item.PrivateKey != "" && item.PrivateKey != "-" { + encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), password) + if err != nil { + return err + } + item.PrivateKey = base64.StdEncoding.EncodeToString(encryptedCBC) + } + if item.Passphrase != "" && item.Passphrase != "-" { + encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), password) + if err != nil { + return err + } + item.Passphrase = base64.StdEncoding.EncodeToString(encryptedCBC) + } + item.Encrypted = true + return nil +} + +func (s assetService) FindByIdAndDecrypt(c context.Context, id string) (model.Asset, error) { + asset, err := repository.AssetRepository.FindById(c, id) + if err != nil { + return model.Asset{}, err + } + if err := s.Decrypt(&asset, config.GlobalCfg.EncryptionPassword); err != nil { + return model.Asset{}, err + } + return asset, nil +} + +func (s assetService) CheckStatus(accessGatewayId string, ip string, port int) (active bool, err error) { if accessGatewayId != "" && accessGatewayId != "-" { - g, e1 := accessGatewayService.GetGatewayAndReconnectById(accessGatewayId) + g, e1 := GatewayService.GetGatewayAndReconnectById(accessGatewayId) if err != nil { return false, e1 } @@ -58,3 +139,118 @@ func (r AssetService) CheckStatus(accessGatewayId string, ip string, port int) ( } return active, err } + +func (s assetService) Create(m echo.Map) (model.Asset, error) { + + data, err := json.Marshal(m) + if err != nil { + return model.Asset{}, err + } + var item model.Asset + if err := json.Unmarshal(data, &item); err != nil { + return model.Asset{}, err + } + + item.ID = utils.UUID() + item.Created = utils.NowJsonTime() + item.Active = true + + return item, env.GetDB().Transaction(func(tx *gorm.DB) error { + c := s.Context(tx) + + if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil { + return err + } + if err := repository.AssetRepository.Create(c, &item); err != nil { + return err + } + + if err := repository.AssetRepository.UpdateAttributes(c, item.ID, item.Protocol, m); err != nil { + return err + } + + go func() { + active, _ := s.CheckStatus(item.AccessGatewayId, item.IP, item.Port) + + if item.Active != active { + _ = repository.AssetRepository.UpdateActiveById(context.TODO(), active, item.ID) + } + }() + return nil + }) +} + +func (s assetService) DeleteById(id string) error { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := s.Context(tx) + // 删除资产 + if err := repository.AssetRepository.DeleteById(c, id); err != nil { + return err + } + // 删除资产属性 + if err := repository.AssetRepository.DeleteAttrByAssetId(c, id); err != nil { + return err + } + // 删除资产与用户的关系 + if err := repository.ResourceSharerRepository.DeleteByResourceId(c, id); err != nil { + return err + } + return nil + }) +} + +func (s assetService) UpdateById(id string, m echo.Map) error { + data, err := json.Marshal(m) + if err != nil { + return err + } + var item model.Asset + if err := json.Unmarshal(data, &item); err != nil { + return err + } + + switch item.AccountType { + case "credential": + item.Username = "-" + item.Password = "-" + item.PrivateKey = "-" + item.Passphrase = "-" + case "private-key": + item.Password = "-" + item.CredentialId = "-" + if len(item.Username) == 0 { + item.Username = "-" + } + if len(item.Passphrase) == 0 { + item.Passphrase = "-" + } + case "custom": + item.PrivateKey = "-" + item.Passphrase = "-" + item.CredentialId = "-" + } + + if len(item.Tags) == 0 { + item.Tags = "-" + } + + if item.Description == "" { + item.Description = "-" + } + + if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil { + return err + } + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := s.Context(tx) + + if err := repository.AssetRepository.UpdateById(c, &item, id); err != nil { + return err + } + if err := repository.AssetRepository.UpdateAttributes(c, id, item.Protocol, m); err != nil { + return err + } + return nil + }) + +} diff --git a/server/service/backup.go b/server/service/backup.go new file mode 100644 index 0000000..4b6ca5b --- /dev/null +++ b/server/service/backup.go @@ -0,0 +1,326 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "strings" + + "next-terminal/server/config" + "next-terminal/server/constant" + "next-terminal/server/dto" + "next-terminal/server/env" + "next-terminal/server/global/security" + "next-terminal/server/repository" + "next-terminal/server/utils" + + "github.com/labstack/echo/v4" + "gorm.io/gorm" +) + +type backupService struct { + baseService +} + +func (service backupService) Export() (error, *dto.Backup) { + ctx := context.TODO() + users, err := repository.UserRepository.FindAll(ctx) + if err != nil { + return err, nil + } + for i := range users { + users[i].Password = "" + } + userGroups, err := repository.UserGroupRepository.FindAll(ctx) + if err != nil { + return err, nil + } + if len(userGroups) > 0 { + for i := range userGroups { + members, err := repository.UserGroupMemberRepository.FindUserIdsByUserGroupId(ctx, userGroups[i].ID) + if err != nil { + return err, nil + } + userGroups[i].Members = members + } + } + + storages, err := repository.StorageRepository.FindAll(ctx) + if err != nil { + return err, nil + } + + strategies, err := repository.StrategyRepository.FindAll(ctx) + if err != nil { + return err, nil + } + jobs, err := repository.JobRepository.FindAll(ctx) + if err != nil { + return err, nil + } + accessSecurities, err := repository.SecurityRepository.FindAll(ctx) + if err != nil { + return err, nil + } + accessGateways, err := repository.GatewayRepository.FindAll(ctx) + if err != nil { + return err, nil + } + commands, err := repository.CommandRepository.FindAll(ctx) + if err != nil { + return err, nil + } + credentials, err := repository.CredentialRepository.FindAll(ctx) + if err != nil { + return err, nil + } + if len(credentials) > 0 { + for i := range credentials { + if err := CredentialService.Decrypt(&credentials[i], config.GlobalCfg.EncryptionPassword); err != nil { + return err, nil + } + } + } + assets, err := repository.AssetRepository.FindAll(ctx) + if err != nil { + return err, nil + } + var assetMaps = make([]map[string]interface{}, 0) + if len(assets) > 0 { + for i := range assets { + asset := assets[i] + if err := AssetService.Decrypt(&asset, config.GlobalCfg.EncryptionPassword); err != nil { + return err, nil + } + attributeMap, err := repository.AssetRepository.FindAssetAttrMapByAssetId(ctx, asset.ID) + if err != nil { + return err, nil + } + itemMap := utils.StructToMap(asset) + for key := range attributeMap { + itemMap[key] = attributeMap[key] + } + itemMap["created"] = asset.Created.Format("2006-01-02 15:04:05") + assetMaps = append(assetMaps, itemMap) + } + } + + resourceSharers, err := repository.ResourceSharerRepository.FindAll(ctx) + if err != nil { + return err, nil + } + + backup := dto.Backup{ + Users: users, + UserGroups: userGroups, + Storages: storages, + Strategies: strategies, + Jobs: jobs, + AccessSecurities: accessSecurities, + AccessGateways: accessGateways, + Commands: commands, + Credentials: credentials, + Assets: assetMaps, + ResourceSharers: resourceSharers, + } + return nil, &backup +} + +func (service backupService) Import(backup *dto.Backup) error { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := service.Context(tx) + var userIdMapping = make(map[string]string) + if len(backup.Users) > 0 { + for _, item := range backup.Users { + oldId := item.ID + if repository.UserRepository.ExistByUsername(c, item.Username) { + delete(userIdMapping, oldId) + continue + } + newId := utils.UUID() + item.ID = newId + item.Password = utils.GenPassword() + if err := repository.UserRepository.Create(c, &item); err != nil { + return err + } + userIdMapping[oldId] = newId + } + } + + var userGroupIdMapping = make(map[string]string) + if len(backup.UserGroups) > 0 { + for _, item := range backup.UserGroups { + oldId := item.ID + + var members = make([]string, 0) + if len(item.Members) > 0 { + for _, member := range item.Members { + members = append(members, userIdMapping[member]) + } + } + + userGroup, err := UserGroupService.Create(item.Name, members) + if err != nil { + if errors.Is(constant.ErrNameAlreadyUsed, err) { + // 删除名称重复的用户组 + delete(userGroupIdMapping, oldId) + continue + } else { + return err + } + } + + userGroupIdMapping[oldId] = userGroup.ID + } + } + + if len(backup.Storages) > 0 { + for _, item := range backup.Storages { + owner := userIdMapping[item.Owner] + if owner == "" { + continue + } + item.ID = utils.UUID() + item.Owner = owner + item.Created = utils.NowJsonTime() + if err := repository.StorageRepository.Create(c, &item); err != nil { + return err + } + } + } + + var strategyIdMapping = make(map[string]string) + if len(backup.Strategies) > 0 { + for _, item := range backup.Strategies { + oldId := item.ID + newId := utils.UUID() + item.ID = newId + item.Created = utils.NowJsonTime() + if err := repository.StrategyRepository.Create(c, &item); err != nil { + return err + } + strategyIdMapping[oldId] = newId + } + } + + if len(backup.AccessSecurities) > 0 { + for _, item := range backup.AccessSecurities { + item.ID = utils.UUID() + if err := repository.SecurityRepository.Create(c, &item); err != nil { + return err + } + // 更新内存中的安全规则 + rule := &security.Security{ + ID: item.ID, + IP: item.IP, + Rule: item.Rule, + Priority: item.Priority, + } + security.GlobalSecurityManager.Add <- rule + } + } + + var accessGatewayIdMapping = make(map[string]string, 0) + if len(backup.AccessGateways) > 0 { + for _, item := range backup.AccessGateways { + oldId := item.ID + newId := utils.UUID() + item.ID = newId + item.Created = utils.NowJsonTime() + if err := repository.GatewayRepository.Create(c, &item); err != nil { + return err + } + accessGatewayIdMapping[oldId] = newId + } + } + + if len(backup.Commands) > 0 { + for _, item := range backup.Commands { + item.ID = utils.UUID() + item.Created = utils.NowJsonTime() + if err := repository.CommandRepository.Create(c, &item); err != nil { + return err + } + } + } + + var credentialIdMapping = make(map[string]string, 0) + if len(backup.Credentials) > 0 { + for _, item := range backup.Credentials { + oldId := item.ID + newId := utils.UUID() + item.ID = newId + if err := CredentialService.Create(&item); err != nil { + return err + } + credentialIdMapping[oldId] = newId + } + } + + var assetIdMapping = make(map[string]string, 0) + if len(backup.Assets) > 0 { + for _, m := range backup.Assets { + data, err := json.Marshal(m) + if err != nil { + return err + } + m := echo.Map{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + credentialId := m["credentialId"].(string) + accessGatewayId := m["accessGatewayId"].(string) + if credentialId != "" && credentialId != "-" { + m["credentialId"] = credentialIdMapping[credentialId] + } + if accessGatewayId != "" && accessGatewayId != "-" { + m["accessGatewayId"] = accessGatewayIdMapping[accessGatewayId] + } + + oldId := m["id"].(string) + asset, err := AssetService.Create(m) + if err != nil { + return err + } + + assetIdMapping[oldId] = asset.ID + } + } + + if len(backup.ResourceSharers) > 0 { + for _, item := range backup.ResourceSharers { + + userGroupId := userGroupIdMapping[item.UserGroupId] + userId := userIdMapping[item.UserId] + strategyId := strategyIdMapping[item.StrategyId] + resourceId := assetIdMapping[item.ResourceId] + + if err := repository.ResourceSharerRepository.AddSharerResources(userGroupId, userId, strategyId, item.ResourceType, []string{resourceId}); err != nil { + return err + } + } + } + + if len(backup.Jobs) > 0 { + for _, item := range backup.Jobs { + if item.Func == constant.FuncCheckAssetStatusJob { + continue + } + + resourceIds := strings.Split(item.ResourceIds, ",") + if len(resourceIds) > 0 { + var newResourceIds = make([]string, 0) + for _, resourceId := range resourceIds { + newResourceIds = append(newResourceIds, assetIdMapping[resourceId]) + } + item.ResourceIds = strings.Join(newResourceIds, ",") + } + if err := JobService.Create(&item); err != nil { + return err + } + } + } + return nil + }) + +} diff --git a/server/service/base.go b/server/service/base.go new file mode 100644 index 0000000..90a675a --- /dev/null +++ b/server/service/base.go @@ -0,0 +1,16 @@ +package service + +import ( + "context" + + "next-terminal/server/constant" + + "gorm.io/gorm" +) + +type baseService struct { +} + +func (service baseService) Context(db *gorm.DB) context.Context { + return context.WithValue(context.TODO(), constant.DB, db) +} diff --git a/server/service/credential.go b/server/service/credential.go index bc766b7..049c507 100644 --- a/server/service/credential.go +++ b/server/service/credential.go @@ -1,20 +1,21 @@ package service import ( + "context" + "encoding/base64" + + "next-terminal/server/model" + "next-terminal/server/utils" + "next-terminal/server/config" "next-terminal/server/repository" ) -type CredentialService struct { - credentialRepository *repository.CredentialRepository +type credentialService struct { } -func NewCredentialService(credentialRepository *repository.CredentialRepository) *CredentialService { - return &CredentialService{credentialRepository: credentialRepository} -} - -func (r CredentialService) Encrypt() error { - items, err := r.credentialRepository.FindAll() +func (s credentialService) EncryptAll() error { + items, err := repository.CredentialRepository.FindAll(context.TODO()) if err != nil { return err } @@ -23,12 +24,96 @@ func (r CredentialService) Encrypt() error { if item.Encrypted { continue } - if err := r.credentialRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil { + if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil { return err } - if err := r.credentialRepository.UpdateById(&item, item.ID); err != nil { + if err := repository.CredentialRepository.UpdateById(context.TODO(), &item, item.ID); err != nil { return err } } return nil } + +func (s credentialService) Encrypt(item *model.Credential, password []byte) error { + if item.Password != "-" { + encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), password) + if err != nil { + return err + } + item.Password = base64.StdEncoding.EncodeToString(encryptedCBC) + } + if item.PrivateKey != "-" { + encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), password) + if err != nil { + return err + } + item.PrivateKey = base64.StdEncoding.EncodeToString(encryptedCBC) + } + if item.Passphrase != "-" { + encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), password) + if err != nil { + return err + } + item.Passphrase = base64.StdEncoding.EncodeToString(encryptedCBC) + } + item.Encrypted = true + return nil +} + +func (s credentialService) Decrypt(item *model.Credential, password []byte) error { + if item.Encrypted { + if item.Password != "" && item.Password != "-" { + origData, err := base64.StdEncoding.DecodeString(item.Password) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, password) + if err != nil { + return err + } + item.Password = string(decryptedCBC) + } + if item.PrivateKey != "" && item.PrivateKey != "-" { + origData, err := base64.StdEncoding.DecodeString(item.PrivateKey) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, password) + if err != nil { + return err + } + item.PrivateKey = string(decryptedCBC) + } + if item.Passphrase != "" && item.Passphrase != "-" { + origData, err := base64.StdEncoding.DecodeString(item.Passphrase) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, password) + if err != nil { + return err + } + item.Passphrase = string(decryptedCBC) + } + } + return nil +} + +func (s credentialService) FindByIdAndDecrypt(c context.Context, id string) (o model.Credential, err error) { + credential, err := repository.CredentialRepository.FindById(c, id) + if err != nil { + return o, err + } + if err := s.Decrypt(&credential, config.GlobalCfg.EncryptionPassword); err != nil { + return o, err + } + return credential, nil +} + +func (s credentialService) Create(item *model.Credential) error { + // 加密密码之后进行存储 + if err := s.Encrypt(item, config.GlobalCfg.EncryptionPassword); err != nil { + return err + } + return repository.CredentialRepository.Create(context.TODO(), item) +} diff --git a/server/service/definitions.go b/server/service/definitions.go deleted file mode 100644 index 8342957..0000000 --- a/server/service/definitions.go +++ /dev/null @@ -1,5 +0,0 @@ -package service - -var ( - accessGatewayService *AccessGatewayService -) diff --git a/server/service/gateway.go b/server/service/gateway.go new file mode 100644 index 0000000..2a9718c --- /dev/null +++ b/server/service/gateway.go @@ -0,0 +1,70 @@ +package service + +import ( + "context" + + "next-terminal/server/global/gateway" + "next-terminal/server/log" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/term" +) + +type gatewayService struct{} + +func (r gatewayService) GetGatewayAndReconnectById(accessGatewayId string) (g *gateway.Gateway, err error) { + g = gateway.GlobalGatewayManager.GetById(accessGatewayId) + if g == nil || !g.Connected { + accessGateway, err := repository.GatewayRepository.FindById(context.TODO(), accessGatewayId) + if err != nil { + return nil, err + } + g = r.ReConnect(&accessGateway) + } + return g, nil +} + +func (r gatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gateway, err error) { + g = gateway.GlobalGatewayManager.GetById(accessGatewayId) + if g == nil { + accessGateway, err := repository.GatewayRepository.FindById(context.TODO(), accessGatewayId) + if err != nil { + return nil, err + } + g = r.ReConnect(&accessGateway) + } + return g, nil +} + +func (r gatewayService) ReConnectAll() error { + gateways, err := repository.GatewayRepository.FindAll(context.TODO()) + if err != nil { + return err + } + if len(gateways) > 0 { + for i := range gateways { + r.ReConnect(&gateways[i]) + } + } + + return nil +} + +func (r gatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway { + log.Debugf("重建接入网关「%v」中...", m.Name) + r.DisconnectById(m.ID) + sshClient, err := term.NewSshClient(m.IP, m.Port, m.Username, m.Password, m.PrivateKey, m.Passphrase) + var g *gateway.Gateway + if err != nil { + g = gateway.NewGateway(m.ID, false, err.Error(), nil) + } else { + g = gateway.NewGateway(m.ID, true, "", sshClient) + } + gateway.GlobalGatewayManager.Add <- g + log.Debugf("重建接入网关「%v」完成", m.Name) + return g +} + +func (r gatewayService) DisconnectById(accessGatewayId string) { + gateway.GlobalGatewayManager.Del <- accessGatewayId +} diff --git a/server/service/job.go b/server/service/job.go index 85126a7..b36b1d0 100644 --- a/server/service/job.go +++ b/server/service/job.go @@ -1,42 +1,27 @@ package service import ( - "encoding/json" + "context" "errors" - "fmt" - "strings" - "time" "next-terminal/server/constant" "next-terminal/server/global/cron" "next-terminal/server/log" "next-terminal/server/model" "next-terminal/server/repository" - "next-terminal/server/term" "next-terminal/server/utils" - - "gorm.io/gorm" ) -type JobService struct { - jobRepository *repository.JobRepository - jobLogRepository *repository.JobLogRepository - assetRepository *repository.AssetRepository - credentialRepository *repository.CredentialRepository - assetService *AssetService +type jobService struct { } -func NewJobService(jobRepository *repository.JobRepository, jobLogRepository *repository.JobLogRepository, assetRepository *repository.AssetRepository, credentialRepository *repository.CredentialRepository, assetService *AssetService) *JobService { - return &JobService{jobRepository: jobRepository, jobLogRepository: jobLogRepository, assetRepository: assetRepository, credentialRepository: credentialRepository, assetService: assetService} -} - -func (r JobService) ChangeStatusById(id, status string) error { - job, err := r.jobRepository.FindById(id) +func (r jobService) ChangeStatusById(id, status string) error { + job, err := repository.JobRepository.FindById(context.TODO(), id) if err != nil { return err } if status == constant.JobStatusRunning { - j, err := getJob(&job, &r) + j, err := getJob(&job) if err != nil { return err } @@ -48,249 +33,38 @@ func (r JobService) ChangeStatusById(id, status string) error { jobForUpdate := model.Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)} - return r.jobRepository.UpdateById(&jobForUpdate) + return repository.JobRepository.UpdateById(context.TODO(), &jobForUpdate) } else { cron.GlobalCron.Remove(cron.JobId(job.CronJobId)) log.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(cron.GlobalCron.Entries())) jobForUpdate := model.Job{ID: id, Status: constant.JobStatusNotRunning} - return r.jobRepository.UpdateById(&jobForUpdate) + return repository.JobRepository.UpdateById(context.TODO(), &jobForUpdate) } } -func getJob(j *model.Job, jobService *JobService) (job cron.Job, err error) { +func getJob(j *model.Job) (job cron.Job, err error) { switch j.Func { case constant.FuncCheckAssetStatusJob: job = CheckAssetStatusJob{ - ID: j.ID, - Mode: j.Mode, - ResourceIds: j.ResourceIds, - Metadata: j.Metadata, - jobService: jobService, - assetService: jobService.assetService, + ID: j.ID, + Mode: j.Mode, + ResourceIds: j.ResourceIds, + Metadata: j.Metadata, } case constant.FuncShellJob: - job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata, jobService: jobService} + job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata} default: return nil, errors.New("未识别的任务") } return job, err } -type CheckAssetStatusJob struct { - ID string - Mode string - ResourceIds string - Metadata string - jobService *JobService - assetService *AssetService -} - -func (r CheckAssetStatusJob) Run() { - if r.ID == "" { - return - } - - var assets []model.Asset - if r.Mode == constant.JobModeAll { - assets, _ = r.jobService.assetRepository.FindAll() - } else { - assets, _ = r.jobService.assetRepository.FindByIds(strings.Split(r.ResourceIds, ",")) - } - - if len(assets) == 0 { - return - } - - msgChan := make(chan string) - for i := range assets { - asset := assets[i] - go func() { - t1 := time.Now() - var ( - msg string - ip = asset.IP - port = asset.Port - ) - active, err := r.assetService.CheckStatus(asset.AccessGatewayId, ip, port) - - elapsed := time.Since(t1) - if err == nil { - msg = fmt.Sprintf("资产「%v」存活状态检测完成,存活「%v」,耗时「%v」", asset.Name, active, elapsed) - } else { - msg = fmt.Sprintf("资产「%v」存活状态检测完成,存活「%v」,耗时「%v」,原因: %v", asset.Name, active, elapsed, err.Error()) - } - - _ = r.jobService.assetRepository.UpdateActiveById(active, asset.ID) - log.Infof(msg) - msgChan <- msg - }() - } - - var message = "" - for i := 0; i < len(assets); i++ { - message += <-msgChan + "\n" - } - - _ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID) - jobLog := model.JobLog{ - ID: utils.UUID(), - JobId: r.ID, - Timestamp: utils.NowJsonTime(), - Message: message, - } - - _ = r.jobService.jobLogRepository.Create(&jobLog) -} - -type ShellJob struct { - ID string - Mode string - ResourceIds string - Metadata string - jobService *JobService -} - -type MetadataShell struct { - Shell string -} - -func (r ShellJob) Run() { - if r.ID == "" { - return - } - - var assets []model.Asset - if r.Mode == constant.JobModeAll { - assets, _ = r.jobService.assetRepository.FindByProtocol("ssh") - } else { - assets, _ = r.jobService.assetRepository.FindByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ",")) - } - - if len(assets) == 0 { - return - } - - var metadataShell MetadataShell - err := json.Unmarshal([]byte(r.Metadata), &metadataShell) - if err != nil { - log.Errorf("JSON数据解析失败 %v", err) - return - } - - msgChan := make(chan string) - for i := range assets { - asset, err := r.jobService.assetRepository.FindByIdAndDecrypt(assets[i].ID) - if err != nil { - msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询数据异常「%v」", assets[i].Name, err.Error()) - return - } - - var ( - username = asset.Username - password = asset.Password - privateKey = asset.PrivateKey - passphrase = asset.Passphrase - ip = asset.IP - port = asset.Port - ) - - if asset.AccountType == "credential" { - credential, err := r.jobService.credentialRepository.FindByIdAndDecrypt(asset.CredentialId) - if err != nil { - msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询授权凭证数据异常「%v」", assets[i].Name, err.Error()) - return - } - - if credential.Type == constant.Custom { - username = credential.Username - password = credential.Password - } else { - username = credential.Username - privateKey = credential.PrivateKey - passphrase = credential.Passphrase - } - } - - go func() { - t1 := time.Now() - result, err := exec(metadataShell.Shell, asset.AccessGatewayId, ip, port, username, password, privateKey, passphrase) - elapsed := time.Since(t1) - var msg string - if err != nil { - if errors.Is(gorm.ErrRecordNotFound, err) { - msg = fmt.Sprintf("资产「%v」Shell执行失败,请检查资产所关联接入网关是否存在,耗时「%v」", asset.Name, elapsed) - } else { - msg = fmt.Sprintf("资产「%v」Shell执行失败,错误内容为:「%v」,耗时「%v」", asset.Name, err.Error(), elapsed) - } - log.Infof(msg) - } else { - msg = fmt.Sprintf("资产「%v」Shell执行成功,返回值「%v」,耗时「%v」", asset.Name, result, elapsed) - log.Infof(msg) - } - - msgChan <- msg - }() - } - - var message = "" - for i := 0; i < len(assets); i++ { - message += <-msgChan + "\n" - } - - _ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID) - jobLog := model.JobLog{ - ID: utils.UUID(), - JobId: r.ID, - Timestamp: utils.NowJsonTime(), - Message: message, - } - - _ = r.jobService.jobLogRepository.Create(&jobLog) -} - -func exec(shell, accessGatewayId, ip string, port int, username, password, privateKey, passphrase string) (string, error) { - if accessGatewayId != "" && accessGatewayId != "-" { - g, err := accessGatewayService.GetGatewayAndReconnectById(accessGatewayId) - if err != nil { - return "", err - } - uuid := utils.UUID() - exposedIP, exposedPort, err := g.OpenSshTunnel(uuid, ip, port) - if err != nil { - return "", err - } - defer g.CloseSshTunnel(uuid) - return ExecCommandBySSH(shell, exposedIP, exposedPort, username, password, privateKey, passphrase) - } else { - return ExecCommandBySSH(shell, ip, port, username, password, privateKey, passphrase) - } -} - -func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) { - sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase) - if err != nil { - return "", err - } - - session, err := sshClient.NewSession() - if err != nil { - return "", err - } - defer session.Close() - //执行远程命令 - combo, err := session.CombinedOutput(cmd) - if err != nil { - return "", err - } - return string(combo), nil -} - -func (r JobService) ExecJobById(id string) (err error) { - job, err := r.jobRepository.FindById(id) +func (r jobService) ExecJobById(id string) (err error) { + job, err := repository.JobRepository.FindById(context.TODO(), id) if err != nil { return err } - j, err := getJob(&job, &r) + j, err := getJob(&job) if err != nil { return err } @@ -298,8 +72,8 @@ func (r JobService) ExecJobById(id string) (err error) { return nil } -func (r JobService) InitJob() error { - jobs, _ := r.jobRepository.FindAll() +func (r jobService) InitJob() error { + jobs, _ := repository.JobRepository.FindAll(context.TODO()) if len(jobs) == 0 { job := model.Job{ ID: utils.UUID(), @@ -311,7 +85,7 @@ func (r JobService) InitJob() error { Created: utils.NowJsonTime(), Updated: utils.NowJsonTime(), } - if err := r.jobRepository.Create(&job); err != nil { + if err := repository.JobRepository.Create(context.TODO(), &job); err != nil { return err } log.Debugf("创建计划任务「%v」cron「%v」", job.Name, job.Cron) @@ -329,10 +103,10 @@ func (r JobService) InitJob() error { return nil } -func (r JobService) Create(o *model.Job) (err error) { +func (r jobService) Create(o *model.Job) (err error) { if o.Status == constant.JobStatusRunning { - j, err := getJob(o, &r) + j, err := getJob(o) if err != nil { return err } @@ -343,11 +117,11 @@ func (r JobService) Create(o *model.Job) (err error) { o.CronJobId = int(jobId) } - return r.jobRepository.Create(o) + return repository.JobRepository.Create(context.TODO(), o) } -func (r JobService) DeleteJobById(id string) error { - job, err := r.jobRepository.FindById(id) +func (r jobService) DeleteJobById(id string) error { + job, err := repository.JobRepository.FindById(context.TODO(), id) if err != nil { return err } @@ -356,11 +130,11 @@ func (r JobService) DeleteJobById(id string) error { return err } } - return r.jobRepository.DeleteJobById(id) + return repository.JobRepository.DeleteJobById(context.TODO(), id) } -func (r JobService) UpdateById(m *model.Job) error { - if err := r.jobRepository.UpdateById(m); err != nil { +func (r jobService) UpdateById(m *model.Job) error { + if err := repository.JobRepository.UpdateById(context.TODO(), m); err != nil { return err } diff --git a/server/service/job_check_asset_status.go b/server/service/job_check_asset_status.go new file mode 100644 index 0000000..94ca101 --- /dev/null +++ b/server/service/job_check_asset_status.go @@ -0,0 +1,78 @@ +package service + +import ( + "context" + "fmt" + "strings" + "time" + + "next-terminal/server/constant" + "next-terminal/server/log" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" +) + +type CheckAssetStatusJob struct { + ID string + Mode string + ResourceIds string + Metadata string +} + +func (r CheckAssetStatusJob) Run() { + if r.ID == "" { + return + } + + var assets []model.Asset + if r.Mode == constant.JobModeAll { + assets, _ = repository.AssetRepository.FindAll(context.TODO()) + } else { + assets, _ = repository.AssetRepository.FindByIds(context.TODO(), strings.Split(r.ResourceIds, ",")) + } + + if len(assets) == 0 { + return + } + + msgChan := make(chan string) + for i := range assets { + asset := assets[i] + go func() { + t1 := time.Now() + var ( + msg string + ip = asset.IP + port = asset.Port + ) + active, err := AssetService.CheckStatus(asset.AccessGatewayId, ip, port) + + elapsed := time.Since(t1) + if err == nil { + msg = fmt.Sprintf("资产「%v」存活状态检测完成,存活「%v」,耗时「%v」", asset.Name, active, elapsed) + } else { + msg = fmt.Sprintf("资产「%v」存活状态检测完成,存活「%v」,耗时「%v」,原因: %v", asset.Name, active, elapsed, err.Error()) + } + + _ = repository.AssetRepository.UpdateActiveById(context.TODO(), active, asset.ID) + log.Infof(msg) + msgChan <- msg + }() + } + + var message = "" + for i := 0; i < len(assets); i++ { + message += <-msgChan + "\n" + } + + _ = repository.JobRepository.UpdateLastUpdatedById(context.TODO(), r.ID) + jobLog := model.JobLog{ + ID: utils.UUID(), + JobId: r.ID, + Timestamp: utils.NowJsonTime(), + Message: message, + } + + _ = repository.JobLogRepository.Create(context.TODO(), &jobLog) +} diff --git a/server/service/job_exec_shell.go b/server/service/job_exec_shell.go new file mode 100644 index 0000000..02ce58d --- /dev/null +++ b/server/service/job_exec_shell.go @@ -0,0 +1,163 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "next-terminal/server/constant" + "next-terminal/server/log" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/term" + "next-terminal/server/utils" + + "gorm.io/gorm" +) + +type ShellJob struct { + ID string + Mode string + ResourceIds string + Metadata string +} + +type MetadataShell struct { + Shell string +} + +func (r ShellJob) Run() { + if r.ID == "" { + return + } + + var assets []model.Asset + if r.Mode == constant.JobModeAll { + assets, _ = repository.AssetRepository.FindByProtocol(context.TODO(), "ssh") + } else { + assets, _ = repository.AssetRepository.FindByProtocolAndIds(context.TODO(), "ssh", strings.Split(r.ResourceIds, ",")) + } + + if len(assets) == 0 { + return + } + + var metadataShell MetadataShell + err := json.Unmarshal([]byte(r.Metadata), &metadataShell) + if err != nil { + log.Errorf("JSON数据解析失败 %v", err) + return + } + + msgChan := make(chan string) + for i := range assets { + asset, err := AssetService.FindByIdAndDecrypt(context.TODO(), assets[i].ID) + if err != nil { + msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询数据异常「%v」", assets[i].Name, err.Error()) + return + } + + var ( + username = asset.Username + password = asset.Password + privateKey = asset.PrivateKey + passphrase = asset.Passphrase + ip = asset.IP + port = asset.Port + ) + + if asset.AccountType == "credential" { + credential, err := CredentialService.FindByIdAndDecrypt(context.TODO(), asset.CredentialId) + if err != nil { + msgChan <- fmt.Sprintf("资产「%v」Shell执行失败,查询授权凭证数据异常「%v」", assets[i].Name, err.Error()) + return + } + + if credential.Type == constant.Custom { + username = credential.Username + password = credential.Password + } else { + username = credential.Username + privateKey = credential.PrivateKey + passphrase = credential.Passphrase + } + } + + go func() { + t1 := time.Now() + result, err := exec(metadataShell.Shell, asset.AccessGatewayId, ip, port, username, password, privateKey, passphrase) + elapsed := time.Since(t1) + var msg string + if err != nil { + if errors.Is(gorm.ErrRecordNotFound, err) { + msg = fmt.Sprintf("资产「%v」Shell执行失败,请检查资产所关联接入网关是否存在,耗时「%v」", asset.Name, elapsed) + } else { + msg = fmt.Sprintf("资产「%v」Shell执行失败,错误内容为:「%v」,耗时「%v」", asset.Name, err.Error(), elapsed) + } + log.Infof(msg) + } else { + msg = fmt.Sprintf("资产「%v」Shell执行成功,返回值「%v」,耗时「%v」", asset.Name, result, elapsed) + log.Infof(msg) + } + + msgChan <- msg + }() + } + + var message = "" + for i := 0; i < len(assets); i++ { + message += <-msgChan + "\n" + } + + _ = repository.JobRepository.UpdateLastUpdatedById(context.TODO(), r.ID) + jobLog := model.JobLog{ + ID: utils.UUID(), + JobId: r.ID, + Timestamp: utils.NowJsonTime(), + Message: message, + } + + _ = repository.JobLogRepository.Create(context.TODO(), &jobLog) +} + +func exec(shell, accessGatewayId, ip string, port int, username, password, privateKey, passphrase string) (string, error) { + if accessGatewayId != "" && accessGatewayId != "-" { + g, err := GatewayService.GetGatewayAndReconnectById(accessGatewayId) + if err != nil { + return "", err + } + uuid := utils.UUID() + exposedIP, exposedPort, err := g.OpenSshTunnel(uuid, ip, port) + if err != nil { + return "", err + } + defer g.CloseSshTunnel(uuid) + return ExecCommandBySSH(shell, exposedIP, exposedPort, username, password, privateKey, passphrase) + } else { + return ExecCommandBySSH(shell, ip, port, username, password, privateKey, passphrase) + } +} + +func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) { + sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase) + if err != nil { + return "", err + } + + session, err := sshClient.NewSession() + if err != nil { + return "", err + } + defer func() { + _ = session.Close() + }() + //执行远程命令 + combo, err := session.CombinedOutput(cmd) + if err != nil { + return "", err + } + return string(combo), nil +} diff --git a/server/service/mail.go b/server/service/mail.go index a41abcb..f5a14c5 100644 --- a/server/service/mail.go +++ b/server/service/mail.go @@ -1,6 +1,7 @@ package service import ( + "context" "net/smtp" "next-terminal/server/constant" @@ -10,16 +11,11 @@ import ( "github.com/jordan-wright/email" ) -type MailService struct { - propertyRepository *repository.PropertyRepository +type mailService struct { } -func NewMailService(propertyRepository *repository.PropertyRepository) *MailService { - return &MailService{propertyRepository: propertyRepository} -} - -func (r MailService) SendMail(to, subject, text string) { - propertiesMap := r.propertyRepository.FindAllMap() +func (r mailService) SendMail(to, subject, text string) { + propertiesMap := repository.PropertyRepository.FindAllMap(context.TODO()) host := propertiesMap[constant.MailHost] port := propertiesMap[constant.MailPort] username := propertiesMap[constant.MailUsername] diff --git a/server/service/property.go b/server/service/property.go index 33dc741..a659180 100644 --- a/server/service/property.go +++ b/server/service/property.go @@ -1,28 +1,31 @@ package service import ( + "context" + "errors" + "fmt" + + "next-terminal/server/env" "next-terminal/server/guacd" "next-terminal/server/model" "next-terminal/server/repository" + + "gorm.io/gorm" ) -type PropertyService struct { - propertyRepository *repository.PropertyRepository +type propertyService struct { + baseService } -func NewPropertyService(propertyRepository *repository.PropertyRepository) *PropertyService { - return &PropertyService{propertyRepository: propertyRepository} -} - -func (r PropertyService) InitProperties() error { - propertyMap := r.propertyRepository.FindAllMap() +func (service propertyService) InitProperties() error { + propertyMap := repository.PropertyRepository.FindAllMap(context.TODO()) if len(propertyMap[guacd.EnableRecording]) == 0 { property := model.Property{ Name: guacd.EnableRecording, Value: "true", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -32,7 +35,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.CreateRecordingPath, Value: "true", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -42,7 +45,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.FontName, Value: "menlo", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -52,7 +55,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.FontSize, Value: "12", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -62,7 +65,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.ColorScheme, Value: "gray-black", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -72,7 +75,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.EnableWallpaper, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -82,7 +85,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.EnableTheming, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -92,7 +95,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.EnableFontSmoothing, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -102,7 +105,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.EnableFullWindowDrag, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -112,7 +115,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.EnableDesktopComposition, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -122,7 +125,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.EnableMenuAnimations, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -132,7 +135,7 @@ func (r PropertyService) InitProperties() error { Name: guacd.DisableBitmapCaching, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } @@ -142,39 +145,71 @@ func (r PropertyService) InitProperties() error { Name: guacd.DisableOffscreenCaching, Value: "false", } - if err := r.propertyRepository.Create(&property); err != nil { + if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil { return err } } - if len(propertyMap[guacd.DisableGlyphCaching]) == 0 { - property := model.Property{ - Name: guacd.DisableGlyphCaching, - Value: "true", - } - if err := r.propertyRepository.Create(&property); err != nil { + if len(propertyMap[guacd.DisableGlyphCaching]) > 0 { + if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DisableGlyphCaching); err != nil { return err } } return nil } -func (r PropertyService) DeleteDeprecatedProperty() error { - propertyMap := r.propertyRepository.FindAllMap() +func (service propertyService) DeleteDeprecatedProperty() error { + propertyMap := repository.PropertyRepository.FindAllMap(context.TODO()) if propertyMap[guacd.EnableDrive] != "" { - if err := r.propertyRepository.DeleteByName(guacd.DriveName); err != nil { + if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DriveName); err != nil { return err } } if propertyMap[guacd.DrivePath] != "" { - if err := r.propertyRepository.DeleteByName(guacd.DrivePath); err != nil { + if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DrivePath); err != nil { return err } } if propertyMap[guacd.DriveName] != "" { - if err := r.propertyRepository.DeleteByName(guacd.DriveName); err != nil { + if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DriveName); err != nil { return err } } return nil } + +func (service propertyService) Update(item map[string]interface{}) error { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := service.Context(tx) + for key := range item { + value := fmt.Sprintf("%v", item[key]) + if value == "" { + value = "-" + } + + property := model.Property{ + Name: key, + Value: value, + } + + if key == "enable-ldap" && value == "false" { + if err := UserService.DeleteALlLdapUser(c); err != nil { + return err + } + } + + _, err := repository.PropertyRepository.FindByName(c, key) + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { + if err := repository.PropertyRepository.Create(c, &property); err != nil { + return err + } + } else { + if err := repository.PropertyRepository.UpdateByName(c, &property, key); err != nil { + return err + } + } + } + return nil + }) + +} diff --git a/server/service/security.go b/server/service/security.go new file mode 100644 index 0000000..18de5cd --- /dev/null +++ b/server/service/security.go @@ -0,0 +1,32 @@ +package service + +import ( + "context" + + "next-terminal/server/global/security" + "next-terminal/server/repository" +) + +type securityService struct{} + +func (service securityService) ReloadAccessSecurity() error { + rules, err := repository.SecurityRepository.FindAll(context.TODO()) + if err != nil { + return err + } + if len(rules) > 0 { + // 先清空 + security.GlobalSecurityManager.Clear() + // 再添加到全局的安全管理器中 + for i := 0; i < len(rules); i++ { + rule := &security.Security{ + ID: rules[i].ID, + IP: rules[i].IP, + Rule: rules[i].Rule, + Priority: rules[i].Priority, + } + security.GlobalSecurityManager.Add <- rule + } + } + return nil +} diff --git a/server/service/session.go b/server/service/session.go index 02bc7fe..3f91b27 100644 --- a/server/service/session.go +++ b/server/service/session.go @@ -1,45 +1,55 @@ package service import ( + "context" + "encoding/base64" + "errors" + "strconv" + "sync" + + "next-terminal/server/config" "next-terminal/server/constant" + "next-terminal/server/env" + "next-terminal/server/global/session" + "next-terminal/server/guacd" + "next-terminal/server/log" "next-terminal/server/model" "next-terminal/server/repository" "next-terminal/server/utils" + + "github.com/gorilla/websocket" + "gorm.io/gorm" ) -type SessionService struct { - sessionRepository *repository.SessionRepository +type sessionService struct { + baseService } -func NewSessionService(sessionRepository *repository.SessionRepository) *SessionService { - return &SessionService{sessionRepository: sessionRepository} -} - -func (r SessionService) FixSessionState() error { - sessions, err := r.sessionRepository.FindByStatus(constant.Connected) +func (service sessionService) FixSessionState() error { + sessions, err := repository.SessionRepository.FindByStatus(context.TODO(), constant.Connected) if err != nil { return err } if len(sessions) > 0 { for i := range sessions { - session := model.Session{ + s := model.Session{ Status: constant.Disconnected, DisconnectedTime: utils.NowJsonTime(), } - _ = r.sessionRepository.UpdateById(&session, sessions[i].ID) + _ = repository.SessionRepository.UpdateById(context.TODO(), &s, sessions[i].ID) } } return nil } -func (r SessionService) EmptyPassword() error { - return r.sessionRepository.EmptyPassword() +func (service sessionService) EmptyPassword() error { + return repository.SessionRepository.EmptyPassword(context.TODO()) } -func (r SessionService) ClearOfflineSession() error { - sessions, err := r.sessionRepository.FindByStatus(constant.Disconnected) +func (service sessionService) ClearOfflineSession() error { + sessions, err := repository.SessionRepository.FindByStatus(context.TODO(), constant.Disconnected) if err != nil { return err } @@ -47,11 +57,11 @@ func (r SessionService) ClearOfflineSession() error { for i := range sessions { sessionIds = append(sessionIds, sessions[i].ID) } - return r.sessionRepository.DeleteByIds(sessionIds) + return repository.SessionRepository.DeleteByIds(context.TODO(), sessionIds) } -func (r SessionService) ReviewedAll() error { - sessions, err := r.sessionRepository.FindAllUnReviewed() +func (service sessionService) ReviewedAll() error { + sessions, err := repository.SessionRepository.FindAllUnReviewed(context.TODO()) if err != nil { return err } @@ -60,13 +70,13 @@ func (r SessionService) ReviewedAll() error { for i := range sessions { sessionIds = append(sessionIds, sessions[i].ID) if i >= 100 && i%100 == 0 { - if err := r.sessionRepository.UpdateReadByIds(true, sessionIds); err != nil { + if err := repository.SessionRepository.UpdateReadByIds(context.TODO(), true, sessionIds); err != nil { return err } sessionIds = nil } else { if i == total-1 { - if err := r.sessionRepository.UpdateReadByIds(true, sessionIds); err != nil { + if err := repository.SessionRepository.UpdateReadByIds(context.TODO(), true, sessionIds); err != nil { return err } } @@ -75,3 +85,272 @@ func (r SessionService) ReviewedAll() error { } return nil } + +var mutex sync.Mutex + +func (service sessionService) CloseSessionById(sessionId string, code int, reason string) { + mutex.Lock() + defer mutex.Unlock() + nextSession := session.GlobalSessionManager.GetById(sessionId) + if nextSession != nil { + log.Debugf("[%v] 会话关闭,原因:%v", sessionId, reason) + service.WriteCloseMessage(nextSession.WebSocket, nextSession.Mode, code, reason) + + if nextSession.Observer != nil { + obs := nextSession.Observer.All() + for _, ob := range obs { + service.WriteCloseMessage(ob.WebSocket, ob.Mode, code, reason) + log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID) + } + } + } + session.GlobalSessionManager.Del <- sessionId + + service.DisDBSess(sessionId, code, reason) +} + +func (service sessionService) WriteCloseMessage(ws *websocket.Conn, mode string, code int, reason string) { + switch mode { + case constant.Guacd: + if ws != nil { + err := guacd.NewInstruction("error", "", strconv.Itoa(code)) + _ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) + disconnect := guacd.NewInstruction("disconnect") + _ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) + } + case constant.Naive: + if ws != nil { + msg := `0` + reason + _ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) + } + case constant.Terminal: + // 这里是关闭观察者的ssh会话 + if ws != nil { + msg := `0` + reason + _ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) + } + } +} + +func (service sessionService) DisDBSess(sessionId string, code int, reason string) { + _ = env.GetDB().Transaction(func(tx *gorm.DB) error { + c := service.Context(tx) + s, err := repository.SessionRepository.FindById(c, sessionId) + if err != nil { + return err + } + + if s.Status == constant.Disconnected { + return err + } + + if s.Status == constant.Connecting { + // 会话还未建立成功,无需保留数据 + if err := repository.SessionRepository.DeleteById(c, sessionId); err != nil { + return err + } + return nil + } + + ss := model.Session{} + ss.ID = sessionId + ss.Status = constant.Disconnected + ss.DisconnectedTime = utils.NowJsonTime() + ss.Code = code + ss.Message = reason + ss.Password = "-" + ss.PrivateKey = "-" + ss.Passphrase = "-" + + if err := repository.SessionRepository.UpdateById(c, &ss, sessionId); err != nil { + return err + } + + return nil + }) +} + +func (service sessionService) FindByIdAndDecrypt(c context.Context, id string) (o model.Session, err error) { + sess, err := repository.SessionRepository.FindById(c, id) + if err != nil { + return o, err + } + if err := service.Decrypt(&sess); err != nil { + return o, err + } + return sess, nil +} + +func (service sessionService) Decrypt(item *model.Session) error { + if item.Password != "" && item.Password != "-" { + origData, err := base64.StdEncoding.DecodeString(item.Password) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword) + if err != nil { + return err + } + item.Password = string(decryptedCBC) + } + if item.PrivateKey != "" && item.PrivateKey != "-" { + origData, err := base64.StdEncoding.DecodeString(item.PrivateKey) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword) + if err != nil { + return err + } + item.PrivateKey = string(decryptedCBC) + } + if item.Passphrase != "" && item.Passphrase != "-" { + origData, err := base64.StdEncoding.DecodeString(item.Passphrase) + if err != nil { + return err + } + decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword) + if err != nil { + return err + } + item.Passphrase = string(decryptedCBC) + } + return nil +} + +func (service sessionService) Create(clientIp, assetId, mode string, user *model.User) (*model.Session, error) { + asset, err := repository.AssetRepository.FindById(context.TODO(), assetId) + if err != nil { + return nil, err + } + + var ( + upload = "1" + download = "1" + _delete = "1" + rename = "1" + edit = "1" + fileSystem = "1" + _copy = "1" + paste = "1" + ) + + if asset.Owner != user.ID && constant.TypeUser == user.Type { + // 普通用户访问非自己创建的资产需要校验权限 + resourceSharers, err := repository.ResourceSharerRepository.FindByResourceIdAndUserId(context.TODO(), assetId, user.ID) + if err != nil { + return nil, err + } + if len(resourceSharers) == 0 { + return nil, errors.New("您没有权限访问此资产") + } + strategyId := resourceSharers[0].StrategyId + if strategyId != "" { + strategy, err := repository.StrategyRepository.FindById(context.TODO(), strategyId) + if err != nil { + if !errors.Is(gorm.ErrRecordNotFound, err) { + return nil, err + } + } else { + upload = strategy.Upload + download = strategy.Download + _delete = strategy.Delete + rename = strategy.Rename + edit = strategy.Edit + _copy = strategy.Copy + paste = strategy.Paste + } + } + } + + var storageId = "" + if constant.RDP == asset.Protocol { + attr, err := repository.AssetRepository.FindAssetAttrMapByAssetId(context.TODO(), assetId) + if err != nil { + return nil, err + } + if "true" == attr[guacd.EnableDrive] { + fileSystem = "1" + storageId = attr[guacd.DrivePath] + if storageId == "" { + storageId = user.ID + } + } else { + fileSystem = "0" + } + } + if fileSystem != "1" { + fileSystem = "0" + } + if upload != "1" { + upload = "0" + } + if download != "1" { + download = "0" + } + if _delete != "1" { + _delete = "0" + } + if rename != "1" { + rename = "0" + } + if edit != "1" { + edit = "0" + } + if _copy != "1" { + _copy = "0" + } + if paste != "1" { + paste = "0" + } + + s := &model.Session{ + ID: utils.UUID(), + AssetId: asset.ID, + Username: asset.Username, + Password: asset.Password, + PrivateKey: asset.PrivateKey, + Passphrase: asset.Passphrase, + Protocol: asset.Protocol, + IP: asset.IP, + Port: asset.Port, + Status: constant.NoConnect, + ClientIP: clientIp, + Mode: mode, + FileSystem: fileSystem, + Upload: upload, + Download: download, + Delete: _delete, + Rename: rename, + Edit: edit, + Copy: _copy, + Paste: paste, + StorageId: storageId, + AccessGatewayId: asset.AccessGatewayId, + Reviewed: false, + } + if constant.Anonymous != user.Type { + s.Creator = user.ID + } + + if asset.AccountType == "credential" { + credential, err := repository.CredentialRepository.FindById(context.TODO(), asset.CredentialId) + if err != nil { + return nil, err + } + + if credential.Type == constant.Custom { + s.Username = credential.Username + s.Password = credential.Password + } else { + s.Username = credential.Username + s.PrivateKey = credential.PrivateKey + s.Passphrase = credential.Passphrase + } + } + + if err := repository.SessionRepository.Create(context.TODO(), s); err != nil { + return nil, err + } + return s, nil +} diff --git a/server/service/storage.go b/server/service/storage.go index 9685352..1641300 100644 --- a/server/service/storage.go +++ b/server/service/storage.go @@ -1,10 +1,15 @@ package service import ( + "bufio" + "context" "errors" + "io" "io/ioutil" + "mime/multipart" "os" "path" + "strings" "next-terminal/server/config" "next-terminal/server/log" @@ -12,37 +17,31 @@ import ( "next-terminal/server/repository" "next-terminal/server/utils" + "github.com/labstack/echo/v4" "gorm.io/gorm" ) -type StorageService struct { - storageRepository *repository.StorageRepository - userRepository *repository.UserRepository - propertyRepository *repository.PropertyRepository +type storageService struct { } -func NewStorageService(storageRepository *repository.StorageRepository, userRepository *repository.UserRepository, propertyRepository *repository.PropertyRepository) *StorageService { - return &StorageService{storageRepository: storageRepository, userRepository: userRepository, propertyRepository: propertyRepository} -} - -func (r StorageService) InitStorages() error { - users, err := r.userRepository.FindAll() +func (service storageService) InitStorages() error { + users, err := repository.UserRepository.FindAll(context.TODO()) if err != nil { return err } for i := range users { userId := users[i].ID - _, err := r.storageRepository.FindByOwnerIdAndDefault(userId, true) + _, err := repository.StorageRepository.FindByOwnerIdAndDefault(context.TODO(), userId, true) if errors.Is(err, gorm.ErrRecordNotFound) { - err = r.CreateStorageByUser(&users[i]) + err = service.CreateStorageByUser(&users[i]) if err != nil { return err } } } - drivePath := r.GetBaseDrivePath() - storages, err := r.storageRepository.FindAll() + drivePath := service.GetBaseDrivePath() + storages, err := repository.StorageRepository.FindAll(context.TODO()) if err != nil { return err } @@ -59,7 +58,7 @@ func (r StorageService) InitStorages() error { } if !userExist { - if err := r.DeleteStorageById(storage.ID, true); err != nil { + if err := service.DeleteStorageById(storage.ID, true); err != nil { return err } } @@ -76,8 +75,8 @@ func (r StorageService) InitStorages() error { return nil } -func (r StorageService) CreateStorageByUser(user *model.User) error { - drivePath := r.GetBaseDrivePath() +func (service storageService) CreateStorageByUser(user *model.User) error { + drivePath := service.GetBaseDrivePath() storage := model.Storage{ ID: user.ID, Name: user.Nickname + "的默认空间", @@ -92,7 +91,7 @@ func (r StorageService) CreateStorageByUser(user *model.User) error { return err } log.Infof("创建storage:「%v」文件夹: %v", storage.Name, storageDir) - err := r.storageRepository.Create(&storage) + err := repository.StorageRepository.Create(context.TODO(), &storage) if err != nil { return err } @@ -109,7 +108,7 @@ type File struct { Size int64 `json:"size"` } -func (r StorageService) Ls(drivePath, remoteDir string) ([]File, error) { +func (service storageService) Ls(drivePath, remoteDir string) ([]File, error) { fileInfos, err := ioutil.ReadDir(path.Join(drivePath, remoteDir)) if err != nil { return nil, err @@ -132,13 +131,13 @@ func (r StorageService) Ls(drivePath, remoteDir string) ([]File, error) { return files, nil } -func (r StorageService) GetBaseDrivePath() string { +func (service storageService) GetBaseDrivePath() string { return config.GlobalCfg.Guacd.Drive } -func (r StorageService) DeleteStorageById(id string, force bool) error { - drivePath := r.GetBaseDrivePath() - storage, err := r.storageRepository.FindById(id) +func (service storageService) DeleteStorageById(id string, force bool) error { + drivePath := service.GetBaseDrivePath() + storage, err := repository.StorageRepository.FindById(context.TODO(), id) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil @@ -153,7 +152,136 @@ func (r StorageService) DeleteStorageById(id string, force bool) error { if err := os.RemoveAll(path.Join(drivePath, id)); err != nil { return err } - if err := r.storageRepository.DeleteById(id); err != nil { + if err := repository.StorageRepository.DeleteById(context.TODO(), id); err != nil { + return err + } + return nil +} + +func (service storageService) StorageUpload(c echo.Context, file *multipart.FileHeader, storageId string) error { + drivePath := service.GetBaseDrivePath() + storage, _ := repository.StorageRepository.FindById(context.TODO(), storageId) + if storage.LimitSize > 0 { + dirSize, err := utils.DirSize(path.Join(drivePath, storageId)) + if err != nil { + return err + } + if dirSize+file.Size > storage.LimitSize { + return errors.New("可用空间不足") + } + } + + filename := file.Filename + src, err := file.Open() + if err != nil { + return err + } + + remoteDir := c.QueryParam("dir") + remoteFile := path.Join(remoteDir, filename) + + if strings.Contains(remoteDir, "../") { + return errors.New("非法请求 :(") + } + if strings.Contains(remoteFile, "../") { + return errors.New("非法请求 :(") + } + + // 判断文件夹不存在时自动创建 + dir := path.Join(path.Join(drivePath, storageId), remoteDir) + if !utils.FileExists(dir) { + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return err + } + } + // Destination + dst, err := os.Create(path.Join(path.Join(drivePath, storageId), remoteFile)) + if err != nil { + return err + } + defer dst.Close() + + // Copy + if _, err = io.Copy(dst, src); err != nil { + return err + } + return nil +} + +func (service storageService) StorageEdit(file string, fileContent string, storageId string) error { + drivePath := service.GetBaseDrivePath() + if strings.Contains(file, "../") { + return errors.New("非法请求 :(") + } + realFilePath := path.Join(path.Join(drivePath, storageId), file) + dstFile, err := os.OpenFile(realFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return err + } + defer dstFile.Close() + write := bufio.NewWriter(dstFile) + if _, err := write.WriteString(fileContent); err != nil { + return err + } + if err := write.Flush(); err != nil { + return err + } + return nil +} + +func (service storageService) StorageDownload(c echo.Context, remoteFile, storageId string) error { + drivePath := service.GetBaseDrivePath() + if strings.Contains(remoteFile, "../") { + return errors.New("非法请求 :(") + } + // 获取带后缀的文件名称 + filenameWithSuffix := path.Base(remoteFile) + return c.Attachment(path.Join(path.Join(drivePath, storageId), remoteFile), filenameWithSuffix) +} + +func (service storageService) StorageLs(remoteDir, storageId string) (error, []File) { + drivePath := service.GetBaseDrivePath() + if strings.Contains(remoteDir, "../") { + return errors.New("非法请求 :("), nil + } + files, err := service.Ls(path.Join(drivePath, storageId), remoteDir) + if err != nil { + return err, nil + } + return nil, files +} + +func (service storageService) StorageMkDir(remoteDir, storageId string) error { + drivePath := service.GetBaseDrivePath() + if strings.Contains(remoteDir, "../") { + return errors.New("非法请求 :(") + } + if err := os.MkdirAll(path.Join(path.Join(drivePath, storageId), remoteDir), os.ModePerm); err != nil { + return err + } + return nil +} + +func (service storageService) StorageRm(file, storageId string) error { + drivePath := service.GetBaseDrivePath() + if strings.Contains(file, "../") { + return errors.New("非法请求 :(") + } + if err := os.RemoveAll(path.Join(path.Join(drivePath, storageId), file)); err != nil { + return err + } + return nil +} + +func (service storageService) StorageRename(oldName, newName, storageId string) error { + drivePath := service.GetBaseDrivePath() + if strings.Contains(oldName, "../") { + return errors.New("非法请求 :(") + } + if strings.Contains(newName, "../") { + return errors.New("非法请求 :(") + } + if err := os.Rename(path.Join(path.Join(drivePath, storageId), oldName), path.Join(path.Join(drivePath, storageId), newName)); err != nil { return err } return nil diff --git a/server/service/user.go b/server/service/user.go index 89b4807..2028cce 100644 --- a/server/service/user.go +++ b/server/service/user.go @@ -1,28 +1,30 @@ package service import ( - "next-terminal/server/global/cache" - "strings" + "errors" + "fmt" "next-terminal/server/constant" + "next-terminal/server/dto" + "next-terminal/server/env" + "next-terminal/server/global/cache" "next-terminal/server/log" "next-terminal/server/model" "next-terminal/server/repository" "next-terminal/server/utils" + "strings" + + "golang.org/x/net/context" + "gorm.io/gorm" ) -type UserService struct { - userRepository *repository.UserRepository - loginLogRepository *repository.LoginLogRepository +type userService struct { + baseService } -func NewUserService(userRepository *repository.UserRepository, loginLogRepository *repository.LoginLogRepository) *UserService { - return &UserService{userRepository: userRepository, loginLogRepository: loginLogRepository} -} +func (service userService) InitUser() (err error) { -func (r UserService) InitUser() (err error) { - - users, err := r.userRepository.FindAll() + users, err := repository.UserRepository.FindAll(context.TODO()) if err != nil { return err } @@ -43,7 +45,7 @@ func (r UserService) InitUser() (err error) { Created: utils.NowJsonTime(), Status: constant.StatusEnabled, } - if err := r.userRepository.Create(&user); err != nil { + if err := repository.UserRepository.Create(context.TODO(), &user); err != nil { return err } @@ -56,7 +58,7 @@ func (r UserService) InitUser() (err error) { Type: constant.TypeAdmin, ID: users[i].ID, } - if err := r.userRepository.Update(&user); err != nil { + if err := repository.UserRepository.Update(context.TODO(), &user); err != nil { return err } log.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID) @@ -66,20 +68,20 @@ func (r UserService) InitUser() (err error) { return nil } -func (r UserService) FixUserOnlineState() error { +func (service userService) FixUserOnlineState() error { // 修正用户登录状态 - onlineUsers, err := r.userRepository.FindOnlineUsers() + onlineUsers, err := repository.UserRepository.FindOnlineUsers(context.TODO()) if err != nil { return err } if len(onlineUsers) > 0 { for i := range onlineUsers { - logs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(onlineUsers[i].Username) + logs, err := repository.LoginLogRepository.FindAliveLoginLogsByUsername(context.TODO(), onlineUsers[i].Username) if err != nil { return err } if len(logs) == 0 { - if err := r.userRepository.UpdateOnlineByUsername(onlineUsers[i].Username, false); err != nil { + if err := repository.UserRepository.UpdateOnlineByUsername(context.TODO(), onlineUsers[i].Username, false); err != nil { return err } } @@ -88,96 +90,220 @@ func (r UserService) FixUserOnlineState() error { return nil } -func (r UserService) LogoutByToken(token string) (err error) { - loginLog, err := r.loginLogRepository.FindById(token) - if err != nil { - log.Warnf("登录日志「%v」获取失败", token) - return - } - cacheKey := r.BuildCacheKeyByToken(token) - cache.GlobalCache.Delete(cacheKey) +func (service userService) LogoutByToken(token string) (err error) { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := service.Context(tx) + loginLog, err := repository.LoginLogRepository.FindById(c, token) + if err != nil { + return err + } + cache.TokenManager.Delete(token) - loginLogForUpdate := &model.LoginLog{LogoutTime: utils.NowJsonTime(), ID: token} - err = r.loginLogRepository.Update(loginLogForUpdate) - if err != nil { + loginLogForUpdate := &model.LoginLog{LogoutTime: utils.NowJsonTime(), ID: token} + err = repository.LoginLogRepository.Update(c, loginLogForUpdate) + if err != nil { + return err + } + + loginLogs, err := repository.LoginLogRepository.FindAliveLoginLogsByUsername(c, loginLog.Username) + if err != nil { + return err + } + + if len(loginLogs) == 0 { + err = repository.UserRepository.UpdateOnlineByUsername(c, loginLog.Username, false) + } return err - } - - loginLogs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(loginLog.Username) - if err != nil { - return - } - - if len(loginLogs) == 0 { - err = r.userRepository.UpdateOnlineByUsername(loginLog.Username, false) - } - return + }) } -func (r UserService) LogoutById(id string) error { - user, err := r.userRepository.FindById(id) +func (service userService) LogoutById(c context.Context, id string) error { + user, err := repository.UserRepository.FindById(c, id) if err != nil { return err } username := user.Username - loginLogs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(username) + loginLogs, err := repository.LoginLogRepository.FindAliveLoginLogsByUsername(c, username) if err != nil { return err } for j := range loginLogs { token := loginLogs[j].ID - if err := r.LogoutByToken(token); err != nil { + if err := service.LogoutByToken(token); err != nil { return err } } return nil } -func (r UserService) BuildCacheKeyByToken(token string) string { - cacheKey := strings.Join([]string{constant.Token, token}, ":") - return cacheKey -} +func (service userService) OnEvicted(token string, value interface{}) { -func (r UserService) GetTokenFormCacheKey(cacheKey string) string { - token := strings.Split(cacheKey, ":")[1] - return token -} - -func (r UserService) OnEvicted(key string, value interface{}) { - if strings.HasPrefix(key, constant.Token) { - token := r.GetTokenFormCacheKey(key) + if strings.HasPrefix(token, "forever") { + log.Debugf("re gen forever token") + } else { log.Debugf("用户Token「%v」过期", token) - err := r.LogoutByToken(token) + err := service.LogoutByToken(token) if err != nil { log.Errorf("退出登录失败 %v", err) } } } -func (r UserService) UpdateStatusById(id string, status string) error { - if constant.StatusDisabled == status { - // 将该用户下线 - if err := r.LogoutById(id); err != nil { - return err +func (service userService) UpdateStatusById(id string, status string) error { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := service.Context(tx) + if c.Value(constant.DB) == nil { + c = context.WithValue(c, constant.DB, env.GetDB()) } - } - u := model.User{ - ID: id, - Status: status, - } - return r.userRepository.Update(&u) + if constant.StatusDisabled == status { + // 将该用户下线 + if err := service.LogoutById(c, id); err != nil { + return err + } + } + u := model.User{ + ID: id, + Status: status, + } + return repository.UserRepository.Update(c, &u) + }) + } -func (r UserService) DeleteLoginLogs(tokens []string) error { - for i := range tokens { - token := tokens[i] - if err := r.LogoutByToken(token); err != nil { +func (service userService) ReloadToken() error { + loginLogs, err := repository.LoginLogRepository.FindAliveLoginLogs(context.TODO()) + if err != nil { + return err + } + + for i := range loginLogs { + loginLog := loginLogs[i] + token := loginLog.ID + user, err := repository.UserRepository.FindByUsername(context.TODO(), loginLog.Username) + if err != nil { + if errors.Is(gorm.ErrRecordNotFound, err) { + _ = repository.LoginLogRepository.DeleteById(context.TODO(), token) + } + continue + } + + authorization := dto.Authorization{ + Token: token, + Type: constant.LoginToken, + Remember: loginLog.Remember, + User: &user, + } + + if authorization.Remember { + // 记住登录有效期两周 + cache.TokenManager.Set(token, authorization, cache.RememberMeExpiration) + } else { + cache.TokenManager.Set(token, authorization, cache.NotRememberExpiration) + } + log.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token) + } + return nil +} + +func (service userService) CreateUser(user model.User) (err error) { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := service.Context(tx) + if repository.UserRepository.ExistByUsername(c, user.Username) { + return fmt.Errorf("username %s is already used", user.Username) + } + password := user.Password + + var pass []byte + if pass, err = utils.Encoder.Encode([]byte(password)); err != nil { return err } - if err := r.loginLogRepository.DeleteById(token); err != nil { + user.Password = string(pass) + + user.ID = utils.UUID() + user.Created = utils.NowJsonTime() + user.Status = constant.StatusEnabled + + if err := repository.UserRepository.Create(c, &user); err != nil { return err } + err = StorageService.CreateStorageByUser(&user) + if err != nil { + return err + } + + if user.Mail != "" { + go MailService.SendMail(user.Mail, "[Next Terminal] 注册通知", "你好,"+user.Nickname+"。管理员为你注册了账号:"+user.Username+" 密码:"+password) + } + return nil + }) + +} + +func (service userService) DeleteUserById(userId string) error { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := service.Context(tx) + // 下线该用户 + if err := service.LogoutById(c, userId); err != nil { + return err + } + // 删除用户 + if err := repository.UserRepository.DeleteById(c, userId); err != nil { + return err + } + // 删除用户与用户组的关系 + if err := repository.UserGroupMemberRepository.DeleteByUserId(c, userId); err != nil { + return err + } + // 删除用户与资产的关系 + if err := repository.ResourceSharerRepository.DeleteByUserId(c, userId); err != nil { + return err + } + // 删除用户的默认磁盘空间 + if err := StorageService.DeleteStorageById(userId, true); err != nil { + return err + } + return nil + }) +} + +func (service userService) DeleteLoginLogs(tokens []string) error { + if len(tokens) > 0 { + for _, token := range tokens { + if err := service.LogoutByToken(token); err != nil { + return err + } + if err := repository.LoginLogRepository.DeleteById(context.TODO(), token); err != nil { + return err + } + } } return nil } + +func (service userService) SaveLoginLog(clientIP, clientUserAgent string, username string, success, remember bool, id, reason string) error { + loginLog := model.LoginLog{ + Username: username, + ClientIP: clientIP, + ClientUserAgent: clientUserAgent, + LoginTime: utils.NowJsonTime(), + Reason: reason, + Remember: remember, + } + if success { + loginLog.State = "1" + loginLog.ID = id + } else { + loginLog.State = "0" + loginLog.ID = utils.LongUUID() + } + + if err := repository.LoginLogRepository.Create(context.TODO(), &loginLog); err != nil { + return err + } + return nil +} + +func (service userService) DeleteALlLdapUser(ctx context.Context) error { + return repository.UserRepository.DeleteBySource(ctx, constant.SourceLdap) +} diff --git a/server/service/user_group.go b/server/service/user_group.go new file mode 100644 index 0000000..67b618a --- /dev/null +++ b/server/service/user_group.go @@ -0,0 +1,115 @@ +package service + +import ( + "context" + "errors" + + "next-terminal/server/constant" + "next-terminal/server/env" + "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/utils" + + "gorm.io/gorm" +) + +type userGroupService struct { +} + +func (service userGroupService) DeleteById(userGroupId string) error { + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := context.WithValue(context.TODO(), constant.DB, tx) + // 删除用户组 + if err := repository.UserGroupRepository.DeleteById(c, userGroupId); err != nil { + return err + } + // 删除用户组与用户的关系 + if err := repository.UserGroupMemberRepository.DeleteByUserGroupId(c, userGroupId); err != nil { + return err + } + // 删除用户组与资产的关系 + if err := repository.ResourceSharerRepository.DeleteByUserGroupId(c, userGroupId); err != nil { + return err + } + return nil + }) +} + +func (service userGroupService) Create(name string, members []string) (model.UserGroup, error) { + var err error + _, err = repository.UserGroupRepository.FindByName(context.TODO(), name) + if err == nil { + return model.UserGroup{}, constant.ErrNameAlreadyUsed + } + + if !errors.Is(gorm.ErrRecordNotFound, err) { + return model.UserGroup{}, err + } + + userGroupId := utils.UUID() + userGroup := model.UserGroup{ + ID: userGroupId, + Created: utils.NowJsonTime(), + Name: name, + } + + return userGroup, env.GetDB().Transaction(func(tx *gorm.DB) error { + c := context.WithValue(context.TODO(), constant.DB, tx) + if err := repository.UserGroupRepository.Create(c, &userGroup); err != nil { + return err + } + if len(members) > 0 { + for _, member := range members { + userGroupMember := model.UserGroupMember{ + ID: utils.Sign([]string{userGroupId, member}), + UserId: member, + UserGroupId: userGroupId, + } + if err := repository.UserGroupMemberRepository.Create(c, &userGroupMember); err != nil { + return err + } + } + } + return nil + }) + +} + +func (service userGroupService) Update(userGroupId string, name string, members []string) (err error) { + var userGroup model.UserGroup + userGroup, err = repository.UserGroupRepository.FindByName(context.TODO(), name) + if err == nil && userGroup.ID != userGroupId { + return constant.ErrNameAlreadyUsed + } + + if !errors.Is(gorm.ErrRecordNotFound, err) { + return err + } + + return env.GetDB().Transaction(func(tx *gorm.DB) error { + c := context.WithValue(context.TODO(), constant.DB, tx) + userGroup := model.UserGroup{ + ID: userGroupId, + Name: name, + } + if err := repository.UserGroupRepository.Update(c, &userGroup); err != nil { + return err + } + if err := repository.UserGroupMemberRepository.DeleteByUserGroupId(c, userGroupId); err != nil { + return err + } + if len(members) > 0 { + for _, member := range members { + userGroupMember := model.UserGroupMember{ + ID: utils.Sign([]string{userGroupId, member}), + UserId: member, + UserGroupId: userGroupId, + } + if err := repository.UserGroupMemberRepository.Create(c, &userGroupMember); err != nil { + return err + } + } + } + return nil + }) +} diff --git a/server/service/var.go b/server/service/var.go new file mode 100644 index 0000000..ef4990c --- /dev/null +++ b/server/service/var.go @@ -0,0 +1,17 @@ +package service + +var ( + AssetService = new(assetService) + BackupService = new(backupService) + CredentialService = new(credentialService) + GatewayService = new(gatewayService) + JobService = new(jobService) + MailService = new(mailService) + PropertyService = new(propertyService) + SecurityService = new(securityService) + SessionService = new(sessionService) + StorageService = new(storageService) + UserService = new(userService) + UserGroupService = new(userGroupService) + AccessTokenService = new(accessTokenService) +) diff --git a/server/sshd/sshd.go b/server/sshd/sshd.go new file mode 100644 index 0000000..83a4e6d --- /dev/null +++ b/server/sshd/sshd.go @@ -0,0 +1,146 @@ +package sshd + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "strings" + + "next-terminal/server/config" + "next-terminal/server/constant" + "next-terminal/server/global/security" + "next-terminal/server/log" + "next-terminal/server/repository" + "next-terminal/server/service" + "next-terminal/server/utils" + + "github.com/gliderlabs/ssh" + "gorm.io/gorm" +) + +type Sshd struct { + gui *Gui +} + +func init() { + gui := &Gui{} + sshd := &Sshd{ + gui: gui, + } + go sshd.Serve() +} + +func (sshd Sshd) passwordAuth(ctx ssh.Context, pass string) bool { + username := ctx.User() + remoteAddr := strings.Split(ctx.RemoteAddr().String(), ":")[0] + user, err := repository.UserRepository.FindByUsername(context.TODO(), username) + + if err != nil { + // 保存登录日志 + _ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确") + return false + } + + if err := utils.Encoder.Match([]byte(user.Password), []byte(pass)); err != nil { + // 保存登录日志 + _ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确") + return false + } + return true +} + +func (sshd Sshd) connCallback(ctx ssh.Context, conn net.Conn) net.Conn { + securities := security.GlobalSecurityManager.Values() + if len(securities) == 0 { + return conn + } + + ip := strings.Split(conn.RemoteAddr().String(), ":")[0] + + for _, s := range securities { + if strings.Contains(s.IP, "/") { + // CIDR + _, ipNet, err := net.ParseCIDR(s.IP) + if err != nil { + continue + } + if !ipNet.Contains(net.ParseIP(ip)) { + continue + } + } else if strings.Contains(s.IP, "-") { + // 范围段 + split := strings.Split(s.IP, "-") + if len(split) < 2 { + continue + } + start := split[0] + end := split[1] + intReqIP := utils.IpToInt(ip) + if intReqIP < utils.IpToInt(start) || intReqIP > utils.IpToInt(end) { + continue + } + } else { + // IP + if s.IP != ip { + continue + } + } + + if s.Rule == constant.AccessRuleAllow { + return conn + } + if s.Rule == constant.AccessRuleReject { + _, _ = conn.Write([]byte("your access request was denied :(\n")) + return nil + } + } + + return conn +} + +func (sshd Sshd) sessionHandler(sess *ssh.Session) { + defer func() { + _ = (*sess).Close() + }() + + username := (*sess).User() + remoteAddr := strings.Split((*sess).RemoteAddr().String(), ":")[0] + + user, err := repository.UserRepository.FindByUsername(context.TODO(), username) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + _, _ = io.WriteString(*sess, "您输入的账户或密码不正确.\n") + } else { + _, _ = io.WriteString(*sess, err.Error()) + } + return + } + + // 判断是否需要进行双因素认证 + if user.TOTPSecret != "" && user.TOTPSecret != "-" { + sshd.gui.totpUI(sess, user, remoteAddr, username) + } else { + // 保存登录日志 + _ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, true, false, utils.LongUUID(), "") + sshd.gui.MainUI(sess, user) + } +} + +func (sshd Sshd) Serve() { + ssh.Handle(func(s ssh.Session) { + _, _ = io.WriteString(s, fmt.Sprintf(constant.Banner, constant.Version)) + sshd.sessionHandler(&s) + }) + + fmt.Printf("⇨ sshd server started on %v\n", config.GlobalCfg.Sshd.Addr) + err := ssh.ListenAndServe( + config.GlobalCfg.Sshd.Addr, + nil, + ssh.PasswordAuth(sshd.passwordAuth), + ssh.HostKeyFile(config.GlobalCfg.Sshd.Key), + ssh.WrapConn(sshd.connCallback), + ) + log.Fatal(fmt.Sprintf("启动sshd服务失败: %v", err.Error())) +} diff --git a/server/api/sshd.go b/server/sshd/ui.go similarity index 53% rename from server/api/sshd.go rename to server/sshd/ui.go index 2cc74c9..d87f296 100644 --- a/server/api/sshd.go +++ b/server/sshd/ui.go @@ -1,16 +1,11 @@ -package api +package sshd import ( - "encoding/hex" + "context" "errors" "fmt" "io" - "net" - "next-terminal/server/global/security" - "path" - "strings" - "time" - + "next-terminal/server/api" "next-terminal/server/config" "next-terminal/server/constant" "next-terminal/server/global/cache" @@ -18,99 +13,22 @@ import ( "next-terminal/server/guacd" "next-terminal/server/log" "next-terminal/server/model" + "next-terminal/server/repository" + "next-terminal/server/service" "next-terminal/server/term" "next-terminal/server/totp" "next-terminal/server/utils" + "path" + "strings" "github.com/gliderlabs/ssh" "github.com/manifoldco/promptui" - "gorm.io/gorm" ) -func sessionHandler(sess *ssh.Session) { - defer func() { - _ = (*sess).Close() - }() - - username := (*sess).User() - remoteAddr := strings.Split((*sess).RemoteAddr().String(), ":")[0] - - user, err := userRepository.FindByUsername(username) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - _, _ = io.WriteString(*sess, "您输入的账户或密码不正确.\n") - } else { - _, _ = io.WriteString(*sess, err.Error()) - } - return - } - - // 判断是否需要进行双因素认证 - if user.TOTPSecret != "" && user.TOTPSecret != "-" { - totpUI(sess, user, remoteAddr, username) - } else { - // 保存登录日志 - _ = SaveLoginLog(remoteAddr, "terminal", username, true, false, utils.UUID(), "") - mainUI(sess, user) - } +type Gui struct { } -func totpUI(sess *ssh.Session, user model.User, remoteAddr string, username string) { - - validate := func(input string) error { - if len(input) < 6 { - return errors.New("双因素认证授权码必须为6个数字") - } - return nil - } - - prompt := promptui.Prompt{ - Label: "请输入双因素认证授权码", - Validate: validate, - Mask: '*', - Stdin: *sess, - Stdout: *sess, - } - - var success = false - for i := 0; i < 5; i++ { - result, err := prompt.Run() - if err != nil { - fmt.Printf("Prompt failed %v\n", err) - return - } - loginFailCountKey := remoteAddr + username - - v, ok := cache.GlobalCache.Get(loginFailCountKey) - if !ok { - v = 1 - } - count := v.(int) - if count >= 5 { - _, _ = io.WriteString(*sess, "登录失败次数过多,请等待30秒后再试\r\n") - continue - } - if !totp.Validate(result, user.TOTPSecret) { - count++ - println(count) - cache.GlobalCache.Set(loginFailCountKey, count, time.Second*time.Duration(30)) - // 保存登录日志 - _ = SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "双因素认证授权码不正确") - _, _ = io.WriteString(*sess, "您输入的双因素认证授权码不匹配\r\n") - continue - } - success = true - break - } - - if success { - // 保存登录日志 - _ = SaveLoginLog(remoteAddr, "terminal", username, true, false, utils.UUID(), "") - mainUI(sess, user) - } -} - -func mainUI(sess *ssh.Session, user model.User) { +func (gui Gui) MainUI(sess *ssh.Session, user model.User) { prompt := promptui.Select{ Label: "欢迎使用 Next Terminal,请选择您要使用的功能", Items: []string{"我的资产", "退出系统"}, @@ -127,15 +45,15 @@ MainLoop: } switch result { case "我的资产": - AssetUI(sess, user) + gui.AssetUI(sess, user) case "退出系统": break MainLoop } } } -func AssetUI(sess *ssh.Session, user model.User) { - assets, err := assetRepository.FindByProtocolAndUser(constant.SSH, user) +func (gui Gui) AssetUI(sess *ssh.Session, user model.User) { + assets, err := repository.AssetRepository.FindByProtocolAndUser(context.TODO(), constant.SSH, user) if err != nil { return } @@ -190,17 +108,16 @@ AssetUILoop: case "quit": break AssetUILoop default: - if err := createSession(sess, assets[i].ID, user.ID); err != nil { + if err := gui.createSession(sess, chooseAssetId, user.ID); err != nil { _, _ = io.WriteString(*sess, err.Error()+"\r\n") return } } } - } -func createSession(sess *ssh.Session, assetId, creator string) (err error) { - asset, err := assetRepository.FindById(assetId) +func (gui Gui) createSession(sess *ssh.Session, assetId, creator string) (err error) { + asset, err := repository.AssetRepository.FindById(context.TODO(), assetId) if err != nil { return err } @@ -230,7 +147,7 @@ func createSession(sess *ssh.Session, assetId, creator string) (err error) { } if asset.AccountType == "credential" { - credential, err := credentialRepository.FindById(asset.CredentialId) + credential, err := repository.CredentialRepository.FindById(context.TODO(), asset.CredentialId) if err != nil { return nil } @@ -245,15 +162,15 @@ func createSession(sess *ssh.Session, assetId, creator string) (err error) { } } - if err := sessionRepository.Create(s); err != nil { + if err := repository.SessionRepository.Create(context.TODO(), s); err != nil { return err } - return handleAccessAsset(sess, s.ID) + return gui.handleAccessAsset(sess, s.ID) } -func handleAccessAsset(sess *ssh.Session, sessionId string) (err error) { - s, err := sessionRepository.FindByIdAndDecrypt(sessionId) +func (gui Gui) handleAccessAsset(sess *ssh.Session, sessionId string) (err error) { + s, err := service.SessionService.FindByIdAndDecrypt(context.TODO(), sessionId) if err != nil { return err } @@ -268,7 +185,7 @@ func handleAccessAsset(sess *ssh.Session, sessionId string) (err error) { ) if s.AccessGatewayId != "" && s.AccessGatewayId != "-" { - g, err := accessGatewayService.GetGatewayAndReconnectById(s.AccessGatewayId) + g, err := service.GatewayService.GetGatewayAndReconnectById(s.AccessGatewayId) if err != nil { return errors.New("获取接入网关失败:" + err.Error()) } @@ -290,7 +207,7 @@ func handleAccessAsset(sess *ssh.Session, sessionId string) (err error) { } recording := "" - property, err := propertyRepository.FindByName(guacd.EnableRecording) + property, err := repository.PropertyRepository.FindByName(context.TODO(), guacd.EnableRecording) if err == nil && property.Value == "true" { recording = path.Join(config.GlobalCfg.Guacd.Recording, sessionId, "recording.cast") } @@ -322,7 +239,7 @@ func handleAccessAsset(sess *ssh.Session, sessionId string) (err error) { } log.Debugf("退出窗口大小监控") // ==== 修改数据库中的会话状态为已断开,修复用户直接关闭窗口时会话状态不正确的问题 ==== - CloseSessionById(sessionId, Normal, "用户正常退出") + service.SessionService.CloseSessionById(sessionId, api.Normal, "用户正常退出") // ==== 修改数据库中的会话状态为已断开,修复用户直接关闭窗口时会话状态不正确的问题 ==== }() @@ -338,7 +255,7 @@ func handleAccessAsset(sess *ssh.Session, sessionId string) (err error) { sessionForUpdate.Reviewed = true } - if err := sessionRepository.UpdateById(&sessionForUpdate, sessionId); err != nil { + if err := repository.SessionRepository.UpdateById(context.TODO(), &sessionForUpdate, sessionId); err != nil { return err } // ==== 修改数据库中的会话状态为已连接 ==== @@ -358,168 +275,63 @@ func handleAccessAsset(sess *ssh.Session, sessionId string) (err error) { } // ==== 修改数据库中的会话状态为已断开 ==== - CloseSessionById(sessionId, Normal, "用户正常退出") + service.SessionService.CloseSessionById(sessionId, api.Normal, "用户正常退出") // ==== 修改数据库中的会话状态为已断开 ==== return nil } -func passwordAuth(ctx ssh.Context, pass string) bool { - username := ctx.User() - remoteAddr := strings.Split(ctx.RemoteAddr().String(), ":")[0] - user, err := userRepository.FindByUsername(username) +func (gui Gui) totpUI(sess *ssh.Session, user model.User, remoteAddr string, username string) { - if err != nil { + validate := func(input string) error { + if len(input) < 6 { + return errors.New("双因素认证授权码必须为6个数字") + } + return nil + } + + prompt := promptui.Prompt{ + Label: "请输入双因素认证授权码", + Validate: validate, + Mask: '*', + Stdin: *sess, + Stdout: *sess, + } + + var success = false + for i := 0; i < 5; i++ { + result, err := prompt.Run() + if err != nil { + fmt.Printf("Prompt failed %v\n", err) + return + } + loginFailCountKey := remoteAddr + username + + v, ok := cache.LoginFailedKeyManager.Get(loginFailCountKey) + if !ok { + v = 1 + } + count := v.(int) + if count >= 5 { + _, _ = io.WriteString(*sess, "登录失败次数过多,请等待5分钟后再试\r\n") + continue + } + if !totp.Validate(result, user.TOTPSecret) { + count++ + println(count) + cache.LoginFailedKeyManager.Set(loginFailCountKey, count, cache.LoginLockExpiration) + // 保存登录日志 + _ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "双因素认证授权码不正确") + _, _ = io.WriteString(*sess, "您输入的双因素认证授权码不匹配\r\n") + continue + } + success = true + break + } + + if success { // 保存登录日志 - _ = SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确") - return false - } - - if err := utils.Encoder.Match([]byte(user.Password), []byte(pass)); err != nil { - // 保存登录日志 - _ = SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确") - return false - } - return true -} - -func connCallback(ctx ssh.Context, conn net.Conn) net.Conn { - securities := security.GlobalSecurityManager.Values() - if len(securities) == 0 { - return conn - } - - ip := strings.Split(conn.RemoteAddr().String(), ":")[0] - - for _, s := range securities { - if strings.Contains(s.IP, "/") { - // CIDR - _, ipNet, err := net.ParseCIDR(s.IP) - if err != nil { - continue - } - if !ipNet.Contains(net.ParseIP(ip)) { - continue - } - } else if strings.Contains(s.IP, "-") { - // 范围段 - split := strings.Split(s.IP, "-") - if len(split) < 2 { - continue - } - start := split[0] - end := split[1] - intReqIP := utils.IpToInt(ip) - if intReqIP < utils.IpToInt(start) || intReqIP > utils.IpToInt(end) { - continue - } - } else { - // IP - if s.IP != ip { - continue - } - } - - if s.Rule == constant.AccessRuleAllow { - return conn - } - if s.Rule == constant.AccessRuleReject { - _, _ = conn.Write([]byte("your access request was denied :(\n")) - return nil - } - } - - return conn -} - -func Setup() { - ssh.Handle(func(s ssh.Session) { - _, _ = io.WriteString(s, fmt.Sprintf(constant.Banner, constant.Version)) - defer func() { - if e, ok := recover().(error); ok { - log.Fatal(e) - } - }() - sessionHandler(&s) - }) - - fmt.Printf("⇨ sshd server started on %v\n", config.GlobalCfg.Sshd.Addr) - err := ssh.ListenAndServe( - config.GlobalCfg.Sshd.Addr, - nil, - ssh.PasswordAuth(passwordAuth), - ssh.HostKeyFile(config.GlobalCfg.Sshd.Key), - ssh.WrapConn(connCallback), - ) - log.Fatal(fmt.Sprintf("启动sshd服务失败: %v", err.Error())) -} - -func init() { - if config.GlobalCfg.Sshd.Enable { - go Setup() - } -} - -type Writer struct { - sessionId string - sess *ssh.Session - recorder *term.Recorder - rz bool - sz bool -} - -func NewWriter(sessionId string, sess *ssh.Session, recorder *term.Recorder) *Writer { - return &Writer{sessionId: sessionId, sess: sess, recorder: recorder} -} - -func (w *Writer) Write(p []byte) (n int, err error) { - if w.recorder != nil { - s := string(p) - if !w.sz && !w.rz { - // rz的开头字符 - hexData := hex.EncodeToString(p) - if strings.Contains(hexData, "727a0d2a2a184230303030303030303030303030300d8a11") { - w.sz = true - } else if strings.Contains(hexData, "727a2077616974696e6720746f20726563656976652e2a2a184230313030303030303233626535300d8a11") { - w.rz = true - } - } - - if w.sz { - // sz 会以 OO 结尾 - if "OO" == s { - w.sz = false - } - } else if w.rz { - // rz 最后会显示 Received /home/xxx - if strings.Contains(s, "Received") { - w.rz = false - // 把上传的文件名称也显示一下 - err := w.recorder.WriteData(s) - if err != nil { - return 0, err - } - sendObData(w.sessionId, s) - } - } else { - err := w.recorder.WriteData(s) - if err != nil { - return 0, err - } - sendObData(w.sessionId, s) - } - } - return (*w.sess).Write(p) -} - -func sendObData(sessionId, s string) { - nextSession := session.GlobalSessionManager.GetById(sessionId) - if nextSession != nil { - if nextSession.Observer != nil { - obs := nextSession.Observer.All() - for _, ob := range obs { - _ = WriteMessage(ob.WebSocket, NewMessage(Data, s)) - } - } + _ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, true, false, utils.UUID(), "") + gui.MainUI(sess, user) } } diff --git a/server/sshd/writer.go b/server/sshd/writer.go new file mode 100644 index 0000000..42d942e --- /dev/null +++ b/server/sshd/writer.go @@ -0,0 +1,77 @@ +package sshd + +import ( + "encoding/hex" + "strings" + + "next-terminal/server/api" + "next-terminal/server/dto" + "next-terminal/server/global/session" + "next-terminal/server/term" + + "github.com/gliderlabs/ssh" +) + +type Writer struct { + sessionId string + sess *ssh.Session + recorder *term.Recorder + rz bool + sz bool +} + +func NewWriter(sessionId string, sess *ssh.Session, recorder *term.Recorder) *Writer { + return &Writer{sessionId: sessionId, sess: sess, recorder: recorder} +} + +func (w *Writer) Write(p []byte) (n int, err error) { + if w.recorder != nil { + s := string(p) + if !w.sz && !w.rz { + // rz的开头字符 + hexData := hex.EncodeToString(p) + if strings.Contains(hexData, "727a0d2a2a184230303030303030303030303030300d8a11") { + w.sz = true + } else if strings.Contains(hexData, "727a2077616974696e6720746f20726563656976652e2a2a184230313030303030303233626535300d8a11") { + w.rz = true + } + } + + if w.sz { + // sz 会以 OO 结尾 + if "OO" == s { + w.sz = false + } + } else if w.rz { + // rz 最后会显示 Received /home/xxx + if strings.Contains(s, "Received") { + w.rz = false + // 把上传的文件名称也显示一下 + err := w.recorder.WriteData(s) + if err != nil { + return 0, err + } + sendObData(w.sessionId, s) + } + } else { + err := w.recorder.WriteData(s) + if err != nil { + return 0, err + } + sendObData(w.sessionId, s) + } + } + return (*w.sess).Write(p) +} + +func sendObData(sessionId, s string) { + nextSession := session.GlobalSessionManager.GetById(sessionId) + if nextSession != nil { + if nextSession.Observer != nil { + obs := nextSession.Observer.All() + for _, ob := range obs { + _ = api.WriteMessage(ob.WebSocket, dto.NewMessage(api.Data, s)) + } + } + } +} diff --git a/server/task/ticker.go b/server/task/ticker.go index ad47ac6..f38de13 100644 --- a/server/task/ticker.go +++ b/server/task/ticker.go @@ -1,6 +1,7 @@ package task import ( + "context" "strconv" "time" @@ -10,14 +11,14 @@ import ( ) type Ticker struct { - sessionRepository *repository.SessionRepository - propertyRepository *repository.PropertyRepository - loginLogRepository *repository.LoginLogRepository - jobLogRepository *repository.JobLogRepository } -func NewTicker(sessionRepository *repository.SessionRepository, propertyRepository *repository.PropertyRepository, loginLogRepository *repository.LoginLogRepository, jobLogRepository *repository.JobLogRepository) *Ticker { - return &Ticker{sessionRepository: sessionRepository, propertyRepository: propertyRepository, loginLogRepository: loginLogRepository, jobLogRepository: jobLogRepository} +func NewTicker() *Ticker { + return &Ticker{} +} +func init() { + ticker := NewTicker() + ticker.SetupTicker() } func (t *Ticker) SetupTicker() { @@ -26,17 +27,7 @@ func (t *Ticker) SetupTicker() { unUsedSessionTicker := time.NewTicker(time.Minute * 60) go func() { for range unUsedSessionTicker.C { - sessions, _ := t.sessionRepository.FindByStatusIn([]string{constant.NoConnect, constant.Connecting}) - if len(sessions) > 0 { - now := time.Now() - for i := range sessions { - if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 { - _ = t.sessionRepository.DeleteById(sessions[i].ID) - s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port) - log.Infof("会话「%v」ID「%v」超过1小时未打开,已删除。", s, sessions[i].ID) - } - } - } + t.deleteUnUsedSession() } }() @@ -44,15 +35,37 @@ func (t *Ticker) SetupTicker() { timeoutSessionTicker := time.NewTicker(time.Hour * 6) go func() { for range timeoutSessionTicker.C { - deleteOutTimeSession(t) - deleteOutTimeLoginLog(t) - deleteOutTimeJobLog(t) + deleteOutTimeSession() + deleteOutTimeLoginLog() + deleteOutTimeJobLog() } }() } -func deleteOutTimeSession(t *Ticker) { - property, err := t.propertyRepository.FindByName("session-saved-limit") +func (t *Ticker) deleteUnUsedSession() { + sessions, err := repository.SessionRepository.FindByStatusIn(context.TODO(), []string{constant.NoConnect, constant.Connecting}) + if err != nil { + log.Errorf("查询会话列表失败: %v", err.Error()) + return + } + if len(sessions) > 0 { + now := time.Now() + for i := range sessions { + if now.Sub(sessions[i].ConnectedTime.Time) > time.Hour*1 { + err := repository.SessionRepository.DeleteById(context.TODO(), sessions[i].ID) + s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port) + if err != nil { + log.Errorf("会话「%v」ID「%v」超过1小时未打开,删除失败: %v", s, sessions[i].ID, err.Error()) + } else { + log.Infof("会话「%v」ID「%v」超过1小时未打开,已删除。", s, sessions[i].ID) + } + } + } + } +} + +func deleteOutTimeSession() { + property, err := repository.PropertyRepository.FindByName(context.TODO(), "session-saved-limit") if err != nil { return } @@ -63,7 +76,7 @@ func deleteOutTimeSession(t *Ticker) { if err != nil { return } - sessions, err := t.sessionRepository.FindOutTimeSessions(limit) + sessions, err := repository.SessionRepository.FindOutTimeSessions(context.TODO(), limit) if err != nil { return } @@ -73,15 +86,15 @@ func deleteOutTimeSession(t *Ticker) { for i := range sessions { ids = append(ids, sessions[i].ID) } - err := t.sessionRepository.DeleteByIds(ids) + err := repository.SessionRepository.DeleteByIds(context.TODO(), ids) if err != nil { log.Errorf("删除离线会话失败 %v", err) } } } -func deleteOutTimeLoginLog(t *Ticker) { - property, err := t.propertyRepository.FindByName("login-log-saved-limit") +func deleteOutTimeLoginLog() { + property, err := repository.PropertyRepository.FindByName(context.TODO(), "login-log-saved-limit") if err != nil { return } @@ -94,7 +107,7 @@ func deleteOutTimeLoginLog(t *Ticker) { return } - loginLogs, err := t.loginLogRepository.FindOutTimeLog(limit) + loginLogs, err := repository.LoginLogRepository.FindOutTimeLog(context.TODO(), limit) if err != nil { log.Errorf("获取登录日志失败 %v", err) return @@ -102,7 +115,7 @@ func deleteOutTimeLoginLog(t *Ticker) { if len(loginLogs) > 0 { for i := range loginLogs { - err := t.loginLogRepository.DeleteById(loginLogs[i].ID) + err := repository.LoginLogRepository.DeleteById(context.TODO(), loginLogs[i].ID) if err != nil { log.Errorf("删除登录日志失败 %v", err) } @@ -110,8 +123,8 @@ func deleteOutTimeLoginLog(t *Ticker) { } } -func deleteOutTimeJobLog(t *Ticker) { - property, err := t.propertyRepository.FindByName("cron-log-saved-limit") +func deleteOutTimeJobLog() { + property, err := repository.PropertyRepository.FindByName(context.TODO(), "cron-log-saved-limit") if err != nil { return } @@ -123,14 +136,14 @@ func deleteOutTimeJobLog(t *Ticker) { return } - jobLogs, err := t.jobLogRepository.FindOutTimeLog(limit) + jobLogs, err := repository.JobLogRepository.FindOutTimeLog(context.TODO(), limit) if err != nil { return } if len(jobLogs) > 0 { for i := range jobLogs { - err := t.jobLogRepository.DeleteById(jobLogs[i].ID) + err := repository.JobLogRepository.DeleteById(context.TODO(), jobLogs[i].ID) if err != nil { log.Errorf("删除计划日志失败 %v", err) } diff --git a/server/term/ssh.go b/server/term/ssh.go index 70a0f3a..f62353d 100644 --- a/server/term/ssh.go +++ b/server/term/ssh.go @@ -2,11 +2,11 @@ package term import ( "fmt" - "golang.org/x/net/proxy" "net" "time" "golang.org/x/crypto/ssh" + "golang.org/x/net/proxy" ) func NewSshClient(ip string, port int, username, password, privateKey, passphrase string) (*ssh.Client, error) { diff --git a/server/utils/guacamole.go b/server/utils/guacamole.go new file mode 100644 index 0000000..3173ce9 --- /dev/null +++ b/server/utils/guacamole.go @@ -0,0 +1,17 @@ +package utils + +import ( + "encoding/base64" + "github.com/gorilla/websocket" + "next-terminal/server/guacd" + "strconv" +) + +func Disconnect(ws *websocket.Conn, code int, reason string) { + // guacd 无法处理中文字符,所以进行了base64编码。 + encodeReason := base64.StdEncoding.EncodeToString([]byte(reason)) + err := guacd.NewInstruction("error", encodeReason, strconv.Itoa(code)) + _ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) + disconnect := guacd.NewInstruction("disconnect") + _ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) +} diff --git a/server/utils/jsontime.go b/server/utils/jsontime.go index 670088b..495a914 100644 --- a/server/utils/jsontime.go +++ b/server/utils/jsontime.go @@ -23,7 +23,7 @@ func NowJsonTime() JsonTime { } } -func (j *JsonTime) MarshalJSON() ([]byte, error) { +func (j JsonTime) MarshalJSON() ([]byte, error) { var stamp = fmt.Sprintf("\"%s\"", j.Format("2006-01-02 15:04:05")) return []byte(stamp), nil } diff --git a/server/utils/util_test.go b/server/utils/util_test.go index 8a84801..6b5a406 100644 --- a/server/utils/util_test.go +++ b/server/utils/util_test.go @@ -86,3 +86,11 @@ func TestGetAvailablePort(t *testing.T) { } println(port) } + +func TestAesEncryptCBC2(t *testing.T) { + origData := []byte("{\"id\":\"xxxx\",\"opcode\":0,\"code\":0,\"message\":\"\",\"data\":\"\"}") // 待加密的数据 + key, _ := base64.StdEncoding.DecodeString("aLSlrPelViToZvNF1T45PQ==") + encryptedCBC, err := utils.AesEncryptCBC(origData, key) + assert.NoError(t, err) + assert.Equal(t, "3Tbnz0MYHQNTsN2L6QDGCJumbNFsQcmErrRz/KglYI/IDh88lsyOhVi7mgaAs/bjevvJa2F1JT7jUMLsz9/cpw==", base64.StdEncoding.EncodeToString(encryptedCBC)) +} diff --git a/server/utils/utils.go b/server/utils/utils.go index b818cf1..601c778 100644 --- a/server/utils/utils.go +++ b/server/utils/utils.go @@ -2,17 +2,12 @@ package utils import ( "bytes" - "crypto" "crypto/aes" "crypto/cipher" "crypto/md5" "crypto/rand" - "crypto/rsa" "crypto/sha256" - "crypto/sha512" - "crypto/x509" "encoding/base64" - "encoding/pem" "errors" "fmt" "image" @@ -43,6 +38,11 @@ func UUID() string { return v4.String() } +func LongUUID() string { + longUUID := strings.Join([]string{UUID(), UUID(), UUID(), UUID()}, "") + return strings.ReplaceAll(longUUID, "-", "") +} + func Tcping(ip string, port int) (bool, error) { var ( conn net.Conn @@ -123,7 +123,7 @@ func Distinct(a []string) []string { return result } -// 排序+拼接+摘要 +// Sign 排序+拼接+摘要 func Sign(a []string) string { sort.Strings(a) data := []byte(strings.Join(a, "")) @@ -131,6 +131,11 @@ func Sign(a []string) string { return fmt.Sprintf("%x", has) } +func Md5(s string) string { + has := md5.Sum([]byte(s)) + return fmt.Sprintf("%x", has) +} + func Contains(s []string, str string) bool { for _, v := range s { if v == str { @@ -341,44 +346,6 @@ func Utf8ToGbk(s []byte) ([]byte, error) { return d, nil } -// SignatureRSA rsa私钥签名 -func SignatureRSA(plainText []byte, rsaPrivateKey string) (signed []byte, err error) { - // 使用pem对读取的内容解码得到block - block, _ := pem.Decode([]byte(rsaPrivateKey)) - //x509将数据解析得到私钥结构体 - privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, err - } - // 创建一个hash对象 - h := sha512.New() - _, _ = h.Write(plainText) - // 计算hash值 - hashText := h.Sum(nil) - // 使用rsa函数对散列值签名 - signed, err = rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA512, hashText) - if err != nil { - return - } - return signed, nil -} - -// VerifyRSA rsa签名认证 -func VerifyRSA(plainText, signText []byte, rsaPublicKey string) bool { - // pem解码得到block - block, _ := pem.Decode([]byte(rsaPublicKey)) - // x509解析得到接口 - publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) - if err != nil { - return false - } - // 对原始明文进行hash运算得到散列值 - hashText := sha512.Sum512(plainText) - // 签名认证 - err = rsa.VerifyPKCS1v15(publicKey, crypto.SHA512, hashText[:], signText) - return err == nil -} - // GetAvailablePort 获取可用端口 func GetAvailablePort() (int, error) { addr, err := net.ResolveTCPAddr("tcp", "localhost:0") diff --git a/web/package-lock.json b/web/package-lock.json index 5958a2c..3454f7f 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "next-terminal", - "version": "1.2.0", + "version": "2.0.5", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "next-terminal", - "version": "1.2.0", + "version": "2.0.5", "dependencies": { "@ant-design/charts": "^1.2.13", "@ant-design/icons": "^4.6.4", @@ -23,6 +23,7 @@ "react-router": "^5.2.0", "react-router-dom": "^5.2.0", "react-scripts": "^4.0.0", + "react-tsparticles": "^1.37.5", "xterm": "^4.9.0", "xterm-addon-fit": "^0.4.0", "xterm-addon-web-links": "^0.4.0" @@ -16838,8 +16839,8 @@ }, "node_modules/react": { "version": "16.14.0", - "resolved": "https://registry.npmmirror.com/react/download/react-16.14.0.tgz", - "integrity": "sha1-lNd23dCqo32j7aj8W2sYpMmjEU0=", + "resolved": "https://registry.npmjs.org/react/-/react-16.14.0.tgz", + "integrity": "sha512-0X2CImDkJGApiAlcf0ODKIneSwBPhqJawOa5wCtKbu7ZECrmS26NvtSILynQ66cgkT/RJ4LidJOc3bUESwmU8g==", "dependencies": { "loose-envify": "^1.1.0", "object-assign": "^4.1.1", @@ -17203,6 +17204,22 @@ "semver": "bin/semver" } }, + "node_modules/react-tsparticles": { + "version": "1.37.5", + "resolved": "https://registry.npmjs.org/react-tsparticles/-/react-tsparticles-1.37.5.tgz", + "integrity": "sha512-Vl+rg3C+vsrek675x7OHBbLLZLJ0dZZjV3OhjT9NiKhdK+NcxNDB2CL3INsSZIv99sln5DRgsxdL3GYrZtqsqg==", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "tsparticles": "^1.37.5" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/matteobruni" + }, + "peerDependencies": { + "react": ">=16" + } + }, "node_modules/read-pkg": { "version": "5.2.0", "resolved": "https://registry.nlark.com/read-pkg/download/read-pkg-5.2.0.tgz?cache=0&sync_timestamp=1628984780649&other_urls=https%3A%2F%2Fregistry.nlark.com%2Fread-pkg%2Fdownload%2Fread-pkg-5.2.0.tgz", @@ -20044,6 +20061,16 @@ "resolved": "https://registry.nlark.com/tslib/download/tslib-2.3.1.tgz", "integrity": "sha1-6KM1rdXOrlGqJh0ypJAVjvBC7wE=" }, + "node_modules/tsparticles": { + "version": "1.37.5", + "resolved": "https://registry.npmjs.org/tsparticles/-/tsparticles-1.37.5.tgz", + "integrity": "sha512-BQBRnnKKhKH2POwxuHzPiuL/zPUjuR5QmMCFQ2vm3fadE2k4vqmH099m1R9R+nt0lTvWL6SCS0hYQQkqUjsMXg==", + "hasInstallScript": true, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/matteobruni" + } + }, "node_modules/tsutils": { "version": "3.21.0", "resolved": "https://registry.npm.taobao.org/tsutils/download/tsutils-3.21.0.tgz?cache=0&sync_timestamp=1615138426726&other_urls=https%3A%2F%2Fregistry.npm.taobao.org%2Ftsutils%2Fdownload%2Ftsutils-3.21.0.tgz", @@ -35787,8 +35814,8 @@ }, "react": { "version": "16.14.0", - "resolved": "https://registry.npmmirror.com/react/download/react-16.14.0.tgz", - "integrity": "sha1-lNd23dCqo32j7aj8W2sYpMmjEU0=", + "resolved": "https://registry.npmjs.org/react/-/react-16.14.0.tgz", + "integrity": "sha512-0X2CImDkJGApiAlcf0ODKIneSwBPhqJawOa5wCtKbu7ZECrmS26NvtSILynQ66cgkT/RJ4LidJOc3bUESwmU8g==", "requires": { "loose-envify": "^1.1.0", "object-assign": "^4.1.1", @@ -36085,6 +36112,15 @@ } } }, + "react-tsparticles": { + "version": "1.37.5", + "resolved": "https://registry.npmjs.org/react-tsparticles/-/react-tsparticles-1.37.5.tgz", + "integrity": "sha512-Vl+rg3C+vsrek675x7OHBbLLZLJ0dZZjV3OhjT9NiKhdK+NcxNDB2CL3INsSZIv99sln5DRgsxdL3GYrZtqsqg==", + "requires": { + "fast-deep-equal": "^3.1.3", + "tsparticles": "^1.37.5" + } + }, "read-pkg": { "version": "5.2.0", "resolved": "https://registry.nlark.com/read-pkg/download/read-pkg-5.2.0.tgz?cache=0&sync_timestamp=1628984780649&other_urls=https%3A%2F%2Fregistry.nlark.com%2Fread-pkg%2Fdownload%2Fread-pkg-5.2.0.tgz", @@ -38425,6 +38461,11 @@ "resolved": "https://registry.nlark.com/tslib/download/tslib-2.3.1.tgz", "integrity": "sha1-6KM1rdXOrlGqJh0ypJAVjvBC7wE=" }, + "tsparticles": { + "version": "1.37.5", + "resolved": "https://registry.npmjs.org/tsparticles/-/tsparticles-1.37.5.tgz", + "integrity": "sha512-BQBRnnKKhKH2POwxuHzPiuL/zPUjuR5QmMCFQ2vm3fadE2k4vqmH099m1R9R+nt0lTvWL6SCS0hYQQkqUjsMXg==" + }, "tsutils": { "version": "3.21.0", "resolved": "https://registry.npm.taobao.org/tsutils/download/tsutils-3.21.0.tgz?cache=0&sync_timestamp=1615138426726&other_urls=https%3A%2F%2Fregistry.npm.taobao.org%2Ftsutils%2Fdownload%2Ftsutils-3.21.0.tgz", diff --git a/web/package.json b/web/package.json index 0ccfa34..f3b9209 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "next-terminal", - "version": "1.2.2", + "version": "1.2.3", "private": true, "dependencies": { "@ant-design/charts": "^1.2.13", @@ -18,6 +18,7 @@ "react-router": "^5.2.0", "react-router-dom": "^5.2.0", "react-scripts": "^4.0.0", + "react-tsparticles": "^1.37.5", "xterm": "^4.9.0", "xterm-addon-fit": "^0.4.0", "xterm-addon-web-links": "^0.4.0" diff --git a/web/public/favicon.ico b/web/public/favicon.ico index f07d4ad..ec65bde 100644 Binary files a/web/public/favicon.ico and b/web/public/favicon.ico differ diff --git a/web/public/static/js/asciinema-player.js b/web/public/static/js/asciinema-player.js index 5ad47e0..68a4bfc 100644 --- a/web/public/static/js/asciinema-player.js +++ b/web/public/static/js/asciinema-player.js @@ -1000,7 +1000,7 @@ g.dispatchEvent=function(a){var b,c=this.ve;if(c)for(b=[];c;c=c.ve)b.push(c);c=t g.nd=function(){qw.Zd.nd.call(this);if(this.Ib){var a=this.Ib,b=0,c;for(c in a.rb){for(var d=a.rb[c],e=0;e { - this.setState({ - collapsed: !this.state.collapsed, - }); + let collapsed = !this.state.collapsed; + if (collapsed) { + this.setState({ + logo: Logo, + logoWidth: 46, + collapsed: collapsed, + }); + } else { + this.setState({ + logo: LogoWithName, + logoWidth: 140, + collapsed: collapsed, + }); + } }; componentDidMount() { @@ -94,7 +110,7 @@ class App extends Component { async getInfo() { - let result = await request.get('/info'); + let result = await request.get('/account/info'); if (result['code'] === 1) { sessionStorage.setItem('user', JSON.stringify(result['data'])); this.setState({ @@ -126,8 +142,8 @@ class App extends Component { sessionStorage.setItem('openKeys', JSON.stringify(openKeys)); } - confirm = async (e) => { - let result = await request.post('/logout'); + confirm = async () => { + let result = await request.post('/account/logout'); if (result['code'] !== 1) { message.error(result['message']); } else { @@ -180,14 +196,7 @@ class App extends Component { <>
- logo - { - !this.state.collapsed ? - <> 

Next Terminal

: - null - } + logo
- - + +
@@ -350,7 +361,8 @@ class App extends Component {
- Next Terminal ©2021 dushixiang Version:{this.state.package['version']} + Copyright © 2020-2022 dushixiang, All Rights Reserved. + Version:{this.state.package['version']}
: @@ -359,11 +371,7 @@ class App extends Component {
- logo - Next Terminal + logo @@ -400,7 +408,8 @@ class App extends Component {
- Next Terminal ©2021 dushixiang Version:{this.state.package['version']} + Copyright © 2020-2022 dushixiang, All Rights Reserved. + Version:{this.state.package['version']}
} diff --git a/web/src/components/Login.css b/web/src/components/Login.css index 6c400c0..e39f8e5 100644 --- a/web/src/components/Login.css +++ b/web/src/components/Login.css @@ -17,8 +17,4 @@ top: 50%; margin-left: -175px; margin-top: -189px; -} - -.login-bg{ - background-image: url(""); } \ No newline at end of file diff --git a/web/src/components/Login.js b/web/src/components/Login.js index 04735c8..ceddcee 100644 --- a/web/src/components/Login.js +++ b/web/src/components/Login.js @@ -5,6 +5,9 @@ import request from "../common/request"; import {message} from "antd/es"; import {withRouter} from "react-router-dom"; import {LockOutlined, OneToOneOutlined, UserOutlined} from '@ant-design/icons'; +import Particles from "react-tsparticles"; +import Background from '../images/bg.png' +import {setToken} from "../utils/utils"; const {Title} = Typography; @@ -56,7 +59,7 @@ class LoginForm extends Component { // 跳转登录 sessionStorage.removeItem('current'); sessionStorage.removeItem('openKeys'); - localStorage.setItem('X-Auth-Token', result['data']); + setToken(result['data']); // this.props.history.push(); window.location.href = "/" } catch (e) { @@ -85,7 +88,7 @@ class LoginForm extends Component { // 跳转登录 sessionStorage.removeItem('current'); sessionStorage.removeItem('openKeys'); - localStorage.setItem('X-Auth-Token', result['data']); + setToken(result['data']); // this.props.history.push(); window.location.href = "/" } catch (e) { @@ -106,7 +109,90 @@ class LoginForm extends Component { render() { return (
+ style={{width: this.state.width, height: this.state.height}}> +
Next Terminal @@ -140,9 +226,6 @@ class LoginForm extends Component { .then(values => { this.handleOk(values); // this.formRef.current.resetFields(); - }) - .catch(info => { - }); }} onCancel={this.handleCancel}> diff --git a/web/src/components/access/Access.css b/web/src/components/access/Access.css index 52f72b2..346fcd7 100644 --- a/web/src/components/access/Access.css +++ b/web/src/components/access/Access.css @@ -1,3 +1,3 @@ -.container div { +.container > div { margin: 0 auto; } \ No newline at end of file diff --git a/web/src/components/access/Access.js b/web/src/components/access/Access.js index 1d57418..08da867 100644 --- a/web/src/components/access/Access.js +++ b/web/src/components/access/Access.js @@ -13,7 +13,7 @@ import { LineChartOutlined, WindowsOutlined } from '@ant-design/icons'; -import {exitFull, getToken, isEmpty, requestFullScreen} from "../../utils/utils"; +import {exitFull, getToken, isEmpty, requestFullScreen, setToken} from "../../utils/utils"; import './Access.css' import Draggable from 'react-draggable'; import FileSystem from "../devops/FileSystem"; @@ -38,13 +38,14 @@ class Access extends Component { state = { session: {}, sessionId: '', - client: {}, + client: undefined, + scale: 1, clientState: STATE_IDLE, clipboardVisible: false, clipboardText: '', containerOverflow: 'hidden', - containerWidth: 0, - containerHeight: 0, + containerWidth: 1024, + containerHeight: 768, uploadAction: '', uploadHeaders: {}, keyboard: {}, @@ -57,15 +58,39 @@ class Access extends Component { fullScreenBtnText: '进入全屏', sink: undefined, commands: [], - showFileSystem: false + showFileSystem: false, + external: false, + fixedSize: false, }; async componentDidMount() { - let urlParams = new URLSearchParams(this.props.location.search); let assetId = urlParams.get('assetId'); document.title = urlParams.get('assetName'); let protocol = urlParams.get('protocol'); + let width = urlParams.get('width'); + let height = urlParams.get('height'); + let fixedSize = false; + + if (width && height) { + fixedSize = true + } else { + width = window.innerWidth; + height = window.innerHeight; + } + + let shareSessionId = urlParams.get('shareSessionId'); + let external = false; + if (shareSessionId && shareSessionId !== '') { + setToken(shareSessionId); + external = true; + let shareSession = await this.getShareSession(shareSessionId); + if (!shareSession) { + return + } + assetId = shareSession['assetId']; + } + let session = await this.createSession(assetId); if (!session) { return; @@ -79,10 +104,14 @@ class Access extends Component { session: session, sessionId: sessionId, protocol: protocol, - showFileSystem: session['fileSystem'] === '1' + showFileSystem: session['fileSystem'] === '1', + external: external, + fixedSize: fixedSize, + containerWidth: width, + containerHeight: height, }); - this.renderDisplay(sessionId, protocol); + this.renderDisplay(sessionId, protocol, width, height); window.addEventListener('resize', this.onWindowResize); window.onfocus = this.onWindowFocus; @@ -95,6 +124,10 @@ class Access extends Component { } sendClipboard(data) { + if (this.state.session['paste'] === '0') { + message.warn('禁止粘贴'); + return + } let writer; // Create stream with proper mimetype @@ -133,6 +166,7 @@ class Access extends Component { } onTunnelStateChange = (state) => { + console.log(state) if (state === Guacamole.Tunnel.State.CLOSED) { console.log('web socket 已关闭'); } @@ -175,12 +209,13 @@ class Access extends Component { break; case STATE_CONNECTED: this.onWindowResize(null); + Modal.destroyAll(); message.destroy(); message.success('连接成功'); // 向后台发送请求,更新会话的状态 this.updateSessionStatus(this.state.sessionId).then(_ => { }) - if (this.state.protocol === 'ssh') { + if (this.state.protocol === 'ssh' && !this.state.external) { // 加载指令 this.getCommands(); } @@ -300,9 +335,11 @@ class Access extends Component { } clientClipboardReceived = (stream, mimetype) => { - console.log('clientClipboardReceived', mimetype) + if (this.state.session['copy'] === '0') { + message.warn('禁止复制'); + return + } let reader; - // If the received data is text, read it as a simple string if (/^text\//.exec(mimetype)) { reader = new Guacamole.StringReader(stream); @@ -378,9 +415,18 @@ class Access extends Component { return result['data']; } - async renderDisplay(sessionId, protocol) { + async getShareSession(shareSessionId) { + let result = await request.get(`/share-sessions/${shareSessionId}`); + if (result['code'] !== 1) { + this.showMessage(result['message']); + return undefined; + } + return result['data']; + } - let tunnel = new Guacamole.WebSocketTunnel(wsServer + '/tunnel'); + async renderDisplay(sessionId, protocol, width, height) { + + let tunnel = new Guacamole.WebSocketTunnel(`${wsServer}/sessions/${sessionId}/tunnel`); tunnel.onstatechange = this.onTunnelStateChange; // Get new client instance @@ -404,17 +450,16 @@ class Access extends Component { const element = client.getDisplay().getElement(); display.appendChild(element); - let width = window.innerWidth; - let height = window.innerHeight; + let scale = 1; let dpi = 96; if (protocol === 'ssh' || protocol === 'telnet') { dpi = dpi * 2; + scale = 0.5; } let token = getToken(); let params = { - 'sessionId': sessionId, 'width': width, 'height': height, 'dpi': dpi, @@ -439,13 +484,9 @@ class Access extends Component { }; mouse.onmousemove = function (mouseState) { - if (protocol === 'ssh' || protocol === 'telnet') { - mouseState.x = mouseState.x * 2; - mouseState.y = mouseState.y * 2; - client.sendMouseState(mouseState); - } else { - client.sendMouseState(mouseState); - } + mouseState.x = mouseState.x / scale; + mouseState.y = mouseState.y / scale; + client.sendMouseState(mouseState); }; const sink = new Guacamole.InputSink(); @@ -460,8 +501,7 @@ class Access extends Component { this.setState({ client: client, - containerWidth: width, - containerHeight: height, + scale: scale, keyboard: keyboard, sink: sink }); @@ -469,19 +509,14 @@ class Access extends Component { onWindowResize = (e) => { - if (this.state.client) { + if (this.state.client && !this.state.fixedSize) { const display = this.state.client.getDisplay(); + let scale = this.state.scale; + display.scale(scale); + let width = window.innerWidth; + let height = window.innerHeight; - const width = window.innerWidth; - const height = window.innerHeight; - - if (this.state.protocol === 'ssh' || this.state.protocol === 'telnet') { - let r = 2; - display.scale(1 / r); - this.state.client.sendSize(width * r, height * r); - } else { - this.state.client.sendSize(width, height); - } + this.state.client.sendSize(width / scale, height / scale); this.setState({ containerWidth: width, @@ -502,47 +537,17 @@ class Access extends Component { onWindowFocus = (e) => { if (navigator.clipboard && this.state.clientState === STATE_CONNECTED) { - navigator.clipboard.readText().then((text) => { - this.sendClipboard({ - 'data': text, - 'type': 'text/plain' - }); - }) - } - }; - - onPaste = (e) => { - const cbd = e.clipboardData; - const ua = window.navigator.userAgent; - - // 如果是 Safari 直接 return - if (!(e.clipboardData && e.clipboardData.items)) { - return; - } - - // Mac平台下Chrome49版本以下 复制Finder中的文件的Bug Hack掉 - if (cbd.items && cbd.items.length === 2 && cbd.items[0].kind === "string" && cbd.items[1].kind === "file" && - cbd.types && cbd.types.length === 2 && cbd.types[0] === "text/plain" && cbd.types[1] === "Files" && - ua.match(/Macintosh/i) && Number(ua.match(/Chrome\/(\d{2})/i)[1]) < 49) { - return; - } - - for (let i = 0; i < cbd.items.length; i++) { - let item = cbd.items[i]; - if (item.kind === "file") { - let blob = item.getAsFile(); - if (blob.size === 0) { - return; - } - // blob 就是从剪切板获得的文件 可以进行上传或其他操作 - } else if (item.kind === 'string') { - item.getAsString((str) => { + try { + navigator.clipboard.readText().then((text) => { this.sendClipboard({ - 'data': str, + 'data': text, 'type': 'text/plain' }); }) + } catch (e) { + // console.error(e); } + } }; @@ -616,7 +621,8 @@ class Access extends Component {
@@ -630,16 +636,20 @@ class Access extends Component { - - - - - + + + 修改密码 +
+ + + + + + this.onNewPasswordChange(value)} style={{width: 240}}/> + + + this.onNewPassword2Change(value)} style={{width: 240}}/> + + + + +
+ + + + + 授权信息 + + + {this.state.accessToken.token} + + + {this.state.accessToken.created} + + + + + + + +
- 双因素认证
{ - return {'key': item['id'], ...item} - }) this.setState({ - items: items, + items: data.items, total: data.total, queryParams: queryParams, loading: false @@ -95,8 +92,7 @@ class UserGroup extends Component { queryParams: queryParams }); - this.loadTableData(queryParams).then(r => { - }) + this.loadTableData(queryParams); }; showDeleteConfirm(id, content) { @@ -139,7 +135,6 @@ class UserGroup extends Component { } await this.handleSearchByNickname(''); - console.log(model) this.setState({ model: model, modalVisible: true, @@ -147,7 +142,7 @@ class UserGroup extends Component { }); }; - handleCancelModal = e => { + handleCancelModal = () => { this.setState({ modalVisible: false, modalTitle: '', @@ -161,37 +156,43 @@ class UserGroup extends Component { modalConfirmLoading: true }); - if (formData.id) { - // 向后台提交数据 - const result = await request.put('/user-groups/' + formData.id, formData); - if (result.code === 1) { - message.success('操作成功', 3); + try { + if (formData.id) { + // 向后台提交数据 + const result = await request.put('/user-groups/' + formData.id, formData); + if (result.code === 1) { + message.success('操作成功', 3); - this.setState({ - modalVisible: false - }); - await this.loadTableData(this.state.queryParams); + this.setState({ + modalVisible: false + }); + await this.loadTableData(this.state.queryParams); + return true; + } else { + message.error(result.message, 10); + return false; + } } else { - message.error(result.message, 10); - } - } else { - // 向后台提交数据 - const result = await request.post('/user-groups', formData); - if (result.code === 1) { - message.success('操作成功', 3); + // 向后台提交数据 + const result = await request.post('/user-groups', formData); + if (result.code === 1) { + message.success('操作成功', 3); - this.setState({ - modalVisible: false - }); - await this.loadTableData(this.state.queryParams); - } else { - message.error(result.message, 10); + this.setState({ + modalVisible: false + }); + await this.loadTableData(this.state.queryParams); + return true; + } else { + message.error(result.message, 10); + return false; + } } + } finally { + this.setState({ + modalConfirmLoading: false + }); } - - this.setState({ - modalConfirmLoading: false - }); }; handleSearchByName = name => { @@ -280,7 +281,7 @@ class UserGroup extends Component { title: '授权资产', dataIndex: 'assetCount', key: 'assetCount', - render: (text, record, index) => { + render: (text, record) => { return
{ form .validateFields() - .then(values => { - form.resetFields(); - handleOk(values); - }) - .catch(info => { + .then(async values => { + let ok = await handleOk(values); + if (ok) { + form.resetFields(); + } }); }} onCancel={handleCancel} diff --git a/web/src/images/bg.png b/web/src/images/bg.png new file mode 100644 index 0000000..718cb36 Binary files /dev/null and b/web/src/images/bg.png differ diff --git a/web/src/images/logo-with-name.svg b/web/src/images/logo-with-name.svg new file mode 100644 index 0000000..a904f40 --- /dev/null +++ b/web/src/images/logo-with-name.svg @@ -0,0 +1,14 @@ + + + + + + + + + + NEXT + TERMINAL + + + \ No newline at end of file diff --git a/web/src/images/logo.svg b/web/src/images/logo.svg new file mode 100644 index 0000000..323520e --- /dev/null +++ b/web/src/images/logo.svg @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/web/src/utils/utils.js b/web/src/utils/utils.js index c7e7d02..a5be35b 100644 --- a/web/src/utils/utils.js +++ b/web/src/utils/utils.js @@ -5,6 +5,10 @@ export const sleep = function (ms) { return new Promise(resolve => setTimeout(resolve, ms)) } +export const setToken = function (token) { + localStorage.setItem('X-Auth-Token', token); +} + export const getToken = function () { return localStorage.getItem('X-Auth-Token'); }