add ssu connector
This commit is contained in:
@ -6,109 +6,184 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/pkg/internal/bufpool"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type serverConn struct {
|
||||
// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
|
||||
type conn struct {
|
||||
net.PacketConn
|
||||
raddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
rc chan []byte // data receive queue
|
||||
fresh int32
|
||||
idle int32
|
||||
closed chan struct{}
|
||||
closeMutex sync.Mutex
|
||||
config *serverConnConfig
|
||||
}
|
||||
|
||||
type serverConnConfig struct {
|
||||
ttl time.Duration
|
||||
qsize int
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
|
||||
if conn == nil || raddr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
cfg = &serverConnConfig{}
|
||||
}
|
||||
c := &serverConn{
|
||||
PacketConn: conn,
|
||||
raddr: raddr,
|
||||
rc: make(chan []byte, cfg.qsize),
|
||||
func newConn(c net.PacketConn, raddr net.Addr, queue int) *conn {
|
||||
return &conn{
|
||||
PacketConn: c,
|
||||
remoteAddr: raddr,
|
||||
rc: make(chan []byte, queue),
|
||||
closed: make(chan struct{}),
|
||||
config: cfg,
|
||||
}
|
||||
go c.ttlWait()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *serverConn) send(b []byte) error {
|
||||
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("queue is full")
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
c.SetIdle(false)
|
||||
bufpool.Put(bb)
|
||||
|
||||
case <-c.closed:
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
|
||||
addr = c.remoteAddr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = c.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
select {
|
||||
case bb := <-c.rc:
|
||||
n = copy(b, bb)
|
||||
atomic.StoreInt32(&c.fresh, 1)
|
||||
case <-c.closed:
|
||||
err = errors.New("read from closed connection")
|
||||
return
|
||||
}
|
||||
|
||||
addr = c.raddr
|
||||
|
||||
return
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
return c.WriteTo(b, c.raddr)
|
||||
}
|
||||
|
||||
func (c *serverConn) Close() error {
|
||||
func (c *conn) Close() error {
|
||||
c.closeMutex.Lock()
|
||||
defer c.closeMutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.closed:
|
||||
return errors.New("connection is closed")
|
||||
default:
|
||||
if c.config.onClose != nil {
|
||||
c.config.onClose()
|
||||
}
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *serverConn) RemoteAddr() net.Addr {
|
||||
return c.raddr
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *serverConn) ttlWait() {
|
||||
ticker := time.NewTicker(c.config.ttl)
|
||||
func (c *conn) IsIdle() bool {
|
||||
return atomic.LoadInt32(&c.idle) > 0
|
||||
}
|
||||
|
||||
func (c *conn) SetIdle(idle bool) {
|
||||
v := int32(0)
|
||||
if idle {
|
||||
v = 1
|
||||
}
|
||||
atomic.StoreInt32(&c.idle, v)
|
||||
}
|
||||
|
||||
func (c *conn) Queue(b []byte) error {
|
||||
select {
|
||||
case c.rc <- b:
|
||||
return nil
|
||||
|
||||
case <-c.closed:
|
||||
return net.ErrClosed
|
||||
|
||||
default:
|
||||
return errors.New("recv queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
m sync.Map
|
||||
ttl time.Duration
|
||||
closed chan struct{}
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func newConnPool(ttl time.Duration) *connPool {
|
||||
p := &connPool{
|
||||
ttl: ttl,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go p.idleCheck()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
|
||||
p.logger = logger
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *connPool) Get(key interface{}) (c *conn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
c, ok = v.(*conn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *connPool) Set(key interface{}, c *conn) {
|
||||
p.m.Store(key, c)
|
||||
}
|
||||
|
||||
func (p *connPool) Delete(key interface{}) {
|
||||
p.m.Delete(key)
|
||||
}
|
||||
|
||||
func (p *connPool) Close() {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
close(p.closed)
|
||||
|
||||
p.m.Range(func(k, v interface{}) bool {
|
||||
if c, ok := v.(*conn); ok && c != nil {
|
||||
c.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (p *connPool) idleCheck() {
|
||||
ticker := time.NewTicker(p.ttl)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
|
||||
c.Close()
|
||||
return
|
||||
size := 0
|
||||
idles := 0
|
||||
p.m.Range(func(key, value interface{}) bool {
|
||||
c, ok := value.(*conn)
|
||||
if !ok || c == nil {
|
||||
p.Delete(key)
|
||||
return true
|
||||
}
|
||||
size++
|
||||
|
||||
if c.IsIdle() {
|
||||
idles++
|
||||
p.Delete(key)
|
||||
c.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
c.SetIdle(true)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if idles > 0 {
|
||||
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
|
||||
}
|
||||
case <-c.closed:
|
||||
case <-p.closed:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -2,9 +2,8 @@ package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/go-gost/gost/pkg/internal/bufpool"
|
||||
"github.com/go-gost/gost/pkg/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
@ -16,13 +15,14 @@ func init() {
|
||||
}
|
||||
|
||||
type udpListener struct {
|
||||
addr string
|
||||
md metadata
|
||||
conn net.PacketConn
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
connPool connPool
|
||||
logger logger.Logger
|
||||
addr string
|
||||
md metadata
|
||||
conn net.PacketConn
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
closeChan chan struct{}
|
||||
connPool *connPool
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
@ -31,8 +31,10 @@ func NewListener(opts ...listener.Option) listener.Listener {
|
||||
opt(options)
|
||||
}
|
||||
return &udpListener{
|
||||
addr: options.Addr,
|
||||
logger: options.Logger,
|
||||
addr: options.Addr,
|
||||
errChan: make(chan error, 1),
|
||||
closeChan: make(chan struct{}),
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,15 +48,13 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
var conn net.PacketConn
|
||||
conn, err = net.ListenUDP("udp", laddr)
|
||||
l.conn, err = net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.conn = conn
|
||||
l.connChan = make(chan net.Conn, l.md.connQueueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
l.connPool = newConnPool(l.md.ttl).WithLogger(l.logger)
|
||||
|
||||
go l.listenLoop()
|
||||
|
||||
@ -74,12 +74,14 @@ func (l *udpListener) Accept() (conn net.Conn, err error) {
|
||||
}
|
||||
|
||||
func (l *udpListener) Close() error {
|
||||
err := l.conn.Close()
|
||||
l.connPool.Range(func(k interface{}, v *serverConn) bool {
|
||||
v.Close()
|
||||
return true
|
||||
})
|
||||
return err
|
||||
select {
|
||||
case <-l.closeChan:
|
||||
return nil
|
||||
default:
|
||||
close(l.closeChan)
|
||||
l.connPool.Close()
|
||||
return l.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *udpListener) Addr() net.Addr {
|
||||
@ -88,43 +90,43 @@ func (l *udpListener) Addr() net.Addr {
|
||||
|
||||
func (l *udpListener) listenLoop() {
|
||||
for {
|
||||
b := make([]byte, l.md.readBufferSize)
|
||||
b := bufpool.Get(l.md.readBufferSize)
|
||||
|
||||
n, raddr, err := l.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
l.logger.Error("accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := l.connPool.Get(raddr.String())
|
||||
if !ok {
|
||||
conn = newServerConn(l.conn, raddr,
|
||||
&serverConnConfig{
|
||||
ttl: l.md.ttl,
|
||||
qsize: l.md.readQueueSize,
|
||||
onClose: func() {
|
||||
l.connPool.Delete(raddr.String())
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
l.connPool.Set(raddr.String(), conn)
|
||||
default:
|
||||
conn.Close()
|
||||
l.logger.Error("connection queue is full")
|
||||
}
|
||||
c := l.getConn(raddr)
|
||||
if c == nil {
|
||||
bufpool.Put(b)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.send(b[:n]); err != nil {
|
||||
l.logger.Warn("data discarded:", err)
|
||||
if err := c.Queue(b[:n]); err != nil {
|
||||
l.logger.Warn("data discarded: ", err)
|
||||
}
|
||||
l.logger.Debug("recv", n)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *udpListener) getConn(addr net.Addr) *conn {
|
||||
c, ok := l.connPool.Get(addr.String())
|
||||
if !ok {
|
||||
c = newConn(l.conn, addr, l.md.readQueueSize)
|
||||
select {
|
||||
case l.connChan <- c:
|
||||
l.connPool.Set(addr.String(), c)
|
||||
default:
|
||||
c.Close()
|
||||
l.logger.Warnf("connection queue is full, client %s discarded", addr.String())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
|
||||
l.md.ttl = md.GetDuration(ttl)
|
||||
if l.md.ttl <= 0 {
|
||||
@ -147,36 +149,3 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type connPool struct {
|
||||
size int64
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
|
||||
v, ok := p.m.Load(key)
|
||||
if ok {
|
||||
conn, ok = v.(*serverConn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *connPool) Set(key interface{}, conn *serverConn) {
|
||||
p.m.Store(key, conn)
|
||||
atomic.AddInt64(&p.size, 1)
|
||||
}
|
||||
|
||||
func (p *connPool) Delete(key interface{}) {
|
||||
p.m.Delete(key)
|
||||
atomic.AddInt64(&p.size, -1)
|
||||
}
|
||||
|
||||
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
|
||||
p.m.Range(func(k, v interface{}) bool {
|
||||
return f(k, v.(*serverConn))
|
||||
})
|
||||
}
|
||||
|
||||
func (p *connPool) Size() int64 {
|
||||
return atomic.LoadInt64(&p.size)
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import "time"
|
||||
|
||||
const (
|
||||
defaultTTL = 60 * time.Second
|
||||
defaultReadBufferSize = 1024
|
||||
defaultReadBufferSize = 4096
|
||||
defaultReadQueueSize = 128
|
||||
defaultConnQueueSize = 128
|
||||
)
|
||||
|
Reference in New Issue
Block a user