add keepAlive option for udp Listener

This commit is contained in:
ginuerzh
2022-04-03 22:23:27 +08:00
parent fc1e6e8ff2
commit 6340d5198f
9 changed files with 109 additions and 79 deletions

View File

@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/go-gost/core/common/net/dialer" "github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/common/util/udp" "github.com/go-gost/core/common/net/udp"
"github.com/go-gost/core/connector" "github.com/go-gost/core/connector"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/core/metrics" "github.com/go-gost/core/metrics"
@ -198,9 +198,14 @@ func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...
"network": network, "network": network,
"address": address, "address": address,
}) })
ln := udp.NewListener(conn, addr, ln := udp.NewListener(conn, &udp.ListenConfig{
options.Backlog, options.UDPDataQueueSize, options.UDPDataBufferSize, Backlog: options.Backlog,
options.UDPConnTTL, logger) ReadQueueSize: options.UDPDataQueueSize,
ReadBufferSize: options.UDPDataBufferSize,
TTL: options.UDPConnTTL,
KeepAlive: true,
Logger: logger,
})
return ln, err return ln, err
default: default:
err := fmt.Errorf("network %s unsupported", network) err := fmt.Errorf("network %s unsupported", network)

View File

@ -9,8 +9,8 @@ import (
"github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/common/bufpool"
) )
// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. // conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type Conn struct { type conn struct {
net.PacketConn net.PacketConn
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
@ -18,19 +18,21 @@ type Conn struct {
idle int32 // indicate the connection is idle idle int32 // indicate the connection is idle
closed chan struct{} closed chan struct{}
closeMutex sync.Mutex closeMutex sync.Mutex
keepAlive bool
} }
func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn { func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn {
return &Conn{ return &conn{
PacketConn: c, PacketConn: c,
localAddr: localAddr, localAddr: laddr,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
rc: make(chan []byte, queueSize), rc: make(chan []byte, queueSize),
closed: make(chan struct{}), closed: make(chan struct{}),
keepAlive: keepAlive,
} }
} }
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select { select {
case bb := <-c.rc: case bb := <-c.rc:
n = copy(b, bb) n = copy(b, bb)
@ -47,16 +49,20 @@ func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
return return
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b) n, _, err = c.ReadFrom(b)
return return
} }
func (c *Conn) Write(b []byte) (n int, err error) { func (c *conn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.remoteAddr) n, err = c.WriteTo(b, c.remoteAddr)
if !c.keepAlive {
c.Close()
}
return
} }
func (c *Conn) Close() error { func (c *conn) Close() error {
c.closeMutex.Lock() c.closeMutex.Lock()
defer c.closeMutex.Unlock() defer c.closeMutex.Unlock()
@ -68,19 +74,19 @@ func (c *Conn) Close() error {
return nil return nil
} }
func (c *Conn) LocalAddr() net.Addr { func (c *conn) LocalAddr() net.Addr {
return c.localAddr return c.localAddr
} }
func (c *Conn) RemoteAddr() net.Addr { func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr return c.remoteAddr
} }
func (c *Conn) IsIdle() bool { func (c *conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0 return atomic.LoadInt32(&c.idle) > 0
} }
func (c *Conn) SetIdle(idle bool) { func (c *conn) SetIdle(idle bool) {
v := int32(0) v := int32(0)
if idle { if idle {
v = 1 v = 1
@ -88,7 +94,7 @@ func (c *Conn) SetIdle(idle bool) {
atomic.StoreInt32(&c.idle, v) atomic.StoreInt32(&c.idle, v)
} }
func (c *Conn) WriteQueue(b []byte) error { func (c *conn) WriteQueue(b []byte) error {
select { select {
case c.rc <- b: case c.rc <- b:
return nil return nil

View File

@ -9,30 +9,37 @@ import (
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
) )
type ListenConfig struct {
Addr net.Addr
Backlog int
ReadQueueSize int
ReadBufferSize int
TTL time.Duration
KeepAlive bool
Logger logger.Logger
}
type listener struct { type listener struct {
addr net.Addr
conn net.PacketConn conn net.PacketConn
cqueue chan net.Conn cqueue chan net.Conn
readQueueSize int connPool *connPool
readBufferSize int
connPool *ConnPool
mux sync.Mutex mux sync.Mutex
closed chan struct{} closed chan struct{}
errChan chan error errChan chan error
logger logger.Logger config *ListenConfig
} }
func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener { func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
if cfg == nil {
cfg = &ListenConfig{}
}
ln := &listener{ ln := &listener{
conn: conn, conn: conn,
addr: addr, cqueue: make(chan net.Conn, cfg.Backlog),
cqueue: make(chan net.Conn, backlog), connPool: newConnPool(cfg.TTL).WithLogger(cfg.Logger),
connPool: NewConnPool(ttl).WithLogger(logger),
readQueueSize: dataQueueSize,
readBufferSize: dataBufferSize,
closed: make(chan struct{}), closed: make(chan struct{}),
errChan: make(chan error, 1), errChan: make(chan error, 1),
logger: logger, config: cfg,
} }
go ln.listenLoop() go ln.listenLoop()
@ -61,7 +68,7 @@ func (ln *listener) listenLoop() {
default: default:
} }
b := bufpool.Get(ln.readBufferSize) b := bufpool.Get(ln.config.ReadBufferSize)
n, raddr, err := ln.conn.ReadFrom(*b) n, raddr, err := ln.conn.ReadFrom(*b)
if err != nil { if err != nil {
@ -77,13 +84,16 @@ func (ln *listener) listenLoop() {
} }
if err := c.WriteQueue((*b)[:n]); err != nil { if err := c.WriteQueue((*b)[:n]); err != nil {
ln.logger.Warn("data discarded: ", err) ln.config.Logger.Warn("data discarded: ", err)
} }
} }
} }
func (ln *listener) Addr() net.Addr { func (ln *listener) Addr() net.Addr {
return ln.addr if ln.config.Addr != nil {
return ln.config.Addr
}
return ln.conn.LocalAddr()
} }
func (ln *listener) Close() error { func (ln *listener) Close() error {
@ -98,7 +108,7 @@ func (ln *listener) Close() error {
return nil return nil
} }
func (ln *listener) getConn(raddr net.Addr) *Conn { func (ln *listener) getConn(raddr net.Addr) *conn {
ln.mux.Lock() ln.mux.Lock()
defer ln.mux.Unlock() defer ln.mux.Unlock()
@ -107,14 +117,14 @@ func (ln *listener) getConn(raddr net.Addr) *Conn {
return c return c
} }
c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize) c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive)
select { select {
case ln.cqueue <- c: case ln.cqueue <- c:
ln.connPool.Set(raddr.String(), c) ln.connPool.Set(raddr.String(), c)
return c return c
default: default:
c.Close() c.Close()
ln.logger.Warnf("connection queue is full, client %s discarded", raddr) ln.config.Logger.Warnf("connection queue is full, client %s discarded", raddr)
return nil return nil
} }
} }

View File

@ -7,15 +7,15 @@ import (
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
) )
type ConnPool struct { type connPool struct {
m sync.Map m sync.Map
ttl time.Duration ttl time.Duration
closed chan struct{} closed chan struct{}
logger logger.Logger logger logger.Logger
} }
func NewConnPool(ttl time.Duration) *ConnPool { func newConnPool(ttl time.Duration) *connPool {
p := &ConnPool{ p := &connPool{
ttl: ttl, ttl: ttl,
closed: make(chan struct{}), closed: make(chan struct{}),
} }
@ -23,28 +23,28 @@ func NewConnPool(ttl time.Duration) *ConnPool {
return p return p
} }
func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool { func (p *connPool) WithLogger(logger logger.Logger) *connPool {
p.logger = logger p.logger = logger
return p return p
} }
func (p *ConnPool) Get(key any) (c *Conn, ok bool) { func (p *connPool) Get(key any) (c *conn, ok bool) {
v, ok := p.m.Load(key) v, ok := p.m.Load(key)
if ok { if ok {
c, ok = v.(*Conn) c, ok = v.(*conn)
} }
return return
} }
func (p *ConnPool) Set(key any, c *Conn) { func (p *connPool) Set(key any, c *conn) {
p.m.Store(key, c) p.m.Store(key, c)
} }
func (p *ConnPool) Delete(key any) { func (p *connPool) Delete(key any) {
p.m.Delete(key) p.m.Delete(key)
} }
func (p *ConnPool) Close() { func (p *connPool) Close() {
select { select {
case <-p.closed: case <-p.closed:
return return
@ -54,14 +54,14 @@ func (p *ConnPool) Close() {
close(p.closed) close(p.closed)
p.m.Range(func(k, v any) bool { p.m.Range(func(k, v any) bool {
if c, ok := v.(*Conn); ok && c != nil { if c, ok := v.(*conn); ok && c != nil {
c.Close() c.Close()
} }
return true return true
}) })
} }
func (p *ConnPool) idleCheck() { func (p *connPool) idleCheck() {
ticker := time.NewTicker(p.ttl) ticker := time.NewTicker(p.ttl)
defer ticker.Stop() defer ticker.Stop()
@ -71,7 +71,7 @@ func (p *ConnPool) idleCheck() {
size := 0 size := 0
idles := 0 idles := 0
p.m.Range(func(key, value any) bool { p.m.Range(func(key, value any) bool {
c, ok := value.(*Conn) c, ok := value.(*conn)
if !ok || c == nil { if !ok || c == nil {
p.Delete(key) p.Delete(key)
return true return true

View File

@ -1,4 +1,4 @@
package net package udp
import ( import (
"io" "io"
@ -6,7 +6,7 @@ import (
"syscall" "syscall"
) )
type UDPConn interface { type Conn interface {
net.PacketConn net.PacketConn
io.Reader io.Reader
io.Writer io.Writer

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/go-gost/core/common/util/udp" "github.com/go-gost/core/common/net/udp"
"github.com/go-gost/core/connector" "github.com/go-gost/core/connector"
"github.com/go-gost/core/internal/util/mux" "github.com/go-gost/core/internal/util/mux"
"github.com/go-gost/core/internal/util/socks" "github.com/go-gost/core/internal/util/socks"
@ -80,13 +80,16 @@ func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, a
return nil, err return nil, err
} }
ln := udp.NewListener( ln := udp.NewListener(socks.UDPTunClientPacketConn(conn),
socks.UDPTunClientPacketConn(conn), &udp.ListenConfig{
laddr, Addr: laddr,
opts.Backlog, Backlog: opts.Backlog,
opts.UDPDataQueueSize, opts.UDPDataBufferSize, ReadQueueSize: opts.UDPDataQueueSize,
opts.UDPConnTTL, ReadBufferSize: opts.UDPDataBufferSize,
log) TTL: opts.UDPConnTTL,
KeepAlive: true,
Logger: log,
})
return ln, nil return ln, nil
} }

View File

@ -3,7 +3,7 @@ package udp
import ( import (
"net" "net"
"github.com/go-gost/core/common/util/udp" "github.com/go-gost/core/common/net/udp"
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata" md "github.com/go-gost/core/metadata"
@ -50,13 +50,14 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
} }
conn = metrics.WrapPacketConn(l.options.Service, conn) conn = metrics.WrapPacketConn(l.options.Service, conn)
l.ln = udp.NewListener( l.ln = udp.NewListener(conn, &udp.ListenConfig{
conn, Backlog: l.md.backlog,
laddr, ReadQueueSize: l.md.readQueueSize,
l.md.backlog, ReadBufferSize: l.md.readBufferSize,
l.md.readQueueSize, l.md.readBufferSize, KeepAlive: l.md.keepalive,
l.md.ttl, TTL: l.md.ttl,
l.logger) Logger: l.logger,
})
return return
} }

View File

@ -14,19 +14,20 @@ const (
) )
type metadata struct { type metadata struct {
ttl time.Duration
readBufferSize int readBufferSize int
readQueueSize int readQueueSize int
backlog int backlog int
keepalive bool
ttl time.Duration
} }
func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) { func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) {
const ( const (
ttl = "ttl"
readBufferSize = "readBufferSize" readBufferSize = "readBufferSize"
readQueueSize = "readQueueSize" readQueueSize = "readQueueSize"
backlog = "backlog" backlog = "backlog"
keepAlive = "keepAlive"
ttl = "ttl"
) )
l.md.ttl = mdata.GetDuration(md, ttl) l.md.ttl = mdata.GetDuration(md, ttl)
@ -47,6 +48,7 @@ func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) {
if l.md.backlog <= 0 { if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog l.md.backlog = defaultBacklog
} }
l.md.keepalive = mdata.GetBool(md, keepAlive)
return return
} }

View File

@ -197,7 +197,10 @@ func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) {
c = tls.Client(c, ex.options.tlsConfig) c = tls.Client(c, ex.options.tlsConfig)
} }
conn := &dns.Conn{Conn: c} conn := &dns.Conn{
UDPSize: 1024,
Conn: c,
}
if _, err = conn.Write(msg); err != nil { if _, err = conn.Write(msg); err != nil {
return nil, err return nil, err