diff --git a/pkg/api/session.go b/pkg/api/session.go index 64f8432..a293e58 100644 --- a/pkg/api/session.go +++ b/pkg/api/session.go @@ -36,7 +36,7 @@ func SessionPagingEndpoint(c echo.Context) error { } for i := 0; i < len(items); i++ { - if len(items[i].Recording) > 0 { + if status == model.Disconnected && len(items[i].Recording) > 0 { recording := items[i].Recording + "/recording" if utils.FileExists(recording) { @@ -82,7 +82,9 @@ func SessionContentEndpoint(c echo.Context) error { session.Status = model.Connected session.ConnectedTime = utils.NowJsonTime() - model.UpdateSessionById(&session, sessionId) + if err := model.UpdateSessionById(&session, sessionId); err != nil { + return err + } return Success(c, nil) } @@ -91,16 +93,20 @@ func SessionDiscontentEndpoint(c echo.Context) error { split := strings.Split(sessionIds, ",") for i := range split { - CloseSessionById(split[i], ForcedDisconnect, "管理员强制关闭了此次接入。") + CloseSessionById(split[i], ForcedDisconnect, "强制断开") } return Success(c, nil) } func CloseSessionById(sessionId string, code int, reason string) { - tun, _ := global.Store.Get(sessionId) - if tun != nil { - _ = tun.Tun.Close() - CloseSessionByWebSocket(tun.WebSocket, code, reason) + observable, _ := global.Store.Get(sessionId) + if observable != nil { + _ = observable.Subject.Tunnel.Close() + for i := 0; i < len(observable.Observers); i++ { + _ = observable.Observers[i].Tunnel.Close() + CloseWebSocket(observable.Observers[i].WebSocket, code, reason) + } + CloseWebSocket(observable.Subject.WebSocket, code, reason) } global.Store.Del(sessionId) @@ -126,10 +132,10 @@ func CloseSessionById(sessionId string, code int, reason string) { session.Code = code session.Message = reason - model.UpdateSessionById(&session, sessionId) + _ = model.UpdateSessionById(&session, sessionId) } -func CloseSessionByWebSocket(ws *websocket.Conn, c int, t string) { +func CloseWebSocket(ws *websocket.Conn, c int, t string) { if ws == nil { return } @@ -157,13 +163,10 @@ func SessionResizeEndpoint(c echo.Context) error { intHeight, _ := strconv.Atoi(height) - session := model.Session{} - session.ID = sessionId - session.Width = intWidth - session.Height = intHeight - - model.UpdateSessionById(&session, sessionId) - return Success(c, session) + if err := model.UpdateSessionWindowSizeById(intWidth, intHeight, sessionId); err != nil { + return err + } + return Success(c, "") } func SessionCreateEndpoint(c echo.Context) error { @@ -239,7 +242,7 @@ func SessionUploadEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - dstFile, err := tun.SftpClient.Create(remoteFile) + dstFile, err := tun.Subject.SftpClient.Create(remoteFile) defer dstFile.Close() if err != nil { return err @@ -292,7 +295,7 @@ func SessionDownloadEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - dstFile, err := tun.SftpClient.Open(remoteFile) + dstFile, err := tun.Subject.SftpClient.Open(remoteFile) if err != nil { return err } @@ -341,16 +344,16 @@ func SessionLsEndpoint(c echo.Context) error { return errors.New("获取sftp客户端失败") } - if tun.SftpClient == nil { + if tun.Subject.SftpClient == nil { sftpClient, err := CreateSftpClient(session.AssetId) if err != nil { logrus.Errorf("创建sftp客户端失败:%v", err.Error()) return err } - tun.SftpClient = sftpClient + tun.Subject.SftpClient = sftpClient } - fileInfos, err := tun.SftpClient.ReadDir(remoteDir) + fileInfos, err := tun.Subject.SftpClient.ReadDir(remoteDir) if err != nil { return err } @@ -410,7 +413,7 @@ func SessionMkDirEndpoint(c echo.Context) error { if !ok { return errors.New("获取sftp客户端失败") } - if err := tun.SftpClient.Mkdir(remoteDir); err != nil { + if err := tun.Subject.SftpClient.Mkdir(remoteDir); err != nil { return err } return Success(c, nil) @@ -441,18 +444,18 @@ func SessionRmDirEndpoint(c echo.Context) error { if !ok { return errors.New("获取sftp客户端失败") } - fileInfos, err := tun.SftpClient.ReadDir(remoteDir) + fileInfos, err := tun.Subject.SftpClient.ReadDir(remoteDir) if err != nil { return err } for i := range fileInfos { - if err := tun.SftpClient.Remove(path.Join(remoteDir, fileInfos[i].Name())); err != nil { + if err := tun.Subject.SftpClient.Remove(path.Join(remoteDir, fileInfos[i].Name())); err != nil { return err } } - if err := tun.SftpClient.RemoveDirectory(remoteDir); err != nil { + if err := tun.Subject.SftpClient.RemoveDirectory(remoteDir); err != nil { return err } return Success(c, nil) @@ -483,7 +486,7 @@ func SessionRmEndpoint(c echo.Context) error { if !ok { return errors.New("获取sftp客户端失败") } - if err := tun.SftpClient.Remove(remoteFile); err != nil { + if err := tun.Subject.SftpClient.Remove(remoteFile); err != nil { return err } return Success(c, nil) diff --git a/pkg/api/tunnel.go b/pkg/api/tunnel.go index 2b3dece..98abf7c 100644 --- a/pkg/api/tunnel.go +++ b/pkg/api/tunnel.go @@ -1,6 +1,8 @@ package api import ( + "errors" + "fmt" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/sirupsen/logrus" @@ -49,9 +51,13 @@ func TunEndpoint(c echo.Context) error { if len(connectionId) > 0 { session, err = model.FindSessionByConnectionId(connectionId) if err != nil { - CloseSessionById(sessionId, NotFoundSession, "会话不存在") + CloseWebSocket(ws, NotFoundSession, "会话不存在") return err } + if session.Status != model.Connected { + CloseWebSocket(ws, NotFoundSession, "会话未在线") + return errors.New("会话未在线") + } configuration.ConnectionID = connectionId } else { session, err = model.FindSessionById(sessionId) @@ -131,38 +137,62 @@ func TunEndpoint(c echo.Context) error { tunnel, err := guacd.NewTunnel(addr, configuration) if err != nil { - CloseSessionById(sessionId, NewTunnelError, err.Error()) + if connectionId == "" { + CloseSessionById(sessionId, NewTunnelError, err.Error()) + } logrus.Printf("建立连接失败: %v", err.Error()) return err } tun := global.Tun{ - Tun: tunnel, + Tunnel: tunnel, WebSocket: ws, } - global.Store.Set(sessionId, &tun) - if len(session.ConnectionId) == 0 { + var observers []global.Tun + observable := global.Observable{ + Subject: &tun, + Observers: observers, + } + + global.Store.Set(sessionId, &observable) + // 创建新会话 session.ConnectionId = tunnel.UUID session.Width = intWidth session.Height = intHeight session.Status = model.Connecting session.Recording = configuration.GetParameter(guacd.RecordingPath) - model.UpdateSessionById(&session, sessionId) + if err := model.UpdateSessionById(&session, sessionId); err != nil { + return err + } + } else { + // TODO 处理监控会话的退出 + // 监控会话 + observable, ok := global.Store.Get(sessionId) + if ok { + observers := append(observable.Observers, tun) + observable.Observers = observers + global.Store.Set(sessionId, observable) + } } go func() { for true { instruction, err := tunnel.Read() + fmt.Printf("<- %v \n", string(instruction)) if err != nil { - CloseSessionById(sessionId, TunnelClosed, "隧道已关闭") + if connectionId == "" { + CloseSessionById(sessionId, TunnelClosed, "远程连接关闭") + } break } err = ws.WriteMessage(websocket.TextMessage, instruction) if err != nil { - CloseSessionById(sessionId, TunnelClosed, "隧道已关闭") + if connectionId == "" { + CloseSessionById(sessionId, Normal, "正常退出") + } break } } @@ -171,12 +201,16 @@ func TunEndpoint(c echo.Context) error { for true { _, message, err := ws.ReadMessage() if err != nil { - CloseSessionById(sessionId, Normal, "用户主动关闭了会话") + if connectionId == "" { + CloseSessionById(sessionId, Normal, "正常退出") + } break } _, err = tunnel.WriteAndFlush(message) if err != nil { - CloseSessionById(sessionId, Normal, "用户主动关闭了会话") + if connectionId == "" { + CloseSessionById(sessionId, TunnelClosed, "远程连接关闭") + } break } } diff --git a/pkg/global/store.go b/pkg/global/store.go index 0b1bd34..fc62b7f 100644 --- a/pkg/global/store.go +++ b/pkg/global/store.go @@ -8,16 +8,21 @@ import ( ) type Tun struct { - Tun *guacd.Tunnel + Tunnel *guacd.Tunnel SftpClient *sftp.Client WebSocket *websocket.Conn } +type Observable struct { + Subject *Tun + Observers []Tun +} + type TunStore struct { m sync.Map } -func (s *TunStore) Set(k string, v *Tun) { +func (s *TunStore) Set(k string, v *Observable) { s.m.Store(k, v) } @@ -25,10 +30,10 @@ func (s *TunStore) Del(k string) { s.m.Delete(k) } -func (s *TunStore) Get(k string) (item *Tun, ok bool) { +func (s *TunStore) Get(k string) (item *Observable, ok bool) { value, ok := s.m.Load(k) if ok { - return value.(*Tun), true + return value.(*Observable), true } return item, false } diff --git a/pkg/guacd/guacd.go b/pkg/guacd/guacd.go index fc4d680..4c754eb 100644 --- a/pkg/guacd/guacd.go +++ b/pkg/guacd/guacd.go @@ -151,9 +151,9 @@ func NewTunnel(address string, config Configuration) (ret *Tunnel, err error) { return nil, err } - if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil { - return nil, err - } + //if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil { + // return nil, err + //} parameters := make([]string, len(args.Args)) for i := range args.Args { @@ -198,7 +198,7 @@ func (opt *Tunnel) WriteInstruction(instruction Instruction) error { } func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) { - //fmt.Printf("-> %v \n", string(p)) + fmt.Printf("-> %v \n", string(p)) nn, err := opt.rw.Write(p) if err != nil { return nn, err @@ -211,7 +211,7 @@ func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) { } func (opt *Tunnel) Write(p []byte) (int, error) { - //fmt.Printf("-> %v \n", string(p)) + fmt.Printf("-> %v \n", string(p)) nn, err := opt.rw.Write(p) if err != nil { return nn, err @@ -225,6 +225,7 @@ func (opt *Tunnel) Flush() error { func (opt *Tunnel) ReadInstruction() (instruction Instruction, err error) { msg, err := opt.rw.ReadString(Delimiter) + fmt.Printf("<- %v \n", msg) if err != nil { return instruction, err } diff --git a/pkg/handle/runner.go b/pkg/handle/runner.go index 2307b37..109635c 100644 --- a/pkg/handle/runner.go +++ b/pkg/handle/runner.go @@ -3,8 +3,6 @@ package handle import ( "github.com/robfig/cron/v3" "github.com/sirupsen/logrus" - "next-terminal/pkg/api" - "next-terminal/pkg/global" "next-terminal/pkg/guacd" "next-terminal/pkg/model" "next-terminal/pkg/utils" @@ -42,21 +40,6 @@ func RunTicker() { } }) - // 定时任务,每隔一分钟校验一次运行中的会话信息 - _, _ = c.AddFunc("0 0/1 0/1 * * ?", func() { - sessions, _ := model.FindSessionByStatus(model.Connected) - if sessions != nil && len(sessions) > 0 { - for i := range sessions { - _, found := global.Store.Get(sessions[i].ID) - if !found { - api.CloseSessionById(sessions[i].ID, api.Normal, "") - s := sessions[i].Username + "@" + sessions[i].IP + ":" + strconv.Itoa(sessions[i].Port) - logrus.Infof("会话「%v」ID「%v」已离线,修改状态为「关闭」。", s, sessions[i].ID) - } - } - } - }) - c.Start() } @@ -72,7 +55,7 @@ func RunDataFix() { DisconnectedTime: utils.NowJsonTime(), } - model.UpdateSessionById(&session, sessions[i].ID) + _ = model.UpdateSessionById(&session, sessions[i].ID) } } diff --git a/pkg/model/session.go b/pkg/model/session.go index 6e970f0..a7a6e19 100644 --- a/pkg/model/session.go +++ b/pkg/model/session.go @@ -132,9 +132,17 @@ func FindSessionByConnectionId(connectionId string) (o Session, err error) { return } -func UpdateSessionById(o *Session, id string) { +func UpdateSessionById(o *Session, id string) error { o.ID = id - global.DB.Updates(o) + return global.DB.Updates(o).Error +} + +func UpdateSessionWindowSizeById(width, height int, id string) error { + session := Session{} + session.Width = width + session.Height = height + + return UpdateSessionById(&session, id) } func DeleteSessionById(id string) { diff --git a/web/src/components/access/Access.js b/web/src/components/access/Access.js index 8f1a1aa..bd90c01 100644 --- a/web/src/components/access/Access.js +++ b/web/src/components/access/Access.js @@ -503,7 +503,7 @@ class Access extends Component { let stateChecker = setInterval(async () => { let result = await request.get(`/sessions/${sessionId}`); if (result['code'] !== 1) { - message.error(result['message']); + clearInterval(stateChecker); } else { let session = result['data']; if (session['status'] === 'connected') { diff --git a/web/src/components/access/Monitor.js b/web/src/components/access/Monitor.js index 83b962d..6dde230 100644 --- a/web/src/components/access/Monitor.js +++ b/web/src/components/access/Monitor.js @@ -1,6 +1,6 @@ import React, {Component} from 'react'; import Guacamole from 'guacamole-common-js'; -import {message, Modal} from 'antd' +import {Modal, Result, Spin} from 'antd' import qs from "qs"; import {wsServer} from "../../common/constants"; import {getToken} from "../../utils/utils"; @@ -20,9 +20,12 @@ class Access extends Component { state = { client: {}, containerOverflow: 'hidden', - containerWidth: 0, - containerHeight: 0, - rate: 1 + width: 0, + height: 0, + rate: 1, + loading: false, + tip: '', + closed: false, }; async componentDidMount() { @@ -38,8 +41,8 @@ class Access extends Component { height = height * 2; } this.setState({ - containerWidth: width * rate, - containerHeight: height * rate, + width: width * rate, + height: height * rate, rate: rate, }) this.renderDisplay(connectionId); @@ -53,40 +56,47 @@ class Access extends Component { onTunnelStateChange = (state) => { console.log('onTunnelStateChange', state); + if (state === Guacamole.Tunnel.State.CLOSED) { + this.setState({ + loading: false, + closed: true, + }); + } }; onClientStateChange = (state) => { switch (state) { case STATE_IDLE: - console.log('初始化'); - message.destroy(); - message.loading('正在初始化中...', 0); + this.setState({ + loading: true, + tip: '正在初始化中...' + }); break; case STATE_CONNECTING: - console.log('正在连接...'); - message.destroy(); - message.loading('正在努力连接中...', 0); + this.setState({ + loading: true, + tip: '正在努力连接中...' + }); break; case STATE_WAITING: - console.log('正在等待...'); - message.destroy(); - message.loading('正在等待服务器响应...', 0); + this.setState({ + loading: true, + tip: '正在等待服务器响应...' + }); break; case STATE_CONNECTED: - console.log('连接成功。'); - message.destroy(); - message.success('连接成功'); + this.setState({ + loading: false + }); if (this.state.client) { this.state.client.getDisplay().scale(this.state.rate); } break; case STATE_DISCONNECTING: - console.log('连接正在关闭中...'); - message.destroy(); + break; case STATE_DISCONNECTED: - console.log('连接关闭。'); - message.destroy(); + break; default: break; @@ -211,16 +221,26 @@ class Access extends Component { render() { return ( -