- 修复查询分享的资源参数无效的问题

- 修复ssh连接并发写入websocket的问题
This commit is contained in:
dushixiang 2022-05-05 09:07:11 +08:00
parent 03b59d6a2f
commit b73bef0c08
7 changed files with 59 additions and 35 deletions

View File

@ -202,7 +202,7 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error {
if err != nil { if err != nil {
service.SessionService.CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭") service.SessionService.CloseSessionById(sessionId, TunnelClosed, "远程连接已关闭")
} else { } else {
_ = WriteMessage(ws, dto.NewMessage(Ping, "")) _ = termHandler.WriteMessage(dto.NewMessage(Ping, ""))
} }
} }

View File

@ -2,6 +2,7 @@ package api
import ( import (
"context" "context"
"sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -15,12 +16,13 @@ import (
type TermHandler struct { type TermHandler struct {
sessionId string sessionId string
isRecording bool isRecording bool
ws *websocket.Conn webSocket *websocket.Conn
nextTerminal *term.NextTerminal nextTerminal *term.NextTerminal
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
dataChan chan rune dataChan chan rune
tick *time.Ticker tick *time.Ticker
mutex sync.Mutex
} }
func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, nextTerminal *term.NextTerminal) *TermHandler { func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, nextTerminal *term.NextTerminal) *TermHandler {
@ -29,7 +31,7 @@ func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, next
return &TermHandler{ return &TermHandler{
sessionId: sessionId, sessionId: sessionId,
isRecording: isRecording, isRecording: isRecording,
ws: ws, webSocket: ws,
nextTerminal: nextTerminal, nextTerminal: nextTerminal,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@ -38,17 +40,17 @@ func NewTermHandler(sessionId string, isRecording bool, ws *websocket.Conn, next
} }
} }
func (r TermHandler) Start() { func (r *TermHandler) Start() {
go r.readFormTunnel() go r.readFormTunnel()
go r.writeToWebsocket() go r.writeToWebsocket()
} }
func (r TermHandler) Stop() { func (r *TermHandler) Stop() {
r.tick.Stop() r.tick.Stop()
r.cancel() r.cancel()
} }
func (r TermHandler) readFormTunnel() { func (r *TermHandler) readFormTunnel() {
for { for {
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
@ -65,7 +67,7 @@ func (r TermHandler) readFormTunnel() {
} }
} }
func (r TermHandler) writeToWebsocket() { func (r *TermHandler) writeToWebsocket() {
var buf []byte var buf []byte
for { for {
select { select {
@ -74,7 +76,7 @@ func (r TermHandler) writeToWebsocket() {
case <-r.tick.C: case <-r.tick.C:
if len(buf) > 0 { if len(buf) > 0 {
s := string(buf) s := string(buf)
if err := WriteMessage(r.ws, dto.NewMessage(Data, s)); err != nil { if err := r.WriteMessage(dto.NewMessage(Data, s)); err != nil {
return return
} }
// 录屏 // 录屏
@ -86,7 +88,7 @@ func (r TermHandler) writeToWebsocket() {
if nextSession != nil && len(nextSession.Observer.All()) > 0 { if nextSession != nil && len(nextSession.Observer.All()) > 0 {
obs := nextSession.Observer.All() obs := nextSession.Observer.All()
for _, ob := range obs { for _, ob := range obs {
_ = WriteMessage(ob.WebSocket, dto.NewMessage(Data, s)) _ = ob.WriteMessage(dto.NewMessage(Data, s))
} }
} }
buf = []byte{} buf = []byte{}
@ -102,3 +104,13 @@ func (r TermHandler) writeToWebsocket() {
} }
} }
} }
func (r *TermHandler) WriteMessage(msg dto.Message) error {
if r.webSocket == nil {
return nil
}
defer r.mutex.Unlock()
r.mutex.Lock()
message := []byte(msg.ToString())
return r.webSocket.WriteMessage(websocket.TextMessage, message)
}

View File

@ -2,7 +2,9 @@ package session
import ( import (
"fmt" "fmt"
"sync"
"next-terminal/server/dto"
"next-terminal/server/guacd" "next-terminal/server/guacd"
"next-terminal/server/term" "next-terminal/server/term"
@ -17,6 +19,27 @@ type Session struct {
GuacdTunnel *guacd.Tunnel GuacdTunnel *guacd.Tunnel
NextTerminal *term.NextTerminal NextTerminal *term.NextTerminal
Observer *Manager Observer *Manager
mutex sync.Mutex
}
func (s *Session) WriteMessage(msg dto.Message) error {
if s.WebSocket == nil {
return nil
}
defer s.mutex.Unlock()
s.mutex.Lock()
message := []byte(msg.ToString())
return s.WebSocket.WriteMessage(websocket.TextMessage, message)
}
func (s *Session) WriteString(str string) error {
if s.WebSocket == nil {
return nil
}
defer s.mutex.Unlock()
s.mutex.Lock()
message := []byte(str)
return s.WebSocket.WriteMessage(websocket.TextMessage, message)
} }
type Manager struct { type Manager struct {

View File

@ -159,16 +159,16 @@ func (r *resourceSharerRepository) FindByResourceIdAndUserId(c context.Context,
func (r *resourceSharerRepository) Find(c context.Context, resourceId, resourceType, userId, userGroupId string) (resourceSharers []model.ResourceSharer, err error) { func (r *resourceSharerRepository) Find(c context.Context, resourceId, resourceType, userId, userGroupId string) (resourceSharers []model.ResourceSharer, err error) {
db := r.GetDB(c) db := r.GetDB(c)
if resourceId != "" { if resourceId != "" {
db = db.Where("resource_id = ?") db = db.Where("resource_id = ?", resourceId)
} }
if resourceType != "" { if resourceType != "" {
db = db.Where("resource_type = ?") db = db.Where("resource_type = ?", resourceType)
} }
if userId != "" { if userId != "" {
db = db.Where("user_id = ?") db = db.Where("user_id = ?", userId)
} }
if userGroupId != "" { if userGroupId != "" {
db = db.Where("user_group_id = ?") db = db.Where("user_group_id = ?", userGroupId)
} }
err = db.Find(&resourceSharers).Error err = db.Find(&resourceSharers).Error
return return

View File

@ -17,7 +17,6 @@ import (
"next-terminal/server/repository" "next-terminal/server/repository"
"next-terminal/server/utils" "next-terminal/server/utils"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -94,12 +93,12 @@ func (service sessionService) CloseSessionById(sessionId string, code int, reaso
nextSession := session.GlobalSessionManager.GetById(sessionId) nextSession := session.GlobalSessionManager.GetById(sessionId)
if nextSession != nil { if nextSession != nil {
log.Debugf("[%v] 会话关闭,原因:%v", sessionId, reason) log.Debugf("[%v] 会话关闭,原因:%v", sessionId, reason)
service.WriteCloseMessage(nextSession.WebSocket, nextSession.Mode, code, reason) service.WriteCloseMessage(nextSession, nextSession.Mode, code, reason)
if nextSession.Observer != nil { if nextSession.Observer != nil {
obs := nextSession.Observer.All() obs := nextSession.Observer.All()
for _, ob := range obs { for _, ob := range obs {
service.WriteCloseMessage(ob.WebSocket, ob.Mode, code, reason) service.WriteCloseMessage(ob, ob.Mode, code, reason)
log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID) log.Debugf("[%v] 强制踢出会话的观察者: %v", sessionId, ob.ID)
} }
} }
@ -109,26 +108,16 @@ func (service sessionService) CloseSessionById(sessionId string, code int, reaso
service.DisDBSess(sessionId, code, reason) service.DisDBSess(sessionId, code, reason)
} }
func (service sessionService) WriteCloseMessage(ws *websocket.Conn, mode string, code int, reason string) { func (service sessionService) WriteCloseMessage(sess *session.Session, mode string, code int, reason string) {
switch mode { switch mode {
case constant.Guacd: case constant.Guacd:
if ws != nil {
err := guacd.NewInstruction("error", "", strconv.Itoa(code)) err := guacd.NewInstruction("error", "", strconv.Itoa(code))
_ = ws.WriteMessage(websocket.TextMessage, []byte(err.String())) _ = sess.WriteString(err.String())
disconnect := guacd.NewInstruction("disconnect") disconnect := guacd.NewInstruction("disconnect")
_ = ws.WriteMessage(websocket.TextMessage, []byte(disconnect.String())) _ = sess.WriteString(disconnect.String())
} case constant.Native, constant.Terminal:
case constant.Native:
if ws != nil {
msg := `0` + reason msg := `0` + reason
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg)) _ = sess.WriteString(msg)
}
case constant.Terminal:
// 这里是关闭观察者的ssh会话
if ws != nil {
msg := `0` + reason
_ = ws.WriteMessage(websocket.TextMessage, []byte(msg))
}
} }
} }

View File

@ -70,7 +70,7 @@ func sendObData(sessionId, s string) {
if nextSession.Observer != nil { if nextSession.Observer != nil {
obs := nextSession.Observer.All() obs := nextSession.Observer.All()
for _, ob := range obs { for _, ob := range obs {
_ = api.WriteMessage(ob.WebSocket, dto.NewMessage(api.Data, s)) _ = ob.WriteMessage(dto.NewMessage(api.Data, s))
} }
} }
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 MiB