add ssu connector

This commit is contained in:
ginuerzh
2021-11-09 23:34:19 +08:00
parent 92dc87830f
commit cae199dbd9
29 changed files with 1031 additions and 678 deletions

View File

@ -6,109 +6,184 @@ import (
"sync"
"sync/atomic"
"time"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/logger"
)
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct {
// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type conn struct {
net.PacketConn
raddr net.Addr
remoteAddr net.Addr
rc chan []byte // data receive queue
fresh int32
idle int32
closed chan struct{}
closeMutex sync.Mutex
config *serverConnConfig
}
type serverConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
PacketConn: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
func newConn(c net.PacketConn, raddr net.Addr, queue int) *conn {
return &conn{
PacketConn: c,
remoteAddr: raddr,
rc: make(chan []byte, queue),
closed: make(chan struct{}),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *serverConn) send(b []byte) error {
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case c.rc <- b:
return nil
default:
return errors.New("queue is full")
case bb := <-c.rc:
n = copy(b, bb)
c.SetIdle(false)
bufpool.Put(bb)
case <-c.closed:
err = net.ErrClosed
return
}
addr = c.remoteAddr
return
}
func (c *serverConn) Read(b []byte) (n int, err error) {
func (c *conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
addr = c.raddr
return
func (c *conn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.remoteAddr)
}
func (c *serverConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
func (c *conn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.raddr
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverConn) ttlWait() {
ticker := time.NewTicker(c.config.ttl)
func (c *conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0
}
func (c *conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *conn) Queue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}
type connPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func newConnPool(ttl time.Duration) *connPool {
p := &connPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
p.logger = logger
return p
}
func (p *connPool) Get(key interface{}) (c *conn, ok bool) {
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*conn)
}
return
}
func (p *connPool) Set(key interface{}, c *conn) {
p.m.Store(key, c)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
}
func (p *connPool) Close() {
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v interface{}) bool {
if c, ok := v.(*conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *connPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
c.Close()
return
size := 0
idles := 0
p.m.Range(func(key, value interface{}) bool {
c, ok := value.(*conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
}
case <-c.closed:
case <-p.closed:
return
}
}

View File

@ -2,9 +2,8 @@ package udp
import (
"net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
@ -16,13 +15,14 @@ func init() {
}
type udpListener struct {
addr string
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
connPool connPool
logger logger.Logger
addr string
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
closeChan chan struct{}
connPool *connPool
logger logger.Logger
}
func NewListener(opts ...listener.Option) listener.Listener {
@ -31,8 +31,10 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options)
}
return &udpListener{
addr: options.Addr,
logger: options.Logger,
addr: options.Addr,
errChan: make(chan error, 1),
closeChan: make(chan struct{}),
logger: options.Logger,
}
}
@ -46,15 +48,13 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
return
}
var conn net.PacketConn
conn, err = net.ListenUDP("udp", laddr)
l.conn, err = net.ListenUDP("udp", laddr)
if err != nil {
return
}
l.conn = conn
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
l.connPool = newConnPool(l.md.ttl).WithLogger(l.logger)
go l.listenLoop()
@ -74,12 +74,14 @@ func (l *udpListener) Accept() (conn net.Conn, err error) {
}
func (l *udpListener) Close() error {
err := l.conn.Close()
l.connPool.Range(func(k interface{}, v *serverConn) bool {
v.Close()
return true
})
return err
select {
case <-l.closeChan:
return nil
default:
close(l.closeChan)
l.connPool.Close()
return l.conn.Close()
}
}
func (l *udpListener) Addr() net.Addr {
@ -88,43 +90,43 @@ func (l *udpListener) Addr() net.Addr {
func (l *udpListener) listenLoop() {
for {
b := make([]byte, l.md.readBufferSize)
b := bufpool.Get(l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connPool.Get(raddr.String())
if !ok {
conn = newServerConn(l.conn, raddr,
&serverConnConfig{
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
c := l.getConn(raddr)
if c == nil {
bufpool.Put(b)
continue
}
if err := conn.send(b[:n]); err != nil {
l.logger.Warn("data discarded:", err)
if err := c.Queue(b[:n]); err != nil {
l.logger.Warn("data discarded: ", err)
}
l.logger.Debug("recv", n)
}
}
func (l *udpListener) getConn(addr net.Addr) *conn {
c, ok := l.connPool.Get(addr.String())
if !ok {
c = newConn(l.conn, addr, l.md.readQueueSize)
select {
case l.connChan <- c:
l.connPool.Set(addr.String(), c)
default:
c.Close()
l.logger.Warnf("connection queue is full, client %s discarded", addr.String())
return nil
}
}
return c
}
func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
l.md.ttl = md.GetDuration(ttl)
if l.md.ttl <= 0 {
@ -147,36 +149,3 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
return
}
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key interface{}, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
p.m.Range(func(k, v interface{}) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

View File

@ -4,7 +4,7 @@ import "time"
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadBufferSize = 4096
defaultReadQueueSize = 128
defaultConnQueueSize = 128
)