diff --git a/server/listener/ftcp/conn.go b/server/listener/ftcp/conn.go new file mode 100644 index 0000000..35c59ed --- /dev/null +++ b/server/listener/ftcp/conn.go @@ -0,0 +1,115 @@ +package ftcp + +import ( + "errors" + "net" + "sync" + "sync/atomic" + "time" +) + +// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. +type serverConn struct { + net.PacketConn + raddr net.Addr + rc chan []byte // data receive queue + fresh int32 + closed chan struct{} + closeMutex sync.Mutex + config *serverConnConfig +} + +type serverConnConfig struct { + ttl time.Duration + qsize int + onClose func() +} + +func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn { + if conn == nil || raddr == nil { + return nil + } + + if cfg == nil { + cfg = &serverConnConfig{} + } + c := &serverConn{ + PacketConn: conn, + raddr: raddr, + rc: make(chan []byte, cfg.qsize), + closed: make(chan struct{}), + config: cfg, + } + go c.ttlWait() + return c +} + +func (c *serverConn) send(b []byte) error { + select { + case c.rc <- b: + return nil + default: + return errors.New("queue is full") + } +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + select { + case bb := <-c.rc: + n = copy(b, bb) + atomic.StoreInt32(&c.fresh, 1) + case <-c.closed: + err = errors.New("read from closed connection") + return + } + + addr = c.raddr + + return +} + +func (c *serverConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.raddr) +} + +func (c *serverConn) Close() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + select { + case <-c.closed: + return errors.New("connection is closed") + default: + if c.config.onClose != nil { + c.config.onClose() + } + close(c.closed) + } + return nil +} + +func (c *serverConn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *serverConn) ttlWait() { + ticker := time.NewTicker(c.config.ttl) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) { + c.Close() + return + } + case <-c.closed: + return + } + } +} diff --git a/server/listener/ftcp/listener.go b/server/listener/ftcp/listener.go new file mode 100644 index 0000000..9402603 --- /dev/null +++ b/server/listener/ftcp/listener.go @@ -0,0 +1,162 @@ +package ftcp + +import ( + "errors" + "net" + "sync" + "sync/atomic" + + "github.com/go-gost/gost/logger" + "github.com/go-gost/gost/server/listener" + "github.com/xtaci/tcpraw" +) + +var ( + _ listener.Listener = (*Listener)(nil) +) + +type Listener struct { + md metadata + conn net.PacketConn + connChan chan net.Conn + errChan chan error + connPool connPool + logger logger.Logger +} + +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } +} + +func (l *Listener) Init(md listener.Metadata) (err error) { + l.md, err = l.parseMetadata(md) + if err != nil { + return + } + + l.conn, err = tcpraw.Listen("tcp", addr) + if err != nil { + return + } + + l.connChan = make(chan net.Conn, l.md.connQueueSize) + l.errChan = make(chan error, 1) + + go l.listenLoop() + + return +} + +func (l *Listener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = listener.ErrClosed + } + } + return +} + +func (l *Listener) Close() error { + err := l.conn.Close() + l.connPool.Range(func(k interface{}, v *serverConn) bool { + v.Close() + return true + }) + return err +} + +func (l *Listener) Addr() net.Addr { + return l.conn.LocalAddr() +} + +func (l *Listener) listenLoop() { + for { + b := make([]byte, l.md.readBufferSize) + + n, raddr, err := l.conn.ReadFrom(b) + if err != nil { + l.logger.Error("accept:", err) + l.errChan <- err + close(l.errChan) + return + } + + conn, ok := l.connPool.Get(raddr.String()) + if !ok { + conn = newServerConn(l.conn, raddr, + &serverConnConfig{ + ttl: l.md.ttl, + qsize: l.md.readQueueSize, + onClose: func() { + l.connPool.Delete(raddr.String()) + }, + }) + + select { + case l.connChan <- conn: + l.connPool.Set(raddr.String(), conn) + default: + conn.Close() + l.logger.Error("connection queue is full") + } + } + + if err := conn.send(b[:n]); err != nil { + l.logger.Warn("data discarded:", err) + } + l.logger.Debug("recv", n) + } +} + +func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { + if val, ok := md[addr]; ok { + m.addr = val + } else { + err = errors.New("missing address") + return + } + + return +} + +type connPool struct { + size int64 + m sync.Map +} + +func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) { + v, ok := p.m.Load(key) + if ok { + conn, ok = v.(*serverConn) + } + return +} + +func (p *connPool) Set(key interface{}, conn *serverConn) { + p.m.Store(key, conn) + atomic.AddInt64(&p.size, 1) +} + +func (p *connPool) Delete(key interface{}) { + p.m.Delete(key) + atomic.AddInt64(&p.size, -1) +} + +func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) { + p.m.Range(func(k, v interface{}) bool { + return f(k, v.(*serverConn)) + }) +} + +func (p *connPool) Size() int64 { + return atomic.LoadInt64(&p.size) +} diff --git a/server/listener/ftcp/metadata.go b/server/listener/ftcp/metadata.go new file mode 100644 index 0000000..38cedc9 --- /dev/null +++ b/server/listener/ftcp/metadata.go @@ -0,0 +1,23 @@ +package ftcp + +import "time" + +const ( + defaultTTL = 60 * time.Second + defaultReadBufferSize = 1024 + defaultReadQueueSize = 128 + defaultConnQueueSize = 128 +) + +const ( + addr = "addr" +) + +type metadata struct { + addr string + ttl time.Duration + + readBufferSize int + readQueueSize int + connQueueSize int +} diff --git a/server/listener/udp/listener.go b/server/listener/udp/listener.go index 9d18584..f328489 100644 --- a/server/listener/udp/listener.go +++ b/server/listener/udp/listener.go @@ -72,7 +72,12 @@ func (l *Listener) Accept() (conn net.Conn, err error) { } func (l *Listener) Close() error { - return l.conn.Close() + err := l.conn.Close() + l.connPool.Range(func(k interface{}, v *serverConn) bool { + v.Close() + return true + }) + return err } func (l *Listener) Addr() net.Addr {