优化接入网关连接

This commit is contained in:
dushixiang 2022-05-05 17:18:55 +08:00
parent f3a78761c8
commit 90751bacfc
4 changed files with 53 additions and 49 deletions

View File

@ -57,15 +57,14 @@ func (g *Gateway) Run() {
} }
func (g *Gateway) Close() { func (g *Gateway) Close() {
g.exit <- true
if g.SshClient != nil { if g.SshClient != nil {
_ = g.SshClient.Close() _ = g.SshClient.Close()
} }
if len(g.tunnels) > 0 { for id := range g.tunnels {
for _, tunnel := range g.tunnels { g.CloseSshTunnel(id)
tunnel.Close()
}
} }
g.exit <- true
} }
func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) { func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) {
@ -111,7 +110,5 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo
} }
func (g Gateway) CloseSshTunnel(id string) { func (g Gateway) CloseSshTunnel(id string) {
if g.tunnels[id] != nil { g.Del <- id
g.tunnels[id].Close()
}
} }

View File

@ -19,24 +19,34 @@ type Tunnel struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
listener net.Listener listener net.Listener
err error localConnections []net.Conn
remoteConnections []net.Conn
} }
func (r *Tunnel) Open() { func (r *Tunnel) Open() {
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort) localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
for {
select { go func() {
case <-r.ctx.Done(): <-r.ctx.Done()
_ = r.listener.Close() _ = r.listener.Close()
for i := range r.localConnections {
_ = r.localConnections[i].Close()
}
r.localConnections = nil
for i := range r.remoteConnections {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
log.Debugf("SSH 隧道 %v 关闭", localAddr) log.Debugf("SSH 隧道 %v 关闭", localAddr)
return }()
default: for {
log.Debugf("等待客户端访问 %v", localAddr) log.Debugf("等待客户端访问 %v", localAddr)
localConn, err := r.listener.Accept() localConn, err := r.listener.Accept()
if err != nil { if err != nil {
log.Debugf("接受连接失败 %v", err.Error()) log.Debugf("接受连接失败 %v, 退出循环", err.Error())
continue return
} }
r.localConnections = append(r.localConnections, localConn)
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
@ -44,16 +54,15 @@ func (r *Tunnel) Open() {
remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr)
if err != nil { if err != nil {
log.Debugf("连接远程主机 %v 失败", remoteAddr) log.Debugf("连接远程主机 %v 失败", remoteAddr)
r.err = err
return return
} }
r.remoteConnections = append(r.remoteConnections, remoteConn)
log.Debugf("连接远程主机 %v 成功", remoteAddr) log.Debugf("连接远程主机 %v 成功", remoteAddr)
go copyConn(localConn, remoteConn) go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn) go copyConn(remoteConn, localConn)
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
} }
}
} }
func (r Tunnel) Close() { func (r Tunnel) Close() {

View File

@ -84,7 +84,7 @@ func (m *Manager) Start() {
_ = ss.GuacdTunnel.Close() _ = ss.GuacdTunnel.Close()
} }
if ss.NextTerminal != nil { if ss.NextTerminal != nil {
_ = ss.NextTerminal.Close() ss.NextTerminal.Close()
} }
if ss.WebSocket != nil { if ss.WebSocket != nil {

View File

@ -79,7 +79,7 @@ func (ret *NextTerminal) Write(p []byte) (int, error) {
return ret.StdinPipe.Write(p) return ret.StdinPipe.Write(p)
} }
func (ret *NextTerminal) Close() error { func (ret *NextTerminal) Close() {
if ret.SftpClient != nil { if ret.SftpClient != nil {
_ = ret.SftpClient.Close() _ = ret.SftpClient.Close()
@ -96,8 +96,6 @@ func (ret *NextTerminal) Close() error {
if ret.Recorder != nil { if ret.Recorder != nil {
ret.Recorder.Close() ret.Recorder.Close()
} }
return nil
} }
func (ret *NextTerminal) WindowChange(h int, w int) error { func (ret *NextTerminal) WindowChange(h int, w int) error {