修复 「1.2.2 用户管理-用户列表勾选单一用户会全选 」 close #216

This commit is contained in:
dushixiang
2022-01-23 17:53:22 +08:00
parent 29c066ca3a
commit d35b348a33
130 changed files with 5467 additions and 4554 deletions

View File

@ -1,75 +0,0 @@
package service
import (
"next-terminal/server/global/gateway"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/term"
)
type AccessGatewayService struct {
accessGatewayRepository *repository.AccessGatewayRepository
}
func NewAccessGatewayService(accessGatewayRepository *repository.AccessGatewayRepository) *AccessGatewayService {
accessGatewayService = &AccessGatewayService{accessGatewayRepository: accessGatewayRepository}
return accessGatewayService
}
func (r AccessGatewayService) GetGatewayAndReconnectById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
if g == nil || !g.Connected {
accessGateway, err := r.accessGatewayRepository.FindById(accessGatewayId)
if err != nil {
return nil, err
}
g = r.ReConnect(&accessGateway)
}
return g, nil
}
func (r AccessGatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
if g == nil {
accessGateway, err := r.accessGatewayRepository.FindById(accessGatewayId)
if err != nil {
return nil, err
}
g = r.ReConnect(&accessGateway)
}
return g, nil
}
func (r AccessGatewayService) ReConnectAll() error {
gateways, err := r.accessGatewayRepository.FindAll()
if err != nil {
return err
}
if len(gateways) > 0 {
for i := range gateways {
r.ReConnect(&gateways[i])
}
}
return nil
}
func (r AccessGatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway {
log.Debugf("重建接入网关「%v」中...", m.Name)
r.DisconnectById(m.ID)
sshClient, err := term.NewSshClient(m.IP, m.Port, m.Username, m.Password, m.PrivateKey, m.Passphrase)
var g *gateway.Gateway
if err != nil {
g = gateway.NewGateway(m.ID, m.Localhost, false, err.Error(), nil)
} else {
g = gateway.NewGateway(m.ID, m.Localhost, true, "", sshClient)
}
gateway.GlobalGatewayManager.Add <- g
log.Debugf("重建接入网关「%v」完成", m.Name)
return g
}
func (r AccessGatewayService) DisconnectById(accessGatewayId string) {
gateway.GlobalGatewayManager.Del <- accessGatewayId
}

View File

@ -0,0 +1,86 @@
package service
import (
"context"
"errors"
"next-terminal/server/constant"
"next-terminal/server/dto"
"next-terminal/server/env"
"next-terminal/server/global/cache"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
"gorm.io/gorm"
)
type accessTokenService struct {
baseService
}
func (service accessTokenService) FindByUserId(userId string) (model.AccessToken, error) {
return repository.AccessTokenRepository.FindByUserId(context.TODO(), userId)
}
func (service accessTokenService) GenAccessToken(userId string) error {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
ctx := service.Context(tx)
user, err := repository.UserRepository.FindById(ctx, userId)
if err != nil {
return err
}
oldAccessToken, err := repository.AccessTokenRepository.FindByUserId(ctx, userId)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if oldAccessToken.Token != "" {
cache.TokenManager.Delete(oldAccessToken.Token)
}
if err := repository.AccessTokenRepository.DeleteByUserId(ctx, userId); err != nil {
return err
}
token := "forever-" + utils.UUID()
accessToken := &model.AccessToken{
ID: utils.UUID(),
UserId: userId,
Token: token,
Created: utils.NowJsonTime(),
}
authorization := dto.Authorization{
Token: token,
Remember: false,
Type: constant.AccessToken,
User: &user,
}
cache.TokenManager.Set(token, authorization, cache.NoExpiration)
return repository.AccessTokenRepository.Create(ctx, accessToken)
})
}
func (service accessTokenService) Reload() error {
accessTokens, err := repository.AccessTokenRepository.FindAll(context.TODO())
if err != nil {
return err
}
for _, accessToken := range accessTokens {
user, err := repository.UserRepository.FindById(context.TODO(), accessToken.UserId)
if err != nil {
return err
}
authorization := dto.Authorization{
Token: accessToken.Token,
Remember: false,
Type: constant.AccessToken,
User: &user,
}
cache.TokenManager.Set(accessToken.Token, authorization, cache.NoExpiration)
}
return nil
}

View File

@ -1,21 +1,26 @@
package service
import (
"context"
"encoding/base64"
"encoding/json"
"next-terminal/server/config"
"next-terminal/server/env"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
type AssetService struct {
assetRepository *repository.AssetRepository
type assetService struct {
baseService
}
func NewAssetService(assetRepository *repository.AssetRepository) *AssetService {
return &AssetService{assetRepository: assetRepository}
}
func (r AssetService) Encrypt() error {
items, err := r.assetRepository.FindAll()
func (s assetService) EncryptAll() error {
items, err := repository.AssetRepository.FindAll(context.TODO())
if err != nil {
return err
}
@ -24,19 +29,95 @@ func (r AssetService) Encrypt() error {
if item.Encrypted {
continue
}
if err := r.assetRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err := r.assetRepository.UpdateById(&item, item.ID); err != nil {
if err := repository.AssetRepository.UpdateById(context.TODO(), &item, item.ID); err != nil {
return err
}
}
return nil
}
func (r AssetService) CheckStatus(accessGatewayId string, ip string, port int) (active bool, err error) {
func (s assetService) Decrypt(item *model.Asset, password []byte) error {
if item.Encrypted {
if item.Password != "" && item.Password != "-" {
origData, err := base64.StdEncoding.DecodeString(item.Password)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, password)
if err != nil {
return err
}
item.Password = string(decryptedCBC)
}
if item.PrivateKey != "" && item.PrivateKey != "-" {
origData, err := base64.StdEncoding.DecodeString(item.PrivateKey)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, password)
if err != nil {
return err
}
item.PrivateKey = string(decryptedCBC)
}
if item.Passphrase != "" && item.Passphrase != "-" {
origData, err := base64.StdEncoding.DecodeString(item.Passphrase)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, password)
if err != nil {
return err
}
item.Passphrase = string(decryptedCBC)
}
}
return nil
}
func (s assetService) Encrypt(item *model.Asset, password []byte) error {
if item.Password != "" && item.Password != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), password)
if err != nil {
return err
}
item.Password = base64.StdEncoding.EncodeToString(encryptedCBC)
}
if item.PrivateKey != "" && item.PrivateKey != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), password)
if err != nil {
return err
}
item.PrivateKey = base64.StdEncoding.EncodeToString(encryptedCBC)
}
if item.Passphrase != "" && item.Passphrase != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), password)
if err != nil {
return err
}
item.Passphrase = base64.StdEncoding.EncodeToString(encryptedCBC)
}
item.Encrypted = true
return nil
}
func (s assetService) FindByIdAndDecrypt(c context.Context, id string) (model.Asset, error) {
asset, err := repository.AssetRepository.FindById(c, id)
if err != nil {
return model.Asset{}, err
}
if err := s.Decrypt(&asset, config.GlobalCfg.EncryptionPassword); err != nil {
return model.Asset{}, err
}
return asset, nil
}
func (s assetService) CheckStatus(accessGatewayId string, ip string, port int) (active bool, err error) {
if accessGatewayId != "" && accessGatewayId != "-" {
g, e1 := accessGatewayService.GetGatewayAndReconnectById(accessGatewayId)
g, e1 := GatewayService.GetGatewayAndReconnectById(accessGatewayId)
if err != nil {
return false, e1
}
@ -58,3 +139,118 @@ func (r AssetService) CheckStatus(accessGatewayId string, ip string, port int) (
}
return active, err
}
func (s assetService) Create(m echo.Map) (model.Asset, error) {
data, err := json.Marshal(m)
if err != nil {
return model.Asset{}, err
}
var item model.Asset
if err := json.Unmarshal(data, &item); err != nil {
return model.Asset{}, err
}
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
item.Active = true
return item, env.GetDB().Transaction(func(tx *gorm.DB) error {
c := s.Context(tx)
if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err := repository.AssetRepository.Create(c, &item); err != nil {
return err
}
if err := repository.AssetRepository.UpdateAttributes(c, item.ID, item.Protocol, m); err != nil {
return err
}
go func() {
active, _ := s.CheckStatus(item.AccessGatewayId, item.IP, item.Port)
if item.Active != active {
_ = repository.AssetRepository.UpdateActiveById(context.TODO(), active, item.ID)
}
}()
return nil
})
}
func (s assetService) DeleteById(id string) error {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := s.Context(tx)
// 删除资产
if err := repository.AssetRepository.DeleteById(c, id); err != nil {
return err
}
// 删除资产属性
if err := repository.AssetRepository.DeleteAttrByAssetId(c, id); err != nil {
return err
}
// 删除资产与用户的关系
if err := repository.ResourceSharerRepository.DeleteByResourceId(c, id); err != nil {
return err
}
return nil
})
}
func (s assetService) UpdateById(id string, m echo.Map) error {
data, err := json.Marshal(m)
if err != nil {
return err
}
var item model.Asset
if err := json.Unmarshal(data, &item); err != nil {
return err
}
switch item.AccountType {
case "credential":
item.Username = "-"
item.Password = "-"
item.PrivateKey = "-"
item.Passphrase = "-"
case "private-key":
item.Password = "-"
item.CredentialId = "-"
if len(item.Username) == 0 {
item.Username = "-"
}
if len(item.Passphrase) == 0 {
item.Passphrase = "-"
}
case "custom":
item.PrivateKey = "-"
item.Passphrase = "-"
item.CredentialId = "-"
}
if len(item.Tags) == 0 {
item.Tags = "-"
}
if item.Description == "" {
item.Description = "-"
}
if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := s.Context(tx)
if err := repository.AssetRepository.UpdateById(c, &item, id); err != nil {
return err
}
if err := repository.AssetRepository.UpdateAttributes(c, id, item.Protocol, m); err != nil {
return err
}
return nil
})
}

326
server/service/backup.go Normal file
View File

@ -0,0 +1,326 @@
package service
import (
"context"
"encoding/json"
"errors"
"strings"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/dto"
"next-terminal/server/env"
"next-terminal/server/global/security"
"next-terminal/server/repository"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
type backupService struct {
baseService
}
func (service backupService) Export() (error, *dto.Backup) {
ctx := context.TODO()
users, err := repository.UserRepository.FindAll(ctx)
if err != nil {
return err, nil
}
for i := range users {
users[i].Password = ""
}
userGroups, err := repository.UserGroupRepository.FindAll(ctx)
if err != nil {
return err, nil
}
if len(userGroups) > 0 {
for i := range userGroups {
members, err := repository.UserGroupMemberRepository.FindUserIdsByUserGroupId(ctx, userGroups[i].ID)
if err != nil {
return err, nil
}
userGroups[i].Members = members
}
}
storages, err := repository.StorageRepository.FindAll(ctx)
if err != nil {
return err, nil
}
strategies, err := repository.StrategyRepository.FindAll(ctx)
if err != nil {
return err, nil
}
jobs, err := repository.JobRepository.FindAll(ctx)
if err != nil {
return err, nil
}
accessSecurities, err := repository.SecurityRepository.FindAll(ctx)
if err != nil {
return err, nil
}
accessGateways, err := repository.GatewayRepository.FindAll(ctx)
if err != nil {
return err, nil
}
commands, err := repository.CommandRepository.FindAll(ctx)
if err != nil {
return err, nil
}
credentials, err := repository.CredentialRepository.FindAll(ctx)
if err != nil {
return err, nil
}
if len(credentials) > 0 {
for i := range credentials {
if err := CredentialService.Decrypt(&credentials[i], config.GlobalCfg.EncryptionPassword); err != nil {
return err, nil
}
}
}
assets, err := repository.AssetRepository.FindAll(ctx)
if err != nil {
return err, nil
}
var assetMaps = make([]map[string]interface{}, 0)
if len(assets) > 0 {
for i := range assets {
asset := assets[i]
if err := AssetService.Decrypt(&asset, config.GlobalCfg.EncryptionPassword); err != nil {
return err, nil
}
attributeMap, err := repository.AssetRepository.FindAssetAttrMapByAssetId(ctx, asset.ID)
if err != nil {
return err, nil
}
itemMap := utils.StructToMap(asset)
for key := range attributeMap {
itemMap[key] = attributeMap[key]
}
itemMap["created"] = asset.Created.Format("2006-01-02 15:04:05")
assetMaps = append(assetMaps, itemMap)
}
}
resourceSharers, err := repository.ResourceSharerRepository.FindAll(ctx)
if err != nil {
return err, nil
}
backup := dto.Backup{
Users: users,
UserGroups: userGroups,
Storages: storages,
Strategies: strategies,
Jobs: jobs,
AccessSecurities: accessSecurities,
AccessGateways: accessGateways,
Commands: commands,
Credentials: credentials,
Assets: assetMaps,
ResourceSharers: resourceSharers,
}
return nil, &backup
}
func (service backupService) Import(backup *dto.Backup) error {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := service.Context(tx)
var userIdMapping = make(map[string]string)
if len(backup.Users) > 0 {
for _, item := range backup.Users {
oldId := item.ID
if repository.UserRepository.ExistByUsername(c, item.Username) {
delete(userIdMapping, oldId)
continue
}
newId := utils.UUID()
item.ID = newId
item.Password = utils.GenPassword()
if err := repository.UserRepository.Create(c, &item); err != nil {
return err
}
userIdMapping[oldId] = newId
}
}
var userGroupIdMapping = make(map[string]string)
if len(backup.UserGroups) > 0 {
for _, item := range backup.UserGroups {
oldId := item.ID
var members = make([]string, 0)
if len(item.Members) > 0 {
for _, member := range item.Members {
members = append(members, userIdMapping[member])
}
}
userGroup, err := UserGroupService.Create(item.Name, members)
if err != nil {
if errors.Is(constant.ErrNameAlreadyUsed, err) {
// 删除名称重复的用户组
delete(userGroupIdMapping, oldId)
continue
} else {
return err
}
}
userGroupIdMapping[oldId] = userGroup.ID
}
}
if len(backup.Storages) > 0 {
for _, item := range backup.Storages {
owner := userIdMapping[item.Owner]
if owner == "" {
continue
}
item.ID = utils.UUID()
item.Owner = owner
item.Created = utils.NowJsonTime()
if err := repository.StorageRepository.Create(c, &item); err != nil {
return err
}
}
}
var strategyIdMapping = make(map[string]string)
if len(backup.Strategies) > 0 {
for _, item := range backup.Strategies {
oldId := item.ID
newId := utils.UUID()
item.ID = newId
item.Created = utils.NowJsonTime()
if err := repository.StrategyRepository.Create(c, &item); err != nil {
return err
}
strategyIdMapping[oldId] = newId
}
}
if len(backup.AccessSecurities) > 0 {
for _, item := range backup.AccessSecurities {
item.ID = utils.UUID()
if err := repository.SecurityRepository.Create(c, &item); err != nil {
return err
}
// 更新内存中的安全规则
rule := &security.Security{
ID: item.ID,
IP: item.IP,
Rule: item.Rule,
Priority: item.Priority,
}
security.GlobalSecurityManager.Add <- rule
}
}
var accessGatewayIdMapping = make(map[string]string, 0)
if len(backup.AccessGateways) > 0 {
for _, item := range backup.AccessGateways {
oldId := item.ID
newId := utils.UUID()
item.ID = newId
item.Created = utils.NowJsonTime()
if err := repository.GatewayRepository.Create(c, &item); err != nil {
return err
}
accessGatewayIdMapping[oldId] = newId
}
}
if len(backup.Commands) > 0 {
for _, item := range backup.Commands {
item.ID = utils.UUID()
item.Created = utils.NowJsonTime()
if err := repository.CommandRepository.Create(c, &item); err != nil {
return err
}
}
}
var credentialIdMapping = make(map[string]string, 0)
if len(backup.Credentials) > 0 {
for _, item := range backup.Credentials {
oldId := item.ID
newId := utils.UUID()
item.ID = newId
if err := CredentialService.Create(&item); err != nil {
return err
}
credentialIdMapping[oldId] = newId
}
}
var assetIdMapping = make(map[string]string, 0)
if len(backup.Assets) > 0 {
for _, m := range backup.Assets {
data, err := json.Marshal(m)
if err != nil {
return err
}
m := echo.Map{}
if err := json.Unmarshal(data, &m); err != nil {
return err
}
credentialId := m["credentialId"].(string)
accessGatewayId := m["accessGatewayId"].(string)
if credentialId != "" && credentialId != "-" {
m["credentialId"] = credentialIdMapping[credentialId]
}
if accessGatewayId != "" && accessGatewayId != "-" {
m["accessGatewayId"] = accessGatewayIdMapping[accessGatewayId]
}
oldId := m["id"].(string)
asset, err := AssetService.Create(m)
if err != nil {
return err
}
assetIdMapping[oldId] = asset.ID
}
}
if len(backup.ResourceSharers) > 0 {
for _, item := range backup.ResourceSharers {
userGroupId := userGroupIdMapping[item.UserGroupId]
userId := userIdMapping[item.UserId]
strategyId := strategyIdMapping[item.StrategyId]
resourceId := assetIdMapping[item.ResourceId]
if err := repository.ResourceSharerRepository.AddSharerResources(userGroupId, userId, strategyId, item.ResourceType, []string{resourceId}); err != nil {
return err
}
}
}
if len(backup.Jobs) > 0 {
for _, item := range backup.Jobs {
if item.Func == constant.FuncCheckAssetStatusJob {
continue
}
resourceIds := strings.Split(item.ResourceIds, ",")
if len(resourceIds) > 0 {
var newResourceIds = make([]string, 0)
for _, resourceId := range resourceIds {
newResourceIds = append(newResourceIds, assetIdMapping[resourceId])
}
item.ResourceIds = strings.Join(newResourceIds, ",")
}
if err := JobService.Create(&item); err != nil {
return err
}
}
}
return nil
})
}

16
server/service/base.go Normal file
View File

@ -0,0 +1,16 @@
package service
import (
"context"
"next-terminal/server/constant"
"gorm.io/gorm"
)
type baseService struct {
}
func (service baseService) Context(db *gorm.DB) context.Context {
return context.WithValue(context.TODO(), constant.DB, db)
}

View File

@ -1,20 +1,21 @@
package service
import (
"context"
"encoding/base64"
"next-terminal/server/model"
"next-terminal/server/utils"
"next-terminal/server/config"
"next-terminal/server/repository"
)
type CredentialService struct {
credentialRepository *repository.CredentialRepository
type credentialService struct {
}
func NewCredentialService(credentialRepository *repository.CredentialRepository) *CredentialService {
return &CredentialService{credentialRepository: credentialRepository}
}
func (r CredentialService) Encrypt() error {
items, err := r.credentialRepository.FindAll()
func (s credentialService) EncryptAll() error {
items, err := repository.CredentialRepository.FindAll(context.TODO())
if err != nil {
return err
}
@ -23,12 +24,96 @@ func (r CredentialService) Encrypt() error {
if item.Encrypted {
continue
}
if err := r.credentialRepository.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
if err := s.Encrypt(&item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
if err := r.credentialRepository.UpdateById(&item, item.ID); err != nil {
if err := repository.CredentialRepository.UpdateById(context.TODO(), &item, item.ID); err != nil {
return err
}
}
return nil
}
func (s credentialService) Encrypt(item *model.Credential, password []byte) error {
if item.Password != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Password), password)
if err != nil {
return err
}
item.Password = base64.StdEncoding.EncodeToString(encryptedCBC)
}
if item.PrivateKey != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.PrivateKey), password)
if err != nil {
return err
}
item.PrivateKey = base64.StdEncoding.EncodeToString(encryptedCBC)
}
if item.Passphrase != "-" {
encryptedCBC, err := utils.AesEncryptCBC([]byte(item.Passphrase), password)
if err != nil {
return err
}
item.Passphrase = base64.StdEncoding.EncodeToString(encryptedCBC)
}
item.Encrypted = true
return nil
}
func (s credentialService) Decrypt(item *model.Credential, password []byte) error {
if item.Encrypted {
if item.Password != "" && item.Password != "-" {
origData, err := base64.StdEncoding.DecodeString(item.Password)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, password)
if err != nil {
return err
}
item.Password = string(decryptedCBC)
}
if item.PrivateKey != "" && item.PrivateKey != "-" {
origData, err := base64.StdEncoding.DecodeString(item.PrivateKey)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, password)
if err != nil {
return err
}
item.PrivateKey = string(decryptedCBC)
}
if item.Passphrase != "" && item.Passphrase != "-" {
origData, err := base64.StdEncoding.DecodeString(item.Passphrase)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, password)
if err != nil {
return err
}
item.Passphrase = string(decryptedCBC)
}
}
return nil
}
func (s credentialService) FindByIdAndDecrypt(c context.Context, id string) (o model.Credential, err error) {
credential, err := repository.CredentialRepository.FindById(c, id)
if err != nil {
return o, err
}
if err := s.Decrypt(&credential, config.GlobalCfg.EncryptionPassword); err != nil {
return o, err
}
return credential, nil
}
func (s credentialService) Create(item *model.Credential) error {
// 加密密码之后进行存储
if err := s.Encrypt(item, config.GlobalCfg.EncryptionPassword); err != nil {
return err
}
return repository.CredentialRepository.Create(context.TODO(), item)
}

View File

@ -1,5 +0,0 @@
package service
var (
accessGatewayService *AccessGatewayService
)

70
server/service/gateway.go Normal file
View File

@ -0,0 +1,70 @@
package service
import (
"context"
"next-terminal/server/global/gateway"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/term"
)
type gatewayService struct{}
func (r gatewayService) GetGatewayAndReconnectById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
if g == nil || !g.Connected {
accessGateway, err := repository.GatewayRepository.FindById(context.TODO(), accessGatewayId)
if err != nil {
return nil, err
}
g = r.ReConnect(&accessGateway)
}
return g, nil
}
func (r gatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
if g == nil {
accessGateway, err := repository.GatewayRepository.FindById(context.TODO(), accessGatewayId)
if err != nil {
return nil, err
}
g = r.ReConnect(&accessGateway)
}
return g, nil
}
func (r gatewayService) ReConnectAll() error {
gateways, err := repository.GatewayRepository.FindAll(context.TODO())
if err != nil {
return err
}
if len(gateways) > 0 {
for i := range gateways {
r.ReConnect(&gateways[i])
}
}
return nil
}
func (r gatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway {
log.Debugf("重建接入网关「%v」中...", m.Name)
r.DisconnectById(m.ID)
sshClient, err := term.NewSshClient(m.IP, m.Port, m.Username, m.Password, m.PrivateKey, m.Passphrase)
var g *gateway.Gateway
if err != nil {
g = gateway.NewGateway(m.ID, false, err.Error(), nil)
} else {
g = gateway.NewGateway(m.ID, true, "", sshClient)
}
gateway.GlobalGatewayManager.Add <- g
log.Debugf("重建接入网关「%v」完成", m.Name)
return g
}
func (r gatewayService) DisconnectById(accessGatewayId string) {
gateway.GlobalGatewayManager.Del <- accessGatewayId
}

View File

@ -1,42 +1,27 @@
package service
import (
"encoding/json"
"context"
"errors"
"fmt"
"strings"
"time"
"next-terminal/server/constant"
"next-terminal/server/global/cron"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/term"
"next-terminal/server/utils"
"gorm.io/gorm"
)
type JobService struct {
jobRepository *repository.JobRepository
jobLogRepository *repository.JobLogRepository
assetRepository *repository.AssetRepository
credentialRepository *repository.CredentialRepository
assetService *AssetService
type jobService struct {
}
func NewJobService(jobRepository *repository.JobRepository, jobLogRepository *repository.JobLogRepository, assetRepository *repository.AssetRepository, credentialRepository *repository.CredentialRepository, assetService *AssetService) *JobService {
return &JobService{jobRepository: jobRepository, jobLogRepository: jobLogRepository, assetRepository: assetRepository, credentialRepository: credentialRepository, assetService: assetService}
}
func (r JobService) ChangeStatusById(id, status string) error {
job, err := r.jobRepository.FindById(id)
func (r jobService) ChangeStatusById(id, status string) error {
job, err := repository.JobRepository.FindById(context.TODO(), id)
if err != nil {
return err
}
if status == constant.JobStatusRunning {
j, err := getJob(&job, &r)
j, err := getJob(&job)
if err != nil {
return err
}
@ -48,249 +33,38 @@ func (r JobService) ChangeStatusById(id, status string) error {
jobForUpdate := model.Job{ID: id, Status: constant.JobStatusRunning, CronJobId: int(entryID)}
return r.jobRepository.UpdateById(&jobForUpdate)
return repository.JobRepository.UpdateById(context.TODO(), &jobForUpdate)
} else {
cron.GlobalCron.Remove(cron.JobId(job.CronJobId))
log.Debugf("关闭计划任务「%v」,运行中计划任务数量「%v」", job.Name, len(cron.GlobalCron.Entries()))
jobForUpdate := model.Job{ID: id, Status: constant.JobStatusNotRunning}
return r.jobRepository.UpdateById(&jobForUpdate)
return repository.JobRepository.UpdateById(context.TODO(), &jobForUpdate)
}
}
func getJob(j *model.Job, jobService *JobService) (job cron.Job, err error) {
func getJob(j *model.Job) (job cron.Job, err error) {
switch j.Func {
case constant.FuncCheckAssetStatusJob:
job = CheckAssetStatusJob{
ID: j.ID,
Mode: j.Mode,
ResourceIds: j.ResourceIds,
Metadata: j.Metadata,
jobService: jobService,
assetService: jobService.assetService,
ID: j.ID,
Mode: j.Mode,
ResourceIds: j.ResourceIds,
Metadata: j.Metadata,
}
case constant.FuncShellJob:
job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata, jobService: jobService}
job = ShellJob{ID: j.ID, Mode: j.Mode, ResourceIds: j.ResourceIds, Metadata: j.Metadata}
default:
return nil, errors.New("未识别的任务")
}
return job, err
}
type CheckAssetStatusJob struct {
ID string
Mode string
ResourceIds string
Metadata string
jobService *JobService
assetService *AssetService
}
func (r CheckAssetStatusJob) Run() {
if r.ID == "" {
return
}
var assets []model.Asset
if r.Mode == constant.JobModeAll {
assets, _ = r.jobService.assetRepository.FindAll()
} else {
assets, _ = r.jobService.assetRepository.FindByIds(strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
msgChan := make(chan string)
for i := range assets {
asset := assets[i]
go func() {
t1 := time.Now()
var (
msg string
ip = asset.IP
port = asset.Port
)
active, err := r.assetService.CheckStatus(asset.AccessGatewayId, ip, port)
elapsed := time.Since(t1)
if err == nil {
msg = fmt.Sprintf("资产「%v」存活状态检测完成存活「%v」耗时「%v」", asset.Name, active, elapsed)
} else {
msg = fmt.Sprintf("资产「%v」存活状态检测完成存活「%v」耗时「%v」原因 %v", asset.Name, active, elapsed, err.Error())
}
_ = r.jobService.assetRepository.UpdateActiveById(active, asset.ID)
log.Infof(msg)
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID)
jobLog := model.JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = r.jobService.jobLogRepository.Create(&jobLog)
}
type ShellJob struct {
ID string
Mode string
ResourceIds string
Metadata string
jobService *JobService
}
type MetadataShell struct {
Shell string
}
func (r ShellJob) Run() {
if r.ID == "" {
return
}
var assets []model.Asset
if r.Mode == constant.JobModeAll {
assets, _ = r.jobService.assetRepository.FindByProtocol("ssh")
} else {
assets, _ = r.jobService.assetRepository.FindByProtocolAndIds("ssh", strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
var metadataShell MetadataShell
err := json.Unmarshal([]byte(r.Metadata), &metadataShell)
if err != nil {
log.Errorf("JSON数据解析失败 %v", err)
return
}
msgChan := make(chan string)
for i := range assets {
asset, err := r.jobService.assetRepository.FindByIdAndDecrypt(assets[i].ID)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询数据异常「%v」", assets[i].Name, err.Error())
return
}
var (
username = asset.Username
password = asset.Password
privateKey = asset.PrivateKey
passphrase = asset.Passphrase
ip = asset.IP
port = asset.Port
)
if asset.AccountType == "credential" {
credential, err := r.jobService.credentialRepository.FindByIdAndDecrypt(asset.CredentialId)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询授权凭证数据异常「%v」", assets[i].Name, err.Error())
return
}
if credential.Type == constant.Custom {
username = credential.Username
password = credential.Password
} else {
username = credential.Username
privateKey = credential.PrivateKey
passphrase = credential.Passphrase
}
}
go func() {
t1 := time.Now()
result, err := exec(metadataShell.Shell, asset.AccessGatewayId, ip, port, username, password, privateKey, passphrase)
elapsed := time.Since(t1)
var msg string
if err != nil {
if errors.Is(gorm.ErrRecordNotFound, err) {
msg = fmt.Sprintf("资产「%v」Shell执行失败请检查资产所关联接入网关是否存在耗时「%v」", asset.Name, elapsed)
} else {
msg = fmt.Sprintf("资产「%v」Shell执行失败错误内容为「%v」耗时「%v」", asset.Name, err.Error(), elapsed)
}
log.Infof(msg)
} else {
msg = fmt.Sprintf("资产「%v」Shell执行成功返回值「%v」耗时「%v」", asset.Name, result, elapsed)
log.Infof(msg)
}
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = r.jobService.jobRepository.UpdateLastUpdatedById(r.ID)
jobLog := model.JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = r.jobService.jobLogRepository.Create(&jobLog)
}
func exec(shell, accessGatewayId, ip string, port int, username, password, privateKey, passphrase string) (string, error) {
if accessGatewayId != "" && accessGatewayId != "-" {
g, err := accessGatewayService.GetGatewayAndReconnectById(accessGatewayId)
if err != nil {
return "", err
}
uuid := utils.UUID()
exposedIP, exposedPort, err := g.OpenSshTunnel(uuid, ip, port)
if err != nil {
return "", err
}
defer g.CloseSshTunnel(uuid)
return ExecCommandBySSH(shell, exposedIP, exposedPort, username, password, privateKey, passphrase)
} else {
return ExecCommandBySSH(shell, ip, port, username, password, privateKey, passphrase)
}
}
func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) {
sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase)
if err != nil {
return "", err
}
session, err := sshClient.NewSession()
if err != nil {
return "", err
}
defer session.Close()
//执行远程命令
combo, err := session.CombinedOutput(cmd)
if err != nil {
return "", err
}
return string(combo), nil
}
func (r JobService) ExecJobById(id string) (err error) {
job, err := r.jobRepository.FindById(id)
func (r jobService) ExecJobById(id string) (err error) {
job, err := repository.JobRepository.FindById(context.TODO(), id)
if err != nil {
return err
}
j, err := getJob(&job, &r)
j, err := getJob(&job)
if err != nil {
return err
}
@ -298,8 +72,8 @@ func (r JobService) ExecJobById(id string) (err error) {
return nil
}
func (r JobService) InitJob() error {
jobs, _ := r.jobRepository.FindAll()
func (r jobService) InitJob() error {
jobs, _ := repository.JobRepository.FindAll(context.TODO())
if len(jobs) == 0 {
job := model.Job{
ID: utils.UUID(),
@ -311,7 +85,7 @@ func (r JobService) InitJob() error {
Created: utils.NowJsonTime(),
Updated: utils.NowJsonTime(),
}
if err := r.jobRepository.Create(&job); err != nil {
if err := repository.JobRepository.Create(context.TODO(), &job); err != nil {
return err
}
log.Debugf("创建计划任务「%v」cron「%v」", job.Name, job.Cron)
@ -329,10 +103,10 @@ func (r JobService) InitJob() error {
return nil
}
func (r JobService) Create(o *model.Job) (err error) {
func (r jobService) Create(o *model.Job) (err error) {
if o.Status == constant.JobStatusRunning {
j, err := getJob(o, &r)
j, err := getJob(o)
if err != nil {
return err
}
@ -343,11 +117,11 @@ func (r JobService) Create(o *model.Job) (err error) {
o.CronJobId = int(jobId)
}
return r.jobRepository.Create(o)
return repository.JobRepository.Create(context.TODO(), o)
}
func (r JobService) DeleteJobById(id string) error {
job, err := r.jobRepository.FindById(id)
func (r jobService) DeleteJobById(id string) error {
job, err := repository.JobRepository.FindById(context.TODO(), id)
if err != nil {
return err
}
@ -356,11 +130,11 @@ func (r JobService) DeleteJobById(id string) error {
return err
}
}
return r.jobRepository.DeleteJobById(id)
return repository.JobRepository.DeleteJobById(context.TODO(), id)
}
func (r JobService) UpdateById(m *model.Job) error {
if err := r.jobRepository.UpdateById(m); err != nil {
func (r jobService) UpdateById(m *model.Job) error {
if err := repository.JobRepository.UpdateById(context.TODO(), m); err != nil {
return err
}

View File

@ -0,0 +1,78 @@
package service
import (
"context"
"fmt"
"strings"
"time"
"next-terminal/server/constant"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
)
type CheckAssetStatusJob struct {
ID string
Mode string
ResourceIds string
Metadata string
}
func (r CheckAssetStatusJob) Run() {
if r.ID == "" {
return
}
var assets []model.Asset
if r.Mode == constant.JobModeAll {
assets, _ = repository.AssetRepository.FindAll(context.TODO())
} else {
assets, _ = repository.AssetRepository.FindByIds(context.TODO(), strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
msgChan := make(chan string)
for i := range assets {
asset := assets[i]
go func() {
t1 := time.Now()
var (
msg string
ip = asset.IP
port = asset.Port
)
active, err := AssetService.CheckStatus(asset.AccessGatewayId, ip, port)
elapsed := time.Since(t1)
if err == nil {
msg = fmt.Sprintf("资产「%v」存活状态检测完成存活「%v」耗时「%v」", asset.Name, active, elapsed)
} else {
msg = fmt.Sprintf("资产「%v」存活状态检测完成存活「%v」耗时「%v」原因 %v", asset.Name, active, elapsed, err.Error())
}
_ = repository.AssetRepository.UpdateActiveById(context.TODO(), active, asset.ID)
log.Infof(msg)
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = repository.JobRepository.UpdateLastUpdatedById(context.TODO(), r.ID)
jobLog := model.JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = repository.JobLogRepository.Create(context.TODO(), &jobLog)
}

View File

@ -0,0 +1,163 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"next-terminal/server/constant"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/term"
"next-terminal/server/utils"
"gorm.io/gorm"
)
type ShellJob struct {
ID string
Mode string
ResourceIds string
Metadata string
}
type MetadataShell struct {
Shell string
}
func (r ShellJob) Run() {
if r.ID == "" {
return
}
var assets []model.Asset
if r.Mode == constant.JobModeAll {
assets, _ = repository.AssetRepository.FindByProtocol(context.TODO(), "ssh")
} else {
assets, _ = repository.AssetRepository.FindByProtocolAndIds(context.TODO(), "ssh", strings.Split(r.ResourceIds, ","))
}
if len(assets) == 0 {
return
}
var metadataShell MetadataShell
err := json.Unmarshal([]byte(r.Metadata), &metadataShell)
if err != nil {
log.Errorf("JSON数据解析失败 %v", err)
return
}
msgChan := make(chan string)
for i := range assets {
asset, err := AssetService.FindByIdAndDecrypt(context.TODO(), assets[i].ID)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询数据异常「%v」", assets[i].Name, err.Error())
return
}
var (
username = asset.Username
password = asset.Password
privateKey = asset.PrivateKey
passphrase = asset.Passphrase
ip = asset.IP
port = asset.Port
)
if asset.AccountType == "credential" {
credential, err := CredentialService.FindByIdAndDecrypt(context.TODO(), asset.CredentialId)
if err != nil {
msgChan <- fmt.Sprintf("资产「%v」Shell执行失败查询授权凭证数据异常「%v」", assets[i].Name, err.Error())
return
}
if credential.Type == constant.Custom {
username = credential.Username
password = credential.Password
} else {
username = credential.Username
privateKey = credential.PrivateKey
passphrase = credential.Passphrase
}
}
go func() {
t1 := time.Now()
result, err := exec(metadataShell.Shell, asset.AccessGatewayId, ip, port, username, password, privateKey, passphrase)
elapsed := time.Since(t1)
var msg string
if err != nil {
if errors.Is(gorm.ErrRecordNotFound, err) {
msg = fmt.Sprintf("资产「%v」Shell执行失败请检查资产所关联接入网关是否存在耗时「%v」", asset.Name, elapsed)
} else {
msg = fmt.Sprintf("资产「%v」Shell执行失败错误内容为「%v」耗时「%v」", asset.Name, err.Error(), elapsed)
}
log.Infof(msg)
} else {
msg = fmt.Sprintf("资产「%v」Shell执行成功返回值「%v」耗时「%v」", asset.Name, result, elapsed)
log.Infof(msg)
}
msgChan <- msg
}()
}
var message = ""
for i := 0; i < len(assets); i++ {
message += <-msgChan + "\n"
}
_ = repository.JobRepository.UpdateLastUpdatedById(context.TODO(), r.ID)
jobLog := model.JobLog{
ID: utils.UUID(),
JobId: r.ID,
Timestamp: utils.NowJsonTime(),
Message: message,
}
_ = repository.JobLogRepository.Create(context.TODO(), &jobLog)
}
func exec(shell, accessGatewayId, ip string, port int, username, password, privateKey, passphrase string) (string, error) {
if accessGatewayId != "" && accessGatewayId != "-" {
g, err := GatewayService.GetGatewayAndReconnectById(accessGatewayId)
if err != nil {
return "", err
}
uuid := utils.UUID()
exposedIP, exposedPort, err := g.OpenSshTunnel(uuid, ip, port)
if err != nil {
return "", err
}
defer g.CloseSshTunnel(uuid)
return ExecCommandBySSH(shell, exposedIP, exposedPort, username, password, privateKey, passphrase)
} else {
return ExecCommandBySSH(shell, ip, port, username, password, privateKey, passphrase)
}
}
func ExecCommandBySSH(cmd, ip string, port int, username, password, privateKey, passphrase string) (result string, err error) {
sshClient, err := term.NewSshClient(ip, port, username, password, privateKey, passphrase)
if err != nil {
return "", err
}
session, err := sshClient.NewSession()
if err != nil {
return "", err
}
defer func() {
_ = session.Close()
}()
//执行远程命令
combo, err := session.CombinedOutput(cmd)
if err != nil {
return "", err
}
return string(combo), nil
}

View File

@ -1,6 +1,7 @@
package service
import (
"context"
"net/smtp"
"next-terminal/server/constant"
@ -10,16 +11,11 @@ import (
"github.com/jordan-wright/email"
)
type MailService struct {
propertyRepository *repository.PropertyRepository
type mailService struct {
}
func NewMailService(propertyRepository *repository.PropertyRepository) *MailService {
return &MailService{propertyRepository: propertyRepository}
}
func (r MailService) SendMail(to, subject, text string) {
propertiesMap := r.propertyRepository.FindAllMap()
func (r mailService) SendMail(to, subject, text string) {
propertiesMap := repository.PropertyRepository.FindAllMap(context.TODO())
host := propertiesMap[constant.MailHost]
port := propertiesMap[constant.MailPort]
username := propertiesMap[constant.MailUsername]

View File

@ -1,28 +1,31 @@
package service
import (
"context"
"errors"
"fmt"
"next-terminal/server/env"
"next-terminal/server/guacd"
"next-terminal/server/model"
"next-terminal/server/repository"
"gorm.io/gorm"
)
type PropertyService struct {
propertyRepository *repository.PropertyRepository
type propertyService struct {
baseService
}
func NewPropertyService(propertyRepository *repository.PropertyRepository) *PropertyService {
return &PropertyService{propertyRepository: propertyRepository}
}
func (r PropertyService) InitProperties() error {
propertyMap := r.propertyRepository.FindAllMap()
func (service propertyService) InitProperties() error {
propertyMap := repository.PropertyRepository.FindAllMap(context.TODO())
if len(propertyMap[guacd.EnableRecording]) == 0 {
property := model.Property{
Name: guacd.EnableRecording,
Value: "true",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -32,7 +35,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.CreateRecordingPath,
Value: "true",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -42,7 +45,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.FontName,
Value: "menlo",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -52,7 +55,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.FontSize,
Value: "12",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -62,7 +65,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.ColorScheme,
Value: "gray-black",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -72,7 +75,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.EnableWallpaper,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -82,7 +85,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.EnableTheming,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -92,7 +95,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.EnableFontSmoothing,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -102,7 +105,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.EnableFullWindowDrag,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -112,7 +115,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.EnableDesktopComposition,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -122,7 +125,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.EnableMenuAnimations,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -132,7 +135,7 @@ func (r PropertyService) InitProperties() error {
Name: guacd.DisableBitmapCaching,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
@ -142,39 +145,71 @@ func (r PropertyService) InitProperties() error {
Name: guacd.DisableOffscreenCaching,
Value: "false",
}
if err := r.propertyRepository.Create(&property); err != nil {
if err := repository.PropertyRepository.Create(context.TODO(), &property); err != nil {
return err
}
}
if len(propertyMap[guacd.DisableGlyphCaching]) == 0 {
property := model.Property{
Name: guacd.DisableGlyphCaching,
Value: "true",
}
if err := r.propertyRepository.Create(&property); err != nil {
if len(propertyMap[guacd.DisableGlyphCaching]) > 0 {
if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DisableGlyphCaching); err != nil {
return err
}
}
return nil
}
func (r PropertyService) DeleteDeprecatedProperty() error {
propertyMap := r.propertyRepository.FindAllMap()
func (service propertyService) DeleteDeprecatedProperty() error {
propertyMap := repository.PropertyRepository.FindAllMap(context.TODO())
if propertyMap[guacd.EnableDrive] != "" {
if err := r.propertyRepository.DeleteByName(guacd.DriveName); err != nil {
if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DriveName); err != nil {
return err
}
}
if propertyMap[guacd.DrivePath] != "" {
if err := r.propertyRepository.DeleteByName(guacd.DrivePath); err != nil {
if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DrivePath); err != nil {
return err
}
}
if propertyMap[guacd.DriveName] != "" {
if err := r.propertyRepository.DeleteByName(guacd.DriveName); err != nil {
if err := repository.PropertyRepository.DeleteByName(context.TODO(), guacd.DriveName); err != nil {
return err
}
}
return nil
}
func (service propertyService) Update(item map[string]interface{}) error {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := service.Context(tx)
for key := range item {
value := fmt.Sprintf("%v", item[key])
if value == "" {
value = "-"
}
property := model.Property{
Name: key,
Value: value,
}
if key == "enable-ldap" && value == "false" {
if err := UserService.DeleteALlLdapUser(c); err != nil {
return err
}
}
_, err := repository.PropertyRepository.FindByName(c, key)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
if err := repository.PropertyRepository.Create(c, &property); err != nil {
return err
}
} else {
if err := repository.PropertyRepository.UpdateByName(c, &property, key); err != nil {
return err
}
}
}
return nil
})
}

View File

@ -0,0 +1,32 @@
package service
import (
"context"
"next-terminal/server/global/security"
"next-terminal/server/repository"
)
type securityService struct{}
func (service securityService) ReloadAccessSecurity() error {
rules, err := repository.SecurityRepository.FindAll(context.TODO())
if err != nil {
return err
}
if len(rules) > 0 {
// 先清空
security.GlobalSecurityManager.Clear()
// 再添加到全局的安全管理器中
for i := 0; i < len(rules); i++ {
rule := &security.Security{
ID: rules[i].ID,
IP: rules[i].IP,
Rule: rules[i].Rule,
Priority: rules[i].Priority,
}
security.GlobalSecurityManager.Add <- rule
}
}
return nil
}

View File

@ -1,45 +1,55 @@
package service
import (
"context"
"encoding/base64"
"errors"
"strconv"
"sync"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/env"
"next-terminal/server/global/session"
"next-terminal/server/guacd"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
type SessionService struct {
sessionRepository *repository.SessionRepository
type sessionService struct {
baseService
}
func NewSessionService(sessionRepository *repository.SessionRepository) *SessionService {
return &SessionService{sessionRepository: sessionRepository}
}
func (r SessionService) FixSessionState() error {
sessions, err := r.sessionRepository.FindByStatus(constant.Connected)
func (service sessionService) FixSessionState() error {
sessions, err := repository.SessionRepository.FindByStatus(context.TODO(), constant.Connected)
if err != nil {
return err
}
if len(sessions) > 0 {
for i := range sessions {
session := model.Session{
s := model.Session{
Status: constant.Disconnected,
DisconnectedTime: utils.NowJsonTime(),
}
_ = r.sessionRepository.UpdateById(&session, sessions[i].ID)
_ = repository.SessionRepository.UpdateById(context.TODO(), &s, sessions[i].ID)
}
}
return nil
}
func (r SessionService) EmptyPassword() error {
return r.sessionRepository.EmptyPassword()
func (service sessionService) EmptyPassword() error {
return repository.SessionRepository.EmptyPassword(context.TODO())
}
func (r SessionService) ClearOfflineSession() error {
sessions, err := r.sessionRepository.FindByStatus(constant.Disconnected)
func (service sessionService) ClearOfflineSession() error {
sessions, err := repository.SessionRepository.FindByStatus(context.TODO(), constant.Disconnected)
if err != nil {
return err
}
@ -47,11 +57,11 @@ func (r SessionService) ClearOfflineSession() error {
for i := range sessions {
sessionIds = append(sessionIds, sessions[i].ID)
}
return r.sessionRepository.DeleteByIds(sessionIds)
return repository.SessionRepository.DeleteByIds(context.TODO(), sessionIds)
}
func (r SessionService) ReviewedAll() error {
sessions, err := r.sessionRepository.FindAllUnReviewed()
func (service sessionService) ReviewedAll() error {
sessions, err := repository.SessionRepository.FindAllUnReviewed(context.TODO())
if err != nil {
return err
}
@ -60,13 +70,13 @@ func (r SessionService) ReviewedAll() error {
for i := range sessions {
sessionIds = append(sessionIds, sessions[i].ID)
if i >= 100 && i%100 == 0 {
if err := r.sessionRepository.UpdateReadByIds(true, sessionIds); err != nil {
if err := repository.SessionRepository.UpdateReadByIds(context.TODO(), true, sessionIds); err != nil {
return err
}
sessionIds = nil
} else {
if i == total-1 {
if err := r.sessionRepository.UpdateReadByIds(true, sessionIds); err != nil {
if err := repository.SessionRepository.UpdateReadByIds(context.TODO(), true, sessionIds); err != nil {
return err
}
}
@ -75,3 +85,272 @@ func (r SessionService) ReviewedAll() error {
}
return nil
}
var mutex sync.Mutex
func (service sessionService) CloseSessionById(sessionId string, code int, reason string) {
mutex.Lock()
defer mutex.Unlock()
nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession != nil {
log.Debugf("[%v] 会话关闭,原因:%v", sessionId, reason)
service.WriteCloseMessage(nextSession.WebSocket, nextSession.Mode, code, reason)
if nextSession.Observer != nil {
obs := nextSession.Observer.All()
for _, ob := range obs {
service.WriteCloseMessage(ob.WebSocket, ob.Mode, code, reason)
log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID)
}
}
}
session.GlobalSessionManager.Del <- sessionId
service.DisDBSess(sessionId, code, reason)
}
func (service sessionService) WriteCloseMessage(ws *websocket.Conn, mode string, code int, reason string) {
switch mode {
case constant.Guacd:
if ws != nil {
err := guacd.NewInstruction("error", "", strconv.Itoa(code))
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.String()))
disconnect := guacd.NewInstruction("disconnect")
_ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String()))
}
case constant.Naive:
if ws != nil {
msg := `0` + reason
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg))
}
case constant.Terminal:
// 这里是关闭观察者的ssh会话
if ws != nil {
msg := `0` + reason
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg))
}
}
}
func (service sessionService) DisDBSess(sessionId string, code int, reason string) {
_ = env.GetDB().Transaction(func(tx *gorm.DB) error {
c := service.Context(tx)
s, err := repository.SessionRepository.FindById(c, sessionId)
if err != nil {
return err
}
if s.Status == constant.Disconnected {
return err
}
if s.Status == constant.Connecting {
// 会话还未建立成功,无需保留数据
if err := repository.SessionRepository.DeleteById(c, sessionId); err != nil {
return err
}
return nil
}
ss := model.Session{}
ss.ID = sessionId
ss.Status = constant.Disconnected
ss.DisconnectedTime = utils.NowJsonTime()
ss.Code = code
ss.Message = reason
ss.Password = "-"
ss.PrivateKey = "-"
ss.Passphrase = "-"
if err := repository.SessionRepository.UpdateById(c, &ss, sessionId); err != nil {
return err
}
return nil
})
}
func (service sessionService) FindByIdAndDecrypt(c context.Context, id string) (o model.Session, err error) {
sess, err := repository.SessionRepository.FindById(c, id)
if err != nil {
return o, err
}
if err := service.Decrypt(&sess); err != nil {
return o, err
}
return sess, nil
}
func (service sessionService) Decrypt(item *model.Session) error {
if item.Password != "" && item.Password != "-" {
origData, err := base64.StdEncoding.DecodeString(item.Password)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
item.Password = string(decryptedCBC)
}
if item.PrivateKey != "" && item.PrivateKey != "-" {
origData, err := base64.StdEncoding.DecodeString(item.PrivateKey)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
item.PrivateKey = string(decryptedCBC)
}
if item.Passphrase != "" && item.Passphrase != "-" {
origData, err := base64.StdEncoding.DecodeString(item.Passphrase)
if err != nil {
return err
}
decryptedCBC, err := utils.AesDecryptCBC(origData, config.GlobalCfg.EncryptionPassword)
if err != nil {
return err
}
item.Passphrase = string(decryptedCBC)
}
return nil
}
func (service sessionService) Create(clientIp, assetId, mode string, user *model.User) (*model.Session, error) {
asset, err := repository.AssetRepository.FindById(context.TODO(), assetId)
if err != nil {
return nil, err
}
var (
upload = "1"
download = "1"
_delete = "1"
rename = "1"
edit = "1"
fileSystem = "1"
_copy = "1"
paste = "1"
)
if asset.Owner != user.ID && constant.TypeUser == user.Type {
// 普通用户访问非自己创建的资产需要校验权限
resourceSharers, err := repository.ResourceSharerRepository.FindByResourceIdAndUserId(context.TODO(), assetId, user.ID)
if err != nil {
return nil, err
}
if len(resourceSharers) == 0 {
return nil, errors.New("您没有权限访问此资产")
}
strategyId := resourceSharers[0].StrategyId
if strategyId != "" {
strategy, err := repository.StrategyRepository.FindById(context.TODO(), strategyId)
if err != nil {
if !errors.Is(gorm.ErrRecordNotFound, err) {
return nil, err
}
} else {
upload = strategy.Upload
download = strategy.Download
_delete = strategy.Delete
rename = strategy.Rename
edit = strategy.Edit
_copy = strategy.Copy
paste = strategy.Paste
}
}
}
var storageId = ""
if constant.RDP == asset.Protocol {
attr, err := repository.AssetRepository.FindAssetAttrMapByAssetId(context.TODO(), assetId)
if err != nil {
return nil, err
}
if "true" == attr[guacd.EnableDrive] {
fileSystem = "1"
storageId = attr[guacd.DrivePath]
if storageId == "" {
storageId = user.ID
}
} else {
fileSystem = "0"
}
}
if fileSystem != "1" {
fileSystem = "0"
}
if upload != "1" {
upload = "0"
}
if download != "1" {
download = "0"
}
if _delete != "1" {
_delete = "0"
}
if rename != "1" {
rename = "0"
}
if edit != "1" {
edit = "0"
}
if _copy != "1" {
_copy = "0"
}
if paste != "1" {
paste = "0"
}
s := &model.Session{
ID: utils.UUID(),
AssetId: asset.ID,
Username: asset.Username,
Password: asset.Password,
PrivateKey: asset.PrivateKey,
Passphrase: asset.Passphrase,
Protocol: asset.Protocol,
IP: asset.IP,
Port: asset.Port,
Status: constant.NoConnect,
ClientIP: clientIp,
Mode: mode,
FileSystem: fileSystem,
Upload: upload,
Download: download,
Delete: _delete,
Rename: rename,
Edit: edit,
Copy: _copy,
Paste: paste,
StorageId: storageId,
AccessGatewayId: asset.AccessGatewayId,
Reviewed: false,
}
if constant.Anonymous != user.Type {
s.Creator = user.ID
}
if asset.AccountType == "credential" {
credential, err := repository.CredentialRepository.FindById(context.TODO(), asset.CredentialId)
if err != nil {
return nil, err
}
if credential.Type == constant.Custom {
s.Username = credential.Username
s.Password = credential.Password
} else {
s.Username = credential.Username
s.PrivateKey = credential.PrivateKey
s.Passphrase = credential.Passphrase
}
}
if err := repository.SessionRepository.Create(context.TODO(), s); err != nil {
return nil, err
}
return s, nil
}

View File

@ -1,10 +1,15 @@
package service
import (
"bufio"
"context"
"errors"
"io"
"io/ioutil"
"mime/multipart"
"os"
"path"
"strings"
"next-terminal/server/config"
"next-terminal/server/log"
@ -12,37 +17,31 @@ import (
"next-terminal/server/repository"
"next-terminal/server/utils"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
)
type StorageService struct {
storageRepository *repository.StorageRepository
userRepository *repository.UserRepository
propertyRepository *repository.PropertyRepository
type storageService struct {
}
func NewStorageService(storageRepository *repository.StorageRepository, userRepository *repository.UserRepository, propertyRepository *repository.PropertyRepository) *StorageService {
return &StorageService{storageRepository: storageRepository, userRepository: userRepository, propertyRepository: propertyRepository}
}
func (r StorageService) InitStorages() error {
users, err := r.userRepository.FindAll()
func (service storageService) InitStorages() error {
users, err := repository.UserRepository.FindAll(context.TODO())
if err != nil {
return err
}
for i := range users {
userId := users[i].ID
_, err := r.storageRepository.FindByOwnerIdAndDefault(userId, true)
_, err := repository.StorageRepository.FindByOwnerIdAndDefault(context.TODO(), userId, true)
if errors.Is(err, gorm.ErrRecordNotFound) {
err = r.CreateStorageByUser(&users[i])
err = service.CreateStorageByUser(&users[i])
if err != nil {
return err
}
}
}
drivePath := r.GetBaseDrivePath()
storages, err := r.storageRepository.FindAll()
drivePath := service.GetBaseDrivePath()
storages, err := repository.StorageRepository.FindAll(context.TODO())
if err != nil {
return err
}
@ -59,7 +58,7 @@ func (r StorageService) InitStorages() error {
}
if !userExist {
if err := r.DeleteStorageById(storage.ID, true); err != nil {
if err := service.DeleteStorageById(storage.ID, true); err != nil {
return err
}
}
@ -76,8 +75,8 @@ func (r StorageService) InitStorages() error {
return nil
}
func (r StorageService) CreateStorageByUser(user *model.User) error {
drivePath := r.GetBaseDrivePath()
func (service storageService) CreateStorageByUser(user *model.User) error {
drivePath := service.GetBaseDrivePath()
storage := model.Storage{
ID: user.ID,
Name: user.Nickname + "的默认空间",
@ -92,7 +91,7 @@ func (r StorageService) CreateStorageByUser(user *model.User) error {
return err
}
log.Infof("创建storage:「%v」文件夹: %v", storage.Name, storageDir)
err := r.storageRepository.Create(&storage)
err := repository.StorageRepository.Create(context.TODO(), &storage)
if err != nil {
return err
}
@ -109,7 +108,7 @@ type File struct {
Size int64 `json:"size"`
}
func (r StorageService) Ls(drivePath, remoteDir string) ([]File, error) {
func (service storageService) Ls(drivePath, remoteDir string) ([]File, error) {
fileInfos, err := ioutil.ReadDir(path.Join(drivePath, remoteDir))
if err != nil {
return nil, err
@ -132,13 +131,13 @@ func (r StorageService) Ls(drivePath, remoteDir string) ([]File, error) {
return files, nil
}
func (r StorageService) GetBaseDrivePath() string {
func (service storageService) GetBaseDrivePath() string {
return config.GlobalCfg.Guacd.Drive
}
func (r StorageService) DeleteStorageById(id string, force bool) error {
drivePath := r.GetBaseDrivePath()
storage, err := r.storageRepository.FindById(id)
func (service storageService) DeleteStorageById(id string, force bool) error {
drivePath := service.GetBaseDrivePath()
storage, err := repository.StorageRepository.FindById(context.TODO(), id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
@ -153,7 +152,136 @@ func (r StorageService) DeleteStorageById(id string, force bool) error {
if err := os.RemoveAll(path.Join(drivePath, id)); err != nil {
return err
}
if err := r.storageRepository.DeleteById(id); err != nil {
if err := repository.StorageRepository.DeleteById(context.TODO(), id); err != nil {
return err
}
return nil
}
func (service storageService) StorageUpload(c echo.Context, file *multipart.FileHeader, storageId string) error {
drivePath := service.GetBaseDrivePath()
storage, _ := repository.StorageRepository.FindById(context.TODO(), storageId)
if storage.LimitSize > 0 {
dirSize, err := utils.DirSize(path.Join(drivePath, storageId))
if err != nil {
return err
}
if dirSize+file.Size > storage.LimitSize {
return errors.New("可用空间不足")
}
}
filename := file.Filename
src, err := file.Open()
if err != nil {
return err
}
remoteDir := c.QueryParam("dir")
remoteFile := path.Join(remoteDir, filename)
if strings.Contains(remoteDir, "../") {
return errors.New("非法请求 :(")
}
if strings.Contains(remoteFile, "../") {
return errors.New("非法请求 :(")
}
// 判断文件夹不存在时自动创建
dir := path.Join(path.Join(drivePath, storageId), remoteDir)
if !utils.FileExists(dir) {
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}
}
// Destination
dst, err := os.Create(path.Join(path.Join(drivePath, storageId), remoteFile))
if err != nil {
return err
}
defer dst.Close()
// Copy
if _, err = io.Copy(dst, src); err != nil {
return err
}
return nil
}
func (service storageService) StorageEdit(file string, fileContent string, storageId string) error {
drivePath := service.GetBaseDrivePath()
if strings.Contains(file, "../") {
return errors.New("非法请求 :(")
}
realFilePath := path.Join(path.Join(drivePath, storageId), file)
dstFile, err := os.OpenFile(realFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
return err
}
defer dstFile.Close()
write := bufio.NewWriter(dstFile)
if _, err := write.WriteString(fileContent); err != nil {
return err
}
if err := write.Flush(); err != nil {
return err
}
return nil
}
func (service storageService) StorageDownload(c echo.Context, remoteFile, storageId string) error {
drivePath := service.GetBaseDrivePath()
if strings.Contains(remoteFile, "../") {
return errors.New("非法请求 :(")
}
// 获取带后缀的文件名称
filenameWithSuffix := path.Base(remoteFile)
return c.Attachment(path.Join(path.Join(drivePath, storageId), remoteFile), filenameWithSuffix)
}
func (service storageService) StorageLs(remoteDir, storageId string) (error, []File) {
drivePath := service.GetBaseDrivePath()
if strings.Contains(remoteDir, "../") {
return errors.New("非法请求 :("), nil
}
files, err := service.Ls(path.Join(drivePath, storageId), remoteDir)
if err != nil {
return err, nil
}
return nil, files
}
func (service storageService) StorageMkDir(remoteDir, storageId string) error {
drivePath := service.GetBaseDrivePath()
if strings.Contains(remoteDir, "../") {
return errors.New("非法请求 :(")
}
if err := os.MkdirAll(path.Join(path.Join(drivePath, storageId), remoteDir), os.ModePerm); err != nil {
return err
}
return nil
}
func (service storageService) StorageRm(file, storageId string) error {
drivePath := service.GetBaseDrivePath()
if strings.Contains(file, "../") {
return errors.New("非法请求 :(")
}
if err := os.RemoveAll(path.Join(path.Join(drivePath, storageId), file)); err != nil {
return err
}
return nil
}
func (service storageService) StorageRename(oldName, newName, storageId string) error {
drivePath := service.GetBaseDrivePath()
if strings.Contains(oldName, "../") {
return errors.New("非法请求 :(")
}
if strings.Contains(newName, "../") {
return errors.New("非法请求 :(")
}
if err := os.Rename(path.Join(path.Join(drivePath, storageId), oldName), path.Join(path.Join(drivePath, storageId), newName)); err != nil {
return err
}
return nil

View File

@ -1,28 +1,30 @@
package service
import (
"next-terminal/server/global/cache"
"strings"
"errors"
"fmt"
"next-terminal/server/constant"
"next-terminal/server/dto"
"next-terminal/server/env"
"next-terminal/server/global/cache"
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
"strings"
"golang.org/x/net/context"
"gorm.io/gorm"
)
type UserService struct {
userRepository *repository.UserRepository
loginLogRepository *repository.LoginLogRepository
type userService struct {
baseService
}
func NewUserService(userRepository *repository.UserRepository, loginLogRepository *repository.LoginLogRepository) *UserService {
return &UserService{userRepository: userRepository, loginLogRepository: loginLogRepository}
}
func (service userService) InitUser() (err error) {
func (r UserService) InitUser() (err error) {
users, err := r.userRepository.FindAll()
users, err := repository.UserRepository.FindAll(context.TODO())
if err != nil {
return err
}
@ -43,7 +45,7 @@ func (r UserService) InitUser() (err error) {
Created: utils.NowJsonTime(),
Status: constant.StatusEnabled,
}
if err := r.userRepository.Create(&user); err != nil {
if err := repository.UserRepository.Create(context.TODO(), &user); err != nil {
return err
}
@ -56,7 +58,7 @@ func (r UserService) InitUser() (err error) {
Type: constant.TypeAdmin,
ID: users[i].ID,
}
if err := r.userRepository.Update(&user); err != nil {
if err := repository.UserRepository.Update(context.TODO(), &user); err != nil {
return err
}
log.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID)
@ -66,20 +68,20 @@ func (r UserService) InitUser() (err error) {
return nil
}
func (r UserService) FixUserOnlineState() error {
func (service userService) FixUserOnlineState() error {
// 修正用户登录状态
onlineUsers, err := r.userRepository.FindOnlineUsers()
onlineUsers, err := repository.UserRepository.FindOnlineUsers(context.TODO())
if err != nil {
return err
}
if len(onlineUsers) > 0 {
for i := range onlineUsers {
logs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(onlineUsers[i].Username)
logs, err := repository.LoginLogRepository.FindAliveLoginLogsByUsername(context.TODO(), onlineUsers[i].Username)
if err != nil {
return err
}
if len(logs) == 0 {
if err := r.userRepository.UpdateOnlineByUsername(onlineUsers[i].Username, false); err != nil {
if err := repository.UserRepository.UpdateOnlineByUsername(context.TODO(), onlineUsers[i].Username, false); err != nil {
return err
}
}
@ -88,96 +90,220 @@ func (r UserService) FixUserOnlineState() error {
return nil
}
func (r UserService) LogoutByToken(token string) (err error) {
loginLog, err := r.loginLogRepository.FindById(token)
if err != nil {
log.Warnf("登录日志「%v」获取失败", token)
return
}
cacheKey := r.BuildCacheKeyByToken(token)
cache.GlobalCache.Delete(cacheKey)
func (service userService) LogoutByToken(token string) (err error) {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := service.Context(tx)
loginLog, err := repository.LoginLogRepository.FindById(c, token)
if err != nil {
return err
}
cache.TokenManager.Delete(token)
loginLogForUpdate := &model.LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}
err = r.loginLogRepository.Update(loginLogForUpdate)
if err != nil {
loginLogForUpdate := &model.LoginLog{LogoutTime: utils.NowJsonTime(), ID: token}
err = repository.LoginLogRepository.Update(c, loginLogForUpdate)
if err != nil {
return err
}
loginLogs, err := repository.LoginLogRepository.FindAliveLoginLogsByUsername(c, loginLog.Username)
if err != nil {
return err
}
if len(loginLogs) == 0 {
err = repository.UserRepository.UpdateOnlineByUsername(c, loginLog.Username, false)
}
return err
}
loginLogs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(loginLog.Username)
if err != nil {
return
}
if len(loginLogs) == 0 {
err = r.userRepository.UpdateOnlineByUsername(loginLog.Username, false)
}
return
})
}
func (r UserService) LogoutById(id string) error {
user, err := r.userRepository.FindById(id)
func (service userService) LogoutById(c context.Context, id string) error {
user, err := repository.UserRepository.FindById(c, id)
if err != nil {
return err
}
username := user.Username
loginLogs, err := r.loginLogRepository.FindAliveLoginLogsByUsername(username)
loginLogs, err := repository.LoginLogRepository.FindAliveLoginLogsByUsername(c, username)
if err != nil {
return err
}
for j := range loginLogs {
token := loginLogs[j].ID
if err := r.LogoutByToken(token); err != nil {
if err := service.LogoutByToken(token); err != nil {
return err
}
}
return nil
}
func (r UserService) BuildCacheKeyByToken(token string) string {
cacheKey := strings.Join([]string{constant.Token, token}, ":")
return cacheKey
}
func (service userService) OnEvicted(token string, value interface{}) {
func (r UserService) GetTokenFormCacheKey(cacheKey string) string {
token := strings.Split(cacheKey, ":")[1]
return token
}
func (r UserService) OnEvicted(key string, value interface{}) {
if strings.HasPrefix(key, constant.Token) {
token := r.GetTokenFormCacheKey(key)
if strings.HasPrefix(token, "forever") {
log.Debugf("re gen forever token")
} else {
log.Debugf("用户Token「%v」过期", token)
err := r.LogoutByToken(token)
err := service.LogoutByToken(token)
if err != nil {
log.Errorf("退出登录失败 %v", err)
}
}
}
func (r UserService) UpdateStatusById(id string, status string) error {
if constant.StatusDisabled == status {
// 将该用户下线
if err := r.LogoutById(id); err != nil {
return err
func (service userService) UpdateStatusById(id string, status string) error {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := service.Context(tx)
if c.Value(constant.DB) == nil {
c = context.WithValue(c, constant.DB, env.GetDB())
}
}
u := model.User{
ID: id,
Status: status,
}
return r.userRepository.Update(&u)
if constant.StatusDisabled == status {
// 将该用户下线
if err := service.LogoutById(c, id); err != nil {
return err
}
}
u := model.User{
ID: id,
Status: status,
}
return repository.UserRepository.Update(c, &u)
})
}
func (r UserService) DeleteLoginLogs(tokens []string) error {
for i := range tokens {
token := tokens[i]
if err := r.LogoutByToken(token); err != nil {
func (service userService) ReloadToken() error {
loginLogs, err := repository.LoginLogRepository.FindAliveLoginLogs(context.TODO())
if err != nil {
return err
}
for i := range loginLogs {
loginLog := loginLogs[i]
token := loginLog.ID
user, err := repository.UserRepository.FindByUsername(context.TODO(), loginLog.Username)
if err != nil {
if errors.Is(gorm.ErrRecordNotFound, err) {
_ = repository.LoginLogRepository.DeleteById(context.TODO(), token)
}
continue
}
authorization := dto.Authorization{
Token: token,
Type: constant.LoginToken,
Remember: loginLog.Remember,
User: &user,
}
if authorization.Remember {
// 记住登录有效期两周
cache.TokenManager.Set(token, authorization, cache.RememberMeExpiration)
} else {
cache.TokenManager.Set(token, authorization, cache.NotRememberExpiration)
}
log.Debugf("重新加载用户「%v」授权Token「%v」到缓存", user.Nickname, token)
}
return nil
}
func (service userService) CreateUser(user model.User) (err error) {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := service.Context(tx)
if repository.UserRepository.ExistByUsername(c, user.Username) {
return fmt.Errorf("username %s is already used", user.Username)
}
password := user.Password
var pass []byte
if pass, err = utils.Encoder.Encode([]byte(password)); err != nil {
return err
}
if err := r.loginLogRepository.DeleteById(token); err != nil {
user.Password = string(pass)
user.ID = utils.UUID()
user.Created = utils.NowJsonTime()
user.Status = constant.StatusEnabled
if err := repository.UserRepository.Create(c, &user); err != nil {
return err
}
err = StorageService.CreateStorageByUser(&user)
if err != nil {
return err
}
if user.Mail != "" {
go MailService.SendMail(user.Mail, "[Next Terminal] 注册通知", "你好,"+user.Nickname+"。管理员为你注册了账号:"+user.Username+" 密码:"+password)
}
return nil
})
}
func (service userService) DeleteUserById(userId string) error {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := service.Context(tx)
// 下线该用户
if err := service.LogoutById(c, userId); err != nil {
return err
}
// 删除用户
if err := repository.UserRepository.DeleteById(c, userId); err != nil {
return err
}
// 删除用户与用户组的关系
if err := repository.UserGroupMemberRepository.DeleteByUserId(c, userId); err != nil {
return err
}
// 删除用户与资产的关系
if err := repository.ResourceSharerRepository.DeleteByUserId(c, userId); err != nil {
return err
}
// 删除用户的默认磁盘空间
if err := StorageService.DeleteStorageById(userId, true); err != nil {
return err
}
return nil
})
}
func (service userService) DeleteLoginLogs(tokens []string) error {
if len(tokens) > 0 {
for _, token := range tokens {
if err := service.LogoutByToken(token); err != nil {
return err
}
if err := repository.LoginLogRepository.DeleteById(context.TODO(), token); err != nil {
return err
}
}
}
return nil
}
func (service userService) SaveLoginLog(clientIP, clientUserAgent string, username string, success, remember bool, id, reason string) error {
loginLog := model.LoginLog{
Username: username,
ClientIP: clientIP,
ClientUserAgent: clientUserAgent,
LoginTime: utils.NowJsonTime(),
Reason: reason,
Remember: remember,
}
if success {
loginLog.State = "1"
loginLog.ID = id
} else {
loginLog.State = "0"
loginLog.ID = utils.LongUUID()
}
if err := repository.LoginLogRepository.Create(context.TODO(), &loginLog); err != nil {
return err
}
return nil
}
func (service userService) DeleteALlLdapUser(ctx context.Context) error {
return repository.UserRepository.DeleteBySource(ctx, constant.SourceLdap)
}

View File

@ -0,0 +1,115 @@
package service
import (
"context"
"errors"
"next-terminal/server/constant"
"next-terminal/server/env"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/utils"
"gorm.io/gorm"
)
type userGroupService struct {
}
func (service userGroupService) DeleteById(userGroupId string) error {
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := context.WithValue(context.TODO(), constant.DB, tx)
// 删除用户组
if err := repository.UserGroupRepository.DeleteById(c, userGroupId); err != nil {
return err
}
// 删除用户组与用户的关系
if err := repository.UserGroupMemberRepository.DeleteByUserGroupId(c, userGroupId); err != nil {
return err
}
// 删除用户组与资产的关系
if err := repository.ResourceSharerRepository.DeleteByUserGroupId(c, userGroupId); err != nil {
return err
}
return nil
})
}
func (service userGroupService) Create(name string, members []string) (model.UserGroup, error) {
var err error
_, err = repository.UserGroupRepository.FindByName(context.TODO(), name)
if err == nil {
return model.UserGroup{}, constant.ErrNameAlreadyUsed
}
if !errors.Is(gorm.ErrRecordNotFound, err) {
return model.UserGroup{}, err
}
userGroupId := utils.UUID()
userGroup := model.UserGroup{
ID: userGroupId,
Created: utils.NowJsonTime(),
Name: name,
}
return userGroup, env.GetDB().Transaction(func(tx *gorm.DB) error {
c := context.WithValue(context.TODO(), constant.DB, tx)
if err := repository.UserGroupRepository.Create(c, &userGroup); err != nil {
return err
}
if len(members) > 0 {
for _, member := range members {
userGroupMember := model.UserGroupMember{
ID: utils.Sign([]string{userGroupId, member}),
UserId: member,
UserGroupId: userGroupId,
}
if err := repository.UserGroupMemberRepository.Create(c, &userGroupMember); err != nil {
return err
}
}
}
return nil
})
}
func (service userGroupService) Update(userGroupId string, name string, members []string) (err error) {
var userGroup model.UserGroup
userGroup, err = repository.UserGroupRepository.FindByName(context.TODO(), name)
if err == nil && userGroup.ID != userGroupId {
return constant.ErrNameAlreadyUsed
}
if !errors.Is(gorm.ErrRecordNotFound, err) {
return err
}
return env.GetDB().Transaction(func(tx *gorm.DB) error {
c := context.WithValue(context.TODO(), constant.DB, tx)
userGroup := model.UserGroup{
ID: userGroupId,
Name: name,
}
if err := repository.UserGroupRepository.Update(c, &userGroup); err != nil {
return err
}
if err := repository.UserGroupMemberRepository.DeleteByUserGroupId(c, userGroupId); err != nil {
return err
}
if len(members) > 0 {
for _, member := range members {
userGroupMember := model.UserGroupMember{
ID: utils.Sign([]string{userGroupId, member}),
UserId: member,
UserGroupId: userGroupId,
}
if err := repository.UserGroupMemberRepository.Create(c, &userGroupMember); err != nil {
return err
}
}
}
return nil
})
}

17
server/service/var.go Normal file
View File

@ -0,0 +1,17 @@
package service
var (
AssetService = new(assetService)
BackupService = new(backupService)
CredentialService = new(credentialService)
GatewayService = new(gatewayService)
JobService = new(jobService)
MailService = new(mailService)
PropertyService = new(propertyService)
SecurityService = new(securityService)
SessionService = new(sessionService)
StorageService = new(storageService)
UserService = new(userService)
UserGroupService = new(userGroupService)
AccessTokenService = new(accessTokenService)
)