增加普通用户访问资产时的校验 close #44

This commit is contained in:
dushixiang 2021-01-26 21:54:44 +08:00
parent 4f1bfa6c5a
commit d771ad6ab6
7 changed files with 71 additions and 13 deletions

View File

@ -180,6 +180,18 @@ func SessionCreateEndpoint(c echo.Context) error {
assetId := c.QueryParam("assetId") assetId := c.QueryParam("assetId")
user, _ := GetCurrentAccount(c) user, _ := GetCurrentAccount(c)
if model.TypeUser == user.Type {
// 检测是否有访问权限
assetIds, err := model.FindAssetIdsByUserId(user.ID)
if err != nil {
return err
}
if !utils.Contains(assetIds, assetId) {
return errors.New("您没有权限访问此资产")
}
}
asset, err := model.FindAssetById(assetId) asset, err := model.FindAssetById(assetId)
if err != nil { if err != nil {
return err return err

View File

@ -72,6 +72,17 @@ func UserDeleteEndpoint(c echo.Context) error {
if account.ID == userId { if account.ID == userId {
return Fail(c, -1, "不允许删除自身账户") return Fail(c, -1, "不允许删除自身账户")
} }
// 将用户强制下线
loginLogs, err := model.FindAliveLoginLogsByUserId(userId)
if err != nil {
return err
}
if loginLogs != nil && len(loginLogs) > 0 {
for j := range loginLogs {
model.Logout(loginLogs[j].ID)
}
}
// 删除用户
model.DeleteUserById(userId) model.DeleteUserById(userId)
} }

View File

@ -81,15 +81,15 @@ func FindLoginLogById(id string) (o LoginLog, err error) {
return return
} }
func Logout(id string) { func Logout(token string) {
loginLog, err := FindLoginLogById(id) loginLog, err := FindLoginLogById(token)
if err != nil { if err != nil {
logrus.Warnf("登录日志「%v」获取失败", id) logrus.Warnf("登录日志「%v」获取失败", token)
return return
} }
global.DB.Table("login_logs").Where("id = ?", id).Update("logout_time", utils.NowJsonTime()) global.DB.Table("login_logs").Where("token = ?", token).Update("logout_time", utils.NowJsonTime())
loginLogs, err := FindAliveLoginLogsByUserId(loginLog.UserId) loginLogs, err := FindAliveLoginLogsByUserId(loginLog.UserId)
if err != nil { if err != nil {

View File

@ -145,3 +145,21 @@ func AddSharerResources(userGroupId, userId, resourceType string, resourceIds []
return nil return nil
}) })
} }
func FindAssetIdsByUserId(userId string) (assetIds []string, err error) {
groupIds, err := FindUserGroupIdsByUserId(userId)
if err != nil {
return nil, err
}
db := global.DB
db = db.Table("resource_sharers").Select("resource_id").Where("user_id = ?", userId)
if groupIds != nil && len(groupIds) > 0 {
db = db.Or("user_group_id in ?", groupIds)
}
err = db.Find(&assetIds).Error
if assetIds == nil {
assetIds = make([]string, 0)
}
return
}

View File

@ -101,6 +101,10 @@ func UpdateUserById(o *User, id string) {
func DeleteUserById(id string) { func DeleteUserById(id string) {
global.DB.Where("id = ?", id).Delete(&User{}) global.DB.Where("id = ?", id).Delete(&User{})
// 删除用户组中的用户关系
global.DB.Where("user_id = ?", id).Delete(&UserGroupMember{})
// 删除用户分享到的资产
global.DB.Where("user_id = ?", id).Delete(&ResourceSharer{})
} }
func CountUser() (total int64, err error) { func CountUser() (total int64, err error) {

View File

@ -143,3 +143,12 @@ func Sign(a []string) string {
has := md5.Sum(data) has := md5.Sum(data)
return fmt.Sprintf("%x", has) return fmt.Sprintf("%x", has)
} }
func Contains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}

View File

@ -28,7 +28,8 @@ import {
CloudUploadOutlined, CloudUploadOutlined,
CopyOutlined, CopyOutlined,
DeleteOutlined, DeleteOutlined,
DesktopOutlined, ExclamationCircleOutlined, DesktopOutlined,
ExclamationCircleOutlined,
ExpandOutlined, ExpandOutlined,
FileZipOutlined, FileZipOutlined,
FolderAddOutlined, FolderAddOutlined,
@ -37,7 +38,7 @@ import {
UploadOutlined UploadOutlined
} from '@ant-design/icons'; } from '@ant-design/icons';
import Upload from "antd/es/upload"; import Upload from "antd/es/upload";
import {download, exitFull, getToken, requestFullScreen} from "../../utils/utils"; import {download, exitFull, getToken, isEmpty, requestFullScreen} from "../../utils/utils";
import './Access.css' import './Access.css'
import Draggable from 'react-draggable'; import Draggable from 'react-draggable';
@ -100,6 +101,9 @@ class Access extends Component {
let assetsId = params.get('assetsId'); let assetsId = params.get('assetsId');
let protocol = params.get('protocol'); let protocol = params.get('protocol');
let sessionId = await this.createSession(assetsId); let sessionId = await this.createSession(assetsId);
if (isEmpty(sessionId)) {
return;
}
this.setState({ this.setState({
sessionId: sessionId, sessionId: sessionId,
@ -428,12 +432,12 @@ class Access extends Component {
async createSession(assetsId) { async createSession(assetsId) {
let result = await request.post(`/sessions?assetId=${assetsId}`); let result = await request.post(`/sessions?assetId=${assetsId}`);
if (result.code !== 1) { if (result['code'] !== 1) {
message.error(result.message, 10); this.showMessage(result['message']);
return; return null;
} }
document.title = result.data['ip'] + ':' + result.data['port']; document.title = result['data']['ip'] + ':' + result['data']['port'];
return result.data['id']; return result['data']['id'];
} }
async renderDisplay(sessionId, protocol) { async renderDisplay(sessionId, protocol) {