diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 220f4ba..76a72e4 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -115,7 +115,7 @@ func SetupRoutes() *echo.Echo { sessions.DELETE("/:id/rm", SessionRmEndpoint) sessions.DELETE("/:id", SessionDeleteEndpoint) sessions.GET("/:id/recording", SessionRecordingEndpoint) - sessions.GET("/:id", SessionGetEndpoint) + sessions.GET("/:id/status", SessionGetStatusEndpoint) } resourceSharers := e.Group("/resource-sharers") diff --git a/pkg/api/session.go b/pkg/api/session.go index a293e58..7b99446 100644 --- a/pkg/api/session.go +++ b/pkg/api/session.go @@ -17,6 +17,7 @@ import ( "path" "strconv" "strings" + "sync" "time" ) @@ -98,13 +99,19 @@ func SessionDiscontentEndpoint(c echo.Context) error { return Success(c, nil) } +var mutex sync.Mutex + func CloseSessionById(sessionId string, code int, reason string) { + mutex.Lock() + defer mutex.Unlock() observable, _ := global.Store.Get(sessionId) if observable != nil { + logrus.Debugf("会话%v创建者退出", observable.Subject.Tunnel.UUID) _ = observable.Subject.Tunnel.Close() for i := 0; i < len(observable.Observers); i++ { _ = observable.Observers[i].Tunnel.Close() CloseWebSocket(observable.Observers[i].WebSocket, code, reason) + logrus.Debugf("强制踢出会话%v的观察者", observable.Observers[i].Tunnel.UUID) } CloseWebSocket(observable.Subject.WebSocket, code, reason) } @@ -515,11 +522,13 @@ func SessionRecordingEndpoint(c echo.Context) error { return c.File(recording) } -func SessionGetEndpoint(c echo.Context) error { +func SessionGetStatusEndpoint(c echo.Context) error { sessionId := c.Param("id") session, err := model.FindSessionById(sessionId) if err != nil { return err } - return Success(c, session) + return Success(c, H{ + "status": session.Status, + }) } diff --git a/pkg/api/tunnel.go b/pkg/api/tunnel.go index 98abf7c..1e07cf8 100644 --- a/pkg/api/tunnel.go +++ b/pkg/api/tunnel.go @@ -2,7 +2,6 @@ package api import ( "errors" - "fmt" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/sirupsen/logrus" @@ -40,9 +39,6 @@ func TunEndpoint(c echo.Context) error { intHeight, _ := strconv.Atoi(height) configuration := guacd.NewConfiguration() - configuration.SetParameter("width", width) - configuration.SetParameter("height", height) - configuration.SetParameter("dpi", dpi) propertyMap := model.FindAllPropertiesMap() @@ -59,7 +55,14 @@ func TunEndpoint(c echo.Context) error { return errors.New("会话未在线") } configuration.ConnectionID = connectionId + sessionId = session.ID + configuration.SetParameter("width", strconv.Itoa(session.Width)) + configuration.SetParameter("height", strconv.Itoa(session.Height)) + configuration.SetParameter("dpi", "96") } else { + configuration.SetParameter("width", width) + configuration.SetParameter("height", height) + configuration.SetParameter("dpi", dpi) session, err = model.FindSessionById(sessionId) if err != nil { CloseSessionById(sessionId, NotFoundSession, "会话不存在") @@ -95,7 +98,6 @@ func TunEndpoint(c echo.Context) error { configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching]) configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching]) configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching]) - configuration.SetParameter("server-layout", "en-us-qwerty") break case "ssh": if len(session.PrivateKey) > 0 && session.PrivateKey != "-" { @@ -133,8 +135,6 @@ func TunEndpoint(c echo.Context) error { addr := propertyMap[guacd.Host] + ":" + propertyMap[guacd.Port] - logrus.Infof("connect to %v with global: %+v", addr, configuration) - tunnel, err := guacd.NewTunnel(addr, configuration) if err != nil { if connectionId == "" { @@ -150,6 +150,7 @@ func TunEndpoint(c echo.Context) error { } if len(session.ConnectionId) == 0 { + var observers []global.Tun observable := global.Observable{ Subject: &tun, @@ -157,31 +158,33 @@ func TunEndpoint(c echo.Context) error { } global.Store.Set(sessionId, &observable) - // 创建新会话 - session.ConnectionId = tunnel.UUID - session.Width = intWidth - session.Height = intHeight - session.Status = model.Connecting - session.Recording = configuration.GetParameter(guacd.RecordingPath) - if err := model.UpdateSessionById(&session, sessionId); err != nil { + sess := model.Session{ + ConnectionId: tunnel.UUID, + Width: intWidth, + Height: intHeight, + Status: model.Connecting, + Recording: configuration.GetParameter(guacd.RecordingPath), + } + // 创建新会话 + logrus.Debugf("创建新会话 %v", sess.ConnectionId) + if err := model.UpdateSessionById(&sess, sessionId); err != nil { return err } } else { - // TODO 处理监控会话的退出 // 监控会话 observable, ok := global.Store.Get(sessionId) if ok { observers := append(observable.Observers, tun) observable.Observers = observers global.Store.Set(sessionId, observable) + logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers)) } } go func() { for true { instruction, err := tunnel.Read() - fmt.Printf("<- %v \n", string(instruction)) if err != nil { if connectionId == "" { CloseSessionById(sessionId, TunnelClosed, "远程连接关闭") diff --git a/pkg/guacd/guacd.go b/pkg/guacd/guacd.go index 4c754eb..5e57de5 100644 --- a/pkg/guacd/guacd.go +++ b/pkg/guacd/guacd.go @@ -140,7 +140,6 @@ func NewTunnel(address string, config Configuration) (ret *Tunnel, err error) { if err := ret.WriteInstructionAndFlush(NewInstruction("size", width, height, dpi)); err != nil { return nil, err } - if err := ret.WriteInstructionAndFlush(NewInstruction("audio")); err != nil { return nil, err } @@ -150,10 +149,9 @@ func NewTunnel(address string, config Configuration) (ret *Tunnel, err error) { if err := ret.WriteInstructionAndFlush(NewInstruction("image")); err != nil { return nil, err } - - //if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil { - // return nil, err - //} + if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil { + return nil, err + } parameters := make([]string, len(args.Args)) for i := range args.Args { @@ -198,7 +196,7 @@ func (opt *Tunnel) WriteInstruction(instruction Instruction) error { } func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) { - fmt.Printf("-> %v \n", string(p)) + //fmt.Printf("-> %v \n", string(p)) nn, err := opt.rw.Write(p) if err != nil { return nn, err @@ -211,7 +209,7 @@ func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) { } func (opt *Tunnel) Write(p []byte) (int, error) { - fmt.Printf("-> %v \n", string(p)) + //fmt.Printf("-> %v \n", string(p)) nn, err := opt.rw.Write(p) if err != nil { return nn, err @@ -225,15 +223,17 @@ func (opt *Tunnel) Flush() error { func (opt *Tunnel) ReadInstruction() (instruction Instruction, err error) { msg, err := opt.rw.ReadString(Delimiter) - fmt.Printf("<- %v \n", msg) + //fmt.Printf("<- %v \n", msg) if err != nil { return instruction, err } return instruction.Parse(msg), err } -func (opt *Tunnel) Read() ([]byte, error) { - return opt.rw.ReadBytes(Delimiter) +func (opt *Tunnel) Read() (p []byte, err error) { + p, err = opt.rw.ReadBytes(Delimiter) + //fmt.Printf("<- %v \n", string(p)) + return } func (opt *Tunnel) expect(opcode string) (instruction Instruction, err error) { diff --git a/pkg/model/session.go b/pkg/model/session.go index a7a6e19..2f4a228 100644 --- a/pkg/model/session.go +++ b/pkg/model/session.go @@ -58,6 +58,8 @@ type SessionVo struct { DisconnectedTime utils.JsonTime `json:"disconnectedTime"` AssetName string `json:"assetName"` CreatorName string `json:"creatorName"` + Code int `json:"code"` + Message string `json:"message"` } func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []SessionVo, total int64, err error) { @@ -67,7 +69,7 @@ func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId, params = append(params, status) - itemSql := "SELECT s.id, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? " + itemSql := "SELECT s.id, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? " countSql := "select count(*) from sessions as s where s.status = ? " if len(userId) > 0 { diff --git a/web/src/components/access/Access.js b/web/src/components/access/Access.js index bd90c01..7353440 100644 --- a/web/src/components/access/Access.js +++ b/web/src/components/access/Access.js @@ -28,7 +28,7 @@ import { CloudUploadOutlined, CopyOutlined, DeleteOutlined, - DesktopOutlined, + DesktopOutlined, ExclamationCircleOutlined, ExpandOutlined, FileZipOutlined, FolderAddOutlined, @@ -146,18 +146,20 @@ class Access extends Component { }) if (this.state.protocol === 'ssh') { if (data.data && data.data.length > 0) { - message.success('您输入的内容已复制到远程服务器上,使用右键将自动粘贴。'); + // message.success('您输入的内容已复制到远程服务器上,使用右键将自动粘贴。'); } } else { if (data.data && data.data.length > 0) { - message.success('您输入的内容已复制到远程服务器上'); + // message.success('您输入的内容已复制到远程服务器上'); } } } onTunnelStateChange = (state) => { - + if(state === Guacamole.Tunnel.State.CLOSED){ + this.showMessage('连接已关闭'); + } }; updateSessionStatus = async (sessionId) => { @@ -281,9 +283,19 @@ class Access extends Component { showMessage(msg) { message.destroy(); - Modal.error({ + Modal.confirm({ title: '提示', + icon: , content: msg, + centered: true, + okText: '重新连接', + cancelText: '关闭页面', + onOk() { + window.location.reload(); + }, + onCancel() { + window.close(); + }, }); } @@ -304,7 +316,7 @@ class Access extends Component { // Set clipboard contents once stream is finished reader.onend = async () => { - message.success('您选择的内容已复制到您的粘贴板中,在右侧的输入框中可同时查看到。'); + // message.success('您选择的内容已复制到您的粘贴板中,在右侧的输入框中可同时查看到。'); this.setState({ clipboardText: data }); @@ -501,7 +513,7 @@ class Access extends Component { keyboard.onkeyup = this.onKeyUp; let stateChecker = setInterval(async () => { - let result = await request.get(`/sessions/${sessionId}`); + let result = await request.get(`/sessions/${sessionId}/status`); if (result['code'] !== 1) { clearInterval(stateChecker); } else {