diff --git a/server/api/guacamole.go b/server/api/guacamole.go
index 94c0084..7210ba4 100644
--- a/server/api/guacamole.go
+++ b/server/api/guacamole.go
@@ -139,8 +139,7 @@ func (api GuacamoleApi) Guacamole(c echo.Context) error {
}
nextSession.Observer = session.NewObserver(sessionId)
- session.GlobalSessionManager.Add <- nextSession
- go nextSession.Observer.Start()
+ session.GlobalSessionManager.Add(nextSession)
sess := model.Session{
ConnectionId: guacdTunnel.UUID,
Width: intWidth,
@@ -254,7 +253,7 @@ func (api GuacamoleApi) GuacamoleMonitor(c echo.Context) error {
return nil
}
nextSession.ID = utils.UUID()
- forObsSession.Observer.Add <- nextSession
+ forObsSession.Observer.Add(nextSession)
log.Debugf("[%v:%v] 观察者[%v]加入会话[%v]", sessionId, connectionId, nextSession.ID, s.ConnectionId)
guacamoleHandler := NewGuacamoleHandler(ws, guacdTunnel)
@@ -269,7 +268,7 @@ func (api GuacamoleApi) GuacamoleMonitor(c echo.Context) error {
_ = guacdTunnel.Close()
observerId := nextSession.ID
- forObsSession.Observer.Del <- observerId
+ forObsSession.Observer.Del(observerId)
log.Debugf("[%v:%v] 观察者[%v]退出会话", sessionId, connectionId, observerId)
return nil
}
diff --git a/server/api/security.go b/server/api/security.go
index abed4b9..37d73ee 100644
--- a/server/api/security.go
+++ b/server/api/security.go
@@ -35,7 +35,7 @@ func (api SecurityApi) SecurityCreateEndpoint(c echo.Context) error {
Rule: item.Rule,
Priority: item.Priority,
}
- security.GlobalSecurityManager.Add <- rule
+ security.GlobalSecurityManager.Add(rule)
return Success(c, "")
}
@@ -72,14 +72,14 @@ func (api SecurityApi) SecurityUpdateEndpoint(c echo.Context) error {
return err
}
// 更新内存中的安全规则
- security.GlobalSecurityManager.Del <- id
+ security.GlobalSecurityManager.Del(id)
rule := &security.Security{
ID: item.ID,
IP: item.IP,
Rule: item.Rule,
Priority: item.Priority,
}
- security.GlobalSecurityManager.Add <- rule
+ security.GlobalSecurityManager.Add(rule)
return Success(c, nil)
}
@@ -94,7 +94,7 @@ func (api SecurityApi) SecurityDeleteEndpoint(c echo.Context) error {
return err
}
// 更新内存中的安全规则
- security.GlobalSecurityManager.Del <- id
+ security.GlobalSecurityManager.Del(id)
}
return Success(c, nil)
diff --git a/server/api/term.go b/server/api/term.go
index a09f5ec..9b83215 100644
--- a/server/api/term.go
+++ b/server/api/term.go
@@ -152,7 +152,7 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error {
NextTerminal: nextTerminal,
Observer: session.NewObserver(s.ID),
}
- session.GlobalSessionManager.Add <- nextSession
+ session.GlobalSessionManager.Add(nextSession)
termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal)
termHandler.Start()
@@ -239,14 +239,14 @@ func (api WebTerminalApi) SshMonitorEndpoint(c echo.Context) error {
Mode: s.Mode,
WebSocket: ws,
}
- nextSession.Observer.Add <- obSession
+ nextSession.Observer.Add(obSession)
log.Debugf("会话 %v 观察者 %v 进入", sessionId, obId)
for {
_, _, err := ws.ReadMessage()
if err != nil {
log.Debugf("会话 %v 观察者 %v 退出", sessionId, obId)
- nextSession.Observer.Del <- obId
+ nextSession.Observer.Del(obId)
break
}
}
diff --git a/server/api/term_handler.go b/server/api/term_handler.go
index 253c713..bb9565d 100644
--- a/server/api/term_handler.go
+++ b/server/api/term_handler.go
@@ -85,11 +85,10 @@ func (r *TermHandler) writeToWebsocket() {
}
nextSession := session.GlobalSessionManager.GetById(r.sessionId)
// 监控
- if nextSession != nil && len(nextSession.Observer.All()) > 0 {
- obs := nextSession.Observer.All()
- for _, ob := range obs {
+ if nextSession != nil && nextSession.Observer != nil {
+ nextSession.Observer.Range(func(key string, ob *session.Session) {
_ = ob.WriteMessage(dto.NewMessage(Data, s))
- }
+ })
}
buf = []byte{}
}
diff --git a/server/global/gateway/gateway.go b/server/global/gateway/gateway.go
index 4d38550..0d8b038 100644
--- a/server/global/gateway/gateway.go
+++ b/server/global/gateway/gateway.go
@@ -8,8 +8,6 @@ import (
"os"
"sync"
- "next-terminal/server/log"
-
"next-terminal/server/utils"
"golang.org/x/crypto/ssh"
@@ -22,11 +20,7 @@ type Gateway struct {
SshClient *ssh.Client
Message string // 失败原因
- tunnels *sync.Map
-
- Add chan *Tunnel
- Del chan string
- exit chan bool
+ tunnels sync.Map
}
func NewGateway(id string, connected bool, message string, client *ssh.Client) *Gateway {
@@ -35,42 +29,14 @@ func NewGateway(id string, connected bool, message string, client *ssh.Client) *
Connected: connected,
Message: message,
SshClient: client,
- Add: make(chan *Tunnel),
- Del: make(chan string),
- tunnels: new(sync.Map),
- exit: make(chan bool, 1),
- }
-}
-
-func (g *Gateway) Run() {
- for {
- select {
- case t := <-g.Add:
- g.tunnels.Store(t.ID, t)
- log.Info("add tunnel: %s", t.ID)
- go t.Open()
- case k := <-g.Del:
- if val, ok := g.tunnels.Load(k); ok {
- if vval, vok := val.(*Tunnel); vok {
- vval.Close()
- g.tunnels.Delete(k)
- }
- }
- case <-g.exit:
- return
- }
}
}
func (g *Gateway) Close() {
g.tunnels.Range(func(key, value interface{}) bool {
- if val, ok := value.(*Tunnel); ok {
- val.Close()
- }
+ g.CloseSshTunnel(key.(string))
return true
})
- g.exit <- true
-
}
func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) {
@@ -110,11 +76,17 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo
cancel: cancel,
listener: listener,
}
- g.Add <- tunnel
+ go tunnel.Open()
+ g.tunnels.Store(tunnel.ID, tunnel)
return tunnel.LocalHost, tunnel.LocalPort, nil
}
-func (g Gateway) CloseSshTunnel(id string) {
- g.Del <- id
+func (g *Gateway) CloseSshTunnel(id string) {
+ if value, ok := g.tunnels.Load(id); ok {
+ if tunnel, vok := value.(*Tunnel); vok {
+ tunnel.Close()
+ g.tunnels.Delete(id)
+ }
+ }
}
diff --git a/server/global/gateway/manager.go b/server/global/gateway/manager.go
index e2b0245..f208c33 100644
--- a/server/global/gateway/manager.go
+++ b/server/global/gateway/manager.go
@@ -7,49 +7,32 @@ import (
)
type Manager struct {
- gateways *sync.Map
-
- Add chan *Gateway
- Del chan string
+ gateways sync.Map
}
func NewManager() *Manager {
- return &Manager{
- Add: make(chan *Gateway),
- Del: make(chan string),
- gateways: new(sync.Map),
- }
+ return &Manager{}
}
-func (m *Manager) Start() {
- for {
- select {
- case g := <-m.Add:
- m.gateways.Store(g.ID, g)
- log.Info("add gateway: %s", g.ID)
- go g.Run()
- case k := <-m.Del:
- if val, ok := m.gateways.Load(k); ok {
- if vv, vok := val.(*Gateway); vok {
- vv.Close()
- m.gateways.Delete(k)
- }
-
- }
- }
- }
-}
-
-func (m Manager) GetById(id string) *Gateway {
+func (m *Manager) GetById(id string) *Gateway {
if val, ok := m.gateways.Load(id); ok {
return val.(*Gateway)
}
return nil
}
+func (m *Manager) Add(g *Gateway) {
+ m.gateways.Store(g.ID, g)
+ log.Infof("add gateway: %s", g.ID)
+}
+
+func (m *Manager) Del(id string) {
+ m.gateways.Delete(id)
+ log.Infof("del gateway: %s", id)
+}
+
var GlobalGatewayManager *Manager
func init() {
GlobalGatewayManager = NewManager()
- go GlobalGatewayManager.Start()
}
diff --git a/server/global/gateway/tunnel.go b/server/global/gateway/tunnel.go
index 307231c..385180e 100644
--- a/server/global/gateway/tunnel.go
+++ b/server/global/gateway/tunnel.go
@@ -10,23 +10,26 @@ import (
)
type Tunnel struct {
- ID string // 唯一标识
- LocalHost string // 本地监听地址
- LocalPort int // 本地端口
- RemoteHost string // 远程连接地址
- RemotePort int // 远程端口
- Gateway *Gateway
- ctx context.Context
- cancel context.CancelFunc
- listener net.Listener
- err error
+ ID string // 唯一标识
+ LocalHost string // 本地监听地址
+ LocalPort int // 本地端口
+ RemoteHost string // 远程连接地址
+ RemotePort int // 远程端口
+ Gateway *Gateway
+ ctx context.Context
+ cancel context.CancelFunc
+ listener net.Listener
+ localConnections []net.Conn
+ remoteConnections []net.Conn
}
func (r *Tunnel) Open() {
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
+
go func() {
<-r.ctx.Done()
_ = r.listener.Close()
+ log.Debugf("SSH 隧道 %v 关闭", localAddr)
}()
for {
log.Debugf("等待客户端访问 %v", localAddr)
@@ -35,6 +38,7 @@ func (r *Tunnel) Open() {
log.Debugf("接受连接失败 %v, 退出循环", err.Error())
return
}
+ r.localConnections = append(r.localConnections, localConn)
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
@@ -44,27 +48,27 @@ func (r *Tunnel) Open() {
log.Debugf("连接远程主机 %v 失败", remoteAddr)
return
}
+ r.remoteConnections = append(r.remoteConnections, remoteConn)
log.Debugf("连接远程主机 %v 成功", remoteAddr)
- go copyConn(r.ctx, localConn, remoteConn)
- go copyConn(r.ctx, remoteConn, localConn)
+ go copyConn(localConn, remoteConn)
+ go copyConn(remoteConn, localConn)
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
}
}
-func (r Tunnel) Close() {
+func (r *Tunnel) Close() {
+ for i := range r.localConnections {
+ _ = r.localConnections[i].Close()
+ }
+ r.localConnections = nil
+ for i := range r.remoteConnections {
+ _ = r.remoteConnections[i].Close()
+ }
+ r.remoteConnections = nil
r.cancel()
}
-func copyConn(ctx context.Context, writer, reader net.Conn) {
+func copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader)
- for {
- select {
- case <-ctx.Done():
- _ = writer.Close()
- _ = reader.Close()
- return
- }
- }
-
}
diff --git a/server/global/security/security.go b/server/global/security/security.go
index 9831ed1..86b62a7 100644
--- a/server/global/security/security.go
+++ b/server/global/security/security.go
@@ -1,6 +1,11 @@
package security
-import "sort"
+import (
+ "sort"
+ "sync"
+
+ "next-terminal/server/log"
+)
type Security struct {
ID string
@@ -10,45 +15,29 @@ type Security struct {
}
type Manager struct {
- securities map[string]*Security
+ securities sync.Map
values []*Security
-
- Add chan *Security
- Del chan string
}
func NewManager() *Manager {
- return &Manager{
- Add: make(chan *Security),
- Del: make(chan string),
- securities: map[string]*Security{},
- }
-}
-
-func (m *Manager) Start() {
- for {
- select {
- case s := <-m.Add:
- m.securities[s.ID] = s
- m.LoadData()
- case s := <-m.Del:
- if _, ok := m.securities[s]; ok {
- delete(m.securities, s)
- m.LoadData()
- }
- }
- }
+ return &Manager{}
}
func (m *Manager) Clear() {
- m.securities = map[string]*Security{}
+ m.securities.Range(func(k, _ interface{}) bool {
+ m.securities.Delete(k)
+ return true
+ })
}
func (m *Manager) LoadData() {
var values []*Security
- for _, security := range m.securities {
- values = append(values, security)
- }
+ m.securities.Range(func(key, value interface{}) bool {
+ if security, ok := value.(*Security); ok {
+ values = append(values, security)
+ }
+ return true
+ })
sort.Slice(values, func(i, j int) bool {
// 优先级数字越小代表优先级越高,因此此处用小于号
@@ -58,13 +47,24 @@ func (m *Manager) LoadData() {
m.values = values
}
-func (m Manager) Values() []*Security {
+func (m *Manager) Values() []*Security {
return m.values
}
+func (m *Manager) Add(s *Security) {
+ m.securities.Store(s.ID, s)
+ m.LoadData()
+ log.Infof("add security: %s", s.ID)
+}
+
+func (m *Manager) Del(id string) {
+ m.securities.Delete(id)
+ m.LoadData()
+ log.Infof("del security: %s", id)
+}
+
var GlobalSecurityManager *Manager
func init() {
GlobalSecurityManager = NewManager()
- go GlobalSecurityManager.Start()
}
diff --git a/server/global/session/session.go b/server/global/session/session.go
index e52aa29..fb77828 100644
--- a/server/global/session/session.go
+++ b/server/global/session/session.go
@@ -1,11 +1,11 @@
package session
import (
- "fmt"
"sync"
"next-terminal/server/dto"
"next-terminal/server/guacd"
+ "next-terminal/server/log"
"next-terminal/server/term"
"github.com/gorilla/websocket"
@@ -42,80 +42,79 @@ func (s *Session) WriteString(str string) error {
return s.WebSocket.WriteMessage(websocket.TextMessage, message)
}
+func (s *Session) Close() {
+ if s.GuacdTunnel != nil {
+ _ = s.GuacdTunnel.Close()
+ }
+ if s.NextTerminal != nil {
+ s.NextTerminal.Close()
+ }
+ if s.WebSocket != nil {
+ _ = s.WebSocket.Close()
+ }
+}
+
type Manager struct {
id string
- sessions map[string]*Session
-
- Add chan *Session
- Del chan string
- exit chan bool
+ sessions sync.Map
}
func NewManager() *Manager {
- return &Manager{
- Add: make(chan *Session),
- Del: make(chan string),
- sessions: map[string]*Session{},
- exit: make(chan bool, 1),
- }
+ return &Manager{}
}
func NewObserver(id string) *Manager {
return &Manager{
- id: id,
- Add: make(chan *Session),
- Del: make(chan string),
- sessions: map[string]*Session{},
- exit: make(chan bool, 1),
+ id: id,
}
}
-func (m *Manager) Start() {
- defer fmt.Printf("Session Manager %v End\n", m.id)
- fmt.Printf("Session Manager %v Open\n", m.id)
- for {
- select {
- case s := <-m.Add:
- m.sessions[s.ID] = s
- case k := <-m.Del:
- if _, ok := m.sessions[k]; ok {
- ss := m.sessions[k]
- if ss.GuacdTunnel != nil {
- _ = ss.GuacdTunnel.Close()
- }
- if ss.NextTerminal != nil {
- ss.NextTerminal.Close()
- }
+func (m *Manager) GetById(id string) *Session {
+ value, ok := m.sessions.Load(id)
+ if ok {
+ return value.(*Session)
+ }
+ return nil
+}
- if ss.WebSocket != nil {
- _ = ss.WebSocket.Close()
- }
- if ss.Observer != nil {
- ss.Observer.Close()
- }
- delete(m.sessions, k)
- }
- case <-m.exit:
- return
+func (m *Manager) Add(s *Session) {
+ m.sessions.Store(s.ID, s)
+ log.Infof("add session: %s", s.ID)
+}
+
+func (m *Manager) Del(id string) {
+ session := m.GetById(id)
+ if session != nil {
+ session.Close()
+ if session.Observer != nil {
+ session.Observer.Clear()
}
}
+ m.sessions.Delete(id)
+ log.Infof("del session: %s", id)
}
-func (m *Manager) Close() {
- m.exit <- true
+func (m *Manager) Clear() {
+ m.sessions.Range(func(key, value interface{}) bool {
+ if session, ok := value.(*Session); ok {
+ session.Close()
+ }
+ m.sessions.Delete(key)
+ return true
+ })
}
-func (m Manager) GetById(id string) *Session {
- return m.sessions[id]
-}
-
-func (m Manager) All() map[string]*Session {
- return m.sessions
+func (m *Manager) Range(f func(key string, value *Session)) {
+ m.sessions.Range(func(key, value interface{}) bool {
+ if session, ok := value.(*Session); ok {
+ f(key.(string), session)
+ }
+ return true
+ })
}
var GlobalSessionManager *Manager
func init() {
GlobalSessionManager = NewManager()
- go GlobalSessionManager.Start()
}
diff --git a/server/service/backup.go b/server/service/backup.go
index 6be6b42..8528d3a 100644
--- a/server/service/backup.go
+++ b/server/service/backup.go
@@ -220,7 +220,7 @@ func (service backupService) Import(backup *dto.Backup) error {
Rule: item.Rule,
Priority: item.Priority,
}
- security.GlobalSecurityManager.Add <- rule
+ security.GlobalSecurityManager.Add(rule)
}
}
diff --git a/server/service/gateway.go b/server/service/gateway.go
index 2a9718c..7830b5f 100644
--- a/server/service/gateway.go
+++ b/server/service/gateway.go
@@ -60,11 +60,11 @@ func (r gatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway {
} else {
g = gateway.NewGateway(m.ID, true, "", sshClient)
}
- gateway.GlobalGatewayManager.Add <- g
+ gateway.GlobalGatewayManager.Add(g)
log.Debugf("重建接入网关「%v」完成", m.Name)
return g
}
func (r gatewayService) DisconnectById(accessGatewayId string) {
- gateway.GlobalGatewayManager.Del <- accessGatewayId
+ gateway.GlobalGatewayManager.Del(accessGatewayId)
}
diff --git a/server/service/security.go b/server/service/security.go
index 18de5cd..821b3cd 100644
--- a/server/service/security.go
+++ b/server/service/security.go
@@ -25,7 +25,7 @@ func (service securityService) ReloadAccessSecurity() error {
Rule: rules[i].Rule,
Priority: rules[i].Priority,
}
- security.GlobalSecurityManager.Add <- rule
+ security.GlobalSecurityManager.Add(rule)
}
}
return nil
diff --git a/server/service/session.go b/server/service/session.go
index 614a5e8..c3cf10e 100644
--- a/server/service/session.go
+++ b/server/service/session.go
@@ -96,14 +96,13 @@ func (service sessionService) CloseSessionById(sessionId string, code int, reaso
service.WriteCloseMessage(nextSession, nextSession.Mode, code, reason)
if nextSession.Observer != nil {
- obs := nextSession.Observer.All()
- for _, ob := range obs {
+ nextSession.Observer.Range(func(key string, ob *session.Session) {
service.WriteCloseMessage(ob, ob.Mode, code, reason)
log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID)
- }
+ })
}
}
- session.GlobalSessionManager.Del <- sessionId
+ session.GlobalSessionManager.Del(sessionId)
service.DisDBSess(sessionId, code, reason)
}
diff --git a/server/sshd/ui.go b/server/sshd/ui.go
index 6596798..2ec6c0e 100644
--- a/server/sshd/ui.go
+++ b/server/sshd/ui.go
@@ -271,8 +271,7 @@ func (gui Gui) handleAccessAsset(sess *ssh.Session, sessionId string) (err error
NextTerminal: nextTerminal,
Observer: session.NewObserver(s.ID),
}
- go nextSession.Observer.Start()
- session.GlobalSessionManager.Add <- nextSession
+ session.GlobalSessionManager.Add(nextSession)
if err := sshSession.Wait(); err != nil {
return err
diff --git a/server/sshd/writer.go b/server/sshd/writer.go
index 4d72bb4..06c07fb 100644
--- a/server/sshd/writer.go
+++ b/server/sshd/writer.go
@@ -66,12 +66,9 @@ func (w *Writer) Write(p []byte) (n int, err error) {
func sendObData(sessionId, s string) {
nextSession := session.GlobalSessionManager.GetById(sessionId)
- if nextSession != nil {
- if nextSession.Observer != nil {
- obs := nextSession.Observer.All()
- for _, ob := range obs {
- _ = ob.WriteMessage(dto.NewMessage(api.Data, s))
- }
- }
+ if nextSession != nil && nextSession.Observer != nil {
+ nextSession.Observer.Range(func(key string, ob *session.Session) {
+ _ = ob.WriteMessage(dto.NewMessage(api.Data, s))
+ })
}
}
diff --git a/web/src/components/asset/AssetModal.js b/web/src/components/asset/AssetModal.js
index 1ffb257..b2622f9 100644
--- a/web/src/components/asset/AssetModal.js
+++ b/web/src/components/asset/AssetModal.js
@@ -471,7 +471,7 @@ Windows需要对远程应用程序的名称使用特殊的符号。
setSshMode(value)
}}>
-
+