250 lines
6.1 KiB
Go
250 lines
6.1 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"net"
|
||
"strings"
|
||
"time"
|
||
|
||
"next-terminal/server/model"
|
||
"next-terminal/server/repository"
|
||
"next-terminal/server/utils"
|
||
)
|
||
|
||
var LoginPolicyService = new(loginPolicyService)
|
||
|
||
type loginPolicyService struct {
|
||
baseService
|
||
}
|
||
|
||
func (s loginPolicyService) Create(c context.Context, m *model.LoginPolicy) error {
|
||
return s.Transaction(c, func(ctx context.Context) error {
|
||
if err := repository.LoginPolicyRepository.Create(ctx, m); err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(m.TimePeriod) > 0 {
|
||
for i := range m.TimePeriod {
|
||
m.TimePeriod[i].ID = utils.UUID()
|
||
m.TimePeriod[i].LoginPolicyId = m.ID
|
||
}
|
||
if err := repository.TimePeriodRepository.CreateInBatches(ctx, m.TimePeriod); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (s loginPolicyService) DeleteByIds(ctx context.Context, ids []string) error {
|
||
return s.Transaction(ctx, func(ctx context.Context) error {
|
||
for _, id := range ids {
|
||
if err := repository.LoginPolicyRepository.DeleteById(ctx, id); err != nil {
|
||
return err
|
||
}
|
||
if err := repository.LoginPolicyUserRefRepository.DeleteByLoginPolicyId(ctx, id); err != nil {
|
||
return err
|
||
}
|
||
if err := repository.TimePeriodRepository.DeleteByLoginPolicyId(ctx, id); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (s loginPolicyService) UpdateById(ctx context.Context, m *model.LoginPolicy, id string) error {
|
||
return s.Transaction(ctx, func(ctx context.Context) error {
|
||
if err := repository.LoginPolicyRepository.UpdateById(ctx, m, id); err != nil {
|
||
return err
|
||
}
|
||
if err := repository.TimePeriodRepository.DeleteByLoginPolicyId(ctx, id); err != nil {
|
||
return err
|
||
}
|
||
if len(m.TimePeriod) > 0 {
|
||
for i := range m.TimePeriod {
|
||
m.TimePeriod[i].ID = utils.UUID()
|
||
m.TimePeriod[i].LoginPolicyId = m.ID
|
||
}
|
||
if err := repository.TimePeriodRepository.CreateInBatches(ctx, m.TimePeriod); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (s loginPolicyService) FindById(ctx context.Context, id string) (*model.LoginPolicy, error) {
|
||
policy, err := repository.LoginPolicyRepository.FindById(ctx, id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
timePeriods, err := repository.TimePeriodRepository.FindByLoginPolicyId(ctx, id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
policy.TimePeriod = timePeriods
|
||
|
||
return &policy, nil
|
||
}
|
||
|
||
func (s loginPolicyService) Check(userId, clientIp string) error {
|
||
ctx := context.Background()
|
||
// 按照优先级倒排进行查询
|
||
policies, err := repository.LoginPolicyRepository.FindByUserId(ctx, userId)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(policies) == 0 {
|
||
return nil
|
||
}
|
||
|
||
if err := s.checkClientIp(policies, clientIp); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := s.checkWeekDay(policies); err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (s loginPolicyService) checkClientIp(policies []model.LoginPolicy, clientIp string) error {
|
||
var pass = true
|
||
// 优先级低的先进行判断
|
||
for _, policy := range policies {
|
||
if !policy.Enabled {
|
||
continue
|
||
}
|
||
ipGroups := strings.Split(policy.IpGroup, ",")
|
||
for _, group := range ipGroups {
|
||
if strings.Contains(group, "/") {
|
||
// CIDR
|
||
_, ipNet, err := net.ParseCIDR(group)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
if !ipNet.Contains(net.ParseIP(clientIp)) {
|
||
continue
|
||
}
|
||
} else if strings.Contains(group, "-") {
|
||
// 范围段
|
||
split := strings.Split(group, "-")
|
||
if len(split) < 2 {
|
||
continue
|
||
}
|
||
start := split[0]
|
||
end := split[1]
|
||
intReqIP := utils.IpToInt(clientIp)
|
||
if intReqIP < utils.IpToInt(start) || intReqIP > utils.IpToInt(end) {
|
||
continue
|
||
}
|
||
} else {
|
||
// IP
|
||
if group != clientIp {
|
||
continue
|
||
}
|
||
}
|
||
pass = policy.Rule == "allow"
|
||
}
|
||
}
|
||
|
||
if !pass {
|
||
return errors.New("非常抱歉,您当前使用的IP地址不允许进行登录。")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (s loginPolicyService) checkWeekDay(policies []model.LoginPolicy) error {
|
||
// 获取当前日期是星期几
|
||
now := time.Now()
|
||
weekday := int(now.Weekday())
|
||
hwc := now.Format("15:04")
|
||
|
||
var timePass = true
|
||
|
||
// 优先级低的先进行判断
|
||
for _, policy := range policies {
|
||
if !policy.Enabled {
|
||
continue
|
||
}
|
||
timePeriods, err := repository.TimePeriodRepository.FindByLoginPolicyId(context.Background(), policy.ID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, period := range timePeriods {
|
||
if weekday != period.Key {
|
||
continue
|
||
}
|
||
if period.Value == "" {
|
||
continue
|
||
}
|
||
// 只处理对应天的数据
|
||
times := strings.Split(period.Value, "、")
|
||
for _, t := range times {
|
||
timeRange := strings.Split(t, "~")
|
||
start := timeRange[0]
|
||
end := timeRange[1]
|
||
if (start == "00:00" && end == "00:00") || (start <= hwc && hwc <= end) {
|
||
timePass = policy.Rule == "allow"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !timePass {
|
||
return errors.New("非常抱歉,当前时段不允许您进行登录。")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s loginPolicyService) Bind(ctx context.Context, loginPolicyId string, items []model.LoginPolicyUserRef) error {
|
||
return s.Transaction(ctx, func(ctx context.Context) error {
|
||
var results []model.LoginPolicyUserRef
|
||
for i := range items {
|
||
if items[i].UserId == "" {
|
||
continue
|
||
}
|
||
exist, err := repository.UserRepository.ExistById(ctx, items[i].UserId)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
if !exist {
|
||
continue
|
||
}
|
||
refId := utils.Sign([]string{items[i].UserId, loginPolicyId})
|
||
if err := repository.LoginPolicyUserRefRepository.DeleteId(ctx, refId); err != nil {
|
||
return err
|
||
}
|
||
results = append(results, model.LoginPolicyUserRef{
|
||
ID: refId,
|
||
UserId: items[i].UserId,
|
||
LoginPolicyId: loginPolicyId,
|
||
})
|
||
}
|
||
if len(results) == 0 {
|
||
return nil
|
||
}
|
||
|
||
return repository.LoginPolicyUserRefRepository.CreateInBatches(ctx, results)
|
||
})
|
||
}
|
||
|
||
func (s loginPolicyService) Unbind(ctx context.Context, loginPolicyId string, items []model.LoginPolicyUserRef) error {
|
||
return s.Transaction(ctx, func(ctx context.Context) error {
|
||
for i := range items {
|
||
if items[i].UserId == "" {
|
||
continue
|
||
}
|
||
if err := repository.LoginPolicyUserRefRepository.DeleteByLoginPolicyIdAndUserId(ctx, loginPolicyId, items[i].UserId); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|