From 90751bacfc3b227f09b5aebbed6bac17835b66e1 Mon Sep 17 00:00:00 2001 From: dushixiang Date: Thu, 5 May 2022 17:18:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=8E=A5=E5=85=A5=E7=BD=91?= =?UTF-8?q?=E5=85=B3=E8=BF=9E=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/global/gateway/gateway.go | 13 ++--- server/global/gateway/tunnel.go | 83 ++++++++++++++++++-------------- server/global/session/session.go | 2 +- server/term/next_terminal.go | 4 +- 4 files changed, 53 insertions(+), 49 deletions(-) diff --git a/server/global/gateway/gateway.go b/server/global/gateway/gateway.go index 54fc6e7..142c847 100644 --- a/server/global/gateway/gateway.go +++ b/server/global/gateway/gateway.go @@ -57,15 +57,14 @@ func (g *Gateway) Run() { } func (g *Gateway) Close() { - g.exit <- true if g.SshClient != nil { _ = g.SshClient.Close() } - if len(g.tunnels) > 0 { - for _, tunnel := range g.tunnels { - tunnel.Close() - } + for id := range g.tunnels { + g.CloseSshTunnel(id) } + + g.exit <- true } func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) { @@ -111,7 +110,5 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo } func (g Gateway) CloseSshTunnel(id string) { - if g.tunnels[id] != nil { - g.tunnels[id].Close() - } + g.Del <- id } diff --git a/server/global/gateway/tunnel.go b/server/global/gateway/tunnel.go index 5a272da..58c88f4 100644 --- a/server/global/gateway/tunnel.go +++ b/server/global/gateway/tunnel.go @@ -10,49 +10,58 @@ import ( ) type Tunnel struct { - ID string // 唯一标识 - LocalHost string // 本地监听地址 - LocalPort int // 本地端口 - RemoteHost string // 远程连接地址 - RemotePort int // 远程端口 - Gateway *Gateway - ctx context.Context - cancel context.CancelFunc - listener net.Listener - err error + ID string // 唯一标识 + LocalHost string // 本地监听地址 + LocalPort int // 本地端口 + RemoteHost string // 远程连接地址 + RemotePort int // 远程端口 + Gateway *Gateway + ctx context.Context + cancel context.CancelFunc + listener net.Listener + localConnections []net.Conn + remoteConnections []net.Conn } func (r *Tunnel) Open() { localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort) - for { - select { - case <-r.ctx.Done(): - _ = r.listener.Close() - log.Debugf("SSH 隧道 %v 关闭", localAddr) - return - default: - log.Debugf("等待客户端访问 %v", localAddr) - localConn, err := r.listener.Accept() - if err != nil { - log.Debugf("接受连接失败 %v", err.Error()) - continue - } - log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) - remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) - log.Debugf("连接远程主机 %v ...", remoteAddr) - remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) - if err != nil { - log.Debugf("连接远程主机 %v 失败", remoteAddr) - r.err = err - return - } - - log.Debugf("连接远程主机 %v 成功", remoteAddr) - go copyConn(localConn, remoteConn) - go copyConn(remoteConn, localConn) - log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) + go func() { + <-r.ctx.Done() + _ = r.listener.Close() + for i := range r.localConnections { + _ = r.localConnections[i].Close() } + r.localConnections = nil + for i := range r.remoteConnections { + _ = r.remoteConnections[i].Close() + } + r.remoteConnections = nil + log.Debugf("SSH 隧道 %v 关闭", localAddr) + }() + for { + log.Debugf("等待客户端访问 %v", localAddr) + localConn, err := r.listener.Accept() + if err != nil { + log.Debugf("接受连接失败 %v, 退出循环", err.Error()) + return + } + r.localConnections = append(r.localConnections, localConn) + + log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) + remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) + log.Debugf("连接远程主机 %v ...", remoteAddr) + remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) + if err != nil { + log.Debugf("连接远程主机 %v 失败", remoteAddr) + return + } + r.remoteConnections = append(r.remoteConnections, remoteConn) + + log.Debugf("连接远程主机 %v 成功", remoteAddr) + go copyConn(localConn, remoteConn) + go copyConn(remoteConn, localConn) + log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) } } diff --git a/server/global/session/session.go b/server/global/session/session.go index 40858a8..e52aa29 100644 --- a/server/global/session/session.go +++ b/server/global/session/session.go @@ -84,7 +84,7 @@ func (m *Manager) Start() { _ = ss.GuacdTunnel.Close() } if ss.NextTerminal != nil { - _ = ss.NextTerminal.Close() + ss.NextTerminal.Close() } if ss.WebSocket != nil { diff --git a/server/term/next_terminal.go b/server/term/next_terminal.go index a78b707..307b5c9 100644 --- a/server/term/next_terminal.go +++ b/server/term/next_terminal.go @@ -79,7 +79,7 @@ func (ret *NextTerminal) Write(p []byte) (int, error) { return ret.StdinPipe.Write(p) } -func (ret *NextTerminal) Close() error { +func (ret *NextTerminal) Close() { if ret.SftpClient != nil { _ = ret.SftpClient.Close() @@ -96,8 +96,6 @@ func (ret *NextTerminal) Close() error { if ret.Recorder != nil { ret.Recorder.Close() } - - return nil } func (ret *NextTerminal) WindowChange(h int, w int) error {