优化代码
This commit is contained in:
@ -8,8 +8,6 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"next-terminal/server/log"
|
||||
|
||||
"next-terminal/server/utils"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -22,11 +20,7 @@ type Gateway struct {
|
||||
SshClient *ssh.Client
|
||||
Message string // 失败原因
|
||||
|
||||
tunnels *sync.Map
|
||||
|
||||
Add chan *Tunnel
|
||||
Del chan string
|
||||
exit chan bool
|
||||
tunnels sync.Map
|
||||
}
|
||||
|
||||
func NewGateway(id string, connected bool, message string, client *ssh.Client) *Gateway {
|
||||
@ -35,42 +29,14 @@ func NewGateway(id string, connected bool, message string, client *ssh.Client) *
|
||||
Connected: connected,
|
||||
Message: message,
|
||||
SshClient: client,
|
||||
Add: make(chan *Tunnel),
|
||||
Del: make(chan string),
|
||||
tunnels: new(sync.Map),
|
||||
exit: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) Run() {
|
||||
for {
|
||||
select {
|
||||
case t := <-g.Add:
|
||||
g.tunnels.Store(t.ID, t)
|
||||
log.Info("add tunnel: %s", t.ID)
|
||||
go t.Open()
|
||||
case k := <-g.Del:
|
||||
if val, ok := g.tunnels.Load(k); ok {
|
||||
if vval, vok := val.(*Tunnel); vok {
|
||||
vval.Close()
|
||||
g.tunnels.Delete(k)
|
||||
}
|
||||
}
|
||||
case <-g.exit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) Close() {
|
||||
g.tunnels.Range(func(key, value interface{}) bool {
|
||||
if val, ok := value.(*Tunnel); ok {
|
||||
val.Close()
|
||||
}
|
||||
g.CloseSshTunnel(key.(string))
|
||||
return true
|
||||
})
|
||||
g.exit <- true
|
||||
|
||||
}
|
||||
|
||||
func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, exposedPort int, err error) {
|
||||
@ -110,11 +76,17 @@ func (g *Gateway) OpenSshTunnel(id, ip string, port int) (exposedIP string, expo
|
||||
cancel: cancel,
|
||||
listener: listener,
|
||||
}
|
||||
g.Add <- tunnel
|
||||
go tunnel.Open()
|
||||
g.tunnels.Store(tunnel.ID, tunnel)
|
||||
|
||||
return tunnel.LocalHost, tunnel.LocalPort, nil
|
||||
}
|
||||
|
||||
func (g Gateway) CloseSshTunnel(id string) {
|
||||
g.Del <- id
|
||||
func (g *Gateway) CloseSshTunnel(id string) {
|
||||
if value, ok := g.tunnels.Load(id); ok {
|
||||
if tunnel, vok := value.(*Tunnel); vok {
|
||||
tunnel.Close()
|
||||
g.tunnels.Delete(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,49 +7,32 @@ import (
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
gateways *sync.Map
|
||||
|
||||
Add chan *Gateway
|
||||
Del chan string
|
||||
gateways sync.Map
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
Add: make(chan *Gateway),
|
||||
Del: make(chan string),
|
||||
gateways: new(sync.Map),
|
||||
}
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
func (m *Manager) Start() {
|
||||
for {
|
||||
select {
|
||||
case g := <-m.Add:
|
||||
m.gateways.Store(g.ID, g)
|
||||
log.Info("add gateway: %s", g.ID)
|
||||
go g.Run()
|
||||
case k := <-m.Del:
|
||||
if val, ok := m.gateways.Load(k); ok {
|
||||
if vv, vok := val.(*Gateway); vok {
|
||||
vv.Close()
|
||||
m.gateways.Delete(k)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m Manager) GetById(id string) *Gateway {
|
||||
func (m *Manager) GetById(id string) *Gateway {
|
||||
if val, ok := m.gateways.Load(id); ok {
|
||||
return val.(*Gateway)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Add(g *Gateway) {
|
||||
m.gateways.Store(g.ID, g)
|
||||
log.Infof("add gateway: %s", g.ID)
|
||||
}
|
||||
|
||||
func (m *Manager) Del(id string) {
|
||||
m.gateways.Delete(id)
|
||||
log.Infof("del gateway: %s", id)
|
||||
}
|
||||
|
||||
var GlobalGatewayManager *Manager
|
||||
|
||||
func init() {
|
||||
GlobalGatewayManager = NewManager()
|
||||
go GlobalGatewayManager.Start()
|
||||
}
|
||||
|
@ -10,23 +10,26 @@ import (
|
||||
)
|
||||
|
||||
type Tunnel struct {
|
||||
ID string // 唯一标识
|
||||
LocalHost string // 本地监听地址
|
||||
LocalPort int // 本地端口
|
||||
RemoteHost string // 远程连接地址
|
||||
RemotePort int // 远程端口
|
||||
Gateway *Gateway
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
listener net.Listener
|
||||
err error
|
||||
ID string // 唯一标识
|
||||
LocalHost string // 本地监听地址
|
||||
LocalPort int // 本地端口
|
||||
RemoteHost string // 远程连接地址
|
||||
RemotePort int // 远程端口
|
||||
Gateway *Gateway
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
listener net.Listener
|
||||
localConnections []net.Conn
|
||||
remoteConnections []net.Conn
|
||||
}
|
||||
|
||||
func (r *Tunnel) Open() {
|
||||
localAddr := fmt.Sprintf("%s:%d", r.LocalHost, r.LocalPort)
|
||||
|
||||
go func() {
|
||||
<-r.ctx.Done()
|
||||
_ = r.listener.Close()
|
||||
log.Debugf("SSH 隧道 %v 关闭", localAddr)
|
||||
}()
|
||||
for {
|
||||
log.Debugf("等待客户端访问 %v", localAddr)
|
||||
@ -35,6 +38,7 @@ func (r *Tunnel) Open() {
|
||||
log.Debugf("接受连接失败 %v, 退出循环", err.Error())
|
||||
return
|
||||
}
|
||||
r.localConnections = append(r.localConnections, localConn)
|
||||
|
||||
log.Debugf("客户端 %v 连接至 %v", localConn.RemoteAddr().String(), localAddr)
|
||||
remoteAddr := fmt.Sprintf("%s:%d", r.RemoteHost, r.RemotePort)
|
||||
@ -44,27 +48,27 @@ func (r *Tunnel) Open() {
|
||||
log.Debugf("连接远程主机 %v 失败", remoteAddr)
|
||||
return
|
||||
}
|
||||
r.remoteConnections = append(r.remoteConnections, remoteConn)
|
||||
|
||||
log.Debugf("连接远程主机 %v 成功", remoteAddr)
|
||||
go copyConn(r.ctx, localConn, remoteConn)
|
||||
go copyConn(r.ctx, remoteConn, localConn)
|
||||
go copyConn(localConn, remoteConn)
|
||||
go copyConn(remoteConn, localConn)
|
||||
log.Debugf("转发数据 [%v]->[%v]", localAddr, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (r Tunnel) Close() {
|
||||
func (r *Tunnel) Close() {
|
||||
for i := range r.localConnections {
|
||||
_ = r.localConnections[i].Close()
|
||||
}
|
||||
r.localConnections = nil
|
||||
for i := range r.remoteConnections {
|
||||
_ = r.remoteConnections[i].Close()
|
||||
}
|
||||
r.remoteConnections = nil
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
func copyConn(ctx context.Context, writer, reader net.Conn) {
|
||||
func copyConn(writer, reader net.Conn) {
|
||||
_, _ = io.Copy(writer, reader)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = writer.Close()
|
||||
_ = reader.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user