diff --git a/server/api/term.go b/server/api/term.go index 8101eba..a09f5ec 100644 --- a/server/api/term.go +++ b/server/api/term.go @@ -152,7 +152,6 @@ func (api WebTerminalApi) SshEndpoint(c echo.Context) error { NextTerminal: nextTerminal, Observer: session.NewObserver(s.ID), } - go nextSession.Observer.Start() session.GlobalSessionManager.Add <- nextSession termHandler := NewTermHandler(sessionId, isRecording, ws, nextTerminal) diff --git a/server/app/app.go b/server/app/app.go index 9466236..98a8e53 100644 --- a/server/app/app.go +++ b/server/app/app.go @@ -3,6 +3,10 @@ package app import ( "encoding/json" "fmt" + "net/http" + _ "net/http/pprof" + + "next-terminal/server/log" "next-terminal/server/cli" "next-terminal/server/config" @@ -104,6 +108,9 @@ func Run() error { if err != nil { return err } + go func() { + log.Fatal(http.ListenAndServe("localhost:8099", nil)) + }() fmt.Printf("当前配置为: %v\n", string(jsonBytes)) } diff --git a/server/global/gateway/gateway.go b/server/global/gateway/gateway.go index 142c847..4d38550 100644 --- a/server/global/gateway/gateway.go +++ b/server/global/gateway/gateway.go @@ -6,6 +6,9 @@ import ( "fmt" "net" "os" + "sync" + + "next-terminal/server/log" "next-terminal/server/utils" @@ -19,7 +22,7 @@ type Gateway struct { SshClient *ssh.Client Message string // 失败原因 - tunnels map[string]*Tunnel + tunnels *sync.Map Add chan *Tunnel Del chan string @@ -34,7 +37,7 @@ func NewGateway(id string, connected bool, message string, client *ssh.Client) * SshClient: client, Add: make(chan *Tunnel), Del: make(chan string), - tunnels: map[string]*Tunnel{}, + tunnels: new(sync.Map), exit: make(chan bool, 1), } } @@ -43,12 +46,15 @@ func (g *Gateway) Run() { for { select { case t := <-g.Add: - g.tunnels[t.ID] = t + g.tunnels.Store(t.ID, t) + log.Info("add tunnel: %s", t.ID) go t.Open() case k := <-g.Del: - if _, ok := g.tunnels[k]; ok { - g.tunnels[k].Close() - delete(g.tunnels, k) + if val, ok := g.tunnels.Load(k); ok { + if vval, vok := val.(*Tunnel); vok { + vval.Close() + g.tunnels.Delete(k) + } } case <-g.exit: return @@ -57,14 +63,14 @@ func (g *Gateway) Run() { } func (g *Gateway) Close() { - if g.SshClient != nil { - _ = g.SshClient.Close() - } - for id := range g.tunnels { - g.CloseSshTunnel(id) - } - + g.tunnels.Range(func(key, value interface{}) bool { + if val, ok := value.(*Tunnel); ok { + val.Close() + } + return true + }) g.exit <- true + } func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) { diff --git a/server/global/gateway/manager.go b/server/global/gateway/manager.go index d2e45c1..e2b0245 100644 --- a/server/global/gateway/manager.go +++ b/server/global/gateway/manager.go @@ -1,7 +1,13 @@ package gateway +import ( + "sync" + + "next-terminal/server/log" +) + type Manager struct { - gateways map[string]*Gateway + gateways *sync.Map Add chan *Gateway Del chan string @@ -11,7 +17,7 @@ func NewManager() *Manager { return &Manager{ Add: make(chan *Gateway), Del: make(chan string), - gateways: map[string]*Gateway{}, + gateways: new(sync.Map), } } @@ -19,19 +25,26 @@ func (m *Manager) Start() { for { select { case g := <-m.Add: - m.gateways[g.ID] = g + m.gateways.Store(g.ID, g) + log.Info("add gateway: %s", g.ID) go g.Run() case k := <-m.Del: - if _, ok := m.gateways[k]; ok { - m.gateways[k].Close() - delete(m.gateways, k) + if val, ok := m.gateways.Load(k); ok { + if vv, vok := val.(*Gateway); vok { + vv.Close() + m.gateways.Delete(k) + } + } } } } func (m Manager) GetById(id string) *Gateway { - return m.gateways[id] + if val, ok := m.gateways.Load(id); ok { + return val.(*Gateway) + } + return nil } var GlobalGatewayManager *Manager diff --git a/server/global/gateway/tunnel.go b/server/global/gateway/tunnel.go index 58c88f4..088bbc5 100644 --- a/server/global/gateway/tunnel.go +++ b/server/global/gateway/tunnel.go @@ -10,65 +10,69 @@ import ( ) type Tunnel struct { - ID string // 唯一标识 - LocalHost string // 本地监听地址 - LocalPort int // 本地端口 - RemoteHost string // 远程连接地址 - RemotePort int // 远程端口 - Gateway *Gateway - ctx context.Context - cancel context.CancelFunc - listener net.Listener - localConnections []net.Conn - remoteConnections []net.Conn + ID string // 唯一标识 + LocalHost string // 本地监听地址 + LocalPort int // 本地端口 + RemoteHost string // 远程连接地址 + RemotePort int // 远程端口 + Gateway *Gateway + ctx context.Context + cancel context.CancelFunc + listener net.Listener + err error } func (r *Tunnel) Open() { localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort) - - go func() { - <-r.ctx.Done() - _ = r.listener.Close() - for i := range r.localConnections { - _ = r.localConnections[i].Close() - } - r.localConnections = nil - for i := range r.remoteConnections { - _ = r.remoteConnections[i].Close() - } - r.remoteConnections = nil - log.Debugf("SSH 隧道 %v 关闭", localAddr) - }() for { - log.Debugf("等待客户端访问 %v", localAddr) - localConn, err := r.listener.Accept() - if err != nil { - log.Debugf("接受连接失败 %v, 退出循环", err.Error()) + select { + case <-r.ctx.Done(): + _ = r.listener.Close() + log.Debugf("SSH 隧道 %v 关闭", localAddr) return - } - r.localConnections = append(r.localConnections, localConn) + default: + log.Debugf("等待客户端访问 %v", localAddr) + localConn, err := r.listener.Accept() + if err != nil { + log.Debugf("接受连接失败 %v", err.Error()) + continue + } - log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) - remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) - log.Debugf("连接远程主机 %v ...", remoteAddr) - remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) - if err != nil { - log.Debugf("连接远程主机 %v 失败", remoteAddr) - return - } - r.remoteConnections = append(r.remoteConnections, remoteConn) + log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr) + remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort) + log.Debugf("连接远程主机 %v ...", remoteAddr) + remoteConn, err := r.Gateway.SshClient.Dial("tcp", remoteAddr) + if err != nil { + log.Debugf("连接远程主机 %v 失败", remoteAddr) + r.err = err + return + } - log.Debugf("连接远程主机 %v 成功", remoteAddr) - go copyConn(localConn, remoteConn) - go copyConn(remoteConn, localConn) - log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) + log.Debugf("连接远程主机 %v 成功", remoteAddr) + go copyConn(r.ctx, localConn, remoteConn) + go copyConn(r.ctx, remoteConn, localConn) + log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr) + } } } func (r Tunnel) Close() { r.cancel() + err := r.listener.Close() + if err != nil { + return + } } -func copyConn(writer, reader net.Conn) { +func copyConn(ctx context.Context, writer, reader net.Conn) { _, _ = io.Copy(writer, reader) + for { + select { + case <-ctx.Done(): + _ = writer.Close() + _ = reader.Close() + return + } + } + }