From 8f62eab194695ffd05e7d9abbceaea9c1b364872 Mon Sep 17 00:00:00 2001 From: 1mtrue Date: Thu, 5 May 2022 18:08:31 +0800 Subject: [PATCH] fix(tunnel):recover code --- server/global/gateway/tunnel.go | 52 ++++++++++++++------------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/server/global/gateway/tunnel.go b/server/global/gateway/tunnel.go index 088bbc5..307231c 100644 --- a/server/global/gateway/tunnel.go +++ b/server/global/gateway/tunnel.go @@ -24,44 +24,36 @@ type Tunnel struct { func (r *Tunnel) Open() { localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort) + go func() { + <-r.ctx.Done() + _ = r.listener.Close() + }() for { - select { - case <-r.ctx.Done(): - _ = r.listener.Close() - log.Debugf("SSH 隧道 %v 关闭", localAddr) + log.Debugf("等待客户端访问 %v", localAddr) + localConn, err := r.listener.Accept() + if err != nil { + log.Debugf("接受连接失败 %v, 退出循环", err.Error()) return - default: - log.Debugf("等待客户端访问 %v", localAddr) - localConn, err := r.listener.Accept() - if err != nil { - log.Debugf("接受连接失败 %v", err.Error()) - continue - } - - log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) - remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) - log.Debugf("连接远程主机 %v ...", remoteAddr) - remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) - if err != nil { - log.Debugf("连接远程主机 %v 失败", remoteAddr) - r.err = err - return - } - - log.Debugf("连接远程主机 %v 成功", remoteAddr) - go copyConn(r.ctx, localConn, remoteConn) - go copyConn(r.ctx, remoteConn, localConn) - log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) } + + log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) + remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) + log.Debugf("连接远程主机 %v ...", remoteAddr) + remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) + if err != nil { + log.Debugf("连接远程主机 %v 失败", remoteAddr) + return + } + + log.Debugf("连接远程主机 %v 成功", remoteAddr) + go copyConn(r.ctx, localConn, remoteConn) + go copyConn(r.ctx, remoteConn, localConn) + log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) } } func (r Tunnel) Close() { r.cancel() - err := r.listener.Close() - if err != nil { - return - } } func copyConn(ctx context.Context, writer, reader net.Conn) {