提交 v1.3.0 beta
This commit is contained in:
@ -3,16 +3,19 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"next-terminal/server/common/nt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"next-terminal/server/config"
|
||||
"next-terminal/server/constant"
|
||||
"next-terminal/server/model"
|
||||
"next-terminal/server/utils"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
var AssetRepository = new(assetRepository)
|
||||
|
||||
type assetRepository struct {
|
||||
baseRepository
|
||||
}
|
||||
@ -28,7 +31,11 @@ func (r assetRepository) FindByIds(c context.Context, assetIds []string) (o []mo
|
||||
}
|
||||
|
||||
func (r assetRepository) FindByProtocol(c context.Context, protocol string) (o []model.Asset, err error) {
|
||||
err = r.GetDB(c).Where("protocol = ?", protocol).Find(&o).Error
|
||||
db := r.GetDB(c)
|
||||
if protocol != "" {
|
||||
db = db.Where("protocol = ?", protocol)
|
||||
}
|
||||
err = db.Order("name asc").Find(&o).Error
|
||||
return
|
||||
}
|
||||
|
||||
@ -37,65 +44,9 @@ func (r assetRepository) FindByProtocolAndIds(c context.Context, protocol string
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) FindByProtocolAndUser(c context.Context, protocol string, account model.User) (o []model.Asset, err error) {
|
||||
db := r.GetDB(c).Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
|
||||
|
||||
if constant.TypeUser == account.Type {
|
||||
owner := account.ID
|
||||
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
|
||||
|
||||
// 查询用户所在用户组列表
|
||||
userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, account.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(userGroupIds) > 0 {
|
||||
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
|
||||
}
|
||||
}
|
||||
|
||||
if len(protocol) > 0 {
|
||||
db = db.Where("assets.protocol = ?", protocol)
|
||||
}
|
||||
err = db.Find(&o).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) Find(c context.Context, pageIndex, pageSize int, name, protocol, tags string, account *model.User, owner, sharer, userGroupId, ip, order, field string) (o []model.AssetForPage, total int64, err error) {
|
||||
db := r.GetDB(c).Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
|
||||
dbCounter := r.GetDB(c).Table("assets").Select("DISTINCT assets.id").Joins("left join resource_sharers on assets.id = resource_sharers.resource_id").Group("assets.id")
|
||||
|
||||
if constant.TypeUser == account.Type {
|
||||
owner := account.ID
|
||||
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
|
||||
dbCounter = dbCounter.Where("assets.owner = ? or resource_sharers.user_id = ?", owner, owner)
|
||||
|
||||
// 查询用户所在用户组列表
|
||||
userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, account.ID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if len(userGroupIds) > 0 {
|
||||
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
|
||||
dbCounter = dbCounter.Or("resource_sharers.user_group_id in ?", userGroupIds)
|
||||
}
|
||||
} else {
|
||||
if len(owner) > 0 {
|
||||
db = db.Where("assets.owner = ?", owner)
|
||||
dbCounter = dbCounter.Where("assets.owner = ?", owner)
|
||||
}
|
||||
if len(sharer) > 0 {
|
||||
db = db.Where("resource_sharers.user_id = ?", sharer)
|
||||
dbCounter = dbCounter.Where("resource_sharers.user_id = ?", sharer)
|
||||
}
|
||||
|
||||
if len(userGroupId) > 0 {
|
||||
db = db.Where("resource_sharers.user_group_id = ?", userGroupId)
|
||||
dbCounter = dbCounter.Where("resource_sharers.user_group_id = ?", userGroupId)
|
||||
}
|
||||
}
|
||||
func (r assetRepository) Find(c context.Context, pageIndex, pageSize int, name, protocol, tags, ip, active, order, field string) (o []model.AssetForPage, total int64, err error) {
|
||||
db := r.GetDB(c).Table("assets").Select("assets.id,assets.name,assets.ip,assets.port,assets.protocol,assets.active,assets.active_message,assets.owner,assets.created,assets.tags,assets.description, users.nickname as owner_name").Joins("left join users on assets.owner = users.id")
|
||||
dbCounter := r.GetDB(c).Table("assets")
|
||||
|
||||
if len(name) > 0 {
|
||||
db = db.Where("assets.name like ?", "%"+name+"%")
|
||||
@ -125,6 +76,14 @@ func (r assetRepository) Find(c context.Context, pageIndex, pageSize int, name,
|
||||
}
|
||||
}
|
||||
|
||||
if active != "" {
|
||||
_active, err := strconv.ParseBool(active)
|
||||
if err == nil {
|
||||
db = db.Where("assets.active = ?", _active)
|
||||
dbCounter = dbCounter.Where("assets.active = ?", _active)
|
||||
}
|
||||
}
|
||||
|
||||
err = dbCounter.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@ -146,22 +105,6 @@ func (r assetRepository) Find(c context.Context, pageIndex, pageSize int, name,
|
||||
|
||||
if o == nil {
|
||||
o = make([]model.AssetForPage, 0)
|
||||
} else {
|
||||
for i := 0; i < len(o); i++ {
|
||||
if o[i].Protocol == "ssh" {
|
||||
attributes, err := r.FindAttrById(c, o[i].ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for j := range attributes {
|
||||
if attributes[j].Name == constant.SshMode {
|
||||
o[i].SshMode = attributes[j].Value
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -180,9 +123,9 @@ func (r assetRepository) UpdateById(c context.Context, o *model.Asset, id string
|
||||
return r.GetDB(c).Updates(o).Error
|
||||
}
|
||||
|
||||
func (r assetRepository) UpdateActiveById(c context.Context, active bool, id string) error {
|
||||
sql := "update assets set active = ? where id = ?"
|
||||
return r.GetDB(c).Exec(sql, active, id).Error
|
||||
func (r assetRepository) UpdateActiveById(c context.Context, active bool, message, id string) error {
|
||||
sql := "update assets set active = ?, active_message = ? where id = ?"
|
||||
return r.GetDB(c).Exec(sql, active, message, id).Error
|
||||
}
|
||||
|
||||
func (r assetRepository) DeleteById(c context.Context, assetId string) (err error) {
|
||||
@ -198,47 +141,16 @@ func (r assetRepository) Count(c context.Context) (total int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) CountByActive(c context.Context, active bool) (total int64, err error) {
|
||||
err = r.GetDB(c).Find(&model.Asset{}).Where("active = ?", active).Count(&total).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) CountByProtocol(c context.Context, protocol string) (total int64, err error) {
|
||||
err = r.GetDB(c).Find(&model.Asset{}).Where("protocol = ?", protocol).Count(&total).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) CountByUserId(c context.Context, userId string) (total int64, err error) {
|
||||
db := r.GetDB(c).Joins("left join resource_sharers on assets.id = resource_sharers.resource_id")
|
||||
|
||||
db = db.Where("assets.owner = ? or resource_sharers.user_id = ?", userId, userId)
|
||||
|
||||
// 查询用户所在用户组列表
|
||||
userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, userId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(userGroupIds) > 0 {
|
||||
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
|
||||
}
|
||||
err = db.Find(&model.Asset{}).Count(&total).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) CountByUserIdAndProtocol(c context.Context, userId, protocol string) (total int64, err error) {
|
||||
db := r.GetDB(c).Joins("left join resource_sharers on assets.id = resource_sharers.resource_id")
|
||||
|
||||
db = db.Where("( assets.owner = ? or resource_sharers.user_id = ? ) and assets.protocol = ?", userId, userId, protocol)
|
||||
|
||||
// 查询用户所在用户组列表
|
||||
userGroupIds, err := UserGroupMemberRepository.FindUserGroupIdsByUserId(c, userId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(userGroupIds) > 0 {
|
||||
db = db.Or("resource_sharers.user_group_id in ?", userGroupIds)
|
||||
}
|
||||
err = db.Find(&model.Asset{}).Count(&total).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) FindTags(c context.Context) (o []string, err error) {
|
||||
var assets []model.Asset
|
||||
err = r.GetDB(c).Not("tags = '' or tags = '-' ").Find(&assets).Error
|
||||
@ -265,15 +177,15 @@ func (r assetRepository) UpdateAttributes(c context.Context, assetId, protocol s
|
||||
var parameterNames []string
|
||||
switch protocol {
|
||||
case "ssh":
|
||||
parameterNames = constant.SSHParameterNames
|
||||
parameterNames = nt.SSHParameterNames
|
||||
case "rdp":
|
||||
parameterNames = constant.RDPParameterNames
|
||||
parameterNames = nt.RDPParameterNames
|
||||
case "vnc":
|
||||
parameterNames = constant.VNCParameterNames
|
||||
parameterNames = nt.VNCParameterNames
|
||||
case "telnet":
|
||||
parameterNames = constant.TelnetParameterNames
|
||||
parameterNames = nt.TelnetParameterNames
|
||||
case "kubernetes":
|
||||
parameterNames = constant.KubernetesParameterNames
|
||||
parameterNames = nt.KubernetesParameterNames
|
||||
}
|
||||
|
||||
for i := range parameterNames {
|
||||
@ -322,15 +234,15 @@ func (r assetRepository) FindAssetAttrMapByAssetId(c context.Context, assetId st
|
||||
var parameterNames []string
|
||||
switch asset.Protocol {
|
||||
case "ssh":
|
||||
parameterNames = constant.SSHParameterNames
|
||||
parameterNames = nt.SSHParameterNames
|
||||
case "rdp":
|
||||
parameterNames = constant.RDPParameterNames
|
||||
parameterNames = nt.RDPParameterNames
|
||||
case "vnc":
|
||||
parameterNames = constant.VNCParameterNames
|
||||
parameterNames = nt.VNCParameterNames
|
||||
case "telnet":
|
||||
parameterNames = constant.TelnetParameterNames
|
||||
parameterNames = nt.TelnetParameterNames
|
||||
case "kubernetes":
|
||||
parameterNames = constant.KubernetesParameterNames
|
||||
parameterNames = nt.KubernetesParameterNames
|
||||
}
|
||||
propertiesMap := PropertyRepository.FindAllMap(c)
|
||||
var attributeMap = make(map[string]string)
|
||||
@ -350,3 +262,91 @@ func (r assetRepository) UpdateAttrs(c context.Context, name, value, newValue st
|
||||
sql := "update asset_attributes set value = ? where name = ? and value = ?"
|
||||
return r.GetDB(c).Exec(sql, newValue, name, value).Error
|
||||
}
|
||||
|
||||
func (r assetRepository) ExistById(c context.Context, id string) (bool, error) {
|
||||
m := model.Asset{}
|
||||
var count uint64
|
||||
err := r.GetDB(c).Table(m.TableName()).Select("count(*)").
|
||||
Where("id = ?", id).
|
||||
Find(&count).
|
||||
Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r assetRepository) FindMyAssets(c context.Context, pageIndex, pageSize int, name, protocol, tags string, assetIds []string, order, field string) (o []model.AssetForPage, total int64, err error) {
|
||||
db := r.GetDB(c).Table("assets").Select("assets.id,assets.name,assets.protocol,assets.active,assets.active_message,assets.tags,assets.description").
|
||||
Where("id in ?", assetIds)
|
||||
dbCounter := r.GetDB(c).Table("assets").Where("id in ?", assetIds)
|
||||
|
||||
if len(name) > 0 {
|
||||
db = db.Where("assets.name like ?", "%"+name+"%")
|
||||
dbCounter = dbCounter.Where("assets.name like ?", "%"+name+"%")
|
||||
}
|
||||
|
||||
if len(protocol) > 0 {
|
||||
db = db.Where("assets.protocol = ?", protocol)
|
||||
dbCounter = dbCounter.Where("assets.protocol = ?", protocol)
|
||||
}
|
||||
|
||||
if len(tags) > 0 {
|
||||
tagArr := strings.Split(tags, ",")
|
||||
for i := range tagArr {
|
||||
if config.GlobalCfg.DB == "sqlite" {
|
||||
db = db.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
|
||||
dbCounter = dbCounter.Where("(',' || assets.tags || ',') LIKE ?", "%,"+tagArr[i]+",%")
|
||||
} else {
|
||||
db = db.Where("find_in_set(?, assets.tags)", tagArr[i])
|
||||
dbCounter = dbCounter.Where("find_in_set(?, assets.tags)", tagArr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = dbCounter.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if order == "ascend" {
|
||||
order = "asc"
|
||||
} else {
|
||||
order = "desc"
|
||||
}
|
||||
|
||||
if field == "name" {
|
||||
field = "name"
|
||||
} else {
|
||||
field = "created"
|
||||
}
|
||||
|
||||
err = db.Order("assets." + field + " " + order).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&o).Error
|
||||
|
||||
if o == nil {
|
||||
o = make([]model.AssetForPage, 0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r assetRepository) FindMyAssetTags(c context.Context, assetIds []string) (o []string, err error) {
|
||||
|
||||
var assets []model.Asset
|
||||
err = r.GetDB(c).Not("tags = '' or tags = '-' ").Where("id in ?", assetIds).Find(&assets).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o = make([]string, 0)
|
||||
|
||||
for i := range assets {
|
||||
if len(assets[i].Tags) == 0 {
|
||||
continue
|
||||
}
|
||||
split := strings.Split(assets[i].Tags, ",")
|
||||
|
||||
o = append(o, split...)
|
||||
}
|
||||
|
||||
return utils.Distinct(o), nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user