diff --git a/limiter/traffic/traffic.go b/limiter/traffic/traffic.go index e42f7b2..3c21446 100644 --- a/limiter/traffic/traffic.go +++ b/limiter/traffic/traffic.go @@ -98,10 +98,8 @@ func LoggerOption(logger logger.Logger) Option { } type trafficLimiter struct { - ipLimits map[string]TrafficLimitGenerator + limits map[string]TrafficLimitGenerator cidrLimits cidranger.Ranger - inLimits map[string]limiter.Limiter - outLimits map[string]limiter.Limiter mu sync.Mutex cancelFunc context.CancelFunc options options @@ -115,10 +113,8 @@ func NewTrafficLimiter(opts ...Option) limiter.TrafficLimiter { ctx, cancel := context.WithCancel(context.TODO()) lim := &trafficLimiter{ - ipLimits: make(map[string]TrafficLimitGenerator), + limits: make(map[string]TrafficLimitGenerator), cidrLimits: cidranger.NewPCTrieRanger(), - inLimits: make(map[string]limiter.Limiter), - outLimits: make(map[string]limiter.Limiter), options: options, cancelFunc: cancel, } @@ -136,15 +132,11 @@ func (l *trafficLimiter) In(key string) limiter.Limiter { l.mu.Lock() defer l.mu.Unlock() - if lim, ok := l.inLimits[key]; ok { - return lim - } - var lims []limiter.Limiter if ip := net.ParseIP(key); ip != nil { found := false - if p := l.ipLimits[key]; p != nil { + if p := l.limits[key]; p != nil { if lim := p.In(); lim != nil { lims = append(lims, lim) found = true @@ -161,12 +153,12 @@ func (l *trafficLimiter) In(key string) limiter.Limiter { } } - if p := l.ipLimits[ConnLimitKey]; p != nil { + if p := l.limits[ConnLimitKey]; p != nil { if lim := p.In(); lim != nil { lims = append(lims, lim) } } - if p := l.ipLimits[GlobalLimitKey]; p != nil { + if p := l.limits[GlobalLimitKey]; p != nil { if lim := p.In(); lim != nil { lims = append(lims, lim) } @@ -176,7 +168,6 @@ func (l *trafficLimiter) In(key string) limiter.Limiter { if len(lims) > 0 { lim = newLimiterGroup(lims...) } - l.inLimits[key] = lim if lim != nil && l.options.logger != nil { l.options.logger.Debugf("input limit for %s: %d", key, lim.Limit()) @@ -189,15 +180,11 @@ func (l *trafficLimiter) Out(key string) limiter.Limiter { l.mu.Lock() defer l.mu.Unlock() - if lim, ok := l.outLimits[key]; ok { - return lim - } - var lims []limiter.Limiter if ip := net.ParseIP(key); ip != nil { found := false - if p := l.ipLimits[key]; p != nil { + if p := l.limits[key]; p != nil { if lim := p.Out(); lim != nil { lims = append(lims, lim) found = true @@ -214,12 +201,12 @@ func (l *trafficLimiter) Out(key string) limiter.Limiter { } } - if p := l.ipLimits[ConnLimitKey]; p != nil { + if p := l.limits[ConnLimitKey]; p != nil { if lim := p.Out(); lim != nil { lims = append(lims, lim) } } - if p := l.ipLimits[GlobalLimitKey]; p != nil { + if p := l.limits[GlobalLimitKey]; p != nil { if lim := p.Out(); lim != nil { lims = append(lims, lim) } @@ -229,7 +216,6 @@ func (l *trafficLimiter) Out(key string) limiter.Limiter { if len(lims) > 0 { lim = newLimiterGroup(lims...) } - l.outLimits[key] = lim if lim != nil && l.options.logger != nil { l.options.logger.Debugf("output limit for %s: %d", key, lim.Limit()) @@ -267,7 +253,7 @@ func (l *trafficLimiter) reload(ctx context.Context) error { lines := append(l.options.limits, v...) - ipLimits := make(map[string]TrafficLimitGenerator) + limits := make(map[string]TrafficLimitGenerator) cidrLimits := cidranger.NewPCTrieRanger() for _, s := range lines { @@ -277,12 +263,12 @@ func (l *trafficLimiter) reload(ctx context.Context) error { } switch key { case GlobalLimitKey: - ipLimits[key] = NewTrafficLimitSingleGenerator(in, out) + limits[key] = NewTrafficLimitSingleGenerator(in, out) case ConnLimitKey: - ipLimits[key] = NewTrafficLimitGenerator(in, out) + limits[key] = NewTrafficLimitGenerator(in, out) default: if ip := net.ParseIP(key); ip != nil { - ipLimits[key] = NewTrafficLimitGenerator(in, out) + limits[key] = NewTrafficLimitGenerator(in, out) break } if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil { @@ -297,10 +283,8 @@ func (l *trafficLimiter) reload(ctx context.Context) error { l.mu.Lock() defer l.mu.Unlock() - l.ipLimits = ipLimits + l.limits = limits l.cidrLimits = cidrLimits - l.inLimits = make(map[string]limiter.Limiter) - l.outLimits = make(map[string]limiter.Limiter) return nil } diff --git a/limiter/traffic/wrapper/conn.go b/limiter/traffic/wrapper/conn.go index 8b171b8..d75a4f7 100644 --- a/limiter/traffic/wrapper/conn.go +++ b/limiter/traffic/wrapper/conn.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net" + "sync" "syscall" limiter "github.com/go-gost/core/limiter/traffic" @@ -20,9 +21,9 @@ var ( // serverConn is a server side Conn with metrics supported. type serverConn struct { net.Conn - rbuf bytes.Buffer - raddr string - limiter limiter.TrafficLimiter + rbuf bytes.Buffer + limiterIn limiter.Limiter + limiterOut limiter.Limiter } func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn { @@ -31,26 +32,23 @@ func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn { } host, _, _ := net.SplitHostPort(c.RemoteAddr().String()) return &serverConn{ - Conn: c, - limiter: rlimiter, - raddr: host, + Conn: c, + limiterIn: rlimiter.In(host), + limiterOut: rlimiter.Out(host), } } func (c *serverConn) Read(b []byte) (n int, err error) { - if c.limiter == nil || - c.limiter.In(c.raddr) == nil { + if c.limiterIn == nil { return c.Conn.Read(b) } - limiter := c.limiter.In(c.raddr) - if c.rbuf.Len() > 0 { burst := len(b) if c.rbuf.Len() < burst { burst = c.rbuf.Len() } - lim := limiter.Wait(context.Background(), burst) + lim := c.limiterIn.Wait(context.Background(), burst) return c.rbuf.Read(b[:lim]) } @@ -59,7 +57,7 @@ func (c *serverConn) Read(b []byte) (n int, err error) { return nn, err } - n = limiter.Wait(context.Background(), nn) + n = c.limiterIn.Wait(context.Background(), nn) if n < nn { if _, err = c.rbuf.Write(b[n:nn]); err != nil { return 0, err @@ -70,15 +68,13 @@ func (c *serverConn) Read(b []byte) (n int, err error) { } func (c *serverConn) Write(b []byte) (n int, err error) { - if c.limiter == nil || - c.limiter.Out(c.raddr) == nil { + if c.limiterOut == nil { return c.Conn.Write(b) } - limiter := c.limiter.Out(c.raddr) nn := 0 for len(b) > 0 { - nn, err = c.Conn.Write(b[:limiter.Wait(context.Background(), len(b))]) + nn, err = c.Conn.Write(b[:c.limiterOut.Wait(context.Background(), len(b))]) n += nn if err != nil { return @@ -100,19 +96,79 @@ func (c *serverConn) SyscallConn() (rc syscall.RawConn, err error) { type packetConn struct { net.PacketConn - limiter limiter.TrafficLimiter + limiter limiter.TrafficLimiter + inLimits map[string]limiter.Limiter + inMux sync.RWMutex + outLimits map[string]limiter.Limiter + outMux sync.RWMutex } -func WrapPacketConn(limiter limiter.TrafficLimiter, pc net.PacketConn) net.PacketConn { - if limiter == nil { +func WrapPacketConn(lim limiter.TrafficLimiter, pc net.PacketConn) net.PacketConn { + if lim == nil { return pc } return &packetConn{ PacketConn: pc, - limiter: limiter, + limiter: lim, + inLimits: make(map[string]limiter.Limiter), + outLimits: make(map[string]limiter.Limiter), } } +func (c *packetConn) getInLimiter(addr net.Addr) limiter.Limiter { + if c.limiter == nil { + return nil + } + + lim, ok := func() (limiter.Limiter, bool) { + c.inMux.RLock() + defer c.inMux.RUnlock() + + lim, ok := c.inLimits[addr.String()] + return lim, ok + }() + if ok { + return lim + } + + host, _, _ := net.SplitHostPort(addr.String()) + lim = c.limiter.In(host) + + c.inMux.Lock() + defer c.inMux.Unlock() + + c.inLimits[addr.String()] = lim + + return lim +} + +func (c *packetConn) getOutLimiter(addr net.Addr) limiter.Limiter { + if c.limiter == nil { + return nil + } + + lim, ok := func() (limiter.Limiter, bool) { + c.outMux.RLock() + defer c.outMux.RUnlock() + + lim, ok := c.outLimits[addr.String()] + return lim, ok + }() + if ok { + return lim + } + + host, _, _ := net.SplitHostPort(addr.String()) + lim = c.limiter.Out(host) + + c.outMux.Lock() + defer c.outMux.Unlock() + + c.outLimits[addr.String()] = lim + + return lim +} + func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { for { n, addr, err = c.PacketConn.ReadFrom(p) @@ -120,13 +176,11 @@ func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return } - host, _, _ := net.SplitHostPort(addr.String()) - - if c.limiter == nil || c.limiter.In(host) == nil { + limiter := c.getInLimiter(addr) + if limiter == nil { return } - limiter := c.limiter.In(host) // discard when exceed the limit size. if limiter.Wait(context.Background(), n) < n { continue @@ -137,14 +191,11 @@ func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.limiter != nil { - host, _, _ := net.SplitHostPort(addr.String()) - // discard when exceed the limit size. - if limiter := c.limiter.Out(host); limiter != nil && - limiter.Wait(context.Background(), len(p)) < len(p) { - n = len(p) - return - } + // discard when exceed the limit size. + if limiter := c.getOutLimiter(addr); limiter != nil && + limiter.Wait(context.Background(), len(p)) < len(p) { + n = len(p) + return } return c.PacketConn.WriteTo(p, addr) @@ -152,7 +203,11 @@ func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { type udpConn struct { net.PacketConn - limiter limiter.TrafficLimiter + limiter limiter.TrafficLimiter + inLimits map[string]limiter.Limiter + inMux sync.RWMutex + outLimits map[string]limiter.Limiter + outMux sync.RWMutex } func WrapUDPConn(limiter limiter.TrafficLimiter, pc net.PacketConn) udp.Conn { @@ -162,6 +217,60 @@ func WrapUDPConn(limiter limiter.TrafficLimiter, pc net.PacketConn) udp.Conn { } } +func (c *udpConn) getInLimiter(addr net.Addr) limiter.Limiter { + if c.limiter == nil { + return nil + } + + lim, ok := func() (limiter.Limiter, bool) { + c.inMux.RLock() + defer c.inMux.RUnlock() + + lim, ok := c.inLimits[addr.String()] + return lim, ok + }() + if ok { + return lim + } + + host, _, _ := net.SplitHostPort(addr.String()) + lim = c.limiter.In(host) + + c.inMux.Lock() + defer c.inMux.Unlock() + + c.inLimits[addr.String()] = lim + + return lim +} + +func (c *udpConn) getOutLimiter(addr net.Addr) limiter.Limiter { + if c.limiter == nil { + return nil + } + + lim, ok := func() (limiter.Limiter, bool) { + c.outMux.RLock() + defer c.outMux.RUnlock() + + lim, ok := c.outLimits[addr.String()] + return lim, ok + }() + if ok { + return lim + } + + host, _, _ := net.SplitHostPort(addr.String()) + lim = c.limiter.Out(host) + + c.outMux.Lock() + defer c.outMux.Unlock() + + c.outLimits[addr.String()] = lim + + return lim +} + func (c *udpConn) RemoteAddr() net.Addr { if nc, ok := c.PacketConn.(xnet.RemoteAddr); ok { return nc.RemoteAddr() @@ -198,14 +307,10 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { if err != nil { return } - host, _, _ := net.SplitHostPort(addr.String()) - if c.limiter == nil || c.limiter.In(host) == nil { - return - } - limiter := c.limiter.In(host) // discard when exceed the limit size. - if limiter.Wait(context.Background(), n) < n { + if limiter := c.getInLimiter(addr); limiter != nil && + limiter.Wait(context.Background(), n) < n { continue } return @@ -220,14 +325,9 @@ func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { return } - host, _, _ := net.SplitHostPort(addr.String()) - - if c.limiter == nil || c.limiter.In(host) == nil { - return - } - limiter := c.limiter.In(host) // discard when exceed the limit size. - if limiter.Wait(context.Background(), n) < n { + if limiter := c.getInLimiter(addr); limiter != nil && + limiter.Wait(context.Background(), n) < n { continue } return @@ -245,14 +345,9 @@ func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAd return } - host, _, _ := net.SplitHostPort(addr.String()) - - if c.limiter == nil || c.limiter.In(host) == nil { - return - } - limiter := c.limiter.In(host) // discard when exceed the limit size. - if limiter.Wait(context.Background(), n) < n { + if limiter := c.getInLimiter(addr); limiter != nil && + limiter.Wait(context.Background(), n) < n { continue } return @@ -272,14 +367,11 @@ func (c *udpConn) Write(b []byte) (n int, err error) { } func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.limiter != nil { - host, _, _ := net.SplitHostPort(addr.String()) - // discard when exceed the limit size. - if limiter := c.limiter.Out(host); limiter != nil && - limiter.Wait(context.Background(), len(p)) < len(p) { - n = len(p) - return - } + // discard when exceed the limit size. + if limiter := c.getOutLimiter(addr); limiter != nil && + limiter.Wait(context.Background(), len(p)) < len(p) { + n = len(p) + return } n, err = c.PacketConn.WriteTo(p, addr) @@ -287,14 +379,11 @@ func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { - if c.limiter != nil { - host, _, _ := net.SplitHostPort(addr.String()) - // discard when exceed the limit size. - if limiter := c.limiter.Out(host); limiter != nil && - limiter.Wait(context.Background(), len(b)) < len(b) { - n = len(b) - return - } + // discard when exceed the limit size. + if limiter := c.getOutLimiter(addr); limiter != nil && + limiter.Wait(context.Background(), len(b)) < len(b) { + n = len(b) + return } if nc, ok := c.PacketConn.(udp.WriteUDP); ok { @@ -306,14 +395,11 @@ func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { } func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { - if c.limiter != nil { - host, _, _ := net.SplitHostPort(addr.String()) - // discard when exceed the limit size. - if limiter := c.limiter.Out(host); limiter != nil && - limiter.Wait(context.Background(), len(b)) < len(b) { - n = len(b) - return - } + // discard when exceed the limit size. + if limiter := c.getOutLimiter(addr); limiter != nil && + limiter.Wait(context.Background(), len(b)) < len(b) { + n = len(b) + return } if nc, ok := c.PacketConn.(udp.WriteUDP); ok {