diff --git a/cmd/gost/main.go b/cmd/gost/main.go index ea9a84f..b092a25 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -13,7 +13,7 @@ import ( ) var ( - log = logger.NewLogger() + log = logger.Default() cfgFile string outputCfgFile string diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 70d8951..6460525 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -6,7 +6,9 @@ import ( "fmt" "net" + "github.com/go-gost/gost/pkg/common/util/udp" "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/logger" ) var ( @@ -96,9 +98,9 @@ func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Co return d.DialContext(ctx, network, address) } -func (r *Route) Bind(ctx context.Context, network, address string) (connector.Accepter, error) { +func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { if r.IsEmpty() { - return r.bindLocal(ctx, network, address) + return r.bindLocal(ctx, network, address, opts...) } conn, err := r.Connect(ctx) @@ -106,29 +108,13 @@ func (r *Route) Bind(ctx context.Context, network, address string) (connector.Ac return nil, err } - accepter, err := r.Last().transport.Bind(ctx, conn, network, address) + ln, err := r.Last().transport.Bind(ctx, conn, network, address, opts...) if err != nil { conn.Close() return nil, err } - return accepter, nil -} - -func (r *Route) bindLocal(ctx context.Context, network, address string) (connector.Accepter, error) { - switch network { - case "tcp", "tcp4", "tcp6": - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err - } - return net.ListenTCP(network, addr) - case "udp", "udp4", "udp6": - return nil, nil - default: - err := fmt.Errorf("network %s unsupported", network) - return nil, err - } + return ln, nil } func (r *Route) IsEmpty() bool { @@ -155,3 +141,39 @@ func (r *Route) Path() (path []*Node) { } return } + +func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { + options := connector.BindOptions{} + for _, opt := range opts { + opt(&options) + } + + switch network { + case "tcp", "tcp4", "tcp6": + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + return net.ListenTCP(network, addr) + case "udp", "udp4", "udp6": + addr, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP(network, addr) + if err != nil { + return nil, err + } + logger := logger.Default().WithFields(map[string]interface{}{ + "network": network, + "address": address, + }) + ln := udp.NewListener(conn, addr, + options.Backlog, options.UDPDataQueueSize, options.UDPDataBufferSize, + options.UDPConnTTL, logger) + return ln, err + default: + err := fmt.Errorf("network %s unsupported", network) + return nil, err + } +} diff --git a/pkg/chain/router.go b/pkg/chain/router.go index acbff74..3401df2 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -60,35 +60,6 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co return } -func (r *Router) Bind(ctx context.Context, network, address string) (accepter connector.Accepter, err error) { - count := r.retries + 1 - if count <= 0 { - count = 1 - } - r.logger.Debugf("bind: %s/%s", address, network) - - for i := 0; i < count; i++ { - route := r.chain.GetRouteFor(network, address) - - if r.logger.IsLevelEnabled(logger.DebugLevel) { - buf := bytes.Buffer{} - for _, node := range route.Path() { - fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) - } - fmt.Fprintf(&buf, "%s", address) - r.logger.Debugf("route(retry=%d): %s", i, buf.String()) - } - - accepter, err = route.Bind(ctx, network, address) - if err == nil { - break - } - r.logger.Errorf("route(retry=%d): %s", i, err) - } - - return -} - func (r *Router) Connect(ctx context.Context) (conn net.Conn, err error) { count := r.retries + 1 if count <= 0 { @@ -115,3 +86,32 @@ func (r *Router) Connect(ctx context.Context) (conn net.Conn, err error) { return } + +func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) { + count := r.retries + 1 + if count <= 0 { + count = 1 + } + r.logger.Debugf("bind: %s/%s", address, network) + + for i := 0; i < count; i++ { + route := r.chain.GetRouteFor(network, address) + + if r.logger.IsLevelEnabled(logger.DebugLevel) { + buf := bytes.Buffer{} + for _, node := range route.Path() { + fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) + } + fmt.Fprintf(&buf, "%s", address) + r.logger.Debugf("route(retry=%d): %s", i, buf.String()) + } + + ln, err = route.Bind(ctx, network, address, opts...) + if err == nil { + break + } + r.logger.Errorf("route(retry=%d): %s", i, err) + } + + return +} diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index 14b6920..d74c0af 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -62,9 +62,9 @@ func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, addres return tr.connector.Connect(ctx, conn, network, address) } -func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { +func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { if binder, ok := tr.connector.(connector.Binder); ok { - return binder.Bind(ctx, conn, network, address, connector.MuxBindOption(true)) + return binder.Bind(ctx, conn, network, address, opts...) } return nil, connector.ErrBindUnsupported } diff --git a/pkg/common/util/udp/listener.go b/pkg/common/util/udp/listener.go new file mode 100644 index 0000000..022e32d --- /dev/null +++ b/pkg/common/util/udp/listener.go @@ -0,0 +1,111 @@ +package udp + +import ( + "net" + "sync" + "time" + + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/logger" +) + +type listener struct { + addr net.Addr + conn net.PacketConn + cqueue chan net.Conn + readQueueSize int + readBufferSize int + connPool *ConnPool + mux sync.Mutex + closed chan struct{} + logger logger.Logger +} + +func NewListener(conn net.PacketConn, addr net.Addr, backlog, dataQueueSize, dataBufferSize int, ttl time.Duration, logger logger.Logger) net.Listener { + ln := &listener{ + conn: conn, + addr: addr, + cqueue: make(chan net.Conn, backlog), + connPool: NewConnPool(ttl).WithLogger(logger), + readQueueSize: dataQueueSize, + readBufferSize: dataBufferSize, + closed: make(chan struct{}), + logger: logger, + } + go ln.listenLoop() + + return ln +} + +func (ln *listener) Accept() (conn net.Conn, err error) { + select { + case conn = <-ln.cqueue: + return + case <-ln.closed: + return nil, net.ErrClosed + } +} + +func (ln *listener) listenLoop() { + for { + select { + case <-ln.closed: + return + default: + } + + b := bufpool.Get(ln.readBufferSize) + + n, raddr, err := ln.conn.ReadFrom(b) + if err != nil { + return + } + + c := ln.getConn(raddr) + if c == nil { + bufpool.Put(b) + continue + } + + if err := c.WriteQueue(b[:n]); err != nil { + ln.logger.Warn("data discarded: ", err) + } + } +} + +func (ln *listener) Addr() net.Addr { + return ln.addr +} + +func (ln *listener) Close() error { + select { + case <-ln.closed: + default: + close(ln.closed) + ln.conn.Close() + ln.connPool.Close() + } + + return nil +} + +func (ln *listener) getConn(raddr net.Addr) *Conn { + ln.mux.Lock() + defer ln.mux.Unlock() + + c, ok := ln.connPool.Get(raddr.String()) + if ok { + return c + } + + c = NewConn(ln.conn, ln.addr, raddr, ln.readQueueSize) + select { + case ln.cqueue <- c: + ln.connPool.Set(raddr.String(), c) + return c + default: + c.Close() + ln.logger.Warnf("connection queue is full, client %s discarded", raddr) + return nil + } +} diff --git a/pkg/connector/bind.go b/pkg/connector/binder.go similarity index 93% rename from pkg/connector/bind.go rename to pkg/connector/binder.go index 4888b18..287ea41 100644 --- a/pkg/connector/bind.go +++ b/pkg/connector/binder.go @@ -17,7 +17,7 @@ type Accepter interface { } type Binder interface { - Bind(ctx context.Context, conn net.Conn, network, address string, opts ...BindOption) (Accepter, error) + Bind(ctx context.Context, conn net.Conn, network, address string, opts ...BindOption) (net.Listener, error) } type AcceptError struct { diff --git a/pkg/connector/option.go b/pkg/connector/option.go index 88f634e..1832b5a 100644 --- a/pkg/connector/option.go +++ b/pkg/connector/option.go @@ -1,6 +1,8 @@ package connector import ( + "time" + "github.com/go-gost/gost/pkg/logger" ) @@ -22,7 +24,11 @@ type ConnectOptions struct { type ConnectOption func(opts *ConnectOptions) type BindOptions struct { - Mux bool + Mux bool + Backlog int + UDPDataQueueSize int + UDPDataBufferSize int + UDPConnTTL time.Duration } type BindOption func(opts *BindOptions) @@ -32,3 +38,27 @@ func MuxBindOption(mux bool) BindOption { opts.Mux = mux } } + +func BacklogBindOption(backlog int) BindOption { + return func(opts *BindOptions) { + opts.Backlog = backlog + } +} + +func UDPDataQueueSizeBindOption(size int) BindOption { + return func(opts *BindOptions) { + opts.UDPDataQueueSize = size + } +} + +func UDPDataBufferSizeBindOption(size int) BindOption { + return func(opts *BindOptions) { + opts.UDPDataBufferSize = size + } +} + +func UDPConnTTLBindOption(ttl time.Duration) BindOption { + return func(opts *BindOptions) { + opts.UDPConnTTL = ttl + } +} diff --git a/pkg/connector/socks/v5/accepter.go b/pkg/connector/socks/v5/accepter.go deleted file mode 100644 index 7a23de5..0000000 --- a/pkg/connector/socks/v5/accepter.go +++ /dev/null @@ -1,191 +0,0 @@ -package v5 - -import ( - "fmt" - "io" - "net" - - "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/common/bufpool" - "github.com/go-gost/gost/pkg/common/util/mux" - "github.com/go-gost/gost/pkg/common/util/udp" - "github.com/go-gost/gost/pkg/logger" -) - -type tcpAccepter struct { - addr net.Addr - conn net.Conn - logger logger.Logger - done chan struct{} -} - -func (p *tcpAccepter) Accept() (net.Conn, error) { - select { - case <-p.done: - return nil, io.EOF - default: - close(p.done) - } - - // second reply, peer connected - rep, err := gosocks5.ReadReply(p.conn) - if err != nil { - return nil, err - } - p.logger.Debug(rep) - - if rep.Rep != gosocks5.Succeeded { - return nil, fmt.Errorf("peer connect failed") - } - - raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) - if err != nil { - return nil, err - } - - return &bindConn{ - Conn: p.conn, - localAddr: p.addr, - remoteAddr: raddr, - }, nil -} - -func (p *tcpAccepter) Addr() net.Addr { - return p.addr -} - -func (p *tcpAccepter) Close() error { - return p.conn.Close() -} - -type tcpMuxAccepter struct { - addr net.Addr - session *mux.Session - logger logger.Logger -} - -func (p *tcpMuxAccepter) Accept() (net.Conn, error) { - cc, err := p.session.Accept() - if err != nil { - return nil, err - } - - conn, err := p.getPeerConn(cc) - if err != nil { - cc.Close() - return nil, err - } - - return conn, nil -} - -func (p *tcpMuxAccepter) getPeerConn(conn net.Conn) (net.Conn, error) { - // second reply, peer connected - rep, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - p.logger.Debug(rep) - - if rep.Rep != gosocks5.Succeeded { - err = fmt.Errorf("peer connect failed") - return nil, err - } - - raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) - if err != nil { - return nil, err - } - - return &bindConn{ - Conn: conn, - localAddr: p.addr, - remoteAddr: raddr, - }, nil -} - -func (p *tcpMuxAccepter) Addr() net.Addr { - return p.addr -} - -func (p *tcpMuxAccepter) Close() error { - return p.session.Close() -} - -type udpAccepter struct { - addr net.Addr - conn net.PacketConn - cqueue chan net.Conn - connPool *udp.ConnPool - readQueueSize int - readBufferSize int - closed chan struct{} - logger logger.Logger -} - -func (p *udpAccepter) Accept() (conn net.Conn, err error) { - select { - case conn = <-p.cqueue: - return - case <-p.closed: - return nil, net.ErrClosed - } -} - -func (p *udpAccepter) acceptLoop() { - for { - select { - case <-p.closed: - return - default: - } - - b := bufpool.Get(p.readBufferSize) - - n, raddr, err := p.conn.ReadFrom(b) - if err != nil { - return - } - - c := p.getConn(raddr) - if c == nil { - bufpool.Put(b) - continue - } - - if err := c.WriteQueue(b[:n]); err != nil { - p.logger.Warn("data discarded: ", err) - } - } -} - -func (p *udpAccepter) Addr() net.Addr { - return p.addr -} - -func (p *udpAccepter) Close() error { - select { - case <-p.closed: - default: - close(p.closed) - p.connPool.Close() - } - - return nil -} - -func (p *udpAccepter) getConn(raddr net.Addr) *udp.Conn { - c, ok := p.connPool.Get(raddr.String()) - if !ok { - c = udp.NewConn(p.conn, p.addr, raddr, p.readQueueSize) - select { - case p.cqueue <- c: - p.connPool.Set(raddr.String(), c) - default: - c.Close() - p.logger.Warnf("connection queue is full, client %s discarded", raddr) - return nil - } - } - return c -} diff --git a/pkg/connector/socks/v5/bind.go b/pkg/connector/socks/v5/bind.go index 6d945f8..8d62c44 100644 --- a/pkg/connector/socks/v5/bind.go +++ b/pkg/connector/socks/v5/bind.go @@ -13,7 +13,7 @@ import ( ) // Bind implements connector.Binder. -func (c *socks5Connector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (connector.Accepter, error) { +func (c *socks5Connector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { c.logger = c.logger.WithFields(map[string]interface{}{ "network": network, "address": address, @@ -32,7 +32,7 @@ func (c *socks5Connector) Bind(ctx context.Context, conn net.Conn, network, addr } return c.bindTCP(ctx, conn, network, address) case "udp", "udp4", "udp6": - return c.bindUDP(ctx, conn, network, address) + return c.bindUDP(ctx, conn, network, address, &options) default: err := fmt.Errorf("network %s is unsupported", network) c.logger.Error(err) @@ -40,21 +40,20 @@ func (c *socks5Connector) Bind(ctx context.Context, conn net.Conn, network, addr } } -func (c *socks5Connector) bindTCP(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { +func (c *socks5Connector) bindTCP(ctx context.Context, conn net.Conn, network, address string) (net.Listener, error) { laddr, err := c.bind(conn, gosocks5.CmdBind, network, address) if err != nil { return nil, err } - return &tcpAccepter{ + return &tcpListener{ addr: laddr, conn: conn, logger: c.logger, - done: make(chan struct{}), }, nil } -func (c *socks5Connector) muxBindTCP(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { +func (c *socks5Connector) muxBindTCP(ctx context.Context, conn net.Conn, network, address string) (net.Listener, error) { laddr, err := c.bind(conn, socks.CmdMuxBind, network, address) if err != nil { return nil, err @@ -65,42 +64,33 @@ func (c *socks5Connector) muxBindTCP(ctx context.Context, conn net.Conn, network return nil, err } - return &tcpMuxAccepter{ + return &tcpMuxListener{ addr: laddr, session: session, logger: c.logger, }, nil } -func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, address string) (connector.Accepter, error) { +func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, address string, opts *connector.BindOptions) (net.Listener, error) { laddr, err := c.bind(conn, socks.CmdUDPTun, network, address) if err != nil { return nil, err } - accepter := &udpAccepter{ - addr: laddr, - conn: socks.UDPTunClientPacketConn(conn), - cqueue: make(chan net.Conn, c.md.backlog), - connPool: udp.NewConnPool(c.md.ttl).WithLogger(c.logger), - readQueueSize: c.md.readQueueSize, - readBufferSize: c.md.readBufferSize, - closed: make(chan struct{}), - logger: c.logger, - } - go accepter.acceptLoop() + ln := udp.NewListener( + socks.UDPTunClientPacketConn(conn), + laddr, + opts.Backlog, + opts.UDPDataQueueSize, opts.UDPDataBufferSize, + opts.UDPConnTTL, + c.logger) - return accepter, nil + return ln, nil } func (l *socks5Connector) bind(conn net.Conn, cmd uint8, network, address string) (net.Addr, error) { - laddr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err - } - addr := gosocks5.Addr{} - addr.ParseFrom(laddr.String()) + addr.ParseFrom(address) req := gosocks5.NewRequest(cmd, &addr) if err := req.Write(conn); err != nil { return nil, err @@ -116,7 +106,7 @@ func (l *socks5Connector) bind(conn net.Conn, cmd uint8, network, address string l.logger.Debug(reply) if reply.Rep != gosocks5.Succeeded { - return nil, fmt.Errorf("bind on %s/%s failed", laddr, laddr.Network()) + return nil, fmt.Errorf("bind on %s/%s failed", address, network) } var baddr net.Addr @@ -133,5 +123,5 @@ func (l *socks5Connector) bind(conn net.Conn, cmd uint8, network, address string } l.logger.Debugf("bind on %s/%s OK", baddr, baddr.Network()) - return laddr, nil + return baddr, nil } diff --git a/pkg/connector/socks/v5/listener.go b/pkg/connector/socks/v5/listener.go new file mode 100644 index 0000000..8519875 --- /dev/null +++ b/pkg/connector/socks/v5/listener.go @@ -0,0 +1,102 @@ +package v5 + +import ( + "fmt" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/common/util/mux" + "github.com/go-gost/gost/pkg/logger" +) + +type tcpListener struct { + addr net.Addr + conn net.Conn + logger logger.Logger +} + +func (p *tcpListener) Accept() (net.Conn, error) { + // second reply, peer connected + rep, err := gosocks5.ReadReply(p.conn) + if err != nil { + return nil, err + } + p.logger.Debug(rep) + + if rep.Rep != gosocks5.Succeeded { + return nil, fmt.Errorf("peer connect failed") + } + + raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) + if err != nil { + return nil, err + } + + return &bindConn{ + Conn: p.conn, + localAddr: p.addr, + remoteAddr: raddr, + }, nil +} + +func (p *tcpListener) Addr() net.Addr { + return p.addr +} + +func (p *tcpListener) Close() error { + return p.conn.Close() +} + +type tcpMuxListener struct { + addr net.Addr + session *mux.Session + logger logger.Logger +} + +func (p *tcpMuxListener) Accept() (net.Conn, error) { + cc, err := p.session.Accept() + if err != nil { + return nil, err + } + + conn, err := p.getPeerConn(cc) + if err != nil { + cc.Close() + return nil, err + } + + return conn, nil +} + +func (p *tcpMuxListener) getPeerConn(conn net.Conn) (net.Conn, error) { + // second reply, peer connected + rep, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + p.logger.Debug(rep) + + if rep.Rep != gosocks5.Succeeded { + err = fmt.Errorf("peer connect failed") + return nil, err + } + + raddr, err := net.ResolveTCPAddr("tcp", rep.Addr.String()) + if err != nil { + return nil, err + } + + return &bindConn{ + Conn: conn, + localAddr: p.addr, + remoteAddr: raddr, + }, nil +} + +func (p *tcpMuxListener) Addr() net.Addr { + return p.addr +} + +func (p *tcpMuxListener) Close() error { + return p.session.Close() +} diff --git a/pkg/connector/socks/v5/metadata.go b/pkg/connector/socks/v5/metadata.go index e2919b3..336130f 100644 --- a/pkg/connector/socks/v5/metadata.go +++ b/pkg/connector/socks/v5/metadata.go @@ -9,23 +9,11 @@ import ( md "github.com/go-gost/gost/pkg/metadata" ) -const ( - defaultTTL = 60 * time.Second - defaultReadBufferSize = 4096 - defaultReadQueueSize = 128 - defaultBacklog = 128 -) - type metadata struct { connectTimeout time.Duration User *url.Userinfo tlsConfig *tls.Config noTLS bool - - ttl time.Duration - readBufferSize int - readQueueSize int - backlog int } func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { @@ -33,11 +21,6 @@ func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { connectTimeout = "timeout" auth = "auth" noTLS = "notls" - - ttl = "ttl" - readBufferSize = "readBufferSize" - readQueueSize = "readQueueSize" - backlog = "backlog" ) if v := md.GetString(auth); v != "" { @@ -52,23 +35,5 @@ func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { c.md.connectTimeout = md.GetDuration(connectTimeout) c.md.noTLS = md.GetBool(noTLS) - c.md.ttl = md.GetDuration(ttl) - if c.md.ttl <= 0 { - c.md.ttl = defaultTTL - } - c.md.readBufferSize = md.GetInt(readBufferSize) - if c.md.readBufferSize <= 0 { - c.md.readBufferSize = defaultReadBufferSize - } - - c.md.readQueueSize = md.GetInt(readQueueSize) - if c.md.readQueueSize <= 0 { - c.md.readQueueSize = defaultReadQueueSize - } - - c.md.backlog = md.GetInt(backlog) - if c.md.backlog <= 0 { - c.md.backlog = defaultBacklog - } return } diff --git a/pkg/listener/rtcp/listener.go b/pkg/listener/rtcp/listener.go index ec31849..68e87ba 100644 --- a/pkg/listener/rtcp/listener.go +++ b/pkg/listener/rtcp/listener.go @@ -17,13 +17,13 @@ func init() { } type rtcpListener struct { - addr string - laddr net.Addr - chain *chain.Chain - accepter connector.Accepter - md metadata - logger logger.Logger - closed chan struct{} + addr string + laddr net.Addr + chain *chain.Chain + ln net.Listener + md metadata + logger logger.Logger + closed chan struct{} } func NewListener(opts ...listener.Option) listener.Listener { @@ -58,6 +58,34 @@ func (l *rtcpListener) Init(md md.Metadata) (err error) { return } +func (l *rtcpListener) Accept() (conn net.Conn, err error) { + select { + case <-l.closed: + return nil, net.ErrClosed + default: + } + + if l.ln == nil { + r := (&chain.Router{}). + WithChain(l.chain). + WithRetry(l.md.retryCount). + WithLogger(l.logger) + l.ln, err = r.Bind(context.Background(), "tcp", l.laddr.String(), + connector.MuxBindOption(true), + ) + if err != nil { + return nil, connector.NewAcceptError(err) + } + } + conn, err = l.ln.Accept() + if err != nil { + l.ln.Close() + l.ln = nil + return nil, connector.NewAcceptError(err) + } + return +} + func (l *rtcpListener) Addr() net.Addr { return l.laddr } @@ -67,27 +95,11 @@ func (l *rtcpListener) Close() error { case <-l.closed: default: close(l.closed) + if l.ln != nil { + l.ln.Close() + l.ln = nil + } } return nil } - -func (l *rtcpListener) Accept() (conn net.Conn, err error) { - if l.accepter == nil { - r := (&chain.Router{}). - WithChain(l.chain). - WithRetry(l.md.retryCount). - WithLogger(l.logger) - l.accepter, err = r.Bind(context.Background(), "tcp", l.laddr.String()) - if err != nil { - return nil, connector.NewAcceptError(err) - } - } - conn, err = l.accepter.Accept() - if err != nil { - l.accepter.Close() - l.accepter = nil - return nil, connector.NewAcceptError(err) - } - return -} diff --git a/pkg/listener/rudp/listener.go b/pkg/listener/rudp/listener.go index 56032e1..931eb93 100644 --- a/pkg/listener/rudp/listener.go +++ b/pkg/listener/rudp/listener.go @@ -2,15 +2,10 @@ package rudp import ( "context" - "fmt" "net" - "time" - "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/chain" - "github.com/go-gost/gost/pkg/common/bufpool" - "github.com/go-gost/gost/pkg/common/util/socks" - "github.com/go-gost/gost/pkg/common/util/udp" + "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -22,14 +17,13 @@ func init() { } type rudpListener struct { - addr string - laddr *net.UDPAddr - chain *chain.Chain - md metadata - cqueue chan net.Conn - closed chan struct{} - connPool *udp.ConnPool - logger logger.Logger + addr string + laddr *net.UDPAddr + chain *chain.Chain + ln net.Listener + md metadata + logger logger.Logger + closed chan struct{} } func NewListener(opts ...listener.Option) listener.Listener { @@ -60,21 +54,43 @@ func (l *rudpListener) Init(md md.Metadata) (err error) { } l.laddr = laddr - l.cqueue = make(chan net.Conn, l.md.backlog) - l.connPool = udp.NewConnPool(l.md.ttl).WithLogger(l.logger) - - go l.listenLoop() return } func (l *rudpListener) Accept() (conn net.Conn, err error) { select { - case conn = <-l.cqueue: - return case <-l.closed: - return nil, listener.ErrClosed + return nil, net.ErrClosed + default: } + + if l.ln == nil { + r := (&chain.Router{}). + WithChain(l.chain). + WithRetry(l.md.retryCount). + WithLogger(l.logger) + l.ln, err = r.Bind(context.Background(), "udp", l.laddr.String(), + connector.BacklogBindOption(l.md.backlog), + connector.UDPConnTTLBindOption(l.md.ttl), + connector.UDPDataBufferSizeBindOption(l.md.readBufferSize), + connector.UDPDataQueueSizeBindOption(l.md.readQueueSize), + ) + if err != nil { + return nil, connector.NewAcceptError(err) + } + } + conn, err = l.ln.Accept() + if err != nil { + l.ln.Close() + l.ln = nil + return nil, connector.NewAcceptError(err) + } + return +} + +func (l *rudpListener) Addr() net.Addr { + return l.laddr } func (l *rudpListener) Close() error { @@ -82,136 +98,11 @@ func (l *rudpListener) Close() error { case <-l.closed: default: close(l.closed) - l.connPool.Close() + if l.ln != nil { + l.ln.Close() + l.ln = nil + } } return nil } - -func (l *rudpListener) Addr() net.Addr { - return l.laddr -} - -func (l *rudpListener) listenLoop() { - for { - conn, err := l.connect() - if err != nil { - l.logger.Error(err) - return - } - - func() { - defer conn.Close() - - for { - b := bufpool.Get(l.md.readBufferSize) - - n, raddr, err := conn.ReadFrom(b) - if err != nil { - return - } - - c := l.getConn(conn, raddr) - if c == nil { - bufpool.Put(b) - continue - } - - if err := c.WriteQueue(b[:n]); err != nil { - l.logger.Warn("data discarded: ", err) - } - } - }() - } -} - -func (l *rudpListener) connect() (conn net.PacketConn, err error) { - var tempDelay time.Duration - - for { - select { - case <-l.closed: - return nil, net.ErrClosed - default: - } - - conn, err = func() (net.PacketConn, error) { - if l.chain.IsEmpty() { - return net.ListenUDP("udp", l.laddr) - } - r := (&chain.Router{}). - WithChain(l.chain). - WithRetry(l.md.retryCount). - WithLogger(l.logger) - cc, err := r.Connect(context.Background()) - if err != nil { - return nil, err - } - - conn, err := l.initUDPTunnel(cc) - if err != nil { - cc.Close() - return nil, err - } - return conn, err - }() - if err == nil { - return - } - - if tempDelay == 0 { - tempDelay = 1000 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 6 * time.Second; tempDelay > max { - tempDelay = max - } - l.logger.Warnf("accept: %v, retrying in %v", err, tempDelay) - time.Sleep(tempDelay) - } -} - -func (l *rudpListener) initUDPTunnel(conn net.Conn) (net.PacketConn, error) { - socksAddr := gosocks5.Addr{} - socksAddr.ParseFrom(l.laddr.String()) - req := gosocks5.NewRequest(socks.CmdUDPTun, &socksAddr) - if err := req.Write(conn); err != nil { - return nil, err - } - l.logger.Debug(req) - - reply, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - l.logger.Debug(reply) - - if reply.Rep != gosocks5.Succeeded { - return nil, fmt.Errorf("bind on %s failed", l.laddr) - } - - baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) - if err != nil { - return nil, err - } - l.logger.Debugf("bind on %s OK", baddr) - - return socks.UDPTunClientPacketConn(conn), nil -} - -func (l *rudpListener) getConn(conn net.PacketConn, raddr net.Addr) *udp.Conn { - c, ok := l.connPool.Get(raddr.String()) - if !ok { - c = udp.NewConn(conn, l.laddr, raddr, l.md.readQueueSize) - select { - case l.cqueue <- c: - l.connPool.Set(raddr.String(), c) - default: - c.Close() - l.logger.Warnf("connection queue is full, client %s discarded", raddr.String()) - return nil - } - } - return c -} diff --git a/pkg/listener/rudp/metadata.go b/pkg/listener/rudp/metadata.go index 986b051..f389955 100644 --- a/pkg/listener/rudp/metadata.go +++ b/pkg/listener/rudp/metadata.go @@ -7,8 +7,8 @@ import ( ) const ( - defaultTTL = 60 * time.Second - defaultReadBufferSize = 4096 + defaultTTL = 5 * time.Second + defaultReadBufferSize = 1024 defaultReadQueueSize = 128 defaultBacklog = 128 ) diff --git a/pkg/listener/udp/listener.go b/pkg/listener/udp/listener.go index 075c1a1..1c9c8f0 100644 --- a/pkg/listener/udp/listener.go +++ b/pkg/listener/udp/listener.go @@ -3,7 +3,6 @@ package udp import ( "net" - "github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/common/util/udp" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -16,14 +15,10 @@ func init() { } type udpListener struct { - addr string - md metadata - conn net.PacketConn - cqueue chan net.Conn - errChan chan error - closed chan struct{} - connPool *udp.ConnPool - logger logger.Logger + addr string + md metadata + net.Listener + logger logger.Logger } func NewListener(opts ...listener.Option) listener.Listener { @@ -32,10 +27,8 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &udpListener{ - addr: options.Addr, - errChan: make(chan error, 1), - closed: make(chan struct{}), - logger: options.Logger, + addr: options.Addr, + logger: options.Logger, } } @@ -49,82 +42,15 @@ func (l *udpListener) Init(md md.Metadata) (err error) { return } - l.conn, err = net.ListenUDP("udp", laddr) + conn, err := net.ListenUDP("udp", laddr) if err != nil { return } - l.cqueue = make(chan net.Conn, l.md.backlog) - l.connPool = udp.NewConnPool(l.md.ttl).WithLogger(l.logger) - - go l.listenLoop() - + l.Listener = udp.NewListener(conn, laddr, + l.md.backlog, + l.md.readQueueSize, l.md.readBufferSize, + l.md.ttl, + l.logger) return } - -func (l *udpListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.cqueue: - case err, ok = <-l.errChan: - if !ok { - err = listener.ErrClosed - } - } - return -} - -func (l *udpListener) Close() error { - select { - case <-l.closed: - default: - close(l.closed) - l.connPool.Close() - return l.conn.Close() - } - - return nil -} - -func (l *udpListener) Addr() net.Addr { - return l.conn.LocalAddr() -} - -func (l *udpListener) listenLoop() { - for { - b := bufpool.Get(l.md.readBufferSize) - - n, raddr, err := l.conn.ReadFrom(b) - if err != nil { - l.errChan <- err - close(l.errChan) - return - } - - c := l.getConn(raddr) - if c == nil { - bufpool.Put(b) - continue - } - - if err := c.WriteQueue(b[:n]); err != nil { - l.logger.Warn("data discarded: ", err) - } - } -} - -func (l *udpListener) getConn(addr net.Addr) *udp.Conn { - c, ok := l.connPool.Get(addr.String()) - if !ok { - c = udp.NewConn(l.conn, l.conn.LocalAddr(), addr, l.md.readQueueSize) - select { - case l.cqueue <- c: - l.connPool.Set(addr.String(), c) - default: - c.Close() - l.logger.Warnf("connection queue is full, client %s discarded", addr.String()) - return nil - } - } - return c -} diff --git a/pkg/listener/udp/metadata.go b/pkg/listener/udp/metadata.go index b2dd89f..e1c2c94 100644 --- a/pkg/listener/udp/metadata.go +++ b/pkg/listener/udp/metadata.go @@ -7,8 +7,8 @@ import ( ) const ( - defaultTTL = 60 * time.Second - defaultReadBufferSize = 4096 + defaultTTL = 5 * time.Second + defaultReadBufferSize = 1024 defaultReadQueueSize = 128 defaultBacklog = 128 ) diff --git a/pkg/logger/gost_logger.go b/pkg/logger/gost_logger.go index 28c125b..773ae11 100644 --- a/pkg/logger/gost_logger.go +++ b/pkg/logger/gost_logger.go @@ -8,10 +8,54 @@ import ( "github.com/sirupsen/logrus" ) +var ( + defaultLogger = NewLogger() +) + +func Default() Logger { + return defaultLogger +} + type logger struct { logger *logrus.Entry } +func NewLogger(opts ...LoggerOption) Logger { + var options LoggerOptions + for _, opt := range opts { + opt(&options) + } + + log := logrus.New() + if options.Output != nil { + log.SetOutput(options.Output) + } + + switch options.Format { + case TextFormat: + log.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + }) + default: + log.SetFormatter(&logrus.JSONFormatter{ + DisableHTMLEscape: true, + // PrettyPrint: true, + }) + } + + switch options.Level { + case DebugLevel, InfoLevel, WarnLevel, ErrorLevel, FatalLevel: + lvl, _ := logrus.ParseLevel(string(options.Level)) + log.SetLevel(lvl) + default: + log.SetLevel(logrus.InfoLevel) + } + + return &logger{ + logger: logrus.NewEntry(log), + } +} + // WithFields adds new fields to log. func (l *logger) WithFields(fields map[string]interface{}) Logger { return &logger{ diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 2c7dd10..a823639 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -2,8 +2,6 @@ package logger import ( "io" - - "github.com/sirupsen/logrus" ) // LogFormat is format type @@ -71,39 +69,3 @@ func LevelLoggerOption(level LogLevel) LoggerOption { opts.Level = level } } - -func NewLogger(opts ...LoggerOption) Logger { - var options LoggerOptions - for _, opt := range opts { - opt(&options) - } - - log := logrus.New() - if options.Output != nil { - log.SetOutput(options.Output) - } - - switch options.Format { - case TextFormat: - log.SetFormatter(&logrus.TextFormatter{ - FullTimestamp: true, - }) - default: - log.SetFormatter(&logrus.JSONFormatter{ - DisableHTMLEscape: true, - // PrettyPrint: true, - }) - } - - switch options.Level { - case DebugLevel, InfoLevel, WarnLevel, ErrorLevel, FatalLevel: - lvl, _ := logrus.ParseLevel(string(options.Level)) - log.SetLevel(lvl) - default: - log.SetLevel(logrus.InfoLevel) - } - - return &logger{ - logger: logrus.NewEntry(log), - } -}