优化接入网关,解决协程泄漏的问题

This commit is contained in:
dushixiang 2022-05-07 16:48:05 +08:00
parent 41768cbec9
commit 5695d6b2e2
16 changed files with 158 additions and 191 deletions

View File

@ -5,6 +5,7 @@ import (
"strconv"
"strings"
"next-terminal/server/global/gateway"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/service"
@ -28,7 +29,7 @@ func (api AccessGatewayApi) AccessGatewayCreateEndpoint(c echo.Context) error {
return err
}
// 连接网关
service.GatewayService.ReConnect(&item)
service.GatewayService.ReLoad(&item)
return Success(c, "")
}
@ -58,12 +59,11 @@ func (api AccessGatewayApi) AccessGatewayPagingEndpoint(c echo.Context) error {
return err
}
for i := 0; i < len(items); i++ {
g, err := service.GatewayService.GetGatewayById(items[i].ID)
if err != nil {
return err
g := gateway.GlobalGatewayManager.GetById(items[i].ID)
if g != nil {
items[i].Connected = g.Connected
items[i].Message = g.Message
}
items[i].Connected = g.Connected
items[i].Message = g.Message
}
return Success(c, Map{
@ -83,7 +83,7 @@ func (api AccessGatewayApi) AccessGatewayUpdateEndpoint(c echo.Context) error {
if err := repository.GatewayRepository.UpdateById(context.TODO(), &item, id); err != nil {
return err
}
service.GatewayService.ReConnect(&item)
service.GatewayService.ReLoad(&item)
return Success(c, nil)
}
@ -110,14 +110,3 @@ func (api AccessGatewayApi) AccessGatewayGetEndpoint(c echo.Context) error {
return Success(c, item)
}
func (api AccessGatewayApi) AccessGatewayReconnectEndpoint(c echo.Context) error {
id := c.Param("id")
item, err := repository.GatewayRepository.FindById(context.TODO(), id)
if err != nil {
return err
}
service.GatewayService.ReConnect(&item)
return Success(c, "")
}

View File

@ -72,15 +72,13 @@ func (api GuacamoleApi) Guacamole(c echo.Context) error {
api.setConfig(propertyMap, s, configuration)
if s.AccessGatewayId != "" && s.AccessGatewayId != "-" {
g, err := service.GatewayService.GetGatewayAndReconnectById(s.AccessGatewayId)
g, err := service.GatewayService.GetGatewayById(s.AccessGatewayId)
if err != nil {
utils.Disconnect(ws, AccessGatewayUnAvailable, "获取接入网关失败:"+err.Error())
return nil
}
if !g.Connected {
utils.Disconnect(ws, AccessGatewayUnAvailable, "接入网关不可用:"+g.Message)
return nil
}
defer g.CloseSshTunnel(s.ID)
exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, s.IP, s.Port)
if err != nil {
utils.Disconnect(ws, AccessGatewayCreateError, "创建SSH隧道失败"+err.Error())
@ -88,7 +86,6 @@ func (api GuacamoleApi) Guacamole(c echo.Context) error {
}
s.IP = exposedIP
s.Port = exposedPort
defer g.CloseSshTunnel(s.ID)
}
configuration.SetParameter("hostname", s.IP)

View File

@ -70,18 +70,16 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error {
)
if s.AccessGatewayId != "" && s.AccessGatewayId != "-" {
g, err := service.GatewayService.GetGatewayAndReconnectById(s.AccessGatewayId)
g, err := service.GatewayService.GetGatewayById(s.AccessGatewayId)
if err != nil {
return WriteMessage(ws, dto.NewMessage(Closed, "获取接入网关失败:"+err.Error()))
}
if !g.Connected {
return WriteMessage(ws, dto.NewMessage(Closed, "接入网关不可用:"+g.Message))
}
defer g.CloseSshTunnel(s.ID)
exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port)
if err != nil {
return WriteMessage(ws, dto.NewMessage(Closed, "创建隧道失败:"+err.Error()))
}
defer g.CloseSshTunnel(s.ID)
ip = exposedIP
port = exposedPort
}

View File

@ -36,7 +36,7 @@ func (app App) InitDBData() (err error) {
if err := service.PropertyService.DeleteDeprecatedProperty(); err != nil {
return err
}
if err := service.GatewayService.ReConnectAll(); err != nil {
if err := service.GatewayService.LoadAll(); err != nil {
return err
}
if err := service.PropertyService.InitProperties(); err != nil {

View File

@ -271,7 +271,6 @@ func setupRoutes() *echo.Echo {
accessGateways.PUT("/:id", AccessGatewayApi.AccessGatewayUpdateEndpoint)
accessGateways.DELETE("/:id", AccessGatewayApi.AccessGatewayDeleteEndpoint)
accessGateways.GET("/:id", AccessGatewayApi.AccessGatewayGetEndpoint)
accessGateways.POST("/:id/reconnect", AccessGatewayApi.AccessGatewayReconnectEndpoint)
}
backup := e.Group("/backup", Admin)

View File

@ -1,13 +1,13 @@
package gateway
import (
"context"
"errors"
"fmt"
"net"
"os"
"sync"
"next-terminal/server/term"
"next-terminal/server/utils"
"golang.org/x/crypto/ssh"
@ -15,33 +15,35 @@ import (
// Gateway 接入网关
type Gateway struct {
ID string // 接入网关ID
Connected bool // 是否已连接
SshClient *ssh.Client
Message string // 失败原因
ID string // 接入网关ID
IP string
Port int
Username string
Password string
PrivateKey string
Passphrase string
Connected bool // 是否已连接
Message string // 失败原因
SshClient *ssh.Client
tunnels sync.Map
}
func NewGateway(id string, connected bool, message string, client *ssh.Client) *Gateway {
return &Gateway{
ID: id,
Connected: connected,
Message: message,
SshClient: client,
}
}
func (g *Gateway) Close() {
g.tunnels.Range(func(key, value interface{}) bool {
g.CloseSshTunnel(key.(string))
return true
})
mutex sync.Mutex
tunnels map[string]*Tunnel
}
func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) {
g.mutex.Lock()
defer g.mutex.Unlock()
if !g.Connected {
return "", 0, errors.New(g.Message)
sshClient, err := term.NewSshClient(g.IP, g.Port, g.Username, g.Password, g.PrivateKey, g.Passphrase)
if err != nil {
g.Connected = false
g.Message = "接入网关不可用:" + err.Error()
return "", 0, errors.New(g.Message)
} else {
g.Connected = true
g.SshClient = sshClient
g.Message = "使用中"
}
}
localPort, err := utils.GetAvailablePort()
@ -63,30 +65,39 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo
return "", 0, err
}
ctx, cancel := context.WithCancel(context.Background())
tunnel := &Tunnel{
ID: id,
LocalHost: hostname,
//LocalHost: "docker.for.mac.host.internal",
LocalPort: localPort,
Gateway: g,
RemoteHost: ip,
RemotePort: port,
ctx: ctx,
cancel: cancel,
id: id,
localHost: hostname,
//localHost: "docker.for.mac.host.internal",
localPort: localPort,
remoteHost: ip,
remotePort: port,
listener: listener,
}
go tunnel.Open()
g.tunnels.Store(tunnel.ID, tunnel)
go tunnel.Open(g.SshClient)
g.tunnels[tunnel.id] = tunnel
return tunnel.LocalHost, tunnel.LocalPort, nil
return tunnel.localHost, tunnel.localPort, nil
}
func (g *Gateway) CloseSshTunnel(id string) {
if value, ok := g.tunnels.Load(id); ok {
if tunnel, vok := value.(*Tunnel); vok {
tunnel.Close()
g.tunnels.Delete(id)
}
g.mutex.Lock()
defer g.mutex.Unlock()
t := g.tunnels[id]
if t != nil {
t.Close()
delete(g.tunnels, id)
}
if len(g.tunnels) == 0 {
_ = g.SshClient.Close()
g.Connected = false
g.Message = "暂未使用"
}
}
func (g *Gateway) Close() {
for id := range g.tunnels {
g.CloseSshTunnel(id)
}
}

View File

@ -4,35 +4,50 @@ import (
"sync"
"next-terminal/server/log"
"next-terminal/server/model"
)
type Manager struct {
type manager struct {
gateways sync.Map
}
func NewManager() *Manager {
return &Manager{}
}
func (m *Manager) GetById(id string) *Gateway {
func (m *manager) GetById(id string) *Gateway {
if val, ok := m.gateways.Load(id); ok {
return val.(*Gateway)
}
return nil
}
func (m *Manager) Add(g *Gateway) {
func (m *manager) Add(model *model.AccessGateway) *Gateway {
g := &Gateway{
ID: model.ID,
IP: model.IP,
Port: model.Port,
Username: model.Username,
Password: model.Password,
PrivateKey: model.PrivateKey,
Passphrase: model.Passphrase,
Connected: false,
SshClient: nil,
Message: "暂未使用",
tunnels: make(map[string]*Tunnel),
}
m.gateways.Store(g.ID, g)
log.Infof("add gateway: %s", g.ID)
log.Infof("add Gateway: %s", g.ID)
return g
}
func (m *Manager) Del(id string) {
func (m *manager) Del(id string) {
g := m.GetById(id)
if g != nil {
g.Close()
}
m.gateways.Delete(id)
log.Infof("del gateway: %s", id)
log.Infof("del Gateway: %s", id)
}
var GlobalGatewayManager *Manager
var GlobalGatewayManager *manager
func init() {
GlobalGatewayManager = NewManager()
GlobalGatewayManager = &manager{}
}

View File

@ -1,59 +1,55 @@
package gateway
import (
"context"
"fmt"
"io"
"net"
"next-terminal/server/log"
"golang.org/x/crypto/ssh"
)
type Tunnel struct {
ID string // 唯一标识
LocalHost string // 本地监听地址
LocalPort int // 本地端口
RemoteHost string // 远程连接地址
RemotePort int // 远程端口
Gateway *Gateway
ctx context.Context
cancel context.CancelFunc
id string // 唯一标识
localHost string // 本地监听地址
localPort int // 本地端口
remoteHost string // 远程连接地址
remotePort int // 远程端口
listener net.Listener
localConnections []net.Conn
remoteConnections []net.Conn
}
func (r *Tunnel) Open() {
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
func (r *Tunnel) Open(sshClient *ssh.Client) {
localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort)
go func() {
<-r.ctx.Done()
_ = r.listener.Close()
log.Debugf("SSH 隧道 %v 关闭", localAddr)
}()
for {
log.Debugf("等待客户端访问 %v", localAddr)
log.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr)
localConn, err := r.listener.Accept()
if err != nil {
log.Debugf("接受连接失败 %v, 退出循环", err.Error())
log.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error())
log.Debug("-------------------------------------------------")
return
}
r.localConnections = append(r.localConnections, localConn)
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
log.Debugf("连接远程主机 %v ...", remoteAddr)
remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr)
log.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String())
remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort)
log.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr)
remoteConn, err := sshClient.Dial("tcp", remoteAddr)
if err != nil {
log.Debugf("连接远程主机 %v 失败", remoteAddr)
log.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error())
log.Debug("-------------------------------------------------")
return
}
r.remoteConnections = append(r.remoteConnections, remoteConn)
log.Debugf("连接远程主机 %v 成功", remoteAddr)
log.Debugf("隧道 %v 连接远程主机成功", r.id)
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
log.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
log.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
}
}
@ -66,7 +62,8 @@ func (r *Tunnel) Close() {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
r.cancel()
_ = r.listener.Close()
log.Debugf("隧道 %v 监听器关闭", r.id)
}
func copyConn(writer, reader net.Conn) {

View File

@ -116,29 +116,24 @@ func (s assetService) FindByIdAndDecrypt(c context.Context, id string) (model.As
return asset, nil
}
func (s assetService) CheckStatus(accessGatewayId string, ip string, port int) (active bool, err error) {
func (s assetService) CheckStatus(accessGatewayId string, ip string, port int) (bool, error) {
if accessGatewayId != "" && accessGatewayId != "-" {
g, e1 := GatewayService.GetGatewayAndReconnectById(accessGatewayId)
if e1 != nil {
return false, e1
g, err := GatewayService.GetGatewayById(accessGatewayId)
if err != nil {
return false, err
}
uuid := utils.UUID()
exposedIP, exposedPort, e2 := g.OpenSshTunnel(uuid, ip, port)
if e2 != nil {
return false, e2
}
defer g.CloseSshTunnel(uuid)
if g.Connected {
active, err = utils.Tcping(exposedIP, exposedPort)
} else {
active = false
exposedIP, exposedPort, err := g.OpenSshTunnel(uuid, ip, port)
if err != nil {
return false, err
}
} else {
active, err = utils.Tcping(ip, port)
return utils.Tcping(exposedIP, exposedPort)
}
return active, err
return utils.Tcping(ip, port)
}
func (s assetService) Create(ctx context.Context, m echo.Map) (model.Asset, error) {
@ -182,7 +177,7 @@ func (s assetService) create(c context.Context, item model.Asset, m echo.Map) er
// active, _ := s.CheckStatus(item.AccessGatewayId, item.IP, item.Port)
//
// if item.Active != active {
// _ = repository.AssetRepository.UpdateActiveById(context.TODO(), active, item.ID)
// _ = repository.AssetRepository.UpdateActiveById(context.TODO(), active, item.id)
// }
//}()
return nil

View File

@ -7,23 +7,10 @@ import (
"next-terminal/server/log"
"next-terminal/server/model"
"next-terminal/server/repository"
"next-terminal/server/term"
)
type gatewayService struct{}
func (r gatewayService) GetGatewayAndReconnectById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
if g == nil || !g.Connected {
accessGateway, err := repository.GatewayRepository.FindById(context.TODO(), accessGatewayId)
if err != nil {
return nil, err
}
g = r.ReConnect(&accessGateway)
}
return g, nil
}
func (r gatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gateway, err error) {
g = gateway.GlobalGatewayManager.GetById(accessGatewayId)
if g == nil {
@ -31,40 +18,32 @@ func (r gatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gatew
if err != nil {
return nil, err
}
g = r.ReConnect(&accessGateway)
g = r.ReLoad(&accessGateway)
}
return g, nil
}
func (r gatewayService) ReConnectAll() error {
func (r gatewayService) LoadAll() error {
gateways, err := repository.GatewayRepository.FindAll(context.TODO())
if err != nil {
return err
}
if len(gateways) > 0 {
for i := range gateways {
r.ReConnect(&gateways[i])
r.ReLoad(&gateways[i])
}
}
return nil
}
func (r gatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway {
func (r gatewayService) ReLoad(m *model.AccessGateway) *gateway.Gateway {
log.Debugf("重建接入网关「%v」中...", m.Name)
r.DisconnectById(m.ID)
sshClient, err := term.NewSshClient(m.IP, m.Port, m.Username, m.Password, m.PrivateKey, m.Passphrase)
var g *gateway.Gateway
if err != nil {
g = gateway.NewGateway(m.ID, false, err.Error(), nil)
} else {
g = gateway.NewGateway(m.ID, true, "", sshClient)
}
gateway.GlobalGatewayManager.Add(g)
g := gateway.GlobalGatewayManager.Add(m)
log.Debugf("重建接入网关「%v」完成", m.Name)
return g
}
func (r gatewayService) DisconnectById(accessGatewayId string) {
gateway.GlobalGatewayManager.Del(accessGatewayId)
func (r gatewayService) DisconnectById(id string) {
gateway.GlobalGatewayManager.Del(id)
}

View File

@ -125,16 +125,16 @@ func (r ShellJob) Run() {
func exec(shell, accessGatewayId, ip string, port int, username, password, privateKey, passphrase string) (string, error) {
if accessGatewayId != "" && accessGatewayId != "-" {
g, err := GatewayService.GetGatewayAndReconnectById(accessGatewayId)
g, err := GatewayService.GetGatewayById(accessGatewayId)
if err != nil {
return "", err
}
uuid := utils.UUID()
defer g.CloseSshTunnel(uuid)
exposedIP, exposedPort, err := g.OpenSshTunnel(uuid, ip, port)
if err != nil {
return "", err
}
defer g.CloseSshTunnel(uuid)
return ExecCommandBySSH(shell, exposedIP, exposedPort, username, password, privateKey, passphrase)
} else {
return ExecCommandBySSH(shell, ip, port, username, password, privateKey, passphrase)

View File

@ -61,7 +61,7 @@ func (service userService) InitUser() (err error) {
if err := repository.UserRepository.Update(context.TODO(), &user); err != nil {
return err
}
log.Infof("自动修正用户「%v」ID「%v」类型为管理员", users[i].Nickname, users[i].ID)
log.Infof("自动修正用户「%v」id「%v」类型为管理员", users[i].Nickname, users[i].ID)
}
}
}

View File

@ -189,18 +189,16 @@ func (gui Gui) handleAccessAsset(sess *ssh.Session, sessionId string) (err error
)
if s.AccessGatewayId != "" && s.AccessGatewayId != "-" {
g, err := service.GatewayService.GetGatewayAndReconnectById(s.AccessGatewayId)
g, err := service.GatewayService.GetGatewayById(s.AccessGatewayId)
if err != nil {
return errors.New("获取接入网关失败:" + err.Error())
}
if !g.Connected {
return errors.New("接入网关不可用:" + g.Message)
}
defer g.CloseSshTunnel(s.ID)
exposedIP, exposedPort, err := g.OpenSshTunnel(s.ID, ip, port)
if err != nil {
return errors.New("开启SSH隧道失败" + err.Error())
}
defer g.CloseSshTunnel(s.ID)
ip = exposedIP
port = exposedPort
}

View File

@ -55,9 +55,9 @@ func (t *Ticker) deleteUnUsedSession() {
err := repository.SessionRepository.DeleteById(context.TODO(), sessions[i].ID)
s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port)
if err != nil {
log.Errorf("会话「%v」ID「%v」超过1小时未打开删除失败: %v", s, sessions[i].ID, err.Error())
log.Errorf("会话「%v」id「%v」超过1小时未打开删除失败: %v", s, sessions[i].ID, err.Error())
} else {
log.Infof("会话「%v」ID「%v」超过1小时未打开已删除。", s, sessions[i].ID)
log.Infof("会话「%v」id「%v」超过1小时未打开已删除。", s, sessions[i].ID)
}
}
}

View File

@ -0,0 +1,15 @@
package utils
import "sync"
type KeyedMutex struct {
mutexes sync.Map // Zero value is empty and ready for use
}
func (m *KeyedMutex) Lock(key string) func() {
value, _ := m.mutexes.LoadOrStore(key, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
return func() { mtx.Unlock() }
}

View File

@ -141,29 +141,6 @@ class AccessGateway extends Component {
await this.showModal('更新接入网关', result.data);
}
async reconnect(id, index) {
let items = this.state.items;
try {
items[index]['reconnectLoading'] = true;
this.setState({
items: items
});
message.info({content: '正在重连中...', key: id, duration: 5});
let result = await request.post(`/access-gateways/${id}/reconnect`);
if (result.code !== 1) {
message.error({content: result.message, key: id, duration: 10});
return;
}
message.success({content: '重连完成。', key: id, duration: 3});
this.loadTableData(this.state.queryParams);
} finally {
items[index]['reconnectLoading'] = false;
this.setState({
items: items
});
}
}
showModal(title, obj) {
this.setState({
modalTitle: title,
@ -357,9 +334,6 @@ class AccessGateway extends Component {
<Button type="link" size='small'
onClick={() => this.update(record['id'])}>编辑</Button>
<Button type="link" size='small' loading={this.state.items[index]['reconnectLoading']}
onClick={() => this.reconnect(record['id'], index)}>重连</Button>
<Button type="text" size='small' danger
onClick={() => this.showDeleteConfirm(record.id, record.name)}>删除</Button>