Be fix goroutine leak (#252)

* fix(gateway):fix goroutine leak

* fix(term):remove useless code

* fix(be):add pprof for debug mode
This commit is contained in:
1mtrue 2022-05-05 17:44:05 +08:00 committed by GitHub
parent 90751bacfc
commit 9e44b25b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 95 additions and 66 deletions

View File

@ -152,7 +152,6 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error {
NextTerminal: nextTerminal, NextTerminal: nextTerminal,
Observer: session.NewObserver(s.ID), Observer: session.NewObserver(s.ID),
} }
go nextSession.Observer.Start()
session.GlobalSessionManager.Add <- nextSession session.GlobalSessionManager.Add <- nextSession
termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal) termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal)

View File

@ -3,6 +3,10 @@ package app
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
_ "net/http/pprof"
"next-terminal/server/log"
"next-terminal/server/cli" "next-terminal/server/cli"
"next-terminal/server/config" "next-terminal/server/config"
@ -104,6 +108,9 @@ func Run() error {
if err != nil { if err != nil {
return err return err
} }
go func() {
log.Fatal(http.ListenAndServe("localhost:8099", nil))
}()
fmt.Printf("当前配置为: %v\n", string(jsonBytes)) fmt.Printf("当前配置为: %v\n", string(jsonBytes))
} }

View File

@ -6,6 +6,9 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sync"
"next-terminal/server/log"
"next-terminal/server/utils" "next-terminal/server/utils"
@ -19,7 +22,7 @@ type Gateway struct {
SshClient *ssh.Client SshClient *ssh.Client
Message string // 失败原因 Message string // 失败原因
tunnels map[string]*Tunnel tunnels *sync.Map
Add chan *Tunnel Add chan *Tunnel
Del chan string Del chan string
@ -34,7 +37,7 @@ func NewGateway(id string, connected bool, message string, client *ssh.Client) *
SshClient: client, SshClient: client,
Add: make(chan *Tunnel), Add: make(chan *Tunnel),
Del: make(chan string), Del: make(chan string),
tunnels: map[string]*Tunnel{}, tunnels: new(sync.Map),
exit: make(chan bool, 1), exit: make(chan bool, 1),
} }
} }
@ -43,12 +46,15 @@ func (g *Gateway) Run() {
for { for {
select { select {
case t := <-g.Add: case t := <-g.Add:
g.tunnels[t.ID] = t g.tunnels.Store(t.ID, t)
log.Info("add tunnel: %s", t.ID)
go t.Open() go t.Open()
case k := <-g.Del: case k := <-g.Del:
if _, ok := g.tunnels[k]; ok { if val, ok := g.tunnels.Load(k); ok {
g.tunnels[k].Close() if vval, vok := val.(*Tunnel); vok {
delete(g.tunnels, k) vval.Close()
g.tunnels.Delete(k)
}
} }
case <-g.exit: case <-g.exit:
return return
@ -57,14 +63,14 @@ func (g *Gateway) Run() {
} }
func (g *Gateway) Close() { func (g *Gateway) Close() {
if g.SshClient != nil { g.tunnels.Range(func(key, value interface{}) bool {
_ = g.SshClient.Close() if val, ok := value.(*Tunnel); ok {
val.Close()
} }
for id := range g.tunnels { return true
g.CloseSshTunnel(id) })
}
g.exit <- true g.exit <- true
} }
func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) { func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) {

View File

@ -1,7 +1,13 @@
package gateway package gateway
import (
"sync"
"next-terminal/server/log"
)
type Manager struct { type Manager struct {
gateways map[string]*Gateway gateways *sync.Map
Add chan *Gateway Add chan *Gateway
Del chan string Del chan string
@ -11,7 +17,7 @@ func NewManager() *Manager {
return &Manager{ return &Manager{
Add: make(chan *Gateway), Add: make(chan *Gateway),
Del: make(chan string), Del: make(chan string),
gateways: map[string]*Gateway{}, gateways: new(sync.Map),
} }
} }
@ -19,19 +25,26 @@ func (m *Manager) Start() {
for { for {
select { select {
case g := <-m.Add: case g := <-m.Add:
m.gateways[g.ID] = g m.gateways.Store(g.ID, g)
log.Info("add gateway: %s", g.ID)
go g.Run() go g.Run()
case k := <-m.Del: case k := <-m.Del:
if _, ok := m.gateways[k]; ok { if val, ok := m.gateways.Load(k); ok {
m.gateways[k].Close() if vv, vok := val.(*Gateway); vok {
delete(m.gateways, k) vv.Close()
m.gateways.Delete(k)
}
} }
} }
} }
} }
func (m Manager) GetById(id string) *Gateway { func (m Manager) GetById(id string) *Gateway {
return m.gateways[id] if val, ok := m.gateways.Load(id); ok {
return val.(*Gateway)
}
return nil
} }
var GlobalGatewayManager *Manager var GlobalGatewayManager *Manager

View File

@ -19,34 +19,24 @@ type Tunnel struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
listener net.Listener listener net.Listener
localConnections []net.Conn err error
remoteConnections []net.Conn
} }
func (r *Tunnel) Open() { func (r *Tunnel) Open() {
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort) localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
go func() {
<-r.ctx.Done()
_ = r.listener.Close()
for i := range r.localConnections {
_ = r.localConnections[i].Close()
}
r.localConnections = nil
for i := range r.remoteConnections {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
log.Debugf("SSH 隧道 %v 关闭", localAddr)
}()
for { for {
select {
case <-r.ctx.Done():
_ = r.listener.Close()
log.Debugf("SSH 隧道 %v 关闭", localAddr)
return
default:
log.Debugf("等待客户端访问 %v", localAddr) log.Debugf("等待客户端访问 %v", localAddr)
localConn, err := r.listener.Accept() localConn, err := r.listener.Accept()
if err != nil { if err != nil {
log.Debugf("接受连接失败 %v, 退出循环", err.Error()) log.Debugf("接受连接失败 %v", err.Error())
return continue
} }
r.localConnections = append(r.localConnections, localConn)
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
@ -54,21 +44,35 @@ func (r *Tunnel) Open() {
remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr)
if err != nil { if err != nil {
log.Debugf("连接远程主机 %v 失败", remoteAddr) log.Debugf("连接远程主机 %v 失败", remoteAddr)
r.err = err
return return
} }
r.remoteConnections = append(r.remoteConnections, remoteConn)
log.Debugf("连接远程主机 %v 成功", remoteAddr) log.Debugf("连接远程主机 %v 成功", remoteAddr)
go copyConn(localConn, remoteConn) go copyConn(r.ctx, localConn, remoteConn)
go copyConn(remoteConn, localConn) go copyConn(r.ctx, remoteConn, localConn)
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
} }
} }
}
func (r Tunnel) Close() { func (r Tunnel) Close() {
r.cancel() r.cancel()
err := r.listener.Close()
if err != nil {
return
}
} }
func copyConn(writer, reader net.Conn) { func copyConn(ctx context.Context, writer, reader net.Conn) {
_, _ = io.Copy(writer, reader) _, _ = io.Copy(writer, reader)
for {
select {
case <-ctx.Done():
_ = writer.Close()
_ = reader.Close()
return
}
}
} }