完成重构数据库操作代码

This commit is contained in:
dushixiang
2021-03-18 23:36:25 +08:00
parent 0150361054
commit 25b8381a4f
44 changed files with 2292 additions and 2016 deletions

View File

@ -0,0 +1,81 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/global"
"next-terminal/server/model"
)
type AccessSecurityRepository struct {
DB *gorm.DB
}
func NewAccessSecurityRepository(db *gorm.DB) *AccessSecurityRepository {
accessSecurityRepository = &AccessSecurityRepository{DB: db}
return accessSecurityRepository
}
func (r AccessSecurityRepository) FindAllAccessSecurities() (o []model.AccessSecurity, err error) {
db := r.DB
err = db.Order("priority asc").Find(&o).Error
return
}
func (r AccessSecurityRepository) Find(pageIndex, pageSize int, ip, rule, order, field string) (o []model.AccessSecurity, total int64, err error) {
t := model.AccessSecurity{}
db := r.DB.Table(t.TableName())
dbCounter := r.DB.Table(t.TableName())
if len(ip) > 0 {
db = db.Where("ip like ?", "%"+ip+"%")
dbCounter = dbCounter.Where("ip like ?", "%"+ip+"%")
}
if len(rule) > 0 {
db = db.Where("rule = ?", rule)
dbCounter = dbCounter.Where("rule = ?", rule)
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "descend" {
order = "desc"
} else {
order = "asc"
}
if field == "ip" {
field = "ip"
} else if field == "rule" {
field = "rule"
} else {
field = "priority"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.AccessSecurity, 0)
}
return
}
func (r AccessSecurityRepository) Create(o *model.AccessSecurity) error {
return global.DB.Create(o).Error
}
func (r AccessSecurityRepository) UpdateById(o *model.AccessSecurity, id string) error {
o.ID = id
return global.DB.Updates(o).Error
}
func (r AccessSecurityRepository) DeleteById(id string) error {
return global.DB.Where("id = ?", id).Delete(model.AccessSecurity{}).Error
}
func (r AccessSecurityRepository) FindById(id string) (o *model.AccessSecurity, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}

302
server/repository/asset.go Normal file
View File

@ -0,0 +1,302 @@
package repository
import (
"fmt"
"github.com/labstack/echo/v4"
"gorm.io/gorm"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
"strings"
)
type AssetRepository struct {
DB *gorm.DB
}
func NewAssetRepository(db *gorm.DB) *AssetRepository {
assetRepository = &AssetRepository{DB: db}
return assetRepository
}
func (r AssetRepository) FindAll() (o []model.Asset, err error) {
err = r.DB.Find(&o).Error
return
}
func (r AssetRepository) FindByIds(assetIds []string) (o []model.Asset, err error) {
err = r.DB.Where("id in ?", assetIds).Find(&o).Error
return
}
func (r AssetRepository) FindByProtocol(protocol string) (o []model.Asset, err error) {
err = r.DB.Where("protocol = ?", protocol).Find(&o).Error
return
}
func (r AssetRepository) FindByProtocolAndIds(protocol string, assetIds []string) (o []model.Asset, err error) {
err = r.DB.Where("protocol = ? and id in ?", protocol, assetIds).Find(&o).Error
return
}
func (r AssetRepository) FindByProtocolAndUser(protocol string, account model.User) (o []model.Asset, err error) {
db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
}
if len(protocol) > 0 {
db = db.Where("assets.protocol = ?", protocol)
}
err = db.Find(&o).Error
return
}
func (r AssetRepository) Find(pageIndex, pageSize int, name, protocol, tags string, account model.User, owner, sharer, userGroupId, ip, order, field string) (o []model.AssetVo, total int64, err error) {
db := r.DB.Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
dbCounter := r.DB.Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
// 查询用户所在用户组列表
userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(account.ID)
if err != nil {
return nil, 0, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
dbCounter = dbCounter.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
} else {
if len(owner) > 0 {
db = db.Where("assets.owner = ?", owner)
dbCounter = dbCounter.Where("assets.owner = ?", owner)
}
if len(sharer) > 0 {
db = db.Where("resource_sharers.user_id = ?", sharer)
dbCounter = dbCounter.Where("resource_sharers.user_id = ?", sharer)
}
if len(userGroupId) > 0 {
db = db.Where("resource_sharers.user_group_id = ?", userGroupId)
dbCounter = dbCounter.Where("resource_sharers.user_group_id = ?", userGroupId)
}
}
if len(name) > 0 {
db = db.Where("assets.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("assets.name like ?", "%"+name+"%")
}
if len(ip) > 0 {
db = db.Where("assets.ip like ?", "%"+ip+"%")
dbCounter = dbCounter.Where("assets.ip like ?", "%"+ip+"%")
}
if len(protocol) > 0 {
db = db.Where("assets.protocol = ?", protocol)
dbCounter = dbCounter.Where("assets.protocol = ?", protocol)
}
if len(tags) > 0 {
tagArr := strings.Split(tags, ",")
for i := range tagArr {
if global.Config.DB == "sqlite" {
db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
} else {
db = db.Where("find_in_set(?, assets.tags)", tagArr[i])
dbCounter = dbCounter.Where("find_in_set(?, assets.tags)", tagArr[i])
}
}
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("assets." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]model.AssetVo, 0)
}
return
}
func (r AssetRepository) Create(o *model.Asset) (err error) {
if err = r.DB.Create(o).Error; err != nil {
return err
}
return nil
}
func (r AssetRepository) FindById(id string) (o model.Asset, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r AssetRepository) UpdateById(o *model.Asset, id string) error {
o.ID = id
return r.DB.Updates(o).Error
}
func (r AssetRepository) UpdateActiveById(active bool, id string) error {
sql := "update assets set active = ? where id = ?"
return r.DB.Exec(sql, active, id).Error
}
func (r AssetRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(&model.Asset{}).Error
}
func (r AssetRepository) CountAsset() (total int64, err error) {
err = r.DB.Find(&model.Asset{}).Count(&total).Error
return
}
func (r AssetRepository) CountByUserId(userId string) (total int64, err error) {
db := r.DB.Joins("left join resource_sharers on assets.id = resource_sharers.resource_id")
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId)
// 查询用户所在用户组列表
userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId)
if err != nil {
return 0, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
err = db.Find(&model.Asset{}).Count(&total).Error
return
}
func (r AssetRepository) FindTags() (o []string, err error) {
var assets []model.Asset
err = r.DB.Not("tags = ?", "").Find(&assets).Error
if err != nil {
return nil, err
}
o = make([]string, 0)
for i := range assets {
if len(assets[i].Tags) == 0 {
continue
}
split := strings.Split(assets[i].Tags, ",")
o = append(o, split...)
}
return utils.Distinct(o), nil
}
func (r AssetRepository) UpdateAttributes(assetId, protocol string, m echo.Map) error {
var data []model.AssetAttribute
var parameterNames []string
switch protocol {
case "ssh":
parameterNames = constant.SSHParameterNames
case "rdp":
parameterNames = constant.RDPParameterNames
case "vnc":
parameterNames = constant.VNCParameterNames
case "telnet":
parameterNames = constant.TelnetParameterNames
case "kubernetes":
parameterNames = constant.KubernetesParameterNames
}
for i := range parameterNames {
name := parameterNames[i]
if m[name] != nil && m[name] != "" {
data = append(data, genAttribute(assetId, name, m))
}
}
return r.DB.Transaction(func(tx *gorm.DB) error {
err := tx.Where("asset_id = ?", assetId).Delete(&model.AssetAttribute{}).Error
if err != nil {
return err
}
return tx.CreateInBatches(&data, len(data)).Error
})
}
func genAttribute(assetId, name string, m echo.Map) model.AssetAttribute {
value := fmt.Sprintf("%v", m[name])
attribute := model.AssetAttribute{
Id: utils.Sign([]string{assetId, name}),
AssetId: assetId,
Name: name,
Value: value,
}
return attribute
}
func (r AssetRepository) FindAttrById(assetId string) (o []model.AssetAttribute, err error) {
err = r.DB.Where("asset_id = ?", assetId).Find(&o).Error
if o == nil {
o = make([]model.AssetAttribute, 0)
}
return o, err
}
func (r AssetRepository) FindAssetAttrMapByAssetId(assetId string) (map[string]interface{}, error) {
asset, err := r.FindById(assetId)
if err != nil {
return nil, err
}
attributes, err := r.FindAttrById(assetId)
if err != nil {
return nil, err
}
var parameterNames []string
switch asset.Protocol {
case "ssh":
parameterNames = constant.SSHParameterNames
case "rdp":
parameterNames = constant.RDPParameterNames
case "vnc":
parameterNames = constant.VNCParameterNames
case "telnet":
parameterNames = constant.TelnetParameterNames
case "kubernetes":
parameterNames = constant.KubernetesParameterNames
}
propertiesMap := propertyRepository.FindAllMap()
var attributeMap = make(map[string]interface{})
for name := range propertiesMap {
if utils.Contains(parameterNames, name) {
attributeMap[name] = propertiesMap[name]
}
}
for i := range attributes {
attributeMap[attributes[i].Name] = attributes[i].Value
}
return attributeMap, nil
}

View File

@ -0,0 +1,82 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/constant"
"next-terminal/server/global"
"next-terminal/server/model"
)
type CommandRepository struct {
DB *gorm.DB
}
func NewCommandRepository(db *gorm.DB) *CommandRepository {
commandRepository = &CommandRepository{DB: db}
return commandRepository
}
func (r CommandRepository) Find(pageIndex, pageSize int, name, content, order, field string, account model.User) (o []model.CommandVo, total int64, err error) {
db := r.DB.Table("commands").Select("commands.id,commands.name,commands.content,commands.owner,commands.created, users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on commands.owner = users.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id")
dbCounter := r.DB.Table("commands").Select("DISTINCT commands.id").Joins("left join resource_sharers on commands.id = resource_sharers.resource_id").Group("commands.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner)
dbCounter = dbCounter.Where("commands.owner = ? or resource_sharers.user_id = ?", owner, owner)
}
if len(name) > 0 {
db = db.Where("commands.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("commands.name like ?", "%"+name+"%")
}
if len(content) > 0 {
db = db.Where("commands.content like ?", "%"+content+"%")
dbCounter = dbCounter.Where("commands.content like ?", "%"+content+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("commands." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]model.CommandVo, 0)
}
return
}
func (r CommandRepository) Create(o *model.Command) (err error) {
if err = r.DB.Create(o).Error; err != nil {
return err
}
return nil
}
func (r CommandRepository) FindById(id string) (o model.Command, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func (r CommandRepository) UpdateById(o *model.Command, id string) {
o.ID = id
global.DB.Updates(o)
}
func (r CommandRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(&model.Command{}).Error
}

View File

@ -0,0 +1,108 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/constant"
"next-terminal/server/model"
)
type CredentialRepository struct {
DB *gorm.DB
}
func NewCredentialRepository(db *gorm.DB) *CredentialRepository {
credentialRepository = &CredentialRepository{DB: db}
return credentialRepository
}
func (r CredentialRepository) FindAllByUser(account model.User) (o []model.CredentialSimpleVo, err error) {
db := r.DB.Table("credentials").Select("DISTINCT credentials.id,credentials.name").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id")
if account.Type == constant.TypeUser {
db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", account.ID, account.ID)
}
err = db.Find(&o).Error
return
}
func (r CredentialRepository) Find(pageIndex, pageSize int, name, order, field string, account model.User) (o []model.CredentialVo, total int64, err error) {
db := r.DB.Table("credentials").Select("credentials.id,credentials.name,credentials.type,credentials.username,credentials.owner,credentials.created,users.nickname as owner_name,COUNT(resource_sharers.user_id) as sharer_count").Joins("left join users on credentials.owner = users.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id")
dbCounter := r.DB.Table("credentials").Select("DISTINCT credentials.id").Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id").Group("credentials.id")
if constant.TypeUser == account.Type {
owner := account.ID
db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner)
dbCounter = dbCounter.Where("credentials.owner = ? or resource_sharers.user_id = ?", owner, owner)
}
if len(name) > 0 {
db = db.Where("credentials.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("credentials.name like ?", "%"+name+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("credentials." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]model.CredentialVo, 0)
}
return
}
func (r CredentialRepository) Create(o *model.Credential) (err error) {
if err = r.DB.Create(o).Error; err != nil {
return err
}
return nil
}
func (r CredentialRepository) FindById(id string) (o model.Credential, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r CredentialRepository) UpdateById(o *model.Credential, id string) error {
o.ID = id
return r.DB.Updates(o).Error
}
func (r CredentialRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(&model.Credential{}).Error
}
func (r CredentialRepository) Count() (total int64, err error) {
err = r.DB.Find(&model.Credential{}).Count(&total).Error
return
}
func (r CredentialRepository) CountByUserId(userId string) (total int64, err error) {
db := r.DB.Joins("left join resource_sharers on credentials.id = resource_sharers.resource_id")
db = db.Where("credentials.owner = ? or resource_sharers.user_id = ?", userId, userId)
// 查询用户所在用户组列表
userGroupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId)
if err != nil {
return 0, err
}
if len(userGroupIds) > 0 {
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
}
err = db.Find(&model.Credential{}).Count(&total).Error
return
}

View File

@ -0,0 +1,17 @@
package repository
var (
userRepository *UserRepository
userGroupRepository *UserGroupRepository
resourceSharerRepository *ResourceSharerRepository
assetRepository *AssetRepository
credentialRepository *CredentialRepository
propertyRepository *PropertyRepository
commandRepository *CommandRepository
sessionRepository *SessionRepository
numRepository *NumRepository
accessSecurityRepository *AccessSecurityRepository
jobRepository *JobRepository
jobLogRepository *JobLogRepository
loginLogRepository *LoginLogRepository
)

108
server/repository/job.go Normal file
View File

@ -0,0 +1,108 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
)
type JobRepository struct {
DB *gorm.DB
}
func NewJobRepository(db *gorm.DB) *JobRepository {
jobRepository = &JobRepository{DB: db}
return jobRepository
}
func (r JobRepository) Find(pageIndex, pageSize int, name, status, order, field string) (o []model.Job, total int64, err error) {
job := model.Job{}
db := r.DB.Table(job.TableName())
dbCounter := r.DB.Table(job.TableName())
if len(name) > 0 {
db = db.Where("name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("name like ?", "%"+name+"%")
}
if len(status) > 0 {
db = db.Where("status = ?", status)
dbCounter = dbCounter.Where("status = ?", status)
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else if field == "created" {
field = "created"
} else {
field = "updated"
}
err = db.Order(field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.Job, 0)
}
return
}
func (r JobRepository) FindByFunc(function string) (o []model.Job, err error) {
db := r.DB
err = db.Where("func = ?", function).Find(&o).Error
return
}
func (r JobRepository) Create(o *model.Job) (err error) {
//
//if o.Status == constant.JobStatusRunning {
// j, err := getJob(o)
// if err != nil {
// return err
// }
// jobId, err := global.Cron.AddJob(o.Cron, j)
// if err != nil {
// return err
// }
// o.CronJobId = int(jobId)
//}
return r.DB.Create(o).Error
}
func (r JobRepository) UpdateById(o *model.Job) (err error) {
return r.DB.Updates(o).Error
}
func (r JobRepository) UpdateLastUpdatedById(id string) (err error) {
err = r.DB.Updates(model.Job{ID: id, Updated: utils.NowJsonTime()}).Error
return
}
func (r JobRepository) FindById(id string) (o model.Job, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r JobRepository) DeleteJobById(id string) error {
//job, err := r.FindById(id)
//if err != nil {
// return err
//}
//if job.Status == constant.JobStatusRunning {
// if err := r.ChangeStatusById(id, constant.JobStatusNotRunning); err != nil {
// return err
// }
//}
return global.DB.Where("id = ?", id).Delete(model.Job{}).Error
}

View File

@ -0,0 +1,28 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/model"
)
type JobLogRepository struct {
DB *gorm.DB
}
func NewJobLogRepository(db *gorm.DB) *JobLogRepository {
jobLogRepository = &JobLogRepository{DB: db}
return jobLogRepository
}
func (r JobLogRepository) Create(o *model.JobLog) error {
return r.DB.Create(o).Error
}
func (r JobLogRepository) FindByJobId(jobId string) (o []model.JobLog, err error) {
err = r.DB.Where("job_id = ?", jobId).Order("timestamp asc").Find(&o).Error
return
}
func (r JobLogRepository) DeleteByJobId(jobId string) error {
return r.DB.Where("job_id = ?", jobId).Delete(model.JobLog{}).Error
}

View File

@ -0,0 +1,69 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/model"
)
type LoginLogRepository struct {
DB *gorm.DB
}
func NewLoginLogRepository(db *gorm.DB) *LoginLogRepository {
loginLogRepository = &LoginLogRepository{DB: db}
return loginLogRepository
}
func (r LoginLogRepository) Find(pageIndex, pageSize int, userId, clientIp string) (o []model.LoginLogVo, total int64, err error) {
db := r.DB.Table("login_logs").Select("login_logs.id,login_logs.user_id,login_logs.client_ip,login_logs.client_user_agent,login_logs.login_time, login_logs.logout_time, users.nickname as user_name").Joins("left join users on login_logs.user_id = users.id")
dbCounter := r.DB.Table("login_logs").Select("DISTINCT login_logs.id")
if userId != "" {
db = db.Where("login_logs.user_id = ?", userId)
dbCounter = dbCounter.Where("login_logs.user_id = ?", userId)
}
if clientIp != "" {
db = db.Where("login_logs.client_ip like ?", "%"+clientIp+"%")
dbCounter = dbCounter.Where("login_logs.client_ip like ?", "%"+clientIp+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
err = db.Order("login_logs.login_time desc").Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
if o == nil {
o = make([]model.LoginLogVo, 0)
}
return
}
func (r LoginLogRepository) FindAliveLoginLogs() (o []model.LoginLog, err error) {
err = r.DB.Where("logout_time is null").Find(&o).Error
return
}
func (r LoginLogRepository) FindAliveLoginLogsByUserId(userId string) (o []model.LoginLog, err error) {
err = r.DB.Where("logout_time is null and user_id = ?", userId).Find(&o).Error
return
}
func (r LoginLogRepository) Create(o *model.LoginLog) (err error) {
return r.DB.Create(o).Error
}
func (r LoginLogRepository) DeleteByIdIn(ids []string) (err error) {
return r.DB.Where("id in ?", ids).Delete(&model.LoginLog{}).Error
}
func (r LoginLogRepository) FindById(id string) (o model.LoginLog, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r LoginLogRepository) Update(o *model.LoginLog) error {
return r.DB.Updates(o).Error
}

26
server/repository/num.go Normal file
View File

@ -0,0 +1,26 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/global"
"next-terminal/server/model"
)
type NumRepository struct {
DB *gorm.DB
}
func NewNumRepository(db *gorm.DB) *NumRepository {
numRepository = &NumRepository{DB: db}
return numRepository
}
func (r NumRepository) FindAll() (o []model.Num, err error) {
err = r.DB.Find(&o).Error
return
}
func (r NumRepository) Create(o *model.Num) (err error) {
err = global.DB.Create(o).Error
return
}

View File

@ -0,0 +1,90 @@
package repository
import (
"github.com/jordan-wright/email"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
"net/smtp"
"next-terminal/server/constant"
"next-terminal/server/guacd"
"next-terminal/server/model"
)
type PropertyRepository struct {
DB *gorm.DB
}
func NewPropertyRepository(db *gorm.DB) *PropertyRepository {
propertyRepository = &PropertyRepository{DB: db}
return propertyRepository
}
func (r PropertyRepository) FindAll() (o []model.Property) {
if r.DB.Find(&o).Error != nil {
return nil
}
return
}
func (r PropertyRepository) Create(o *model.Property) (err error) {
err = r.DB.Create(o).Error
return
}
func (r PropertyRepository) UpdatePropertyByName(o *model.Property, name string) error {
o.Name = name
return r.DB.Updates(o).Error
}
func (r PropertyRepository) FindByName(name string) (o model.Property, err error) {
err = r.DB.Where("name = ?", name).First(&o).Error
return
}
func (r PropertyRepository) FindAllMap() map[string]string {
properties := r.FindAll()
propertyMap := make(map[string]string)
for i := range properties {
propertyMap[properties[i].Name] = properties[i].Value
}
return propertyMap
}
func (r PropertyRepository) GetDrivePath() (string, error) {
property, err := r.FindByName(guacd.DrivePath)
if err != nil {
return "", err
}
return property.Value, nil
}
func (r PropertyRepository) GetRecordingPath() (string, error) {
property, err := r.FindByName(guacd.RecordingPath)
if err != nil {
return "", err
}
return property.Value, nil
}
func (r PropertyRepository) SendMail(to, subject, text string) {
propertiesMap := r.FindAllMap()
host := propertiesMap[constant.MailHost]
port := propertiesMap[constant.MailPort]
username := propertiesMap[constant.MailUsername]
password := propertiesMap[constant.MailPassword]
if host == "" || port == "" || username == "" || password == "" {
logrus.Debugf("邮箱信息不完整,跳过发送邮件。")
return
}
e := email.NewEmail()
e.From = "Next Terminal <" + username + ">"
e.To = []string{to}
e.Subject = subject
e.Text = []byte(text)
err := e.Send(host+":"+port, smtp.PlainAuth("", username, password, host))
if err != nil {
logrus.Errorf("邮件发送失败: %v", err.Error())
}
}

View File

@ -0,0 +1,193 @@
package repository
import (
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"gorm.io/gorm"
"next-terminal/server/model"
"next-terminal/server/utils"
)
type ResourceSharerRepository struct {
DB *gorm.DB
}
func NewResourceSharerRepository(db *gorm.DB) *ResourceSharerRepository {
resourceSharerRepository = &ResourceSharerRepository{DB: db}
return resourceSharerRepository
}
func (r *ResourceSharerRepository) FindUserIdsByResourceId(resourceId string) (o []string, err error) {
err = r.DB.Table("resource_sharers").Select("user_id").Where("resource_id = ?", resourceId).Find(&o).Error
if o == nil {
o = make([]string, 0)
}
return
}
func (r *ResourceSharerRepository) OverwriteUserIdsByResourceId(resourceId, resourceType string, userIds []string) (err error) {
db := r.DB.Begin()
var owner string
// 检查资产是否存在
switch resourceType {
case "asset":
resource := model.Asset{}
err = db.Where("id = ?", resourceId).First(&resource).Error
owner = resource.Owner
case "command":
resource := model.Command{}
err = db.Where("id = ?", resourceId).First(&resource).Error
owner = resource.Owner
case "credential":
resource := model.Credential{}
err = db.Where("id = ?", resourceId).First(&resource).Error
owner = resource.Owner
}
if err == gorm.ErrRecordNotFound {
return echo.NewHTTPError(404, "资源「"+resourceId+"」不存在")
}
for i := range userIds {
if owner == userIds[i] {
return echo.NewHTTPError(400, "参数错误")
}
}
db.Where("resource_id = ?", resourceId).Delete(&ResourceSharerRepository{})
for i := range userIds {
userId := userIds[i]
if len(userId) == 0 {
continue
}
id := utils.Sign([]string{resourceId, resourceType, userId})
resource := &model.ResourceSharer{
ID: id,
ResourceId: resourceId,
ResourceType: resourceType,
UserId: userId,
}
err = db.Create(resource).Error
if err != nil {
return err
}
}
db.Commit()
return nil
}
func (r *ResourceSharerRepository) DeleteByUserIdAndResourceTypeAndResourceIdIn(userGroupId, userId, resourceType string, resourceIds []string) error {
db := r.DB
if userGroupId != "" {
db = db.Where("user_group_id = ?", userGroupId)
}
if userId != "" {
db = db.Where("user_id = ?", userId)
}
if resourceType != "" {
db = db.Where("resource_type = ?", resourceType)
}
if resourceIds != nil {
db = db.Where("resource_id in ?", resourceIds)
}
return db.Delete(&ResourceSharerRepository{}).Error
}
func (r *ResourceSharerRepository) DeleteResourceSharerByResourceId(resourceId string) error {
return r.DB.Where("resource_id = ?", resourceId).Delete(&ResourceSharerRepository{}).Error
}
func (r *ResourceSharerRepository) AddSharerResources(userGroupId, userId, resourceType string, resourceIds []string) error {
return r.DB.Transaction(func(tx *gorm.DB) (err error) {
for i := range resourceIds {
resourceId := resourceIds[i]
var owner string
// 检查资产是否存在
switch resourceType {
case "asset":
resource := model.Asset{}
if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil {
return errors.Wrap(err, "find asset fail")
}
owner = resource.Owner
case "command":
resource := model.Command{}
if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil {
return errors.Wrap(err, "find command fail")
}
owner = resource.Owner
case "credential":
resource := model.Credential{}
if err = tx.Where("id = ?", resourceId).First(&resource).Error; err != nil {
return errors.Wrap(err, "find credential fail")
}
owner = resource.Owner
}
if owner == userId {
return echo.NewHTTPError(400, "参数错误")
}
id := utils.Sign([]string{resourceId, resourceType, userId, userGroupId})
resource := &model.ResourceSharer{
ID: id,
ResourceId: resourceId,
ResourceType: resourceType,
UserId: userId,
UserGroupId: userGroupId,
}
err = tx.Create(resource).Error
if err != nil {
return err
}
}
return nil
})
}
func (r *ResourceSharerRepository) FindAssetIdsByUserId(userId string) (assetIds []string, err error) {
// 查询当前用户创建的资产
var ownerAssetIds, sharerAssetIds []string
asset := model.Asset{}
err = r.DB.Table(asset.TableName()).Select("id").Where("owner = ?", userId).Find(&ownerAssetIds).Error
if err != nil {
return nil, err
}
// 查询其他用户授权给该用户的资产
groupIds, err := userGroupRepository.FindUserGroupIdsByUserId(userId)
if err != nil {
return nil, err
}
db := r.DB.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId)
if len(groupIds) > 0 {
db = db.Or("user_group_id in ?", groupIds)
}
err = db.Find(&sharerAssetIds).Error
if err != nil {
return nil, err
}
// 合并查询到的资产ID
assetIds = make([]string, 0)
if ownerAssetIds != nil {
assetIds = append(assetIds, ownerAssetIds...)
}
if sharerAssetIds != nil {
assetIds = append(assetIds, sharerAssetIds...)
}
return
}

View File

@ -0,0 +1,167 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/constant"
"next-terminal/server/model"
"os"
"path"
"time"
)
type SessionRepository struct {
DB *gorm.DB
}
func NewSessionRepository(db *gorm.DB) *SessionRepository {
sessionRepository = &SessionRepository{DB: db}
return sessionRepository
}
func (r SessionRepository) Find(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []model.SessionVo, total int64, err error) {
db := r.DB
var params []interface{}
params = append(params, status)
itemSql := "SELECT s.id,s.mode, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? "
countSql := "select count(*) from sessions as s where s.status = ? "
if len(userId) > 0 {
itemSql += " and s.creator = ?"
countSql += " and s.creator = ?"
params = append(params, userId)
}
if len(clientIp) > 0 {
itemSql += " and s.client_ip like ?"
countSql += " and s.client_ip like ?"
params = append(params, "%"+clientIp+"%")
}
if len(assetId) > 0 {
itemSql += " and s.asset_id = ?"
countSql += " and s.asset_id = ?"
params = append(params, assetId)
}
if len(protocol) > 0 {
itemSql += " and s.protocol = ?"
countSql += " and s.protocol = ?"
params = append(params, protocol)
}
params = append(params, (pageIndex-1)*pageSize, pageSize)
itemSql += " order by s.connected_time desc LIMIT ?, ?"
db.Raw(countSql, params...).Scan(&total)
err = db.Raw(itemSql, params...).Scan(&results).Error
if results == nil {
results = make([]model.SessionVo, 0)
}
return
}
func (r SessionRepository) FindByStatus(status string) (o []model.Session, err error) {
err = r.DB.Where("status = ?", status).Find(&o).Error
return
}
func (r SessionRepository) FindByStatusIn(statuses []string) (o []model.Session, err error) {
err = r.DB.Where("status in ?", statuses).Find(&o).Error
return
}
func (r SessionRepository) FindOutTimeSessions(dayLimit int) (o []model.Session, err error) {
limitTime := time.Now().Add(time.Duration(-dayLimit*24) * time.Hour)
err = r.DB.Where("status = ? and connected_time < ?", constant.Disconnected, limitTime).Find(&o).Error
return
}
func (r SessionRepository) Create(o *model.Session) (err error) {
err = r.DB.Create(o).Error
return
}
func (r SessionRepository) FindById(id string) (o model.Session, err error) {
err = r.DB.Where("id = ?", id).First(&o).Error
return
}
func (r SessionRepository) FindByConnectionId(connectionId string) (o model.Session, err error) {
err = r.DB.Where("connection_id = ?", connectionId).First(&o).Error
return
}
func (r SessionRepository) UpdateById(o *model.Session, id string) error {
o.ID = id
return r.DB.Updates(o).Error
}
func (r SessionRepository) UpdateWindowSizeById(width, height int, id string) error {
session := model.Session{}
session.Width = width
session.Height = height
return r.UpdateById(&session, id)
}
func (r SessionRepository) DeleteById(id string) error {
return r.DB.Where("id = ?", id).Delete(&model.Session{}).Error
}
func (r SessionRepository) DeleteByIds(sessionIds []string) error {
drivePath, err := propertyRepository.GetRecordingPath()
if err != nil {
return err
}
for i := range sessionIds {
if err := os.RemoveAll(path.Join(drivePath, sessionIds[i])); err != nil {
return err
}
if err := r.DeleteById(sessionIds[i]); err != nil {
return err
}
}
return nil
}
func (r SessionRepository) DeleteByStatus(status string) error {
return r.DB.Where("status = ?", status).Delete(&model.Session{}).Error
}
func (r SessionRepository) CountOnlineSession() (total int64, err error) {
err = r.DB.Where("status = ?", constant.Connected).Find(&model.Session{}).Count(&total).Error
return
}
type D struct {
Day string `json:"day"`
Count int `json:"count"`
Protocol string `json:"protocol"`
}
func (r SessionRepository) CountSessionByDay(day int) (results []D, err error) {
today := time.Now().Format("20060102")
sql := "select t1.`day`, count(t2.id) as count\nfrom (\n SELECT @date := DATE_ADD(@date, INTERVAL - 1 DAY) day\n FROM (SELECT @date := DATE_ADD('" + today + "', INTERVAL + 1 DAY) FROM nums) as t0\n LIMIT ?\n )\n as t1\n left join\n (\n select DATE(s.connected_time) as day, s.id\n from sessions as s\n WHERE protocol = ? and DATE(connected_time) <= '" + today + "'\n AND DATE(connected_time) > DATE_SUB('" + today + "', INTERVAL ? DAY)\n ) as t2 on t1.day = t2.day\ngroup by t1.day"
protocols := []string{"rdp", "ssh", "vnc", "telnet"}
for i := range protocols {
var result []D
err = r.DB.Raw(sql, day, protocols[i], day).Scan(&result).Error
if err != nil {
return nil, err
}
for j := range result {
result[j].Protocol = protocols[i]
}
results = append(results, result...)
}
return
}

View File

@ -9,6 +9,11 @@ type UserRepository struct {
DB *gorm.DB
}
func NewUserRepository(db *gorm.DB) *UserRepository {
userRepository = &UserRepository{DB: db}
return userRepository
}
func (r UserRepository) FindAll() (o []model.User) {
if r.DB.Find(&o).Error != nil {
return nil

View File

@ -0,0 +1,140 @@
package repository
import (
"gorm.io/gorm"
"next-terminal/server/global"
"next-terminal/server/model"
"next-terminal/server/utils"
)
type UserGroupRepository struct {
DB *gorm.DB
}
func NewUserGroupRepository(db *gorm.DB) *UserGroupRepository {
userGroupRepository = &UserGroupRepository{DB: db}
return userGroupRepository
}
func (r UserGroupRepository) FindAll() (o []model.UserGroup) {
if r.DB.Find(&o).Error != nil {
return nil
}
return
}
func (r UserGroupRepository) Find(pageIndex, pageSize int, name, order, field string) (o []model.UserGroup, total int64, err error) {
db := r.DB.Table("user_groups").Select("user_groups.id, user_groups.name, user_groups.created, count(resource_sharers.user_group_id) as asset_count").Joins("left join resource_sharers on user_groups.id = resource_sharers.user_group_id and resource_sharers.resource_type = 'asset'").Group("user_groups.id")
dbCounter := r.DB.Table("user_groups")
if len(name) > 0 {
db = db.Where("user_groups.name like ?", "%"+name+"%")
dbCounter = dbCounter.Where("name like ?", "%"+name+"%")
}
err = dbCounter.Count(&total).Error
if err != nil {
return nil, 0, err
}
if order == "ascend" {
order = "asc"
} else {
order = "desc"
}
if field == "name" {
field = "name"
} else {
field = "created"
}
err = db.Order("user_groups." + field + " " + order).Find(&o).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Error
if o == nil {
o = make([]model.UserGroup, 0)
}
return
}
func (r UserGroupRepository) FindById(id string) (o model.UserGroup, err error) {
err = global.DB.Where("id = ?", id).First(&o).Error
return
}
func (r UserGroupRepository) FindUserGroupIdsByUserId(userId string) (o []string, err error) {
// 先查询用户所在的用户
err = r.DB.Table("user_group_members").Select("user_group_id").Where("user_id = ?", userId).Find(&o).Error
return
}
func (r UserGroupRepository) FindUserGroupMembersByUserGroupId(userGroupId string) (o []string, err error) {
err = r.DB.Table("user_group_members").Select("user_id").Where("user_group_id = ?", userGroupId).Find(&o).Error
return
}
func (r UserGroupRepository) Create(o *model.UserGroup, members []string) (err error) {
return r.DB.Transaction(func(tx *gorm.DB) error {
err = tx.Create(o).Error
if err != nil {
return err
}
if members != nil {
userGroupId := o.ID
err = AddUserGroupMembers(tx, members, userGroupId)
if err != nil {
return err
}
}
return err
})
}
func (r UserGroupRepository) Update(o *model.UserGroup, members []string, id string) error {
return r.DB.Transaction(func(tx *gorm.DB) error {
o.ID = id
err := tx.Updates(o).Error
if err != nil {
return err
}
err = tx.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{}).Error
if err != nil {
return err
}
if members != nil {
userGroupId := o.ID
err = AddUserGroupMembers(tx, members, userGroupId)
if err != nil {
return err
}
}
return err
})
}
func (r UserGroupRepository) DeleteById(id string) {
r.DB.Where("id = ?", id).Delete(&model.UserGroup{})
r.DB.Where("user_group_id = ?", id).Delete(&model.UserGroupMember{})
}
func AddUserGroupMembers(tx *gorm.DB, userIds []string, userGroupId string) error {
userRepository := NewUserRepository(tx)
for i := range userIds {
userId := userIds[i]
_, err := userRepository.FindById(userId)
if err != nil {
return err
}
userGroupMember := model.UserGroupMember{
ID: utils.Sign([]string{userGroupId, userId}),
UserId: userId,
UserGroupId: userGroupId,
}
err = tx.Create(&userGroupMember).Error
if err != nil {
return err
}
}
return nil
}