add keepAlive option for udp Listener

This commit is contained in:
ginuerzh
2022-04-03 22:23:27 +08:00
parent fc1e6e8ff2
commit 6340d5198f
9 changed files with 109 additions and 79 deletions

108
common/net/udp/conn.go Normal file
View File

@ -0,0 +1,108 @@
package udp
import (
"errors"
"net"
"sync"
"sync/atomic"
"github.com/go-gost/core/common/bufpool"
)
// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type conn struct {
net.PacketConn
localAddr net.Addr
remoteAddr net.Addr
rc chan []byte // data receive queue
idle int32 // indicate the connection is idle
closed chan struct{}
closeMutex sync.Mutex
keepAlive bool
}
func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn {
return &conn{
PacketConn: c,
localAddr: laddr,
remoteAddr: remoteAddr,
rc: make(chan []byte, queueSize),
closed: make(chan struct{}),
keepAlive: keepAlive,
}
}
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
c.SetIdle(false)
bufpool.Put(&bb)
case <-c.closed:
err = net.ErrClosed
return
}
addr = c.remoteAddr
return
}
func (c *conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *conn) Write(b []byte) (n int, err error) {
n, err = c.WriteTo(b, c.remoteAddr)
if !c.keepAlive {
c.Close()
}
return
}
func (c *conn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0
}
func (c *conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *conn) WriteQueue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}

130
common/net/udp/listener.go Normal file
View File

@ -0,0 +1,130 @@
package udp
import (
"net"
"sync"
"time"
"github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger"
)
type ListenConfig struct {
Addr net.Addr
Backlog int
ReadQueueSize int
ReadBufferSize int
TTL time.Duration
KeepAlive bool
Logger logger.Logger
}
type listener struct {
conn net.PacketConn
cqueue chan net.Conn
connPool *connPool
mux sync.Mutex
closed chan struct{}
errChan chan error
config *ListenConfig
}
func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
if cfg == nil {
cfg = &ListenConfig{}
}
ln := &listener{
conn: conn,
cqueue: make(chan net.Conn, cfg.Backlog),
connPool: newConnPool(cfg.TTL).WithLogger(cfg.Logger),
closed: make(chan struct{}),
errChan: make(chan error, 1),
config: cfg,
}
go ln.listenLoop()
return ln
}
func (ln *listener) Accept() (conn net.Conn, err error) {
select {
case conn = <-ln.cqueue:
return
case <-ln.closed:
return nil, net.ErrClosed
case err = <-ln.errChan:
if err == nil {
err = net.ErrClosed
}
return
}
}
func (ln *listener) listenLoop() {
for {
select {
case <-ln.closed:
return
default:
}
b := bufpool.Get(ln.config.ReadBufferSize)
n, raddr, err := ln.conn.ReadFrom(*b)
if err != nil {
ln.errChan <- err
close(ln.errChan)
return
}
c := ln.getConn(raddr)
if c == nil {
bufpool.Put(b)
continue
}
if err := c.WriteQueue((*b)[:n]); err != nil {
ln.config.Logger.Warn("data discarded: ", err)
}
}
}
func (ln *listener) Addr() net.Addr {
if ln.config.Addr != nil {
return ln.config.Addr
}
return ln.conn.LocalAddr()
}
func (ln *listener) Close() error {
select {
case <-ln.closed:
default:
close(ln.closed)
ln.conn.Close()
ln.connPool.Close()
}
return nil
}
func (ln *listener) getConn(raddr net.Addr) *conn {
ln.mux.Lock()
defer ln.mux.Unlock()
c, ok := ln.connPool.Get(raddr.String())
if ok {
return c
}
c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive)
select {
case ln.cqueue <- c:
ln.connPool.Set(raddr.String(), c)
return c
default:
c.Close()
ln.config.Logger.Warnf("connection queue is full, client %s discarded", raddr)
return nil
}
}

100
common/net/udp/pool.go Normal file
View File

@ -0,0 +1,100 @@
package udp
import (
"sync"
"time"
"github.com/go-gost/core/logger"
)
type connPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func newConnPool(ttl time.Duration) *connPool {
p := &connPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
p.logger = logger
return p
}
func (p *connPool) Get(key any) (c *conn, ok bool) {
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*conn)
}
return
}
func (p *connPool) Set(key any, c *conn) {
p.m.Store(key, c)
}
func (p *connPool) Delete(key any) {
p.m.Delete(key)
}
func (p *connPool) Close() {
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v any) bool {
if c, ok := v.(*conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *connPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
size := 0
idles := 0
p.m.Range(func(key, value any) bool {
c, ok := value.(*conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
}
case <-p.closed:
return
}
}
}

View File

@ -1,4 +1,4 @@
package net
package udp
import (
"io"
@ -6,7 +6,7 @@ import (
"syscall"
)
type UDPConn interface {
type Conn interface {
net.PacketConn
io.Reader
io.Writer