next-terminal/server/sshd/sshd.go

147 lines
3.4 KiB
Go

package sshd
import (
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"next-terminal/server/config"
"next-terminal/server/constant"
"next-terminal/server/global/security"
"next-terminal/server/log"
"next-terminal/server/repository"
"next-terminal/server/service"
"next-terminal/server/utils"
"github.com/gliderlabs/ssh"
"gorm.io/gorm"
)
type Sshd struct {
gui *Gui
}
func init() {
gui := &Gui{}
sshd := &Sshd{
gui: gui,
}
go sshd.Serve()
}
func (sshd Sshd) passwordAuth(ctx ssh.Context, pass string) bool {
username := ctx.User()
remoteAddr := strings.Split(ctx.RemoteAddr().String(), ":")[0]
user, err := repository.UserRepository.FindByUsername(context.TODO(), username)
if err != nil {
// 保存登录日志
_ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确")
return false
}
if err := utils.Encoder.Match([]byte(user.Password), []byte(pass)); err != nil {
// 保存登录日志
_ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, false, false, "", "账号或密码不正确")
return false
}
return true
}
func (sshd Sshd) connCallback(ctx ssh.Context, conn net.Conn) net.Conn {
securities := security.GlobalSecurityManager.Values()
if len(securities) == 0 {
return conn
}
ip := strings.Split(conn.RemoteAddr().String(), ":")[0]
for _, s := range securities {
if strings.Contains(s.IP, "/") {
// CIDR
_, ipNet, err := net.ParseCIDR(s.IP)
if err != nil {
continue
}
if !ipNet.Contains(net.ParseIP(ip)) {
continue
}
} else if strings.Contains(s.IP, "-") {
// 范围段
split := strings.Split(s.IP, "-")
if len(split) < 2 {
continue
}
start := split[0]
end := split[1]
intReqIP := utils.IpToInt(ip)
if intReqIP < utils.IpToInt(start) || intReqIP > utils.IpToInt(end) {
continue
}
} else {
// IP
if s.IP != ip {
continue
}
}
if s.Rule == constant.AccessRuleAllow {
return conn
}
if s.Rule == constant.AccessRuleReject {
_, _ = conn.Write([]byte("your access request was denied :(\n"))
return nil
}
}
return conn
}
func (sshd Sshd) sessionHandler(sess *ssh.Session) {
defer func() {
_ = (*sess).Close()
}()
username := (*sess).User()
remoteAddr := strings.Split((*sess).RemoteAddr().String(), ":")[0]
user, err := repository.UserRepository.FindByUsername(context.TODO(), username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
_, _ = io.WriteString(*sess, "您输入的账户或密码不正确.\n")
} else {
_, _ = io.WriteString(*sess, err.Error())
}
return
}
// 判断是否需要进行双因素认证
if user.TOTPSecret != "" && user.TOTPSecret != "-" {
sshd.gui.totpUI(sess, user, remoteAddr, username)
} else {
// 保存登录日志
_ = service.UserService.SaveLoginLog(remoteAddr, "terminal", username, true, false, utils.LongUUID(), "")
sshd.gui.MainUI(sess, user)
}
}
func (sshd Sshd) Serve() {
ssh.Handle(func(s ssh.Session) {
_, _ = io.WriteString(s, fmt.Sprintf(constant.Banner, constant.Version))
sshd.sessionHandler(&s)
})
fmt.Printf("⇨ sshd server started on %v\n", config.GlobalCfg.Sshd.Addr)
err := ssh.ListenAndServe(
config.GlobalCfg.Sshd.Addr,
nil,
ssh.PasswordAuth(sshd.passwordAuth),
ssh.HostKeyFile(config.GlobalCfg.Sshd.Key),
ssh.WrapConn(sshd.connCallback),
)
log.Fatal(fmt.Sprintf("启动sshd服务失败: %v", err.Error()))
}