fix traffic limiter

This commit is contained in:
ginuerzh
2022-12-19 19:33:29 +08:00
parent 1cb719f694
commit 15feb7599e
6 changed files with 438 additions and 246 deletions

View File

@ -6,40 +6,63 @@ import (
"errors"
"io"
"net"
"sync"
"syscall"
"time"
limiter "github.com/go-gost/core/limiter/traffic"
xnet "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/internal/net/udp"
"github.com/patrickmn/go-cache"
)
var (
errUnsupport = errors.New("unsupported operation")
)
// serverConn is a server side Conn with metrics supported.
// serverConn is a server side Conn with traffic limiter supported.
type serverConn struct {
net.Conn
rbuf bytes.Buffer
limiter limiter.TrafficLimiter
limiterIn limiter.Limiter
expIn int64
limiterOut limiter.Limiter
expOut int64
}
func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn {
if rlimiter == nil {
func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn {
if limiter == nil {
return c
}
host, _, _ := net.SplitHostPort(c.RemoteAddr().String())
return &serverConn{
Conn: c,
limiterIn: rlimiter.In(host),
limiterOut: rlimiter.Out(host),
Conn: c,
limiter: limiter,
}
}
func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
now := time.Now().UnixNano()
// cache the limiter for 1s
if c.limiter != nil && time.Duration(now-c.expIn) > time.Second {
c.limiterIn = c.limiter.In(addr.String())
c.expIn = now
}
return c.limiterIn
}
func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter {
now := time.Now().UnixNano()
// cache the limiter for 1s
if c.limiter != nil && time.Duration(now-c.expOut) > time.Second {
c.limiterOut = c.limiter.Out(addr.String())
c.expOut = now
}
return c.limiterOut
}
func (c *serverConn) Read(b []byte) (n int, err error) {
if c.limiterIn == nil {
limiter := c.getInLimiter(c.RemoteAddr())
if limiter == nil {
return c.Conn.Read(b)
}
@ -48,7 +71,7 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
if c.rbuf.Len() < burst {
burst = c.rbuf.Len()
}
lim := c.limiterIn.Wait(context.Background(), burst)
lim := limiter.Wait(context.Background(), burst)
return c.rbuf.Read(b[:lim])
}
@ -57,7 +80,7 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
return nn, err
}
n = c.limiterIn.Wait(context.Background(), nn)
n = limiter.Wait(context.Background(), nn)
if n < nn {
if _, err = c.rbuf.Write(b[n:nn]); err != nil {
return 0, err
@ -68,13 +91,14 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
}
func (c *serverConn) Write(b []byte) (n int, err error) {
if c.limiterOut == nil {
limiter := c.getOutLimiter(c.RemoteAddr())
if limiter == nil {
return c.Conn.Write(b)
}
nn := 0
for len(b) > 0 {
nn, err = c.Conn.Write(b[:c.limiterOut.Wait(context.Background(), len(b))])
nn, err = c.Conn.Write(b[:limiter.Wait(context.Background(), len(b))])
n += nn
if err != nil {
return
@ -97,10 +121,8 @@ func (c *serverConn) SyscallConn() (rc syscall.RawConn, err error) {
type packetConn struct {
net.PacketConn
limiter limiter.TrafficLimiter
inLimits map[string]limiter.Limiter
inMux sync.RWMutex
outLimits map[string]limiter.Limiter
outMux sync.RWMutex
inLimits *cache.Cache
outLimits *cache.Cache
}
func WrapPacketConn(lim limiter.TrafficLimiter, pc net.PacketConn) net.PacketConn {
@ -110,8 +132,8 @@ func WrapPacketConn(lim limiter.TrafficLimiter, pc net.PacketConn) net.PacketCon
return &packetConn{
PacketConn: pc,
limiter: lim,
inLimits: make(map[string]limiter.Limiter),
outLimits: make(map[string]limiter.Limiter),
inLimits: cache.New(time.Second, 10*time.Second),
outLimits: cache.New(time.Second, 10*time.Second),
}
}
@ -120,24 +142,21 @@ func (c *packetConn) getInLimiter(addr net.Addr) limiter.Limiter {
return nil
}
lim, ok := func() (limiter.Limiter, bool) {
c.inMux.RLock()
defer c.inMux.RUnlock()
lim, ok := c.inLimits[addr.String()]
return lim, ok
lim, ok := func() (lim limiter.Limiter, ok bool) {
v, ok := c.inLimits.Get(addr.String())
if ok {
if v != nil {
lim = v.(limiter.Limiter)
}
}
return
}()
if ok {
return lim
}
host, _, _ := net.SplitHostPort(addr.String())
lim = c.limiter.In(host)
c.inMux.Lock()
defer c.inMux.Unlock()
c.inLimits[addr.String()] = lim
lim = c.limiter.In(addr.String())
c.inLimits.Set(addr.String(), lim, 0)
return lim
}
@ -147,24 +166,21 @@ func (c *packetConn) getOutLimiter(addr net.Addr) limiter.Limiter {
return nil
}
lim, ok := func() (limiter.Limiter, bool) {
c.outMux.RLock()
defer c.outMux.RUnlock()
lim, ok := c.outLimits[addr.String()]
return lim, ok
lim, ok := func() (lim limiter.Limiter, ok bool) {
v, ok := c.outLimits.Get(addr.String())
if ok {
if v != nil {
lim = v.(limiter.Limiter)
}
}
return
}()
if ok {
return lim
}
host, _, _ := net.SplitHostPort(addr.String())
lim = c.limiter.Out(host)
c.outMux.Lock()
defer c.outMux.Unlock()
c.outLimits[addr.String()] = lim
lim = c.limiter.Out(addr.String())
c.outLimits.Set(addr.String(), lim, 0)
return lim
}
@ -204,16 +220,16 @@ func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
type udpConn struct {
net.PacketConn
limiter limiter.TrafficLimiter
inLimits map[string]limiter.Limiter
inMux sync.RWMutex
outLimits map[string]limiter.Limiter
outMux sync.RWMutex
inLimits *cache.Cache
outLimits *cache.Cache
}
func WrapUDPConn(limiter limiter.TrafficLimiter, pc net.PacketConn) udp.Conn {
return &udpConn{
PacketConn: pc,
limiter: limiter,
inLimits: cache.New(time.Second, 10*time.Second),
outLimits: cache.New(time.Second, 10*time.Second),
}
}
@ -222,24 +238,21 @@ func (c *udpConn) getInLimiter(addr net.Addr) limiter.Limiter {
return nil
}
lim, ok := func() (limiter.Limiter, bool) {
c.inMux.RLock()
defer c.inMux.RUnlock()
lim, ok := c.inLimits[addr.String()]
return lim, ok
lim, ok := func() (lim limiter.Limiter, ok bool) {
v, ok := c.inLimits.Get(addr.String())
if ok {
if v != nil {
lim = v.(limiter.Limiter)
}
}
return
}()
if ok {
return lim
}
host, _, _ := net.SplitHostPort(addr.String())
lim = c.limiter.In(host)
c.inMux.Lock()
defer c.inMux.Unlock()
c.inLimits[addr.String()] = lim
lim = c.limiter.In(addr.String())
c.inLimits.Set(addr.String(), lim, 0)
return lim
}
@ -249,24 +262,21 @@ func (c *udpConn) getOutLimiter(addr net.Addr) limiter.Limiter {
return nil
}
lim, ok := func() (limiter.Limiter, bool) {
c.outMux.RLock()
defer c.outMux.RUnlock()
lim, ok := c.outLimits[addr.String()]
return lim, ok
lim, ok := func() (lim limiter.Limiter, ok bool) {
v, ok := c.outLimits.Get(addr.String())
if ok {
if v != nil {
lim = v.(limiter.Limiter)
}
}
return
}()
if ok {
return lim
}
host, _, _ := net.SplitHostPort(addr.String())
lim = c.limiter.Out(host)
c.outMux.Lock()
defer c.outMux.Unlock()
c.outLimits[addr.String()] = lim
lim = c.limiter.Out(addr.String())
c.outLimits.Set(addr.String(), lim, 0)
return lim
}