优化代码

This commit is contained in:
dushixiang 2022-05-06 21:10:29 +08:00
parent c58776fa24
commit 7357cebc34
16 changed files with 159 additions and 208 deletions

View File

@ -139,8 +139,7 @@ func (api GuacamoleApi) Guacamole(c echo.Context) error {
} }
nextSession.Observer = session.NewObserver(sessionId) nextSession.Observer = session.NewObserver(sessionId)
session.GlobalSessionManager.Add <- nextSession session.GlobalSessionManager.Add(nextSession)
go nextSession.Observer.Start()
sess := model.Session{ sess := model.Session{
ConnectionId: guacdTunnel.UUID, ConnectionId: guacdTunnel.UUID,
Width: intWidth, Width: intWidth,
@ -254,7 +253,7 @@ func (api GuacamoleApi) GuacamoleMonitor(c echo.Context) error {
return nil return nil
} }
nextSession.ID = utils.UUID() nextSession.ID = utils.UUID()
forObsSession.Observer.Add <- nextSession forObsSession.Observer.Add(nextSession)
log.Debugf("[%v:%v] 观察者[%v]加入会话[%v]", sessionId, connectionId, nextSession.ID, s.ConnectionId) log.Debugf("[%v:%v] 观察者[%v]加入会话[%v]", sessionId, connectionId, nextSession.ID, s.ConnectionId)
guacamoleHandler := NewGuacamoleHandler(ws, guacdTunnel) guacamoleHandler := NewGuacamoleHandler(ws, guacdTunnel)
@ -269,7 +268,7 @@ func (api GuacamoleApi) GuacamoleMonitor(c echo.Context) error {
_ = guacdTunnel.Close() _ = guacdTunnel.Close()
observerId := nextSession.ID observerId := nextSession.ID
forObsSession.Observer.Del <- observerId forObsSession.Observer.Del(observerId)
log.Debugf("[%v:%v] 观察者[%v]退出会话", sessionId, connectionId, observerId) log.Debugf("[%v:%v] 观察者[%v]退出会话", sessionId, connectionId, observerId)
return nil return nil
} }

View File

@ -35,7 +35,7 @@ func (api SecurityApi) SecurityCreateEndpoint(c echo.Context) error {
Rule: item.Rule, Rule: item.Rule,
Priority: item.Priority, Priority: item.Priority,
} }
security.GlobalSecurityManager.Add <- rule security.GlobalSecurityManager.Add(rule)
return Success(c, "") return Success(c, "")
} }
@ -72,14 +72,14 @@ func (api SecurityApi) SecurityUpdateEndpoint(c echo.Context) error {
return err return err
} }
// 更新内存中的安全规则 // 更新内存中的安全规则
security.GlobalSecurityManager.Del <- id security.GlobalSecurityManager.Del(id)
rule := &security.Security{ rule := &security.Security{
ID: item.ID, ID: item.ID,
IP: item.IP, IP: item.IP,
Rule: item.Rule, Rule: item.Rule,
Priority: item.Priority, Priority: item.Priority,
} }
security.GlobalSecurityManager.Add <- rule security.GlobalSecurityManager.Add(rule)
return Success(c, nil) return Success(c, nil)
} }
@ -94,7 +94,7 @@ func (api SecurityApi) SecurityDeleteEndpoint(c echo.Context) error {
return err return err
} }
// 更新内存中的安全规则 // 更新内存中的安全规则
security.GlobalSecurityManager.Del <- id security.GlobalSecurityManager.Del(id)
} }
return Success(c, nil) return Success(c, nil)

View File

@ -152,7 +152,7 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error {
NextTerminal: nextTerminal, NextTerminal: nextTerminal,
Observer: session.NewObserver(s.ID), Observer: session.NewObserver(s.ID),
} }
session.GlobalSessionManager.Add <- nextSession session.GlobalSessionManager.Add(nextSession)
termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal) termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal)
termHandler.Start() termHandler.Start()
@ -239,14 +239,14 @@ func (api WebTerminalApi) SshMonitorEndpoint(c echo.Context) error {
Mode: s.Mode, Mode: s.Mode,
WebSocket: ws, WebSocket: ws,
} }
nextSession.Observer.Add <- obSession nextSession.Observer.Add(obSession)
log.Debugf("会话 %v 观察者 %v 进入", sessionId, obId) log.Debugf("会话 %v 观察者 %v 进入", sessionId, obId)
for { for {
_, _, err := ws.ReadMessage() _, _, err := ws.ReadMessage()
if err != nil { if err != nil {
log.Debugf("会话 %v 观察者 %v 退出", sessionId, obId) log.Debugf("会话 %v 观察者 %v 退出", sessionId, obId)
nextSession.Observer.Del <- obId nextSession.Observer.Del(obId)
break break
} }
} }

View File

@ -85,11 +85,10 @@ func (r *TermHandler) writeToWebsocket() {
} }
nextSession := session.GlobalSessionManager.GetById(r.sessionId) nextSession := session.GlobalSessionManager.GetById(r.sessionId)
// 监控 // 监控
if nextSession != nil && len(nextSession.Observer.All()) > 0 { if nextSession != nil && nextSession.Observer != nil {
obs := nextSession.Observer.All() nextSession.Observer.Range(func(key string, ob *session.Session) {
for _, ob := range obs {
_ = ob.WriteMessage(dto.NewMessage(Data, s)) _ = ob.WriteMessage(dto.NewMessage(Data, s))
} })
} }
buf = []byte{} buf = []byte{}
} }

View File

@ -8,8 +8,6 @@ import (
"os" "os"
"sync" "sync"
"next-terminal/server/log"
"next-terminal/server/utils" "next-terminal/server/utils"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -22,11 +20,7 @@ type Gateway struct {
SshClient *ssh.Client SshClient *ssh.Client
Message string // 失败原因 Message string // 失败原因
tunnels *sync.Map tunnels sync.Map
Add chan *Tunnel
Del chan string
exit chan bool
} }
func NewGateway(id string, connected bool, message string, client *ssh.Client) *Gateway { func NewGateway(id string, connected bool, message string, client *ssh.Client) *Gateway {
@ -35,42 +29,14 @@ func NewGateway(id string, connected bool, message string, client *ssh.Client) *
Connected: connected, Connected: connected,
Message: message, Message: message,
SshClient: client, SshClient: client,
Add: make(chan *Tunnel),
Del: make(chan string),
tunnels: new(sync.Map),
exit: make(chan bool, 1),
}
}
func (g *Gateway) Run() {
for {
select {
case t := <-g.Add:
g.tunnels.Store(t.ID, t)
log.Info("add tunnel: %s", t.ID)
go t.Open()
case k := <-g.Del:
if val, ok := g.tunnels.Load(k); ok {
if vval, vok := val.(*Tunnel); vok {
vval.Close()
g.tunnels.Delete(k)
}
}
case <-g.exit:
return
}
} }
} }
func (g *Gateway) Close() { func (g *Gateway) Close() {
g.tunnels.Range(func(key, value interface{}) bool { g.tunnels.Range(func(key, value interface{}) bool {
if val, ok := value.(*Tunnel); ok { g.CloseSshTunnel(key.(string))
val.Close()
}
return true return true
}) })
g.exit <- true
} }
func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) { func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) {
@ -110,11 +76,17 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo
cancel: cancel, cancel: cancel,
listener: listener, listener: listener,
} }
g.Add <- tunnel go tunnel.Open()
g.tunnels.Store(tunnel.ID, tunnel)
return tunnel.LocalHost, tunnel.LocalPort, nil return tunnel.LocalHost, tunnel.LocalPort, nil
} }
func (g Gateway) CloseSshTunnel(id string) { func (g *Gateway) CloseSshTunnel(id string) {
g.Del <- id if value, ok := g.tunnels.Load(id); ok {
if tunnel, vok := value.(*Tunnel); vok {
tunnel.Close()
g.tunnels.Delete(id)
}
}
} }

View File

@ -7,49 +7,32 @@ import (
) )
type Manager struct { type Manager struct {
gateways *sync.Map gateways sync.Map
Add chan *Gateway
Del chan string
} }
func NewManager() *Manager { func NewManager() *Manager {
return &Manager{ return &Manager{}
Add: make(chan *Gateway),
Del: make(chan string),
gateways: new(sync.Map),
}
} }
func (m *Manager) Start() { func (m *Manager) GetById(id string) *Gateway {
for {
select {
case g := <-m.Add:
m.gateways.Store(g.ID, g)
log.Info("add gateway: %s", g.ID)
go g.Run()
case k := <-m.Del:
if val, ok := m.gateways.Load(k); ok {
if vv, vok := val.(*Gateway); vok {
vv.Close()
m.gateways.Delete(k)
}
}
}
}
}
func (m Manager) GetById(id string) *Gateway {
if val, ok := m.gateways.Load(id); ok { if val, ok := m.gateways.Load(id); ok {
return val.(*Gateway) return val.(*Gateway)
} }
return nil return nil
} }
func (m *Manager) Add(g *Gateway) {
m.gateways.Store(g.ID, g)
log.Infof("add gateway: %s", g.ID)
}
func (m *Manager) Del(id string) {
m.gateways.Delete(id)
log.Infof("del gateway: %s", id)
}
var GlobalGatewayManager *Manager var GlobalGatewayManager *Manager
func init() { func init() {
GlobalGatewayManager = NewManager() GlobalGatewayManager = NewManager()
go GlobalGatewayManager.Start()
} }

View File

@ -10,23 +10,26 @@ import (
) )
type Tunnel struct { type Tunnel struct {
ID string // 唯一标识 ID string // 唯一标识
LocalHost string // 本地监听地址 LocalHost string // 本地监听地址
LocalPort int // 本地端口 LocalPort int // 本地端口
RemoteHost string // 远程连接地址 RemoteHost string // 远程连接地址
RemotePort int // 远程端口 RemotePort int // 远程端口
Gateway *Gateway Gateway *Gateway
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
listener net.Listener listener net.Listener
err error localConnections []net.Conn
remoteConnections []net.Conn
} }
func (r *Tunnel) Open() { func (r *Tunnel) Open() {
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort) localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
go func() { go func() {
<-r.ctx.Done() <-r.ctx.Done()
_ = r.listener.Close() _ = r.listener.Close()
log.Debugf("SSH 隧道 %v 关闭", localAddr)
}() }()
for { for {
log.Debugf("等待客户端访问 %v", localAddr) log.Debugf("等待客户端访问 %v", localAddr)
@ -35,6 +38,7 @@ func (r *Tunnel) Open() {
log.Debugf("接受连接失败 %v, 退出循环", err.Error()) log.Debugf("接受连接失败 %v, 退出循环", err.Error())
return return
} }
r.localConnections = append(r.localConnections, localConn)
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
@ -44,27 +48,27 @@ func (r *Tunnel) Open() {
log.Debugf("连接远程主机 %v 失败", remoteAddr) log.Debugf("连接远程主机 %v 失败", remoteAddr)
return return
} }
r.remoteConnections = append(r.remoteConnections, remoteConn)
log.Debugf("连接远程主机 %v 成功", remoteAddr) log.Debugf("连接远程主机 %v 成功", remoteAddr)
go copyConn(r.ctx, localConn, remoteConn) go copyConn(localConn, remoteConn)
go copyConn(r.ctx, remoteConn, localConn) go copyConn(remoteConn, localConn)
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
} }
} }
func (r Tunnel) Close() { func (r *Tunnel) Close() {
for i := range r.localConnections {
_ = r.localConnections[i].Close()
}
r.localConnections = nil
for i := range r.remoteConnections {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
r.cancel() r.cancel()
} }
func copyConn(ctx context.Context, writer, reader net.Conn) { func copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader) _, _ = io.Copy(writer, reader)
for {
select {
case <-ctx.Done():
_ = writer.Close()
_ = reader.Close()
return
}
}
} }

View File

@ -1,6 +1,11 @@
package security package security
import "sort" import (
"sort"
"sync"
"next-terminal/server/log"
)
type Security struct { type Security struct {
ID string ID string
@ -10,45 +15,29 @@ type Security struct {
} }
type Manager struct { type Manager struct {
securities map[string]*Security securities sync.Map
values []*Security values []*Security
Add chan *Security
Del chan string
} }
func NewManager() *Manager { func NewManager() *Manager {
return &Manager{ return &Manager{}
Add: make(chan *Security),
Del: make(chan string),
securities: map[string]*Security{},
}
}
func (m *Manager) Start() {
for {
select {
case s := <-m.Add:
m.securities[s.ID] = s
m.LoadData()
case s := <-m.Del:
if _, ok := m.securities[s]; ok {
delete(m.securities, s)
m.LoadData()
}
}
}
} }
func (m *Manager) Clear() { func (m *Manager) Clear() {
m.securities = map[string]*Security{} m.securities.Range(func(k, _ interface{}) bool {
m.securities.Delete(k)
return true
})
} }
func (m *Manager) LoadData() { func (m *Manager) LoadData() {
var values []*Security var values []*Security
for _, security := range m.securities { m.securities.Range(func(key, value interface{}) bool {
values = append(values, security) if security, ok := value.(*Security); ok {
} values = append(values, security)
}
return true
})
sort.Slice(values, func(i, j int) bool { sort.Slice(values, func(i, j int) bool {
// 优先级数字越小代表优先级越高,因此此处用小于号 // 优先级数字越小代表优先级越高,因此此处用小于号
@ -58,13 +47,24 @@ func (m *Manager) LoadData() {
m.values = values m.values = values
} }
func (m Manager) Values() []*Security { func (m *Manager) Values() []*Security {
return m.values return m.values
} }
func (m *Manager) Add(s *Security) {
m.securities.Store(s.ID, s)
m.LoadData()
log.Infof("add security: %s", s.ID)
}
func (m *Manager) Del(id string) {
m.securities.Delete(id)
m.LoadData()
log.Infof("del security: %s", id)
}
var GlobalSecurityManager *Manager var GlobalSecurityManager *Manager
func init() { func init() {
GlobalSecurityManager = NewManager() GlobalSecurityManager = NewManager()
go GlobalSecurityManager.Start()
} }

View File

@ -1,11 +1,11 @@
package session package session
import ( import (
"fmt"
"sync" "sync"
"next-terminal/server/dto" "next-terminal/server/dto"
"next-terminal/server/guacd" "next-terminal/server/guacd"
"next-terminal/server/log"
"next-terminal/server/term" "next-terminal/server/term"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -42,80 +42,79 @@ func (s *Session) WriteString(str string) error {
return s.WebSocket.WriteMessage(websocket.TextMessage, message) return s.WebSocket.WriteMessage(websocket.TextMessage, message)
} }
func (s *Session) Close() {
if s.GuacdTunnel != nil {
_ = s.GuacdTunnel.Close()
}
if s.NextTerminal != nil {
s.NextTerminal.Close()
}
if s.WebSocket != nil {
_ = s.WebSocket.Close()
}
}
type Manager struct { type Manager struct {
id string id string
sessions map[string]*Session sessions sync.Map
Add chan *Session
Del chan string
exit chan bool
} }
func NewManager() *Manager { func NewManager() *Manager {
return &Manager{ return &Manager{}
Add: make(chan *Session),
Del: make(chan string),
sessions: map[string]*Session{},
exit: make(chan bool, 1),
}
} }
func NewObserver(id string) *Manager { func NewObserver(id string) *Manager {
return &Manager{ return &Manager{
id: id, id: id,
Add: make(chan *Session),
Del: make(chan string),
sessions: map[string]*Session{},
exit: make(chan bool, 1),
} }
} }
func (m *Manager) Start() { func (m *Manager) GetById(id string) *Session {
defer fmt.Printf("Session Manager %v End\n", m.id) value, ok := m.sessions.Load(id)
fmt.Printf("Session Manager %v Open\n", m.id) if ok {
for { return value.(*Session)
select { }
case s := <-m.Add: return nil
m.sessions[s.ID] = s }
case k := <-m.Del:
if _, ok := m.sessions[k]; ok {
ss := m.sessions[k]
if ss.GuacdTunnel != nil {
_ = ss.GuacdTunnel.Close()
}
if ss.NextTerminal != nil {
ss.NextTerminal.Close()
}
if ss.WebSocket != nil { func (m *Manager) Add(s *Session) {
_ = ss.WebSocket.Close() m.sessions.Store(s.ID, s)
} log.Infof("add session: %s", s.ID)
if ss.Observer != nil { }
ss.Observer.Close()
} func (m *Manager) Del(id string) {
delete(m.sessions, k) session := m.GetById(id)
} if session != nil {
case <-m.exit: session.Close()
return if session.Observer != nil {
session.Observer.Clear()
} }
} }
m.sessions.Delete(id)
log.Infof("del session: %s", id)
} }
func (m *Manager) Close() { func (m *Manager) Clear() {
m.exit <- true m.sessions.Range(func(key, value interface{}) bool {
if session, ok := value.(*Session); ok {
session.Close()
}
m.sessions.Delete(key)
return true
})
} }
func (m Manager) GetById(id string) *Session { func (m *Manager) Range(f func(key string, value *Session)) {
return m.sessions[id] m.sessions.Range(func(key, value interface{}) bool {
} if session, ok := value.(*Session); ok {
f(key.(string), session)
func (m Manager) All() map[string]*Session { }
return m.sessions return true
})
} }
var GlobalSessionManager *Manager var GlobalSessionManager *Manager
func init() { func init() {
GlobalSessionManager = NewManager() GlobalSessionManager = NewManager()
go GlobalSessionManager.Start()
} }

View File

@ -220,7 +220,7 @@ func (service backupService) Import(backup *dto.Backup) error {
Rule: item.Rule, Rule: item.Rule,
Priority: item.Priority, Priority: item.Priority,
} }
security.GlobalSecurityManager.Add <- rule security.GlobalSecurityManager.Add(rule)
} }
} }

View File

@ -60,11 +60,11 @@ func (r gatewayService) ReConnect(m *model.AccessGateway) *gateway.Gateway {
} else { } else {
g = gateway.NewGateway(m.ID, true, "", sshClient) g = gateway.NewGateway(m.ID, true, "", sshClient)
} }
gateway.GlobalGatewayManager.Add <- g gateway.GlobalGatewayManager.Add(g)
log.Debugf("重建接入网关「%v」完成", m.Name) log.Debugf("重建接入网关「%v」完成", m.Name)
return g return g
} }
func (r gatewayService) DisconnectById(accessGatewayId string) { func (r gatewayService) DisconnectById(accessGatewayId string) {
gateway.GlobalGatewayManager.Del <- accessGatewayId gateway.GlobalGatewayManager.Del(accessGatewayId)
} }

View File

@ -25,7 +25,7 @@ func (service securityService) ReloadAccessSecurity() error {
Rule: rules[i].Rule, Rule: rules[i].Rule,
Priority: rules[i].Priority, Priority: rules[i].Priority,
} }
security.GlobalSecurityManager.Add <- rule security.GlobalSecurityManager.Add(rule)
} }
} }
return nil return nil

View File

@ -96,14 +96,13 @@ func (service sessionService) CloseSessionById(sessionId string, code int, reaso
service.WriteCloseMessage(nextSession, nextSession.Mode, code, reason) service.WriteCloseMessage(nextSession, nextSession.Mode, code, reason)
if nextSession.Observer != nil { if nextSession.Observer != nil {
obs := nextSession.Observer.All() nextSession.Observer.Range(func(key string, ob *session.Session) {
for _, ob := range obs {
service.WriteCloseMessage(ob, ob.Mode, code, reason) service.WriteCloseMessage(ob, ob.Mode, code, reason)
log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID) log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID)
} })
} }
} }
session.GlobalSessionManager.Del <- sessionId session.GlobalSessionManager.Del(sessionId)
service.DisDBSess(sessionId, code, reason) service.DisDBSess(sessionId, code, reason)
} }

View File

@ -271,8 +271,7 @@ func (gui Gui) handleAccessAsset(sess *ssh.Session, sessionId string) (err error
NextTerminal: nextTerminal, NextTerminal: nextTerminal,
Observer: session.NewObserver(s.ID), Observer: session.NewObserver(s.ID),
} }
go nextSession.Observer.Start() session.GlobalSessionManager.Add(nextSession)
session.GlobalSessionManager.Add <- nextSession
if err := sshSession.Wait(); err != nil { if err := sshSession.Wait(); err != nil {
return err return err

View File

@ -66,12 +66,9 @@ func (w *Writer) Write(p []byte) (n int, err error) {
func sendObData(sessionId, s string) { func sendObData(sessionId, s string) {
nextSession := session.GlobalSessionManager.GetById(sessionId) nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession != nil { if nextSession != nil && nextSession.Observer != nil {
if nextSession.Observer != nil { nextSession.Observer.Range(func(key string, ob *session.Session) {
obs := nextSession.Observer.All() _ = ob.WriteMessage(dto.NewMessage(api.Data, s))
for _, ob := range obs { })
_ = ob.WriteMessage(dto.NewMessage(api.Data, s))
}
}
} }
} }

View File

@ -471,7 +471,7 @@ Windows需要对远程应用程序的名称使用特殊的符号。
setSshMode(value) setSshMode(value)
}}> }}>
<Option value="">guacd</Option> <Option value="">guacd</Option>
<Option value="naive">原生</Option> <Option value="native">原生</Option>
</Select> </Select>
</Form.Item> </Form.Item>
</Panel> </Panel>