diff --git a/limiter/conn/conn.go b/limiter/conn/conn.go index 6f06dbe..c49fa1e 100644 --- a/limiter/conn/conn.go +++ b/limiter/conn/conn.go @@ -238,7 +238,7 @@ func (l *connLimiter) reload(ctx context.Context) error { ipLimits[key] = NewConnLimitGenerator(limit) default: if ip := net.ParseIP(key); ip != nil { - ipLimits[key] = NewConnLimitGenerator(limit) + ipLimits[key] = NewConnLimitSingleGenerator(limit) break } if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil { diff --git a/limiter/rate/rate.go b/limiter/rate/rate.go index 7483265..326588e 100644 --- a/limiter/rate/rate.go +++ b/limiter/rate/rate.go @@ -231,7 +231,7 @@ func (l *rateLimiter) reload(ctx context.Context) error { ipLimits[key] = NewRateLimitGenerator(limit) default: if ip := net.ParseIP(key); ip != nil { - ipLimits[key] = NewRateLimitGenerator(limit) + ipLimits[key] = NewRateLimitSingleGenerator(limit) break } if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil { diff --git a/limiter/traffic/traffic.go b/limiter/traffic/traffic.go index 3c21446..26309d8 100644 --- a/limiter/traffic/traffic.go +++ b/limiter/traffic/traffic.go @@ -100,6 +100,8 @@ func LoggerOption(logger logger.Logger) Option { type trafficLimiter struct { limits map[string]TrafficLimitGenerator cidrLimits cidranger.Ranger + inLimits map[string]limiter.Limiter + outLimits map[string]limiter.Limiter mu sync.Mutex cancelFunc context.CancelFunc options options @@ -115,6 +117,8 @@ func NewTrafficLimiter(opts ...Option) limiter.TrafficLimiter { lim := &trafficLimiter{ limits: make(map[string]TrafficLimitGenerator), cidrLimits: cidranger.NewPCTrieRanger(), + inLimits: make(map[string]limiter.Limiter), + outLimits: make(map[string]limiter.Limiter), options: options, cancelFunc: cancel, } @@ -134,25 +138,6 @@ func (l *trafficLimiter) In(key string) limiter.Limiter { var lims []limiter.Limiter - if ip := net.ParseIP(key); ip != nil { - found := false - if p := l.limits[key]; p != nil { - if lim := p.In(); lim != nil { - lims = append(lims, lim) - found = true - } - } - if !found { - if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 { - if v, _ := p[0].(*cidrLimitEntry); v != nil { - if lim := v.limit.In(); lim != nil { - lims = append(lims, lim) - } - } - } - } - } - if p := l.limits[ConnLimitKey]; p != nil { if lim := p.In(); lim != nil { lims = append(lims, lim) @@ -164,6 +149,31 @@ func (l *trafficLimiter) In(key string) limiter.Limiter { } } + // IP limiter + if lim, ok := l.inLimits[key]; ok { + if lim != nil { + lims = append(lims, lim) + } + } else { + if ip := net.ParseIP(key); ip != nil { + if p := l.limits[key]; p != nil { + if lim = p.In(); lim != nil { + lims = append(lims, lim) + } + } + if lim == nil { + if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 { + if v, _ := p[0].(*cidrLimitEntry); v != nil { + if lim = v.limit.In(); lim != nil { + lims = append(lims, lim) + } + } + } + } + l.inLimits[key] = lim + } + } + var lim limiter.Limiter if len(lims) > 0 { lim = newLimiterGroup(lims...) @@ -182,25 +192,6 @@ func (l *trafficLimiter) Out(key string) limiter.Limiter { var lims []limiter.Limiter - if ip := net.ParseIP(key); ip != nil { - found := false - if p := l.limits[key]; p != nil { - if lim := p.Out(); lim != nil { - lims = append(lims, lim) - found = true - } - } - if !found { - if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 { - if v, _ := p[0].(*cidrLimitEntry); v != nil { - if lim := v.limit.Out(); lim != nil { - lims = append(lims, lim) - } - } - } - } - } - if p := l.limits[ConnLimitKey]; p != nil { if lim := p.Out(); lim != nil { lims = append(lims, lim) @@ -212,6 +203,31 @@ func (l *trafficLimiter) Out(key string) limiter.Limiter { } } + // IP limiter + if lim, ok := l.outLimits[key]; ok { + if lim != nil { + lims = append(lims, lim) + } + } else { + if ip := net.ParseIP(key); ip != nil { + if p := l.limits[key]; p != nil { + if lim = p.Out(); lim != nil { + lims = append(lims, lim) + } + } + if lim == nil { + if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 { + if v, _ := p[0].(*cidrLimitEntry); v != nil { + if lim = v.limit.Out(); lim != nil { + lims = append(lims, lim) + } + } + } + } + l.outLimits[key] = lim + } + } + var lim limiter.Limiter if len(lims) > 0 { lim = newLimiterGroup(lims...) @@ -268,7 +284,7 @@ func (l *trafficLimiter) reload(ctx context.Context) error { limits[key] = NewTrafficLimitGenerator(in, out) default: if ip := net.ParseIP(key); ip != nil { - limits[key] = NewTrafficLimitGenerator(in, out) + limits[key] = NewTrafficLimitSingleGenerator(in, out) break } if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil { @@ -285,6 +301,8 @@ func (l *trafficLimiter) reload(ctx context.Context) error { l.limits = limits l.cidrLimits = cidrLimits + l.inLimits = make(map[string]limiter.Limiter) + l.outLimits = make(map[string]limiter.Limiter) return nil }