From 6340d5198f836acdb5a6e36dae4b363ea5062544 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 3 Apr 2022 22:23:27 +0800 Subject: [PATCH] add keepAlive option for udp Listener --- chain/route.go | 13 ++++-- common/{util => net}/udp/conn.go | 36 +++++++++------- common/{util => net}/udp/listener.go | 62 ++++++++++++++++------------ common/{util => net}/udp/pool.go | 24 +++++------ common/net/{ => udp}/udp.go | 4 +- connector/socks/v5/bind.go | 19 +++++---- listener/udp/listener.go | 17 ++++---- listener/udp/metadata.go | 8 ++-- resolver/exchanger/exchanger.go | 5 ++- 9 files changed, 109 insertions(+), 79 deletions(-) rename common/{util => net}/udp/conn.go (60%) rename common/{util => net}/udp/listener.go (53%) rename common/{util => net}/udp/pool.go (69%) rename common/net/{ => udp}/udp.go (94%) diff --git a/chain/route.go b/chain/route.go index 6c4c309..1d16b70 100644 --- a/chain/route.go +++ b/chain/route.go @@ -8,7 +8,7 @@ import ( "time" "github.com/go-gost/core/common/net/dialer" - "github.com/go-gost/core/common/util/udp" + "github.com/go-gost/core/common/net/udp" "github.com/go-gost/core/connector" "github.com/go-gost/core/logger" "github.com/go-gost/core/metrics" @@ -198,9 +198,14 @@ func (r *Route) bindLocal(ctx context.Context, network, address string, opts ... "network": network, "address": address, }) - ln := udp.NewListener(conn, addr, - options.Backlog, options.UDPDataQueueSize, options.UDPDataBufferSize, - options.UDPConnTTL, logger) + ln := udp.NewListener(conn, &udp.ListenConfig{ + Backlog: options.Backlog, + ReadQueueSize: options.UDPDataQueueSize, + ReadBufferSize: options.UDPDataBufferSize, + TTL: options.UDPConnTTL, + KeepAlive: true, + Logger: logger, + }) return ln, err default: err := fmt.Errorf("network %s unsupported", network) diff --git a/common/util/udp/conn.go b/common/net/udp/conn.go similarity index 60% rename from common/util/udp/conn.go rename to common/net/udp/conn.go index 5884da0..f3f4c9e 100644 --- a/common/util/udp/conn.go +++ b/common/net/udp/conn.go @@ -9,8 +9,8 @@ import ( "github.com/go-gost/core/common/bufpool" ) -// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. -type Conn struct { +// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. +type conn struct { net.PacketConn localAddr net.Addr remoteAddr net.Addr @@ -18,19 +18,21 @@ type Conn struct { idle int32 // indicate the connection is idle closed chan struct{} closeMutex sync.Mutex + keepAlive bool } -func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn { - return &Conn{ +func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn { + return &conn{ PacketConn: c, - localAddr: localAddr, + localAddr: laddr, remoteAddr: remoteAddr, rc: make(chan []byte, queueSize), closed: make(chan struct{}), + keepAlive: keepAlive, } } -func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { +func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { select { case bb := <-c.rc: n = copy(b, bb) @@ -47,16 +49,20 @@ func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { return } -func (c *Conn) Read(b []byte) (n int, err error) { +func (c *conn) Read(b []byte) (n int, err error) { n, _, err = c.ReadFrom(b) return } -func (c *Conn) Write(b []byte) (n int, err error) { - return c.WriteTo(b, c.remoteAddr) +func (c *conn) Write(b []byte) (n int, err error) { + n, err = c.WriteTo(b, c.remoteAddr) + if !c.keepAlive { + c.Close() + } + return } -func (c *Conn) Close() error { +func (c *conn) Close() error { c.closeMutex.Lock() defer c.closeMutex.Unlock() @@ -68,19 +74,19 @@ func (c *Conn) Close() error { return nil } -func (c *Conn) LocalAddr() net.Addr { +func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *Conn) RemoteAddr() net.Addr { +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *Conn) IsIdle() bool { +func (c *conn) IsIdle() bool { return atomic.LoadInt32(&c.idle) > 0 } -func (c *Conn) SetIdle(idle bool) { +func (c *conn) SetIdle(idle bool) { v := int32(0) if idle { v = 1 @@ -88,7 +94,7 @@ func (c *Conn) SetIdle(idle bool) { atomic.StoreInt32(&c.idle, v) } -func (c *Conn) WriteQueue(b []byte) error { +func (c *conn) WriteQueue(b []byte) error { select { case c.rc <- b: return nil diff --git a/common/util/udp/listener.go b/common/net/udp/listener.go similarity index 53% rename from common/util/udp/listener.go rename to common/net/udp/listener.go index 0aee50b..146ef09 100644 --- a/common/util/udp/listener.go +++ b/common/net/udp/listener.go @@ -9,30 +9,37 @@ import ( "github.com/go-gost/core/logger" ) +type ListenConfig struct { + Addr net.Addr + Backlog int + ReadQueueSize int + ReadBufferSize int + TTL time.Duration + KeepAlive bool + Logger logger.Logger +} type listener struct { - addr net.Addr - conn net.PacketConn - cqueue chan net.Conn - readQueueSize int - readBufferSize int - connPool *ConnPool - mux sync.Mutex - closed chan struct{} - errChan chan error - logger logger.Logger + conn net.PacketConn + cqueue chan net.Conn + connPool *connPool + mux sync.Mutex + closed chan struct{} + errChan chan error + config *ListenConfig } -func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener { +func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener { + if cfg == nil { + cfg = &ListenConfig{} + } + ln := &listener{ - conn: conn, - addr: addr, - cqueue: make(chan net.Conn, backlog), - connPool: NewConnPool(ttl).WithLogger(logger), - readQueueSize: dataQueueSize, - readBufferSize: dataBufferSize, - closed: make(chan struct{}), - errChan: make(chan error, 1), - logger: logger, + conn: conn, + cqueue: make(chan net.Conn, cfg.Backlog), + connPool: newConnPool(cfg.TTL).WithLogger(cfg.Logger), + closed: make(chan struct{}), + errChan: make(chan error, 1), + config: cfg, } go ln.listenLoop() @@ -61,7 +68,7 @@ func (ln *listener) listenLoop() { default: } - b := bufpool.Get(ln.readBufferSize) + b := bufpool.Get(ln.config.ReadBufferSize) n, raddr, err := ln.conn.ReadFrom(*b) if err != nil { @@ -77,13 +84,16 @@ func (ln *listener) listenLoop() { } if err := c.WriteQueue((*b)[:n]); err != nil { - ln.logger.Warn("data discarded: ", err) + ln.config.Logger.Warn("data discarded: ", err) } } } func (ln *listener) Addr() net.Addr { - return ln.addr + if ln.config.Addr != nil { + return ln.config.Addr + } + return ln.conn.LocalAddr() } func (ln *listener) Close() error { @@ -98,7 +108,7 @@ func (ln *listener) Close() error { return nil } -func (ln *listener) getConn(raddr net.Addr) *Conn { +func (ln *listener) getConn(raddr net.Addr) *conn { ln.mux.Lock() defer ln.mux.Unlock() @@ -107,14 +117,14 @@ func (ln *listener) getConn(raddr net.Addr) *Conn { return c } - c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize) + c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive) select { case ln.cqueue <- c: ln.connPool.Set(raddr.String(), c) return c default: c.Close() - ln.logger.Warnf("connection queue is full, client %s discarded", raddr) + ln.config.Logger.Warnf("connection queue is full, client %s discarded", raddr) return nil } } diff --git a/common/util/udp/pool.go b/common/net/udp/pool.go similarity index 69% rename from common/util/udp/pool.go rename to common/net/udp/pool.go index 7048227..1d69578 100644 --- a/common/util/udp/pool.go +++ b/common/net/udp/pool.go @@ -7,15 +7,15 @@ import ( "github.com/go-gost/core/logger" ) -type ConnPool struct { +type connPool struct { m sync.Map ttl time.Duration closed chan struct{} logger logger.Logger } -func NewConnPool(ttl time.Duration) *ConnPool { - p := &ConnPool{ +func newConnPool(ttl time.Duration) *connPool { + p := &connPool{ ttl: ttl, closed: make(chan struct{}), } @@ -23,28 +23,28 @@ func NewConnPool(ttl time.Duration) *ConnPool { return p } -func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool { +func (p *connPool) WithLogger(logger logger.Logger) *connPool { p.logger = logger return p } -func (p *ConnPool) Get(key any) (c *Conn, ok bool) { +func (p *connPool) Get(key any) (c *conn, ok bool) { v, ok := p.m.Load(key) if ok { - c, ok = v.(*Conn) + c, ok = v.(*conn) } return } -func (p *ConnPool) Set(key any, c *Conn) { +func (p *connPool) Set(key any, c *conn) { p.m.Store(key, c) } -func (p *ConnPool) Delete(key any) { +func (p *connPool) Delete(key any) { p.m.Delete(key) } -func (p *ConnPool) Close() { +func (p *connPool) Close() { select { case <-p.closed: return @@ -54,14 +54,14 @@ func (p *ConnPool) Close() { close(p.closed) p.m.Range(func(k, v any) bool { - if c, ok := v.(*Conn); ok && c != nil { + if c, ok := v.(*conn); ok && c != nil { c.Close() } return true }) } -func (p *ConnPool) idleCheck() { +func (p *connPool) idleCheck() { ticker := time.NewTicker(p.ttl) defer ticker.Stop() @@ -71,7 +71,7 @@ func (p *ConnPool) idleCheck() { size := 0 idles := 0 p.m.Range(func(key, value any) bool { - c, ok := value.(*Conn) + c, ok := value.(*conn) if !ok || c == nil { p.Delete(key) return true diff --git a/common/net/udp.go b/common/net/udp/udp.go similarity index 94% rename from common/net/udp.go rename to common/net/udp/udp.go index 515060f..6428dee 100644 --- a/common/net/udp.go +++ b/common/net/udp/udp.go @@ -1,4 +1,4 @@ -package net +package udp import ( "io" @@ -6,7 +6,7 @@ import ( "syscall" ) -type UDPConn interface { +type Conn interface { net.PacketConn io.Reader io.Writer diff --git a/connector/socks/v5/bind.go b/connector/socks/v5/bind.go index 7d28764..ba419c7 100644 --- a/connector/socks/v5/bind.go +++ b/connector/socks/v5/bind.go @@ -5,7 +5,7 @@ import ( "fmt" "net" - "github.com/go-gost/core/common/util/udp" + "github.com/go-gost/core/common/net/udp" "github.com/go-gost/core/connector" "github.com/go-gost/core/internal/util/mux" "github.com/go-gost/core/internal/util/socks" @@ -80,13 +80,16 @@ func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, a return nil, err } - ln := udp.NewListener( - socks.UDPTunClientPacketConn(conn), - laddr, - opts.Backlog, - opts.UDPDataQueueSize, opts.UDPDataBufferSize, - opts.UDPConnTTL, - log) + ln := udp.NewListener(socks.UDPTunClientPacketConn(conn), + &udp.ListenConfig{ + Addr: laddr, + Backlog: opts.Backlog, + ReadQueueSize: opts.UDPDataQueueSize, + ReadBufferSize: opts.UDPDataBufferSize, + TTL: opts.UDPConnTTL, + KeepAlive: true, + Logger: log, + }) return ln, nil } diff --git a/listener/udp/listener.go b/listener/udp/listener.go index c516b2d..2858107 100644 --- a/listener/udp/listener.go +++ b/listener/udp/listener.go @@ -3,7 +3,7 @@ package udp import ( "net" - "github.com/go-gost/core/common/util/udp" + "github.com/go-gost/core/common/net/udp" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" @@ -50,13 +50,14 @@ func (l *udpListener) Init(md md.Metadata) (err error) { } conn = metrics.WrapPacketConn(l.options.Service, conn) - l.ln = udp.NewListener( - conn, - laddr, - l.md.backlog, - l.md.readQueueSize, l.md.readBufferSize, - l.md.ttl, - l.logger) + l.ln = udp.NewListener(conn, &udp.ListenConfig{ + Backlog: l.md.backlog, + ReadQueueSize: l.md.readQueueSize, + ReadBufferSize: l.md.readBufferSize, + KeepAlive: l.md.keepalive, + TTL: l.md.ttl, + Logger: l.logger, + }) return } diff --git a/listener/udp/metadata.go b/listener/udp/metadata.go index a9fb453..a82add3 100644 --- a/listener/udp/metadata.go +++ b/listener/udp/metadata.go @@ -14,19 +14,20 @@ const ( ) type metadata struct { - ttl time.Duration - readBufferSize int readQueueSize int backlog int + keepalive bool + ttl time.Duration } func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) { const ( - ttl = "ttl" readBufferSize = "readBufferSize" readQueueSize = "readQueueSize" backlog = "backlog" + keepAlive = "keepAlive" + ttl = "ttl" ) l.md.ttl = mdata.GetDuration(md, ttl) @@ -47,6 +48,7 @@ func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) { if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } + l.md.keepalive = mdata.GetBool(md, keepAlive) return } diff --git a/resolver/exchanger/exchanger.go b/resolver/exchanger/exchanger.go index 509ab16..2d66d28 100644 --- a/resolver/exchanger/exchanger.go +++ b/resolver/exchanger/exchanger.go @@ -197,7 +197,10 @@ func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) { c = tls.Client(c, ex.options.tlsConfig) } - conn := &dns.Conn{Conn: c} + conn := &dns.Conn{ + UDPSize: 1024, + Conn: c, + } if _, err = conn.Write(msg); err != nil { return nil, err