From b25f90c55e98fa997022db9c5eacaef53a4640ae Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 21 Apr 2022 21:40:08 +0800 Subject: [PATCH] add limiter --- limiter/rate.go | 12 ++++ limiter/wrapper/conn.go | 130 ++++++++++++++++++++++++++++++++++++ limiter/wrapper/listener.go | 32 +++++++++ listener/option.go | 24 ++++--- 4 files changed, 190 insertions(+), 8 deletions(-) create mode 100644 limiter/rate.go create mode 100644 limiter/wrapper/conn.go create mode 100644 limiter/wrapper/listener.go diff --git a/limiter/rate.go b/limiter/rate.go new file mode 100644 index 0000000..1edd939 --- /dev/null +++ b/limiter/rate.go @@ -0,0 +1,12 @@ +package limiter + +type Limiter interface { + // Limit checks the requested size b and returns the limit size, + // the returned value is less or equal to b. + Limit(b int) int +} + +type RateLimiter interface { + Input() Limiter + Output() Limiter +} diff --git a/limiter/wrapper/conn.go b/limiter/wrapper/conn.go new file mode 100644 index 0000000..29d044c --- /dev/null +++ b/limiter/wrapper/conn.go @@ -0,0 +1,130 @@ +package wrapper + +import ( + "bytes" + "errors" + "net" + "syscall" + + "github.com/go-gost/core/limiter" +) + +var ( + errUnsupport = errors.New("unsupported operation") +) + +// serverConn is a server side Conn with metrics supported. +type serverConn struct { + net.Conn + rlimiter limiter.RateLimiter + rbuf bytes.Buffer +} + +func WrapConn(rlimiter limiter.RateLimiter, c net.Conn) net.Conn { + if rlimiter == nil { + return c + } + return &serverConn{ + Conn: c, + rlimiter: rlimiter, + } +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + if c.rlimiter == nil || c.rlimiter.Input() == nil { + return c.Conn.Read(b) + } + + burst := len(b) + if c.rbuf.Len() > 0 { + if c.rbuf.Len() < burst { + burst = c.rbuf.Len() + } + return c.rbuf.Read(b[:c.rlimiter.Input().Limit(burst)]) + } + + nn, err := c.Conn.Read(b) + if err != nil { + return nn, err + } + + n = c.rlimiter.Input().Limit(nn) + if n < nn { + if _, err = c.rbuf.Write(b[n:nn]); err != nil { + return 0, err + } + } + + return +} + +func (c *serverConn) Write(b []byte) (n int, err error) { + if c.rlimiter == nil || c.rlimiter.Output() == nil { + return c.Conn.Write(b) + } + + nn := 0 + for len(b) > 0 { + nn, err = c.Conn.Write(b[:c.rlimiter.Output().Limit(len(b))]) + n += nn + if err != nil { + return + } + b = b[nn:] + } + + return +} + +func (c *serverConn) SyscallConn() (rc syscall.RawConn, err error) { + if sc, ok := c.Conn.(syscall.Conn); ok { + rc, err = sc.SyscallConn() + return + } + err = errUnsupport + return +} + +type packetConn struct { + net.PacketConn + rlimiter limiter.RateLimiter +} + +func WrapPacketConn(rlimiter limiter.RateLimiter, pc net.PacketConn) net.PacketConn { + if rlimiter == nil { + return pc + } + return &packetConn{ + PacketConn: pc, + rlimiter: rlimiter, + } +} + +func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + if c.rlimiter == nil || c.rlimiter.Input() == nil { + return + } + + if c.rlimiter.Input().Limit(n) < n { + continue + } + + return + } +} + +func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.rlimiter != nil && + c.rlimiter.Output() != nil && + c.rlimiter.Output().Limit(len(p)) < len(p) { + n = len(p) + return + } + + return c.PacketConn.WriteTo(p, addr) +} diff --git a/limiter/wrapper/listener.go b/limiter/wrapper/listener.go new file mode 100644 index 0000000..42a2c8a --- /dev/null +++ b/limiter/wrapper/listener.go @@ -0,0 +1,32 @@ +package wrapper + +import ( + "net" + + "github.com/go-gost/core/limiter" +) + +type listener struct { + net.Listener + rlimiter limiter.RateLimiter +} + +func WrapListener(rlimiter limiter.RateLimiter, ln net.Listener) net.Listener { + if rlimiter == nil { + return ln + } + + return &listener{ + rlimiter: rlimiter, + Listener: ln, + } +} + +func (ln *listener) Accept() (net.Conn, error) { + c, err := ln.Listener.Accept() + if err != nil { + return nil, err + } + + return WrapConn(ln.rlimiter, c), nil +} diff --git a/listener/option.go b/listener/option.go index 5a553c8..9424061 100644 --- a/listener/option.go +++ b/listener/option.go @@ -7,18 +7,20 @@ import ( "github.com/go-gost/core/admission" "github.com/go-gost/core/auth" "github.com/go-gost/core/chain" + "github.com/go-gost/core/limiter" "github.com/go-gost/core/logger" ) type Options struct { - Addr string - Auther auth.Authenticator - Auth *url.Userinfo - TLSConfig *tls.Config - Admission admission.Admission - Chain chain.Chainer - Logger logger.Logger - Service string + Addr string + Auther auth.Authenticator + Auth *url.Userinfo + TLSConfig *tls.Config + Admission admission.Admission + RateLimiter limiter.RateLimiter + Chain chain.Chainer + Logger logger.Logger + Service string } type Option func(opts *Options) @@ -53,6 +55,12 @@ func AdmissionOption(admission admission.Admission) Option { } } +func RateLimiterOption(rlimiter limiter.RateLimiter) Option { + return func(opts *Options) { + opts.RateLimiter = rlimiter + } +} + func ChainOption(chain chain.Chainer) Option { return func(opts *Options) { opts.Chain = chain