From 7357cebc34b83de0d8b29c865987452d8893b52b Mon Sep 17 00:00:00 2001 From: dushixiang Date: Fri, 6 May 2022 21:10:29 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/api/guacamole.go | 7 +- server/api/security.go | 8 +- server/api/term.go | 6 +- server/api/term_handler.go | 7 +- server/global/gateway/gateway.go | 50 +++--------- server/global/gateway/manager.go | 43 ++++------- server/global/gateway/tunnel.go | 50 ++++++------ server/global/security/security.go | 62 +++++++-------- server/global/session/session.go | 103 ++++++++++++------------- server/service/backup.go | 2 +- server/service/gateway.go | 4 +- server/service/security.go | 2 +- server/service/session.go | 7 +- server/sshd/ui.go | 3 +- server/sshd/writer.go | 11 +-- web/src/components/asset/AssetModal.js | 2 +- 16 files changed, 159 insertions(+), 208 deletions(-) diff --git a/server/api/guacamole.go b/server/api/guacamole.go index 94c0084..7210ba4 100644 --- a/server/api/guacamole.go +++ b/server/api/guacamole.go @@ -139,8 +139,7 @@ func (api GuacamoleApi) Guacamole(c echo.Context) error { } nextSession.Observer = session.NewObserver(sessionId) - session.GlobalSessionManager.Add <- nextSession - go nextSession.Observer.Start() + session.GlobalSessionManager.Add(nextSession) sess := model.Session{ ConnectionId: guacdTunnel.UUID, Width: intWidth, @@ -254,7 +253,7 @@ func (api GuacamoleApi) GuacamoleMonitor(c echo.Context) error { return nil } nextSession.ID = utils.UUID() - forObsSession.Observer.Add <- nextSession + forObsSession.Observer.Add(nextSession) log.Debugf("[%v:%v] 观察者[%v]加入会话[%v]", sessionId, connectionId, nextSession.ID, s.ConnectionId) guacamoleHandler := NewGuacamoleHandler(ws, guacdTunnel) @@ -269,7 +268,7 @@ func (api GuacamoleApi) GuacamoleMonitor(c echo.Context) error { _ = guacdTunnel.Close() observerId := nextSession.ID - forObsSession.Observer.Del <- observerId + forObsSession.Observer.Del(observerId) log.Debugf("[%v:%v] 观察者[%v]退出会话", sessionId, connectionId, observerId) return nil } diff --git a/server/api/security.go b/server/api/security.go index abed4b9..37d73ee 100644 --- a/server/api/security.go +++ b/server/api/security.go @@ -35,7 +35,7 @@ func (api SecurityApi) SecurityCreateEndpoint(c echo.Context) error { Rule: item.Rule, Priority: item.Priority, } - security.GlobalSecurityManager.Add <- rule + security.GlobalSecurityManager.Add(rule) return Success(c, "") } @@ -72,14 +72,14 @@ func (api SecurityApi) SecurityUpdateEndpoint(c echo.Context) error { return err } // 更新内存中的安全规则 - security.GlobalSecurityManager.Del <- id + security.GlobalSecurityManager.Del(id) rule := &security.Security{ ID: item.ID, IP: item.IP, Rule: item.Rule, Priority: item.Priority, } - security.GlobalSecurityManager.Add <- rule + security.GlobalSecurityManager.Add(rule) return Success(c, nil) } @@ -94,7 +94,7 @@ func (api SecurityApi) SecurityDeleteEndpoint(c echo.Context) error { return err } // 更新内存中的安全规则 - security.GlobalSecurityManager.Del <- id + security.GlobalSecurityManager.Del(id) } return Success(c, nil) diff --git a/server/api/term.go b/server/api/term.go index a09f5ec..9b83215 100644 --- a/server/api/term.go +++ b/server/api/term.go @@ -152,7 +152,7 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error { NextTerminal: nextTerminal, Observer: session.NewObserver(s.ID), } - session.GlobalSessionManager.Add <- nextSession + session.GlobalSessionManager.Add(nextSession) termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal) termHandler.Start() @@ -239,14 +239,14 @@ func (api WebTerminalApi) SshMonitorEndpoint(c echo.Context) error { Mode: s.Mode, WebSocket: ws, } - nextSession.Observer.Add <- obSession + nextSession.Observer.Add(obSession) log.Debugf("会话 %v 观察者 %v 进入", sessionId, obId) for { _, _, err := ws.ReadMessage() if err != nil { log.Debugf("会话 %v 观察者 %v 退出", sessionId, obId) - nextSession.Observer.Del <- obId + nextSession.Observer.Del(obId) break } } diff --git a/server/api/term_handler.go b/server/api/term_handler.go index 253c713..bb9565d 100644 --- a/server/api/term_handler.go +++ b/server/api/term_handler.go @@ -85,11 +85,10 @@ func (r *TermHandler) writeToWebsocket() { } nextSession := session.GlobalSessionManager.GetById(r.sessionId) // 监控 - if nextSession != nil && len(nextSession.Observer.All()) > 0 { - obs := nextSession.Observer.All() - for _, ob := range obs { + if nextSession != nil && nextSession.Observer != nil { + nextSession.Observer.Range(func(key string, ob *session.Session) { _ = ob.WriteMessage(dto.NewMessage(Data, s)) - } + }) } buf = []byte{} } diff --git a/server/global/gateway/gateway.go b/server/global/gateway/gateway.go index 4d38550..0d8b038 100644 --- a/server/global/gateway/gateway.go +++ b/server/global/gateway/gateway.go @@ -8,8 +8,6 @@ import ( "os" "sync" - "next-terminal/server/log" - "next-terminal/server/utils" "golang.org/x/crypto/ssh" @@ -22,11 +20,7 @@ type Gateway struct { SshClient *ssh.Client Message string // 失败原因 - tunnels *sync.Map - - Add chan *Tunnel - Del chan string - exit chan bool + tunnels sync.Map } func NewGateway(id string, connected bool, message string, client *ssh.Client) *Gateway { @@ -35,42 +29,14 @@ func NewGateway(id string, connected bool, message string, client *ssh.Client) * Connected: connected, Message: message, SshClient: client, - Add: make(chan *Tunnel), - Del: make(chan string), - tunnels: new(sync.Map), - exit: make(chan bool, 1), - } -} - -func (g *Gateway) Run() { - for { - select { - case t := <-g.Add: - g.tunnels.Store(t.ID, t) - log.Info("add tunnel: %s", t.ID) - go t.Open() - case k := <-g.Del: - if val, ok := g.tunnels.Load(k); ok { - if vval, vok := val.(*Tunnel); vok { - vval.Close() - g.tunnels.Delete(k) - } - } - case <-g.exit: - return - } } } func (g *Gateway) Close() { g.tunnels.Range(func(key, value interface{}) bool { - if val, ok := value.(*Tunnel); ok { - val.Close() - } + g.CloseSshTunnel(key.(string)) return true }) - g.exit <- true - } func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) { @@ -110,11 +76,17 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo cancel: cancel, listener: listener, } - g.Add <- tunnel + go tunnel.Open() + g.tunnels.Store(tunnel.ID, tunnel) return tunnel.LocalHost, tunnel.LocalPort, nil } -func (g Gateway) CloseSshTunnel(id string) { - g.Del <- id +func (g *Gateway) CloseSshTunnel(id string) { + if value, ok := g.tunnels.Load(id); ok { + if tunnel, vok := value.(*Tunnel); vok { + tunnel.Close() + g.tunnels.Delete(id) + } + } } diff --git a/server/global/gateway/manager.go b/server/global/gateway/manager.go index e2b0245..f208c33 100644 --- a/server/global/gateway/manager.go +++ b/server/global/gateway/manager.go @@ -7,49 +7,32 @@ import ( ) type Manager struct { - gateways *sync.Map - - Add chan *Gateway - Del chan string + gateways sync.Map } func NewManager() *Manager { - return &Manager{ - Add: make(chan *Gateway), - Del: make(chan string), - gateways: new(sync.Map), - } + return &Manager{} } -func (m *Manager) Start() { - for { - select { - case g := <-m.Add: - m.gateways.Store(g.ID, g) - log.Info("add gateway: %s", g.ID) - go g.Run() - case k := <-m.Del: - if val, ok := m.gateways.Load(k); ok { - if vv, vok := val.(*Gateway); vok { - vv.Close() - m.gateways.Delete(k) - } - - } - } - } -} - -func (m Manager) GetById(id string) *Gateway { +func (m *Manager) GetById(id string) *Gateway { if val, ok := m.gateways.Load(id); ok { return val.(*Gateway) } return nil } +func (m *Manager) Add(g *Gateway) { + m.gateways.Store(g.ID, g) + log.Infof("add gateway: %s", g.ID) +} + +func (m *Manager) Del(id string) { + m.gateways.Delete(id) + log.Infof("del gateway: %s", id) +} + var GlobalGatewayManager *Manager func init() { GlobalGatewayManager = NewManager() - go GlobalGatewayManager.Start() } diff --git a/server/global/gateway/tunnel.go b/server/global/gateway/tunnel.go index 307231c..385180e 100644 --- a/server/global/gateway/tunnel.go +++ b/server/global/gateway/tunnel.go @@ -10,23 +10,26 @@ import ( ) type Tunnel struct { - ID string // 唯一标识 - LocalHost string // 本地监听地址 - LocalPort int // 本地端口 - RemoteHost string // 远程连接地址 - RemotePort int // 远程端口 - Gateway *Gateway - ctx context.Context - cancel context.CancelFunc - listener net.Listener - err error + ID string // 唯一标识 + LocalHost string // 本地监听地址 + LocalPort int // 本地端口 + RemoteHost string // 远程连接地址 + RemotePort int // 远程端口 + Gateway *Gateway + ctx context.Context + cancel context.CancelFunc + listener net.Listener + localConnections []net.Conn + remoteConnections []net.Conn } func (r *Tunnel) Open() { localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort) + go func() { <-r.ctx.Done() _ = r.listener.Close() + log.Debugf("SSH 隧道 %v 关闭", localAddr) }() for { log.Debugf("等待客户端访问 %v", localAddr) @@ -35,6 +38,7 @@ func (r *Tunnel) Open() { log.Debugf("接受连接失败 %v, 退出循环", err.Error()) return } + r.localConnections = append(r.localConnections, localConn) log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) @@ -44,27 +48,27 @@ func (r *Tunnel) Open() { log.Debugf("连接远程主机 %v 失败", remoteAddr) return } + r.remoteConnections = append(r.remoteConnections, remoteConn) log.Debugf("连接远程主机 %v 成功", remoteAddr) - go copyConn(r.ctx, localConn, remoteConn) - go copyConn(r.ctx, remoteConn, localConn) + go copyConn(localConn, remoteConn) + go copyConn(remoteConn, localConn) log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) } } -func (r Tunnel) Close() { +func (r *Tunnel) Close() { + for i := range r.localConnections { + _ = r.localConnections[i].Close() + } + r.localConnections = nil + for i := range r.remoteConnections { + _ = r.remoteConnections[i].Close() + } + r.remoteConnections = nil r.cancel() } -func copyConn(ctx context.Context, writer, reader net.Conn) { +func copyConn(writer, reader net.Conn) { _, _ = io.Copy(writer, reader) - for { - select { - case <-ctx.Done(): - _ = writer.Close() - _ = reader.Close() - return - } - } - } diff --git a/server/global/security/security.go b/server/global/security/security.go index 9831ed1..86b62a7 100644 --- a/server/global/security/security.go +++ b/server/global/security/security.go @@ -1,6 +1,11 @@ package security -import "sort" +import ( + "sort" + "sync" + + "next-terminal/server/log" +) type Security struct { ID string @@ -10,45 +15,29 @@ type Security struct { } type Manager struct { - securities map[string]*Security + securities sync.Map values []*Security - - Add chan *Security - Del chan string } func NewManager() *Manager { - return &Manager{ - Add: make(chan *Security), - Del: make(chan string), - securities: map[string]*Security{}, - } -} - -func (m *Manager) Start() { - for { - select { - case s := <-m.Add: - m.securities[s.ID] = s - m.LoadData() - case s := <-m.Del: - if _, ok := m.securities[s]; ok { - delete(m.securities, s) - m.LoadData() - } - } - } + return &Manager{} } func (m *Manager) Clear() { - m.securities = map[string]*Security{} + m.securities.Range(func(k, _ interface{}) bool { + m.securities.Delete(k) + return true + }) } func (m *Manager) LoadData() { var values []*Security - for _, security := range m.securities { - values = append(values, security) - } + m.securities.Range(func(key, value interface{}) bool { + if security, ok := value.(*Security); ok { + values = append(values, security) + } + return true + }) sort.Slice(values, func(i, j int) bool { // 优先级数字越小代表优先级越高,因此此处用小于号 @@ -58,13 +47,24 @@ func (m *Manager) LoadData() { m.values = values } -func (m Manager) Values() []*Security { +func (m *Manager) Values() []*Security { return m.values } +func (m *Manager) Add(s *Security) { + m.securities.Store(s.ID, s) + m.LoadData() + log.Infof("add security: %s", s.ID) +} + +func (m *Manager) Del(id string) { + m.securities.Delete(id) + m.LoadData() + log.Infof("del security: %s", id) +} + var GlobalSecurityManager *Manager func init() { GlobalSecurityManager = NewManager() - go GlobalSecurityManager.Start() } diff --git a/server/global/session/session.go b/server/global/session/session.go index e52aa29..fb77828 100644 --- a/server/global/session/session.go +++ b/server/global/session/session.go @@ -1,11 +1,11 @@ package session import ( - "fmt" "sync" "next-terminal/server/dto" "next-terminal/server/guacd" + "next-terminal/server/log" "next-terminal/server/term" "github.com/gorilla/websocket" @@ -42,80 +42,79 @@ func (s *Session) WriteString(str string) error { return s.WebSocket.WriteMessage(websocket.TextMessage, message) } +func (s *Session) Close() { + if s.GuacdTunnel != nil { + _ = s.GuacdTunnel.Close() + } + if s.NextTerminal != nil { + s.NextTerminal.Close() + } + if s.WebSocket != nil { + _ = s.WebSocket.Close() + } +} + type Manager struct { id string - sessions map[string]*Session - - Add chan *Session - Del chan string - exit chan bool + sessions sync.Map } func NewManager() *Manager { - return &Manager{ - Add: make(chan *Session), - Del: make(chan string), - sessions: map[string]*Session{}, - exit: make(chan bool, 1), - } + return &Manager{} } func NewObserver(id string) *Manager { return &Manager{ - id: id, - Add: make(chan *Session), - Del: make(chan string), - sessions: map[string]*Session{}, - exit: make(chan bool, 1), + id: id, } } -func (m *Manager) Start() { - defer fmt.Printf("Session Manager %v End\n", m.id) - fmt.Printf("Session Manager %v Open\n", m.id) - for { - select { - case s := <-m.Add: - m.sessions[s.ID] = s - case k := <-m.Del: - if _, ok := m.sessions[k]; ok { - ss := m.sessions[k] - if ss.GuacdTunnel != nil { - _ = ss.GuacdTunnel.Close() - } - if ss.NextTerminal != nil { - ss.NextTerminal.Close() - } +func (m *Manager) GetById(id string) *Session { + value, ok := m.sessions.Load(id) + if ok { + return value.(*Session) + } + return nil +} - if ss.WebSocket != nil { - _ = ss.WebSocket.Close() - } - if ss.Observer != nil { - ss.Observer.Close() - } - delete(m.sessions, k) - } - case <-m.exit: - return +func (m *Manager) Add(s *Session) { + m.sessions.Store(s.ID, s) + log.Infof("add session: %s", s.ID) +} + +func (m *Manager) Del(id string) { + session := m.GetById(id) + if session != nil { + session.Close() + if session.Observer != nil { + session.Observer.Clear() } } + m.sessions.Delete(id) + log.Infof("del session: %s", id) } -func (m *Manager) Close() { - m.exit <- true +func (m *Manager) Clear() { + m.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(*Session); ok { + session.Close() + } + m.sessions.Delete(key) + return true + }) } -func (m Manager) GetById(id string) *Session { - return m.sessions[id] -} - -func (m Manager) All() map[string]*Session { - return m.sessions +func (m *Manager) Range(f func(key string, value *Session)) { + m.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(*Session); ok { + f(key.(string), session) + } + return true + }) } var GlobalSessionManager *Manager func init() { GlobalSessionManager = NewManager() - go GlobalSessionManager.Start() } diff --git a/server/service/backup.go b/server/service/backup.go index 6be6b42..8528d3a 100644 --- a/server/service/backup.go +++ b/server/service/backup.go @@ -220,7 +220,7 @@ func (service backupService) Import(backup *dto.Backup) error { Rule: item.Rule, Priority: item.Priority, } - security.GlobalSecurityManager.Add <- rule + security.GlobalSecurityManager.Add(rule) } } diff --git a/server/service/gateway.go b/server/service/gateway.go index 2a9718c..7830b5f 100644 --- a/server/service/gateway.go +++ b/server/service/gateway.go @@ -60,11 +60,11 @@ func (r gatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway { } else { g = gateway.NewGateway(m.ID, true, "", sshClient) } - gateway.GlobalGatewayManager.Add <- g + gateway.GlobalGatewayManager.Add(g) log.Debugf("重建接入网关「%v」完成", m.Name) return g } func (r gatewayService) DisconnectById(accessGatewayId string) { - gateway.GlobalGatewayManager.Del <- accessGatewayId + gateway.GlobalGatewayManager.Del(accessGatewayId) } diff --git a/server/service/security.go b/server/service/security.go index 18de5cd..821b3cd 100644 --- a/server/service/security.go +++ b/server/service/security.go @@ -25,7 +25,7 @@ func (service securityService) ReloadAccessSecurity() error { Rule: rules[i].Rule, Priority: rules[i].Priority, } - security.GlobalSecurityManager.Add <- rule + security.GlobalSecurityManager.Add(rule) } } return nil diff --git a/server/service/session.go b/server/service/session.go index 614a5e8..c3cf10e 100644 --- a/server/service/session.go +++ b/server/service/session.go @@ -96,14 +96,13 @@ func (service sessionService) CloseSessionById(sessionId string, code int, reaso service.WriteCloseMessage(nextSession, nextSession.Mode, code, reason) if nextSession.Observer != nil { - obs := nextSession.Observer.All() - for _, ob := range obs { + nextSession.Observer.Range(func(key string, ob *session.Session) { service.WriteCloseMessage(ob, ob.Mode, code, reason) log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID) - } + }) } } - session.GlobalSessionManager.Del <- sessionId + session.GlobalSessionManager.Del(sessionId) service.DisDBSess(sessionId, code, reason) } diff --git a/server/sshd/ui.go b/server/sshd/ui.go index 6596798..2ec6c0e 100644 --- a/server/sshd/ui.go +++ b/server/sshd/ui.go @@ -271,8 +271,7 @@ func (gui Gui) handleAccessAsset(sess *ssh.Session, sessionId string) (err error NextTerminal: nextTerminal, Observer: session.NewObserver(s.ID), } - go nextSession.Observer.Start() - session.GlobalSessionManager.Add <- nextSession + session.GlobalSessionManager.Add(nextSession) if err := sshSession.Wait(); err != nil { return err diff --git a/server/sshd/writer.go b/server/sshd/writer.go index 4d72bb4..06c07fb 100644 --- a/server/sshd/writer.go +++ b/server/sshd/writer.go @@ -66,12 +66,9 @@ func (w *Writer) Write(p []byte) (n int, err error) { func sendObData(sessionId, s string) { nextSession := session.GlobalSessionManager.GetById(sessionId) - if nextSession != nil { - if nextSession.Observer != nil { - obs := nextSession.Observer.All() - for _, ob := range obs { - _ = ob.WriteMessage(dto.NewMessage(api.Data, s)) - } - } + if nextSession != nil && nextSession.Observer != nil { + nextSession.Observer.Range(func(key string, ob *session.Session) { + _ = ob.WriteMessage(dto.NewMessage(api.Data, s)) + }) } } diff --git a/web/src/components/asset/AssetModal.js b/web/src/components/asset/AssetModal.js index 1ffb257..b2622f9 100644 --- a/web/src/components/asset/AssetModal.js +++ b/web/src/components/asset/AssetModal.js @@ -471,7 +471,7 @@ Windows需要对远程应用程序的名称使用特殊的符号。 setSshMode(value) }}> - +