diff --git a/server/api/term.go b/server/api/term.go index adff3c2..8101eba 100644 --- a/server/api/term.go +++ b/server/api/term.go @@ -202,7 +202,7 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error { if err != nil { service.SessionService.CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") } else { - _ = WriteMessage(ws, dto.NewMessage(Ping, "")) + _ = termHandler.WriteMessage(dto.NewMessage(Ping, "")) } } diff --git a/server/api/term_handler.go b/server/api/term_handler.go index 8a0c99a..253c713 100644 --- a/server/api/term_handler.go +++ b/server/api/term_handler.go @@ -2,6 +2,7 @@ package api import ( "context" + "sync" "time" "unicode/utf8" @@ -15,12 +16,13 @@ import ( type TermHandler struct { sessionId string isRecording bool - ws *websocket.Conn + webSocket *websocket.Conn nextTerminal *term.NextTerminal ctx context.Context cancel context.CancelFunc dataChan chan rune tick *time.Ticker + mutex sync.Mutex } func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, nextTerminal *term.NextTerminal) *TermHandler { @@ -29,7 +31,7 @@ func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, next return &TermHandler{ sessionId: sessionId, isRecording: isRecording, - ws: ws, + webSocket: ws, nextTerminal: nextTerminal, ctx: ctx, cancel: cancel, @@ -38,17 +40,17 @@ func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, next } } -func (r TermHandler) Start() { +func (r *TermHandler) Start() { go r.readFormTunnel() go r.writeToWebsocket() } -func (r TermHandler) Stop() { +func (r *TermHandler) Stop() { r.tick.Stop() r.cancel() } -func (r TermHandler) readFormTunnel() { +func (r *TermHandler) readFormTunnel() { for { select { case <-r.ctx.Done(): @@ -65,7 +67,7 @@ func (r TermHandler) readFormTunnel() { } } -func (r TermHandler) writeToWebsocket() { +func (r *TermHandler) writeToWebsocket() { var buf []byte for { select { @@ -74,7 +76,7 @@ func (r TermHandler) writeToWebsocket() { case <-r.tick.C: if len(buf) > 0 { s := string(buf) - if err := WriteMessage(r.ws, dto.NewMessage(Data, s)); err != nil { + if err := r.WriteMessage(dto.NewMessage(Data, s)); err != nil { return } // 录屏 @@ -86,7 +88,7 @@ func (r TermHandler) writeToWebsocket() { if nextSession != nil && len(nextSession.Observer.All()) > 0 { obs := nextSession.Observer.All() for _, ob := range obs { - _ = WriteMessage(ob.WebSocket, dto.NewMessage(Data, s)) + _ = ob.WriteMessage(dto.NewMessage(Data, s)) } } buf = []byte{} @@ -102,3 +104,13 @@ func (r TermHandler) writeToWebsocket() { } } } + +func (r *TermHandler) WriteMessage(msg dto.Message) error { + if r.webSocket == nil { + return nil + } + defer r.mutex.Unlock() + r.mutex.Lock() + message := []byte(msg.ToString()) + return r.webSocket.WriteMessage(websocket.TextMessage, message) +} diff --git a/server/global/session/session.go b/server/global/session/session.go index f5d7e9a..40858a8 100644 --- a/server/global/session/session.go +++ b/server/global/session/session.go @@ -2,7 +2,9 @@ package session import ( "fmt" + "sync" + "next-terminal/server/dto" "next-terminal/server/guacd" "next-terminal/server/term" @@ -17,6 +19,27 @@ type Session struct { GuacdTunnel *guacd.Tunnel NextTerminal *term.NextTerminal Observer *Manager + mutex sync.Mutex +} + +func (s *Session) WriteMessage(msg dto.Message) error { + if s.WebSocket == nil { + return nil + } + defer s.mutex.Unlock() + s.mutex.Lock() + message := []byte(msg.ToString()) + return s.WebSocket.WriteMessage(websocket.TextMessage, message) +} + +func (s *Session) WriteString(str string) error { + if s.WebSocket == nil { + return nil + } + defer s.mutex.Unlock() + s.mutex.Lock() + message := []byte(str) + return s.WebSocket.WriteMessage(websocket.TextMessage, message) } type Manager struct { diff --git a/server/repository/resource_sharer.go b/server/repository/resource_sharer.go index 72a33f1..022a2dc 100644 --- a/server/repository/resource_sharer.go +++ b/server/repository/resource_sharer.go @@ -159,16 +159,16 @@ func (r *resourceSharerRepository) FindByResourceIdAndUserId(c context.Context, func (r *resourceSharerRepository) Find(c context.Context, resourceId, resourceType, userId, userGroupId string) (resourceSharers []model.ResourceSharer, err error) { db := r.GetDB(c) if resourceId != "" { - db = db.Where("resource_id = ?") + db = db.Where("resource_id = ?", resourceId) } if resourceType != "" { - db = db.Where("resource_type = ?") + db = db.Where("resource_type = ?", resourceType) } if userId != "" { - db = db.Where("user_id = ?") + db = db.Where("user_id = ?", userId) } if userGroupId != "" { - db = db.Where("user_group_id = ?") + db = db.Where("user_group_id = ?", userGroupId) } err = db.Find(&resourceSharers).Error return diff --git a/server/service/session.go b/server/service/session.go index c138bfc..614a5e8 100644 --- a/server/service/session.go +++ b/server/service/session.go @@ -17,7 +17,6 @@ import ( "next-terminal/server/repository" "next-terminal/server/utils" - "github.com/gorilla/websocket" "gorm.io/gorm" ) @@ -94,12 +93,12 @@ func (service sessionService) CloseSessionById(sessionId string, code int, reaso nextSession := session.GlobalSessionManager.GetById(sessionId) if nextSession != nil { log.Debugf("[%v] 会话关闭,原因:%v", sessionId, reason) - service.WriteCloseMessage(nextSession.WebSocket, nextSession.Mode, code, reason) + service.WriteCloseMessage(nextSession, nextSession.Mode, code, reason) if nextSession.Observer != nil { obs := nextSession.Observer.All() for _, ob := range obs { - service.WriteCloseMessage(ob.WebSocket, ob.Mode, code, reason) + service.WriteCloseMessage(ob, ob.Mode, code, reason) log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID) } } @@ -109,26 +108,16 @@ func (service sessionService) CloseSessionById(sessionId string, code int, reaso service.DisDBSess(sessionId, code, reason) } -func (service sessionService) WriteCloseMessage(ws *websocket.Conn, mode string, code int, reason string) { +func (service sessionService) WriteCloseMessage(sess *session.Session, mode string, code int, reason string) { switch mode { case constant.Guacd: - if ws != nil { - err := guacd.NewInstruction("error", "", strconv.Itoa(code)) - _ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) - disconnect := guacd.NewInstruction("disconnect") - _ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) - } - case constant.Native: - if ws != nil { - msg := `0` + reason - _ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) - } - case constant.Terminal: - // 这里是关闭观察者的ssh会话 - if ws != nil { - msg := `0` + reason - _ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) - } + err := guacd.NewInstruction("error", "", strconv.Itoa(code)) + _ = sess.WriteString(err.String()) + disconnect := guacd.NewInstruction("disconnect") + _ = sess.WriteString(disconnect.String()) + case constant.Native, constant.Terminal: + msg := `0` + reason + _ = sess.WriteString(msg) } } diff --git a/server/sshd/writer.go b/server/sshd/writer.go index 42d942e..4d72bb4 100644 --- a/server/sshd/writer.go +++ b/server/sshd/writer.go @@ -70,7 +70,7 @@ func sendObData(sessionId, s string) { if nextSession.Observer != nil { obs := nextSession.Observer.All() for _, ob := range obs { - _ = api.WriteMessage(ob.WebSocket, dto.NewMessage(api.Data, s)) + _ = ob.WriteMessage(dto.NewMessage(api.Data, s)) } } } diff --git a/web/src/images/bg.png b/web/src/images/bg.png deleted file mode 100644 index 718cb36..0000000 Binary files a/web/src/images/bg.png and /dev/null differ