add keepAlive option for udp Listener
This commit is contained in:
@ -8,7 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/core/common/net/dialer"
|
"github.com/go-gost/core/common/net/dialer"
|
||||||
"github.com/go-gost/core/common/util/udp"
|
"github.com/go-gost/core/common/net/udp"
|
||||||
"github.com/go-gost/core/connector"
|
"github.com/go-gost/core/connector"
|
||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
"github.com/go-gost/core/metrics"
|
"github.com/go-gost/core/metrics"
|
||||||
@ -198,9 +198,14 @@ func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...
|
|||||||
"network": network,
|
"network": network,
|
||||||
"address": address,
|
"address": address,
|
||||||
})
|
})
|
||||||
ln := udp.NewListener(conn, addr,
|
ln := udp.NewListener(conn, &udp.ListenConfig{
|
||||||
options.Backlog, options.UDPDataQueueSize, options.UDPDataBufferSize,
|
Backlog: options.Backlog,
|
||||||
options.UDPConnTTL, logger)
|
ReadQueueSize: options.UDPDataQueueSize,
|
||||||
|
ReadBufferSize: options.UDPDataBufferSize,
|
||||||
|
TTL: options.UDPConnTTL,
|
||||||
|
KeepAlive: true,
|
||||||
|
Logger: logger,
|
||||||
|
})
|
||||||
return ln, err
|
return ln, err
|
||||||
default:
|
default:
|
||||||
err := fmt.Errorf("network %s unsupported", network)
|
err := fmt.Errorf("network %s unsupported", network)
|
||||||
|
@ -9,8 +9,8 @@ import (
|
|||||||
"github.com/go-gost/core/common/bufpool"
|
"github.com/go-gost/core/common/bufpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||||
type Conn struct {
|
type conn struct {
|
||||||
net.PacketConn
|
net.PacketConn
|
||||||
localAddr net.Addr
|
localAddr net.Addr
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
@ -18,19 +18,21 @@ type Conn struct {
|
|||||||
idle int32 // indicate the connection is idle
|
idle int32 // indicate the connection is idle
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
closeMutex sync.Mutex
|
closeMutex sync.Mutex
|
||||||
|
keepAlive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn {
|
func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn {
|
||||||
return &Conn{
|
return &conn{
|
||||||
PacketConn: c,
|
PacketConn: c,
|
||||||
localAddr: localAddr,
|
localAddr: laddr,
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
rc: make(chan []byte, queueSize),
|
rc: make(chan []byte, queueSize),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
|
keepAlive: keepAlive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||||
select {
|
select {
|
||||||
case bb := <-c.rc:
|
case bb := <-c.rc:
|
||||||
n = copy(b, bb)
|
n = copy(b, bb)
|
||||||
@ -47,16 +49,20 @@ func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *conn) Read(b []byte) (n int, err error) {
|
||||||
n, _, err = c.ReadFrom(b)
|
n, _, err = c.ReadFrom(b)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
func (c *conn) Write(b []byte) (n int, err error) {
|
||||||
return c.WriteTo(b, c.remoteAddr)
|
n, err = c.WriteTo(b, c.remoteAddr)
|
||||||
|
if !c.keepAlive {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *conn) Close() error {
|
||||||
c.closeMutex.Lock()
|
c.closeMutex.Lock()
|
||||||
defer c.closeMutex.Unlock()
|
defer c.closeMutex.Unlock()
|
||||||
|
|
||||||
@ -68,19 +74,19 @@ func (c *Conn) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *conn) LocalAddr() net.Addr {
|
||||||
return c.localAddr
|
return c.localAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *conn) RemoteAddr() net.Addr {
|
||||||
return c.remoteAddr
|
return c.remoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) IsIdle() bool {
|
func (c *conn) IsIdle() bool {
|
||||||
return atomic.LoadInt32(&c.idle) > 0
|
return atomic.LoadInt32(&c.idle) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) SetIdle(idle bool) {
|
func (c *conn) SetIdle(idle bool) {
|
||||||
v := int32(0)
|
v := int32(0)
|
||||||
if idle {
|
if idle {
|
||||||
v = 1
|
v = 1
|
||||||
@ -88,7 +94,7 @@ func (c *Conn) SetIdle(idle bool) {
|
|||||||
atomic.StoreInt32(&c.idle, v)
|
atomic.StoreInt32(&c.idle, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) WriteQueue(b []byte) error {
|
func (c *conn) WriteQueue(b []byte) error {
|
||||||
select {
|
select {
|
||||||
case c.rc <- b:
|
case c.rc <- b:
|
||||||
return nil
|
return nil
|
@ -9,30 +9,37 @@ import (
|
|||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ListenConfig struct {
|
||||||
|
Addr net.Addr
|
||||||
|
Backlog int
|
||||||
|
ReadQueueSize int
|
||||||
|
ReadBufferSize int
|
||||||
|
TTL time.Duration
|
||||||
|
KeepAlive bool
|
||||||
|
Logger logger.Logger
|
||||||
|
}
|
||||||
type listener struct {
|
type listener struct {
|
||||||
addr net.Addr
|
|
||||||
conn net.PacketConn
|
conn net.PacketConn
|
||||||
cqueue chan net.Conn
|
cqueue chan net.Conn
|
||||||
readQueueSize int
|
connPool *connPool
|
||||||
readBufferSize int
|
|
||||||
connPool *ConnPool
|
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
errChan chan error
|
errChan chan error
|
||||||
logger logger.Logger
|
config *ListenConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener {
|
func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &ListenConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
ln := &listener{
|
ln := &listener{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
addr: addr,
|
cqueue: make(chan net.Conn, cfg.Backlog),
|
||||||
cqueue: make(chan net.Conn, backlog),
|
connPool: newConnPool(cfg.TTL).WithLogger(cfg.Logger),
|
||||||
connPool: NewConnPool(ttl).WithLogger(logger),
|
|
||||||
readQueueSize: dataQueueSize,
|
|
||||||
readBufferSize: dataBufferSize,
|
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
logger: logger,
|
config: cfg,
|
||||||
}
|
}
|
||||||
go ln.listenLoop()
|
go ln.listenLoop()
|
||||||
|
|
||||||
@ -61,7 +68,7 @@ func (ln *listener) listenLoop() {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
b := bufpool.Get(ln.readBufferSize)
|
b := bufpool.Get(ln.config.ReadBufferSize)
|
||||||
|
|
||||||
n, raddr, err := ln.conn.ReadFrom(*b)
|
n, raddr, err := ln.conn.ReadFrom(*b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -77,13 +84,16 @@ func (ln *listener) listenLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := c.WriteQueue((*b)[:n]); err != nil {
|
if err := c.WriteQueue((*b)[:n]); err != nil {
|
||||||
ln.logger.Warn("data discarded: ", err)
|
ln.config.Logger.Warn("data discarded: ", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *listener) Addr() net.Addr {
|
func (ln *listener) Addr() net.Addr {
|
||||||
return ln.addr
|
if ln.config.Addr != nil {
|
||||||
|
return ln.config.Addr
|
||||||
|
}
|
||||||
|
return ln.conn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *listener) Close() error {
|
func (ln *listener) Close() error {
|
||||||
@ -98,7 +108,7 @@ func (ln *listener) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *listener) getConn(raddr net.Addr) *Conn {
|
func (ln *listener) getConn(raddr net.Addr) *conn {
|
||||||
ln.mux.Lock()
|
ln.mux.Lock()
|
||||||
defer ln.mux.Unlock()
|
defer ln.mux.Unlock()
|
||||||
|
|
||||||
@ -107,14 +117,14 @@ func (ln *listener) getConn(raddr net.Addr) *Conn {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize)
|
c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive)
|
||||||
select {
|
select {
|
||||||
case ln.cqueue <- c:
|
case ln.cqueue <- c:
|
||||||
ln.connPool.Set(raddr.String(), c)
|
ln.connPool.Set(raddr.String(), c)
|
||||||
return c
|
return c
|
||||||
default:
|
default:
|
||||||
c.Close()
|
c.Close()
|
||||||
ln.logger.Warnf("connection queue is full, client %s discarded", raddr)
|
ln.config.Logger.Warnf("connection queue is full, client %s discarded", raddr)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -7,15 +7,15 @@ import (
|
|||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnPool struct {
|
type connPool struct {
|
||||||
m sync.Map
|
m sync.Map
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnPool(ttl time.Duration) *ConnPool {
|
func newConnPool(ttl time.Duration) *connPool {
|
||||||
p := &ConnPool{
|
p := &connPool{
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
}
|
}
|
||||||
@ -23,28 +23,28 @@ func NewConnPool(ttl time.Duration) *ConnPool {
|
|||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool {
|
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
|
||||||
p.logger = logger
|
p.logger = logger
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Get(key any) (c *Conn, ok bool) {
|
func (p *connPool) Get(key any) (c *conn, ok bool) {
|
||||||
v, ok := p.m.Load(key)
|
v, ok := p.m.Load(key)
|
||||||
if ok {
|
if ok {
|
||||||
c, ok = v.(*Conn)
|
c, ok = v.(*conn)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Set(key any, c *Conn) {
|
func (p *connPool) Set(key any, c *conn) {
|
||||||
p.m.Store(key, c)
|
p.m.Store(key, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Delete(key any) {
|
func (p *connPool) Delete(key any) {
|
||||||
p.m.Delete(key)
|
p.m.Delete(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) Close() {
|
func (p *connPool) Close() {
|
||||||
select {
|
select {
|
||||||
case <-p.closed:
|
case <-p.closed:
|
||||||
return
|
return
|
||||||
@ -54,14 +54,14 @@ func (p *ConnPool) Close() {
|
|||||||
close(p.closed)
|
close(p.closed)
|
||||||
|
|
||||||
p.m.Range(func(k, v any) bool {
|
p.m.Range(func(k, v any) bool {
|
||||||
if c, ok := v.(*Conn); ok && c != nil {
|
if c, ok := v.(*conn); ok && c != nil {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ConnPool) idleCheck() {
|
func (p *connPool) idleCheck() {
|
||||||
ticker := time.NewTicker(p.ttl)
|
ticker := time.NewTicker(p.ttl)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ func (p *ConnPool) idleCheck() {
|
|||||||
size := 0
|
size := 0
|
||||||
idles := 0
|
idles := 0
|
||||||
p.m.Range(func(key, value any) bool {
|
p.m.Range(func(key, value any) bool {
|
||||||
c, ok := value.(*Conn)
|
c, ok := value.(*conn)
|
||||||
if !ok || c == nil {
|
if !ok || c == nil {
|
||||||
p.Delete(key)
|
p.Delete(key)
|
||||||
return true
|
return true
|
@ -1,4 +1,4 @@
|
|||||||
package net
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
@ -6,7 +6,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UDPConn interface {
|
type Conn interface {
|
||||||
net.PacketConn
|
net.PacketConn
|
||||||
io.Reader
|
io.Reader
|
||||||
io.Writer
|
io.Writer
|
@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/go-gost/core/common/util/udp"
|
"github.com/go-gost/core/common/net/udp"
|
||||||
"github.com/go-gost/core/connector"
|
"github.com/go-gost/core/connector"
|
||||||
"github.com/go-gost/core/internal/util/mux"
|
"github.com/go-gost/core/internal/util/mux"
|
||||||
"github.com/go-gost/core/internal/util/socks"
|
"github.com/go-gost/core/internal/util/socks"
|
||||||
@ -80,13 +80,16 @@ func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, a
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ln := udp.NewListener(
|
ln := udp.NewListener(socks.UDPTunClientPacketConn(conn),
|
||||||
socks.UDPTunClientPacketConn(conn),
|
&udp.ListenConfig{
|
||||||
laddr,
|
Addr: laddr,
|
||||||
opts.Backlog,
|
Backlog: opts.Backlog,
|
||||||
opts.UDPDataQueueSize, opts.UDPDataBufferSize,
|
ReadQueueSize: opts.UDPDataQueueSize,
|
||||||
opts.UDPConnTTL,
|
ReadBufferSize: opts.UDPDataBufferSize,
|
||||||
log)
|
TTL: opts.UDPConnTTL,
|
||||||
|
KeepAlive: true,
|
||||||
|
Logger: log,
|
||||||
|
})
|
||||||
|
|
||||||
return ln, nil
|
return ln, nil
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package udp
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/go-gost/core/common/util/udp"
|
"github.com/go-gost/core/common/net/udp"
|
||||||
"github.com/go-gost/core/listener"
|
"github.com/go-gost/core/listener"
|
||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
md "github.com/go-gost/core/metadata"
|
md "github.com/go-gost/core/metadata"
|
||||||
@ -50,13 +50,14 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
|
|||||||
}
|
}
|
||||||
conn = metrics.WrapPacketConn(l.options.Service, conn)
|
conn = metrics.WrapPacketConn(l.options.Service, conn)
|
||||||
|
|
||||||
l.ln = udp.NewListener(
|
l.ln = udp.NewListener(conn, &udp.ListenConfig{
|
||||||
conn,
|
Backlog: l.md.backlog,
|
||||||
laddr,
|
ReadQueueSize: l.md.readQueueSize,
|
||||||
l.md.backlog,
|
ReadBufferSize: l.md.readBufferSize,
|
||||||
l.md.readQueueSize, l.md.readBufferSize,
|
KeepAlive: l.md.keepalive,
|
||||||
l.md.ttl,
|
TTL: l.md.ttl,
|
||||||
l.logger)
|
Logger: l.logger,
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,19 +14,20 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type metadata struct {
|
type metadata struct {
|
||||||
ttl time.Duration
|
|
||||||
|
|
||||||
readBufferSize int
|
readBufferSize int
|
||||||
readQueueSize int
|
readQueueSize int
|
||||||
backlog int
|
backlog int
|
||||||
|
keepalive bool
|
||||||
|
ttl time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) {
|
func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||||
const (
|
const (
|
||||||
ttl = "ttl"
|
|
||||||
readBufferSize = "readBufferSize"
|
readBufferSize = "readBufferSize"
|
||||||
readQueueSize = "readQueueSize"
|
readQueueSize = "readQueueSize"
|
||||||
backlog = "backlog"
|
backlog = "backlog"
|
||||||
|
keepAlive = "keepAlive"
|
||||||
|
ttl = "ttl"
|
||||||
)
|
)
|
||||||
|
|
||||||
l.md.ttl = mdata.GetDuration(md, ttl)
|
l.md.ttl = mdata.GetDuration(md, ttl)
|
||||||
@ -47,6 +48,7 @@ func (l *udpListener) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
if l.md.backlog <= 0 {
|
if l.md.backlog <= 0 {
|
||||||
l.md.backlog = defaultBacklog
|
l.md.backlog = defaultBacklog
|
||||||
}
|
}
|
||||||
|
l.md.keepalive = mdata.GetBool(md, keepAlive)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -197,7 +197,10 @@ func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
|||||||
c = tls.Client(c, ex.options.tlsConfig)
|
c = tls.Client(c, ex.options.tlsConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &dns.Conn{Conn: c}
|
conn := &dns.Conn{
|
||||||
|
UDPSize: 1024,
|
||||||
|
Conn: c,
|
||||||
|
}
|
||||||
|
|
||||||
if _, err = conn.Write(msg); err != nil {
|
if _, err = conn.Write(msg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Reference in New Issue
Block a user