Be fix goroutine leak (#252)

* fix(gateway):fix goroutine leak

* fix(term):remove useless code

* fix(be):add pprof for debug mode
This commit is contained in:
1mtrue
2022-05-05 17:44:05 +08:00
committed by GitHub
parent 90751bacfc
commit 9e44b25b87
5 changed files with 95 additions and 66 deletions

View File

@ -10,65 +10,69 @@ import (
)
type Tunnel struct {
ID string // 唯一标识
LocalHost string // 本地监听地址
LocalPort int // 本地端口
RemoteHost string // 远程连接地址
RemotePort int // 远程端口
Gateway *Gateway
ctx context.Context
cancel context.CancelFunc
listener net.Listener
localConnections []net.Conn
remoteConnections []net.Conn
ID string // 唯一标识
LocalHost string // 本地监听地址
LocalPort int // 本地端口
RemoteHost string // 远程连接地址
RemotePort int // 远程端口
Gateway *Gateway
ctx context.Context
cancel context.CancelFunc
listener net.Listener
err error
}
func (r *Tunnel) Open() {
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
go func() {
<-r.ctx.Done()
_ = r.listener.Close()
for i := range r.localConnections {
_ = r.localConnections[i].Close()
}
r.localConnections = nil
for i := range r.remoteConnections {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
log.Debugf("SSH 隧道 %v 关闭", localAddr)
}()
for {
log.Debugf("等待客户端访问 %v", localAddr)
localConn, err := r.listener.Accept()
if err != nil {
log.Debugf("接受连接失败 %v, 退出循环", err.Error())
select {
case <-r.ctx.Done():
_ = r.listener.Close()
log.Debugf("SSH 隧道 %v 关闭", localAddr)
return
}
r.localConnections = append(r.localConnections, localConn)
default:
log.Debugf("等待客户端访问 %v", localAddr)
localConn, err := r.listener.Accept()
if err != nil {
log.Debugf("接受连接失败 %v", err.Error())
continue
}
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
log.Debugf("连接远程主机 %v ...", remoteAddr)
remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr)
if err != nil {
log.Debugf("连接远程主机 %v 失败", remoteAddr)
return
}
r.remoteConnections = append(r.remoteConnections, remoteConn)
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
log.Debugf("连接远程主机 %v ...", remoteAddr)
remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr)
if err != nil {
log.Debugf("连接远程主机 %v 失败", remoteAddr)
r.err = err
return
}
log.Debugf("连接远程主机 %v 成功", remoteAddr)
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
log.Debugf("连接远程主机 %v 成功", remoteAddr)
go copyConn(r.ctx, localConn, remoteConn)
go copyConn(r.ctx, remoteConn, localConn)
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
}
}
}
func (r Tunnel) Close() {
r.cancel()
err := r.listener.Close()
if err != nil {
return
}
}
func copyConn(writer, reader net.Conn) {
func copyConn(ctx context.Context, writer, reader net.Conn) {
_, _ = io.Copy(writer, reader)
for {
select {
case <-ctx.Done():
_ = writer.Close()
_ = reader.Close()
return
}
}
}