From 83dacf67d57e454cdb0cdf2e746443e5cff2dc52 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 16 Nov 2021 16:12:16 +0800 Subject: [PATCH] add rudp --- cmd/gost/gost.yml | 22 +- cmd/gost/register.go | 3 +- pkg/chain/node.go | 30 +-- pkg/chain/selector.go | 8 +- pkg/common/util/udp/conn.go | 102 ++++++++++ pkg/common/util/udp/pool.go | 100 +++++++++ pkg/handler/forward/{local => }/handler.go | 12 +- pkg/handler/forward/{local => }/metadata.go | 4 +- pkg/listener/rtcp/listener.go | 19 +- pkg/listener/rtcp/metadata.go | 20 +- pkg/listener/rudp/listener.go | 213 ++++++++++++++++++++ pkg/listener/rudp/metadata.go | 55 +++++ pkg/listener/udp/conn.go | 190 ----------------- pkg/listener/udp/listener.go | 19 +- pkg/listener/udp/metadata.go | 12 +- 15 files changed, 543 insertions(+), 266 deletions(-) create mode 100644 pkg/common/util/udp/conn.go create mode 100644 pkg/common/util/udp/pool.go rename pkg/handler/forward/{local => }/handler.go (88%) rename pkg/handler/forward/{local => }/metadata.go (77%) create mode 100644 pkg/listener/rudp/listener.go create mode 100644 pkg/listener/rudp/metadata.go delete mode 100644 pkg/listener/udp/conn.go diff --git a/cmd/gost/gost.yml b/cmd/gost/gost.yml index 4fa4648..451bf48 100644 --- a/cmd/gost/gost.yml +++ b/cmd/gost/gost.yml @@ -131,10 +131,30 @@ services: retry: 3 listener: type: rtcp - chain: chain-socks5 + # chain: chain-socks5 metadata: keepAlive: 15s mux: true +- name: rudp + addr: ":1053" + forwarder: + targets: + - 192.168.8.8:53 + - 192.168.8.1:53 + selector: + strategy: round + maxFails: 1 + failTimeout: 30s + handler: + type: forward + metadata: + readTimeout: 5s + retry: 3 + listener: + type: rudp + chain: chain-socks5 + metadata: + keepAlive: 15s chains: - name: chain01 diff --git a/cmd/gost/register.go b/cmd/gost/register.go index a87682b..35a32f0 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -12,7 +12,7 @@ import ( _ "github.com/go-gost/gost/pkg/dialer/udp" // Register handlers - _ "github.com/go-gost/gost/pkg/handler/forward/local" + _ "github.com/go-gost/gost/pkg/handler/forward" _ "github.com/go-gost/gost/pkg/handler/http" _ "github.com/go-gost/gost/pkg/handler/socks/v4" _ "github.com/go-gost/gost/pkg/handler/socks/v5" @@ -27,6 +27,7 @@ import ( _ "github.com/go-gost/gost/pkg/listener/obfs/tls" _ "github.com/go-gost/gost/pkg/listener/quic" _ "github.com/go-gost/gost/pkg/listener/rtcp" + _ "github.com/go-gost/gost/pkg/listener/rudp" _ "github.com/go-gost/gost/pkg/listener/tcp" _ "github.com/go-gost/gost/pkg/listener/tls" _ "github.com/go-gost/gost/pkg/listener/tls/mux" diff --git a/pkg/chain/node.go b/pkg/chain/node.go index d4c011f..67d8469 100644 --- a/pkg/chain/node.go +++ b/pkg/chain/node.go @@ -1,7 +1,7 @@ package chain import ( - "sync" + "sync/atomic" "time" "github.com/go-gost/gost/pkg/bypass" @@ -86,8 +86,7 @@ func (g *NodeGroup) Next() *Node { type FailMarker struct { failTime int64 - failCount uint32 - mux sync.RWMutex + failCount int64 } func (m *FailMarker) FailTime() int64 { @@ -95,21 +94,15 @@ func (m *FailMarker) FailTime() int64 { return 0 } - m.mux.RLock() - defer m.mux.RUnlock() - - return m.failTime + return atomic.LoadInt64(&m.failCount) } -func (m *FailMarker) FailCount() uint32 { +func (m *FailMarker) FailCount() int64 { if m == nil { return 0 } - m.mux.RLock() - defer m.mux.RUnlock() - - return m.failCount + return atomic.LoadInt64(&m.failCount) } func (m *FailMarker) Mark() { @@ -117,11 +110,8 @@ func (m *FailMarker) Mark() { return } - m.mux.Lock() - defer m.mux.Unlock() - - m.failTime = time.Now().Unix() - m.failCount++ + atomic.AddInt64(&m.failCount, 1) + atomic.StoreInt64(&m.failTime, time.Now().Unix()) } func (m *FailMarker) Reset() { @@ -129,9 +119,5 @@ func (m *FailMarker) Reset() { return } - m.mux.Lock() - defer m.mux.Unlock() - - m.failTime = 0 - m.failCount = 0 + atomic.StoreInt64(&m.failCount, 0) } diff --git a/pkg/chain/selector.go b/pkg/chain/selector.go index 8c259db..2588407 100644 --- a/pkg/chain/selector.go +++ b/pkg/chain/selector.go @@ -11,7 +11,6 @@ import ( // default options for FailFilter const ( - DefaultMaxFails = 1 DefaultFailTimeout = 30 * time.Second ) @@ -128,20 +127,17 @@ func FailFilter(maxFails int, timeout time.Duration) Filter { // Filter filters dead nodes. func (f *failFilter) Filter(nodes ...*Node) []*Node { maxFails := f.maxFails - if maxFails == 0 { - maxFails = DefaultMaxFails - } failTimeout := f.failTimeout if failTimeout == 0 { failTimeout = DefaultFailTimeout } - if len(nodes) <= 1 || maxFails < 0 { + if len(nodes) <= 1 || maxFails <= 0 { return nodes } var nl []*Node for _, node := range nodes { - if node.Marker().FailCount() < uint32(maxFails) || + if node.Marker().FailCount() < int64(maxFails) || time.Since(time.Unix(node.Marker().FailTime(), 0)) >= failTimeout { nl = append(nl, node) } diff --git a/pkg/common/util/udp/conn.go b/pkg/common/util/udp/conn.go new file mode 100644 index 0000000..184c995 --- /dev/null +++ b/pkg/common/util/udp/conn.go @@ -0,0 +1,102 @@ +package udp + +import ( + "errors" + "net" + "sync" + "sync/atomic" + + "github.com/go-gost/gost/pkg/common/bufpool" +) + +// Conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. +type Conn struct { + net.PacketConn + localAddr net.Addr + remoteAddr net.Addr + rc chan []byte // data receive queue + idle int32 // indicate the connection is idle + closed chan struct{} + closeMutex sync.Mutex +} + +func NewConn(c net.PacketConn, localAddr, remoteAddr net.Addr, queueSize int) *Conn { + return &Conn{ + PacketConn: c, + localAddr: localAddr, + remoteAddr: remoteAddr, + rc: make(chan []byte, queueSize), + closed: make(chan struct{}), + } +} + +func (c *Conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + select { + case bb := <-c.rc: + n = copy(b, bb) + c.SetIdle(false) + bufpool.Put(bb) + + case <-c.closed: + err = net.ErrClosed + return + } + + addr = c.remoteAddr + + return +} + +func (c *Conn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *Conn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.remoteAddr) +} + +func (c *Conn) Close() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + select { + case <-c.closed: + default: + close(c.closed) + } + return nil +} + +func (c *Conn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *Conn) IsIdle() bool { + return atomic.LoadInt32(&c.idle) > 0 +} + +func (c *Conn) SetIdle(idle bool) { + v := int32(0) + if idle { + v = 1 + } + atomic.StoreInt32(&c.idle, v) +} + +func (c *Conn) WriteQueue(b []byte) error { + select { + case c.rc <- b: + return nil + + case <-c.closed: + return net.ErrClosed + + default: + return errors.New("recv queue is full") + } +} diff --git a/pkg/common/util/udp/pool.go b/pkg/common/util/udp/pool.go new file mode 100644 index 0000000..513f70e --- /dev/null +++ b/pkg/common/util/udp/pool.go @@ -0,0 +1,100 @@ +package udp + +import ( + "sync" + "time" + + "github.com/go-gost/gost/pkg/logger" +) + +type ConnPool struct { + m sync.Map + ttl time.Duration + closed chan struct{} + logger logger.Logger +} + +func NewConnPool(ttl time.Duration) *ConnPool { + p := &ConnPool{ + ttl: ttl, + closed: make(chan struct{}), + } + go p.idleCheck() + return p +} + +func (p *ConnPool) WithLogger(logger logger.Logger) *ConnPool { + p.logger = logger + return p +} + +func (p *ConnPool) Get(key interface{}) (c *Conn, ok bool) { + v, ok := p.m.Load(key) + if ok { + c, ok = v.(*Conn) + } + return +} + +func (p *ConnPool) Set(key interface{}, c *Conn) { + p.m.Store(key, c) +} + +func (p *ConnPool) Delete(key interface{}) { + p.m.Delete(key) +} + +func (p *ConnPool) Close() { + select { + case <-p.closed: + return + default: + } + + close(p.closed) + + p.m.Range(func(k, v interface{}) bool { + if c, ok := v.(*Conn); ok && c != nil { + c.Close() + } + return true + }) +} + +func (p *ConnPool) idleCheck() { + ticker := time.NewTicker(p.ttl) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + size := 0 + idles := 0 + p.m.Range(func(key, value interface{}) bool { + c, ok := value.(*Conn) + if !ok || c == nil { + p.Delete(key) + return true + } + size++ + + if c.IsIdle() { + idles++ + p.Delete(key) + c.Close() + return true + } + + c.SetIdle(true) + + return true + }) + + if idles > 0 { + p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles) + } + case <-p.closed: + return + } + } +} diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/handler.go similarity index 88% rename from pkg/handler/forward/local/handler.go rename to pkg/handler/forward/handler.go index ef55200..e4cd309 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/handler.go @@ -1,4 +1,4 @@ -package local +package forward import ( "context" @@ -17,7 +17,7 @@ func init() { registry.RegisterHandler("forward", NewHandler) } -type localForwardHandler struct { +type forwardHandler struct { group *chain.NodeGroup chain *chain.Chain bypass bypass.Bypass @@ -31,23 +31,23 @@ func NewHandler(opts ...handler.Option) handler.Handler { opt(options) } - return &localForwardHandler{ + return &forwardHandler{ chain: options.Chain, bypass: options.Bypass, logger: options.Logger, } } -func (h *localForwardHandler) Init(md md.Metadata) (err error) { +func (h *forwardHandler) Init(md md.Metadata) (err error) { return h.parseMetadata(md) } // Forward implements handler.Forwarder. -func (h *localForwardHandler) Forward(group *chain.NodeGroup) { +func (h *forwardHandler) Forward(group *chain.NodeGroup) { h.group = group } -func (h *localForwardHandler) Handle(ctx context.Context, conn net.Conn) { +func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() diff --git a/pkg/handler/forward/local/metadata.go b/pkg/handler/forward/metadata.go similarity index 77% rename from pkg/handler/forward/local/metadata.go rename to pkg/handler/forward/metadata.go index b9e2c3a..9bf7df0 100644 --- a/pkg/handler/forward/local/metadata.go +++ b/pkg/handler/forward/metadata.go @@ -1,4 +1,4 @@ -package local +package forward import ( "time" @@ -11,7 +11,7 @@ type metadata struct { retryCount int } -func (h *localForwardHandler) parseMetadata(md md.Metadata) (err error) { +func (h *forwardHandler) parseMetadata(md md.Metadata) (err error) { const ( readTimeout = "readTimeout" retryCount = "retry" diff --git a/pkg/listener/rtcp/listener.go b/pkg/listener/rtcp/listener.go index 827e2a5..61852cb 100644 --- a/pkg/listener/rtcp/listener.go +++ b/pkg/listener/rtcp/listener.go @@ -26,7 +26,7 @@ type rtcpListener struct { chain *chain.Chain md metadata ln net.Listener - connChan chan net.Conn + cqueue chan net.Conn session *mux.Session sessionMux sync.Mutex logger logger.Logger @@ -57,7 +57,7 @@ func (l *rtcpListener) Init(md md.Metadata) (err error) { } l.laddr = laddr - l.connChan = make(chan net.Conn, l.md.connQueueSize) + l.cqueue = make(chan net.Conn, l.md.backlog) if l.chain.IsEmpty() { l.ln, err = net.ListenTCP("tcp", laddr) @@ -93,7 +93,7 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) { } select { - case conn = <-l.connChan: + case conn = <-l.cqueue: case <-l.closed: err = net.ErrClosed } @@ -130,7 +130,7 @@ func (l *rtcpListener) listenLoop() { tempDelay = 0 select { - case l.connChan <- conn: + case l.cqueue <- conn: default: conn.Close() l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr().String()) @@ -169,36 +169,29 @@ func (l *rtcpListener) waitPeer(conn net.Conn) (net.Conn, error) { addr.ParseFrom(l.addr) req := gosocks5.NewRequest(gosocks5.CmdBind, &addr) if err := req.Write(conn); err != nil { - l.logger.Error(err) return nil, err } // first reply, bind status rep, err := gosocks5.ReadReply(conn) if err != nil { - l.logger.Error(err) return nil, err } l.logger.Debug(rep) if rep.Rep != gosocks5.Succeeded { - err = fmt.Errorf("bind on %s failed", l.addr) - l.logger.Error(err) - return nil, err + return nil, fmt.Errorf("bind on %s failed", l.addr) } l.logger.Debugf("bind on %s OK", rep.Addr) // second reply, peer connected rep, err = gosocks5.ReadReply(conn) if err != nil { - l.logger.Error(err) return nil, err } if rep.Rep != gosocks5.Succeeded { - err = fmt.Errorf("peer connect failed") - l.logger.Error(err) - return nil, err + return nil, fmt.Errorf("peer connect failed") } raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) diff --git a/pkg/listener/rtcp/metadata.go b/pkg/listener/rtcp/metadata.go index 8ee306e..c25e577 100644 --- a/pkg/listener/rtcp/metadata.go +++ b/pkg/listener/rtcp/metadata.go @@ -8,28 +8,28 @@ import ( const ( defaultKeepAlivePeriod = 180 * time.Second - defaultConnQueueSize = 128 + defaultBacklog = 128 ) type metadata struct { - enableMux bool - connQueueSize int - retryCount int + enableMux bool + backlog int + retryCount int } func (l *rtcpListener) parseMetadata(md md.Metadata) (err error) { const ( - enableMux = "mux" - connQueueSize = "connQueueSize" - retryCount = "retry" + enableMux = "mux" + backlog = "backlog" + retryCount = "retry" ) l.md.enableMux = md.GetBool(enableMux) l.md.retryCount = md.GetInt(retryCount) - l.md.connQueueSize = md.GetInt(connQueueSize) - if l.md.connQueueSize <= 0 { - l.md.connQueueSize = defaultConnQueueSize + l.md.backlog = md.GetInt(backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog } return } diff --git a/pkg/listener/rudp/listener.go b/pkg/listener/rudp/listener.go new file mode 100644 index 0000000..6105fc9 --- /dev/null +++ b/pkg/listener/rudp/listener.go @@ -0,0 +1,213 @@ +package rudp + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/common/util/udp" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterListener("rudp", NewListener) +} + +type rudpListener struct { + addr string + laddr *net.UDPAddr + chain *chain.Chain + md metadata + cqueue chan net.Conn + closed chan struct{} + connPool *udp.ConnPool + logger logger.Logger +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &rudpListener{ + addr: options.Addr, + chain: options.Chain, + closed: make(chan struct{}), + logger: options.Logger, + } +} + +func (l *rudpListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + laddr, err := net.ResolveUDPAddr("udp", l.addr) + if err != nil { + return + } + + l.laddr = laddr + l.cqueue = make(chan net.Conn, l.md.backlog) + l.connPool = udp.NewConnPool(l.md.ttl).WithLogger(l.logger) + + go l.listenLoop() + + return +} + +func (l *rudpListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.cqueue: + return + case <-l.closed: + return nil, listener.ErrClosed + } +} + +func (l *rudpListener) Close() error { + select { + case <-l.closed: + default: + close(l.closed) + l.connPool.Close() + } + + return nil +} + +func (l *rudpListener) Addr() net.Addr { + return l.laddr +} + +func (l *rudpListener) listenLoop() { + for { + conn, err := l.connect() + if err != nil { + l.logger.Error(err) + return + } + + func() { + defer conn.Close() + + for { + b := bufpool.Get(l.md.readBufferSize) + + n, raddr, err := conn.ReadFrom(b) + if err != nil { + return + } + + c := l.getConn(conn, raddr) + if c == nil { + bufpool.Put(b) + continue + } + + if err := c.WriteQueue(b[:n]); err != nil { + l.logger.Warn("data discarded: ", err) + } + } + }() + } +} + +func (l *rudpListener) connect() (conn net.PacketConn, err error) { + var tempDelay time.Duration + + for { + select { + case <-l.closed: + return nil, net.ErrClosed + default: + } + + conn, err = func() (net.PacketConn, error) { + if l.chain.IsEmpty() { + return net.ListenUDP("udp", l.laddr) + } + r := (&chain.Router{}). + WithChain(l.chain). + WithRetry(l.md.retryCount). + WithLogger(l.logger) + cc, err := r.Connect(context.Background()) + if err != nil { + return nil, err + } + + conn, err := l.initUDPTunnel(cc) + if err != nil { + cc.Close() + return nil, err + } + return conn, err + }() + if err == nil { + return + } + + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + l.logger.Warnf("accept: %v, retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + } +} + +func (l *rudpListener) initUDPTunnel(conn net.Conn) (net.PacketConn, error) { + socksAddr := gosocks5.Addr{} + socksAddr.ParseFrom(l.laddr.String()) + req := gosocks5.NewRequest(socks.CmdUDPTun, &socksAddr) + if err := req.Write(conn); err != nil { + return nil, err + } + l.logger.Debug(req) + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + l.logger.Debug(reply) + + if reply.Rep != gosocks5.Succeeded { + return nil, fmt.Errorf("bind on %s failed", l.laddr) + } + + baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) + if err != nil { + return nil, err + } + l.logger.Debugf("bind on %s OK", baddr) + + return socks.UDPTunClientConn(conn, nil), nil +} + +func (l *rudpListener) getConn(conn net.PacketConn, raddr net.Addr) *udp.Conn { + c, ok := l.connPool.Get(raddr.String()) + if !ok { + c = udp.NewConn(conn, l.laddr, raddr, l.md.readQueueSize) + select { + case l.cqueue <- c: + l.connPool.Set(raddr.String(), c) + default: + c.Close() + l.logger.Warnf("connection queue is full, client %s discarded", raddr.String()) + return nil + } + } + return c +} diff --git a/pkg/listener/rudp/metadata.go b/pkg/listener/rudp/metadata.go new file mode 100644 index 0000000..6986813 --- /dev/null +++ b/pkg/listener/rudp/metadata.go @@ -0,0 +1,55 @@ +package rudp + +import ( + "time" + + md "github.com/go-gost/gost/pkg/metadata" +) + +const ( + defaultTTL = 60 * time.Second + defaultReadBufferSize = 4096 + defaultReadQueueSize = 128 + defaultBacklog = 128 +) + +type metadata struct { + ttl time.Duration + + readBufferSize int + readQueueSize int + backlog int + retryCount int +} + +func (l *rudpListener) parseMetadata(md md.Metadata) (err error) { + const ( + ttl = "ttl" + readBufferSize = "readBufferSize" + readQueueSize = "readQueueSize" + backlog = "backlog" + retryCount = "retry" + ) + + l.md.ttl = md.GetDuration(ttl) + if l.md.ttl <= 0 { + l.md.ttl = defaultTTL + } + l.md.readBufferSize = md.GetInt(readBufferSize) + if l.md.readBufferSize <= 0 { + l.md.readBufferSize = defaultReadBufferSize + } + + l.md.readQueueSize = md.GetInt(readQueueSize) + if l.md.readQueueSize <= 0 { + l.md.readQueueSize = defaultReadQueueSize + } + + l.md.backlog = md.GetInt(backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + + l.md.retryCount = md.GetInt(retryCount) + return +} diff --git a/pkg/listener/udp/conn.go b/pkg/listener/udp/conn.go deleted file mode 100644 index 5c7f1f6..0000000 --- a/pkg/listener/udp/conn.go +++ /dev/null @@ -1,190 +0,0 @@ -package udp - -import ( - "errors" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/go-gost/gost/pkg/common/bufpool" - "github.com/go-gost/gost/pkg/logger" -) - -// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. -type conn struct { - net.PacketConn - remoteAddr net.Addr - rc chan []byte // data receive queue - idle int32 - closed chan struct{} - closeMutex sync.Mutex -} - -func newConn(c net.PacketConn, raddr net.Addr, queue int) *conn { - return &conn{ - PacketConn: c, - remoteAddr: raddr, - rc: make(chan []byte, queue), - closed: make(chan struct{}), - } -} - -func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - select { - case bb := <-c.rc: - n = copy(b, bb) - c.SetIdle(false) - bufpool.Put(bb) - - case <-c.closed: - err = net.ErrClosed - return - } - - addr = c.remoteAddr - - return -} - -func (c *conn) Read(b []byte) (n int, err error) { - n, _, err = c.ReadFrom(b) - return -} - -func (c *conn) Write(b []byte) (n int, err error) { - return c.WriteTo(b, c.remoteAddr) -} - -func (c *conn) Close() error { - c.closeMutex.Lock() - defer c.closeMutex.Unlock() - - select { - case <-c.closed: - default: - close(c.closed) - } - return nil -} - -func (c *conn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *conn) IsIdle() bool { - return atomic.LoadInt32(&c.idle) > 0 -} - -func (c *conn) SetIdle(idle bool) { - v := int32(0) - if idle { - v = 1 - } - atomic.StoreInt32(&c.idle, v) -} - -func (c *conn) Queue(b []byte) error { - select { - case c.rc <- b: - return nil - - case <-c.closed: - return net.ErrClosed - - default: - return errors.New("recv queue is full") - } -} - -type connPool struct { - m sync.Map - ttl time.Duration - closed chan struct{} - logger logger.Logger -} - -func newConnPool(ttl time.Duration) *connPool { - p := &connPool{ - ttl: ttl, - closed: make(chan struct{}), - } - go p.idleCheck() - return p -} - -func (p *connPool) WithLogger(logger logger.Logger) *connPool { - p.logger = logger - return p -} - -func (p *connPool) Get(key interface{}) (c *conn, ok bool) { - v, ok := p.m.Load(key) - if ok { - c, ok = v.(*conn) - } - return -} - -func (p *connPool) Set(key interface{}, c *conn) { - p.m.Store(key, c) -} - -func (p *connPool) Delete(key interface{}) { - p.m.Delete(key) -} - -func (p *connPool) Close() { - select { - case <-p.closed: - return - default: - } - - close(p.closed) - - p.m.Range(func(k, v interface{}) bool { - if c, ok := v.(*conn); ok && c != nil { - c.Close() - } - return true - }) -} - -func (p *connPool) idleCheck() { - ticker := time.NewTicker(p.ttl) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - size := 0 - idles := 0 - p.m.Range(func(key, value interface{}) bool { - c, ok := value.(*conn) - if !ok || c == nil { - p.Delete(key) - return true - } - size++ - - if c.IsIdle() { - idles++ - p.Delete(key) - c.Close() - return true - } - - c.SetIdle(true) - - return true - }) - - if idles > 0 { - p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles) - } - case <-p.closed: - return - } - } -} diff --git a/pkg/listener/udp/listener.go b/pkg/listener/udp/listener.go index 89427b6..075c1a1 100644 --- a/pkg/listener/udp/listener.go +++ b/pkg/listener/udp/listener.go @@ -4,6 +4,7 @@ import ( "net" "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/common/util/udp" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -18,10 +19,10 @@ type udpListener struct { addr string md metadata conn net.PacketConn - connChan chan net.Conn + cqueue chan net.Conn errChan chan error closed chan struct{} - connPool *connPool + connPool *udp.ConnPool logger logger.Logger } @@ -53,8 +54,8 @@ func (l *udpListener) Init(md md.Metadata) (err error) { return } - l.connChan = make(chan net.Conn, l.md.connQueueSize) - l.connPool = newConnPool(l.md.ttl).WithLogger(l.logger) + l.cqueue = make(chan net.Conn, l.md.backlog) + l.connPool = udp.NewConnPool(l.md.ttl).WithLogger(l.logger) go l.listenLoop() @@ -64,7 +65,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) { func (l *udpListener) Accept() (conn net.Conn, err error) { var ok bool select { - case conn = <-l.connChan: + case conn = <-l.cqueue: case err, ok = <-l.errChan: if !ok { err = listener.ErrClosed @@ -106,18 +107,18 @@ func (l *udpListener) listenLoop() { continue } - if err := c.Queue(b[:n]); err != nil { + if err := c.WriteQueue(b[:n]); err != nil { l.logger.Warn("data discarded: ", err) } } } -func (l *udpListener) getConn(addr net.Addr) *conn { +func (l *udpListener) getConn(addr net.Addr) *udp.Conn { c, ok := l.connPool.Get(addr.String()) if !ok { - c = newConn(l.conn, addr, l.md.readQueueSize) + c = udp.NewConn(l.conn, l.conn.LocalAddr(), addr, l.md.readQueueSize) select { - case l.connChan <- c: + case l.cqueue <- c: l.connPool.Set(addr.String(), c) default: c.Close() diff --git a/pkg/listener/udp/metadata.go b/pkg/listener/udp/metadata.go index 89cf7d8..b2dd89f 100644 --- a/pkg/listener/udp/metadata.go +++ b/pkg/listener/udp/metadata.go @@ -10,7 +10,7 @@ const ( defaultTTL = 60 * time.Second defaultReadBufferSize = 4096 defaultReadQueueSize = 128 - defaultConnQueueSize = 128 + defaultBacklog = 128 ) type metadata struct { @@ -18,7 +18,7 @@ type metadata struct { readBufferSize int readQueueSize int - connQueueSize int + backlog int } func (l *udpListener) parseMetadata(md md.Metadata) (err error) { @@ -26,7 +26,7 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) { ttl = "ttl" readBufferSize = "readBufferSize" readQueueSize = "readQueueSize" - connQueueSize = "connQueueSize" + backlog = "backlog" ) l.md.ttl = md.GetDuration(ttl) @@ -43,9 +43,9 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) { l.md.readQueueSize = defaultReadQueueSize } - l.md.connQueueSize = md.GetInt(connQueueSize) - if l.md.connQueueSize <= 0 { - l.md.connQueueSize = defaultConnQueueSize + l.md.backlog = md.GetInt(backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog } return