修复rdp协议无法监控的bug

This commit is contained in:
dushixiang 2021-01-26 20:57:48 +08:00
parent ac99332453
commit 4f1bfa6c5a
6 changed files with 63 additions and 37 deletions

View File

@ -115,7 +115,7 @@ func SetupRoutes() *echo.Echo {
sessions.DELETE("/:id/rm", SessionRmEndpoint) sessions.DELETE("/:id/rm", SessionRmEndpoint)
sessions.DELETE("/:id", SessionDeleteEndpoint) sessions.DELETE("/:id", SessionDeleteEndpoint)
sessions.GET("/:id/recording", SessionRecordingEndpoint) sessions.GET("/:id/recording", SessionRecordingEndpoint)
sessions.GET("/:id", SessionGetEndpoint) sessions.GET("/:id/status", SessionGetStatusEndpoint)
} }
resourceSharers := e.Group("/resource-sharers") resourceSharers := e.Group("/resource-sharers")

View File

@ -17,6 +17,7 @@ import (
"path" "path"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@ -98,13 +99,19 @@ func SessionDiscontentEndpoint(c echo.Context) error {
return Success(c, nil) return Success(c, nil)
} }
var mutex sync.Mutex
func CloseSessionById(sessionId string, code int, reason string) { func CloseSessionById(sessionId string, code int, reason string) {
mutex.Lock()
defer mutex.Unlock()
observable, _ := global.Store.Get(sessionId) observable, _ := global.Store.Get(sessionId)
if observable != nil { if observable != nil {
logrus.Debugf("会话%v创建者退出", observable.Subject.Tunnel.UUID)
_ = observable.Subject.Tunnel.Close() _ = observable.Subject.Tunnel.Close()
for i := 0; i < len(observable.Observers); i++ { for i := 0; i < len(observable.Observers); i++ {
_ = observable.Observers[i].Tunnel.Close() _ = observable.Observers[i].Tunnel.Close()
CloseWebSocket(observable.Observers[i].WebSocket, code, reason) CloseWebSocket(observable.Observers[i].WebSocket, code, reason)
logrus.Debugf("强制踢出会话%v的观察者", observable.Observers[i].Tunnel.UUID)
} }
CloseWebSocket(observable.Subject.WebSocket, code, reason) CloseWebSocket(observable.Subject.WebSocket, code, reason)
} }
@ -515,11 +522,13 @@ func SessionRecordingEndpoint(c echo.Context) error {
return c.File(recording) return c.File(recording)
} }
func SessionGetEndpoint(c echo.Context) error { func SessionGetStatusEndpoint(c echo.Context) error {
sessionId := c.Param("id") sessionId := c.Param("id")
session, err := model.FindSessionById(sessionId) session, err := model.FindSessionById(sessionId)
if err != nil { if err != nil {
return err return err
} }
return Success(c, session) return Success(c, H{
"status": session.Status,
})
} }

View File

@ -2,7 +2,6 @@ package api
import ( import (
"errors" "errors"
"fmt"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -40,9 +39,6 @@ func TunEndpoint(c echo.Context) error {
intHeight, _ := strconv.Atoi(height) intHeight, _ := strconv.Atoi(height)
configuration := guacd.NewConfiguration() configuration := guacd.NewConfiguration()
configuration.SetParameter("width", width)
configuration.SetParameter("height", height)
configuration.SetParameter("dpi", dpi)
propertyMap := model.FindAllPropertiesMap() propertyMap := model.FindAllPropertiesMap()
@ -59,7 +55,14 @@ func TunEndpoint(c echo.Context) error {
return errors.New("会话未在线") return errors.New("会话未在线")
} }
configuration.ConnectionID = connectionId configuration.ConnectionID = connectionId
sessionId = session.ID
configuration.SetParameter("width", strconv.Itoa(session.Width))
configuration.SetParameter("height", strconv.Itoa(session.Height))
configuration.SetParameter("dpi", "96")
} else { } else {
configuration.SetParameter("width", width)
configuration.SetParameter("height", height)
configuration.SetParameter("dpi", dpi)
session, err = model.FindSessionById(sessionId) session, err = model.FindSessionById(sessionId)
if err != nil { if err != nil {
CloseSessionById(sessionId, NotFoundSession, "会话不存在") CloseSessionById(sessionId, NotFoundSession, "会话不存在")
@ -95,7 +98,6 @@ func TunEndpoint(c echo.Context) error {
configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching]) configuration.SetParameter(guacd.DisableBitmapCaching, propertyMap[guacd.DisableBitmapCaching])
configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching]) configuration.SetParameter(guacd.DisableOffscreenCaching, propertyMap[guacd.DisableOffscreenCaching])
configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching]) configuration.SetParameter(guacd.DisableGlyphCaching, propertyMap[guacd.DisableGlyphCaching])
configuration.SetParameter("server-layout", "en-us-qwerty")
break break
case "ssh": case "ssh":
if len(session.PrivateKey) > 0 && session.PrivateKey != "-" { if len(session.PrivateKey) > 0 && session.PrivateKey != "-" {
@ -133,8 +135,6 @@ func TunEndpoint(c echo.Context) error {
addr := propertyMap[guacd.Host] + ":" + propertyMap[guacd.Port] addr := propertyMap[guacd.Host] + ":" + propertyMap[guacd.Port]
logrus.Infof("connect to %v with global: %+v", addr, configuration)
tunnel, err := guacd.NewTunnel(addr, configuration) tunnel, err := guacd.NewTunnel(addr, configuration)
if err != nil { if err != nil {
if connectionId == "" { if connectionId == "" {
@ -150,6 +150,7 @@ func TunEndpoint(c echo.Context) error {
} }
if len(session.ConnectionId) == 0 { if len(session.ConnectionId) == 0 {
var observers []global.Tun var observers []global.Tun
observable := global.Observable{ observable := global.Observable{
Subject: &tun, Subject: &tun,
@ -157,31 +158,33 @@ func TunEndpoint(c echo.Context) error {
} }
global.Store.Set(sessionId, &observable) global.Store.Set(sessionId, &observable)
// 创建新会话
session.ConnectionId = tunnel.UUID
session.Width = intWidth
session.Height = intHeight
session.Status = model.Connecting
session.Recording = configuration.GetParameter(guacd.RecordingPath)
if err := model.UpdateSessionById(&session, sessionId); err != nil { sess := model.Session{
ConnectionId: tunnel.UUID,
Width: intWidth,
Height: intHeight,
Status: model.Connecting,
Recording: configuration.GetParameter(guacd.RecordingPath),
}
// 创建新会话
logrus.Debugf("创建新会话 %v", sess.ConnectionId)
if err := model.UpdateSessionById(&sess, sessionId); err != nil {
return err return err
} }
} else { } else {
// TODO 处理监控会话的退出
// 监控会话 // 监控会话
observable, ok := global.Store.Get(sessionId) observable, ok := global.Store.Get(sessionId)
if ok { if ok {
observers := append(observable.Observers, tun) observers := append(observable.Observers, tun)
observable.Observers = observers observable.Observers = observers
global.Store.Set(sessionId, observable) global.Store.Set(sessionId, observable)
logrus.Debugf("加入会话%v,当前观察者数量为:%v", session.ConnectionId, len(observers))
} }
} }
go func() { go func() {
for true { for true {
instruction, err := tunnel.Read() instruction, err := tunnel.Read()
fmt.Printf("<- %v \n", string(instruction))
if err != nil { if err != nil {
if connectionId == "" { if connectionId == "" {
CloseSessionById(sessionId, TunnelClosed, "远程连接关闭") CloseSessionById(sessionId, TunnelClosed, "远程连接关闭")

View File

@ -140,7 +140,6 @@ func NewTunnel(address string, config Configuration) (ret *Tunnel, err error) {
if err := ret.WriteInstructionAndFlush(NewInstruction("size", width, height, dpi)); err != nil { if err := ret.WriteInstructionAndFlush(NewInstruction("size", width, height, dpi)); err != nil {
return nil, err return nil, err
} }
if err := ret.WriteInstructionAndFlush(NewInstruction("audio")); err != nil { if err := ret.WriteInstructionAndFlush(NewInstruction("audio")); err != nil {
return nil, err return nil, err
} }
@ -150,10 +149,9 @@ func NewTunnel(address string, config Configuration) (ret *Tunnel, err error) {
if err := ret.WriteInstructionAndFlush(NewInstruction("image")); err != nil { if err := ret.WriteInstructionAndFlush(NewInstruction("image")); err != nil {
return nil, err return nil, err
} }
if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil {
//if err := ret.WriteInstructionAndFlush(NewInstruction("timezone", "Asia/Shanghai")); err != nil { return nil, err
// return nil, err }
//}
parameters := make([]string, len(args.Args)) parameters := make([]string, len(args.Args))
for i := range args.Args { for i := range args.Args {
@ -198,7 +196,7 @@ func (opt *Tunnel) WriteInstruction(instruction Instruction) error {
} }
func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) { func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) {
fmt.Printf("-> %v \n", string(p)) //fmt.Printf("-> %v \n", string(p))
nn, err := opt.rw.Write(p) nn, err := opt.rw.Write(p)
if err != nil { if err != nil {
return nn, err return nn, err
@ -211,7 +209,7 @@ func (opt *Tunnel) WriteAndFlush(p []byte) (int, error) {
} }
func (opt *Tunnel) Write(p []byte) (int, error) { func (opt *Tunnel) Write(p []byte) (int, error) {
fmt.Printf("-> %v \n", string(p)) //fmt.Printf("-> %v \n", string(p))
nn, err := opt.rw.Write(p) nn, err := opt.rw.Write(p)
if err != nil { if err != nil {
return nn, err return nn, err
@ -225,15 +223,17 @@ func (opt *Tunnel) Flush() error {
func (opt *Tunnel) ReadInstruction() (instruction Instruction, err error) { func (opt *Tunnel) ReadInstruction() (instruction Instruction, err error) {
msg, err := opt.rw.ReadString(Delimiter) msg, err := opt.rw.ReadString(Delimiter)
fmt.Printf("<- %v \n", msg) //fmt.Printf("<- %v \n", msg)
if err != nil { if err != nil {
return instruction, err return instruction, err
} }
return instruction.Parse(msg), err return instruction.Parse(msg), err
} }
func (opt *Tunnel) Read() ([]byte, error) { func (opt *Tunnel) Read() (p []byte, err error) {
return opt.rw.ReadBytes(Delimiter) p, err = opt.rw.ReadBytes(Delimiter)
//fmt.Printf("<- %v \n", string(p))
return
} }
func (opt *Tunnel) expect(opcode string) (instruction Instruction, err error) { func (opt *Tunnel) expect(opcode string) (instruction Instruction, err error) {

View File

@ -58,6 +58,8 @@ type SessionVo struct {
DisconnectedTime utils.JsonTime `json:"disconnectedTime"` DisconnectedTime utils.JsonTime `json:"disconnectedTime"`
AssetName string `json:"assetName"` AssetName string `json:"assetName"`
CreatorName string `json:"creatorName"` CreatorName string `json:"creatorName"`
Code int `json:"code"`
Message string `json:"message"`
} }
func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []SessionVo, total int64, err error) { func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId, protocol string) (results []SessionVo, total int64, err error) {
@ -67,7 +69,7 @@ func FindPageSession(pageIndex, pageSize int, status, userId, clientIp, assetId,
params = append(params, status) params = append(params, status)
itemSql := "SELECT s.id, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? " itemSql := "SELECT s.id, s.protocol,s.recording, s.connection_id, s.asset_id, s.creator, s.client_ip, s.width, s.height, s.ip, s.port, s.username, s.status, s.connected_time, s.disconnected_time,s.code, s.message, a.name AS asset_name, u.nickname AS creator_name FROM sessions s LEFT JOIN assets a ON s.asset_id = a.id LEFT JOIN users u ON s.creator = u.id WHERE s.STATUS = ? "
countSql := "select count(*) from sessions as s where s.status = ? " countSql := "select count(*) from sessions as s where s.status = ? "
if len(userId) > 0 { if len(userId) > 0 {

View File

@ -28,7 +28,7 @@ import {
CloudUploadOutlined, CloudUploadOutlined,
CopyOutlined, CopyOutlined,
DeleteOutlined, DeleteOutlined,
DesktopOutlined, DesktopOutlined, ExclamationCircleOutlined,
ExpandOutlined, ExpandOutlined,
FileZipOutlined, FileZipOutlined,
FolderAddOutlined, FolderAddOutlined,
@ -146,18 +146,20 @@ class Access extends Component {
}) })
if (this.state.protocol === 'ssh') { if (this.state.protocol === 'ssh') {
if (data.data && data.data.length > 0) { if (data.data && data.data.length > 0) {
message.success('您输入的内容已复制到远程服务器上,使用右键将自动粘贴。'); // message.success('您输入的内容已复制到远程服务器上,使用右键将自动粘贴。');
} }
} else { } else {
if (data.data && data.data.length > 0) { if (data.data && data.data.length > 0) {
message.success('您输入的内容已复制到远程服务器上'); // message.success('您输入的内容已复制到远程服务器上');
} }
} }
} }
onTunnelStateChange = (state) => { onTunnelStateChange = (state) => {
if(state === Guacamole.Tunnel.State.CLOSED){
this.showMessage('连接已关闭');
}
}; };
updateSessionStatus = async (sessionId) => { updateSessionStatus = async (sessionId) => {
@ -281,9 +283,19 @@ class Access extends Component {
showMessage(msg) { showMessage(msg) {
message.destroy(); message.destroy();
Modal.error({ Modal.confirm({
title: '提示', title: '提示',
icon: <ExclamationCircleOutlined />,
content: msg, content: msg,
centered: true,
okText: '重新连接',
cancelText: '关闭页面',
onOk() {
window.location.reload();
},
onCancel() {
window.close();
},
}); });
} }
@ -304,7 +316,7 @@ class Access extends Component {
// Set clipboard contents once stream is finished // Set clipboard contents once stream is finished
reader.onend = async () => { reader.onend = async () => {
message.success('您选择的内容已复制到您的粘贴板中,在右侧的输入框中可同时查看到。'); // message.success('您选择的内容已复制到您的粘贴板中,在右侧的输入框中可同时查看到。');
this.setState({ this.setState({
clipboardText: data clipboardText: data
}); });
@ -501,7 +513,7 @@ class Access extends Component {
keyboard.onkeyup = this.onKeyUp; keyboard.onkeyup = this.onKeyUp;
let stateChecker = setInterval(async () => { let stateChecker = setInterval(async () => {
let result = await request.get(`/sessions/${sessionId}`); let result = await request.get(`/sessions/${sessionId}/status`);
if (result['code'] !== 1) { if (result['code'] !== 1) {
clearInterval(stateChecker); clearInterval(stateChecker);
} else { } else {