update limiter
This commit is contained in:
@ -102,7 +102,7 @@ type connLimiter struct {
|
||||
ipLimits map[string]ConnLimitGenerator
|
||||
cidrLimits cidranger.Ranger
|
||||
limits map[string]limiter.Limiter
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
cancelFunc context.CancelFunc
|
||||
options options
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ func (l *llimiter) Limit() int {
|
||||
}
|
||||
|
||||
func (l *llimiter) Allow(n int) bool {
|
||||
if atomic.AddInt64(&l.current, int64(n)) >= int64(l.limit) {
|
||||
if atomic.AddInt64(&l.current, int64(n)) > int64(l.limit) {
|
||||
if n > 0 {
|
||||
atomic.AddInt64(&l.current, -int64(n))
|
||||
}
|
||||
|
44
limiter/rate/generator.go
Normal file
44
limiter/rate/generator.go
Normal file
@ -0,0 +1,44 @@
|
||||
package rate
|
||||
|
||||
import (
|
||||
"github.com/go-gost/core/limiter/rate"
|
||||
limiter "github.com/go-gost/core/limiter/rate"
|
||||
)
|
||||
|
||||
type RateLimitGenerator interface {
|
||||
Limiter() limiter.Limiter
|
||||
}
|
||||
|
||||
type rateLimitGenerator struct {
|
||||
r float64
|
||||
}
|
||||
|
||||
func NewRateLimitGenerator(r float64) RateLimitGenerator {
|
||||
return &rateLimitGenerator{
|
||||
r: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *rateLimitGenerator) Limiter() limiter.Limiter {
|
||||
if p == nil || p.r <= 0 {
|
||||
return nil
|
||||
}
|
||||
return NewLimiter(p.r, int(p.r)+1)
|
||||
}
|
||||
|
||||
type rateLimitSingleGenerator struct {
|
||||
limiter rate.Limiter
|
||||
}
|
||||
|
||||
func NewRateLimitSingleGenerator(r float64) RateLimitGenerator {
|
||||
p := &rateLimitSingleGenerator{}
|
||||
if r > 0 {
|
||||
p.limiter = NewLimiter(r, int(r)+1)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *rateLimitSingleGenerator) Limiter() limiter.Limiter {
|
||||
return p.limiter
|
||||
}
|
26
limiter/rate/limiter.go
Normal file
26
limiter/rate/limiter.go
Normal file
@ -0,0 +1,26 @@
|
||||
package rate
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
limiter "github.com/go-gost/core/limiter/rate"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type rlimiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func NewLimiter(r float64, b int) limiter.Limiter {
|
||||
return &rlimiter{
|
||||
limiter: rate.NewLimiter(rate.Limit(r), b),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *rlimiter) Allow(n int) bool {
|
||||
return l.limiter.AllowN(time.Now(), n)
|
||||
}
|
||||
|
||||
func (l *rlimiter) Limit() float64 {
|
||||
return float64(l.limiter.Limit())
|
||||
}
|
353
limiter/rate/rate.go
Normal file
353
limiter/rate/rate.go
Normal file
@ -0,0 +1,353 @@
|
||||
package rate
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
limiter "github.com/go-gost/core/limiter/rate"
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/x/internal/loader"
|
||||
"github.com/yl2chen/cidranger"
|
||||
)
|
||||
|
||||
const (
|
||||
GlobalLimitKey = "$"
|
||||
IPLimitKey = "$$"
|
||||
)
|
||||
|
||||
type limiterGroup struct {
|
||||
limiters []limiter.Limiter
|
||||
}
|
||||
|
||||
func newLimiterGroup(limiters ...limiter.Limiter) *limiterGroup {
|
||||
sort.Slice(limiters, func(i, j int) bool {
|
||||
return limiters[i].Limit() < limiters[j].Limit()
|
||||
})
|
||||
return &limiterGroup{limiters: limiters}
|
||||
}
|
||||
|
||||
func (l *limiterGroup) Allow(n int) (b bool) {
|
||||
b = true
|
||||
for i := range l.limiters {
|
||||
if v := l.limiters[i].Allow(n); !v {
|
||||
b = false
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *limiterGroup) Limit() float64 {
|
||||
if len(l.limiters) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return l.limiters[0].Limit()
|
||||
}
|
||||
|
||||
type options struct {
|
||||
limits []string
|
||||
fileLoader loader.Loader
|
||||
redisLoader loader.Loader
|
||||
period time.Duration
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
type Option func(opts *options)
|
||||
|
||||
func LimitsOption(limits ...string) Option {
|
||||
return func(opts *options) {
|
||||
opts.limits = limits
|
||||
}
|
||||
}
|
||||
|
||||
func ReloadPeriodOption(period time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.period = period
|
||||
}
|
||||
}
|
||||
|
||||
func FileLoaderOption(fileLoader loader.Loader) Option {
|
||||
return func(opts *options) {
|
||||
opts.fileLoader = fileLoader
|
||||
}
|
||||
}
|
||||
|
||||
func RedisLoaderOption(redisLoader loader.Loader) Option {
|
||||
return func(opts *options) {
|
||||
opts.redisLoader = redisLoader
|
||||
}
|
||||
}
|
||||
|
||||
func LoggerOption(logger logger.Logger) Option {
|
||||
return func(opts *options) {
|
||||
opts.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
ipLimits map[string]RateLimitGenerator
|
||||
cidrLimits cidranger.Ranger
|
||||
limits map[string]limiter.Limiter
|
||||
mu sync.Mutex
|
||||
cancelFunc context.CancelFunc
|
||||
options options
|
||||
}
|
||||
|
||||
func NewRateLimiter(opts ...Option) limiter.RateLimiter {
|
||||
var options options
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
lim := &rateLimiter{
|
||||
ipLimits: make(map[string]RateLimitGenerator),
|
||||
cidrLimits: cidranger.NewPCTrieRanger(),
|
||||
limits: make(map[string]limiter.Limiter),
|
||||
options: options,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
if err := lim.reload(ctx); err != nil {
|
||||
options.logger.Warnf("reload: %v", err)
|
||||
}
|
||||
if lim.options.period > 0 {
|
||||
go lim.periodReload(ctx)
|
||||
}
|
||||
return lim
|
||||
}
|
||||
|
||||
func (l *rateLimiter) Limiter(key string) limiter.Limiter {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if lim, ok := l.limits[key]; ok {
|
||||
return lim
|
||||
}
|
||||
|
||||
var lims []limiter.Limiter
|
||||
|
||||
if ip := net.ParseIP(key); ip != nil {
|
||||
found := false
|
||||
if p := l.ipLimits[key]; p != nil {
|
||||
if lim := p.Limiter(); lim != nil {
|
||||
lims = append(lims, lim)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 {
|
||||
if v, _ := p[0].(*cidrLimitEntry); v != nil {
|
||||
if lim := v.limit.Limiter(); lim != nil {
|
||||
lims = append(lims, lim)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(lims) == 0 {
|
||||
if p := l.ipLimits[IPLimitKey]; p != nil {
|
||||
if lim := p.Limiter(); lim != nil {
|
||||
lims = append(lims, lim)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p := l.ipLimits[GlobalLimitKey]; p != nil {
|
||||
if lim := p.Limiter(); lim != nil {
|
||||
lims = append(lims, lim)
|
||||
}
|
||||
}
|
||||
|
||||
var lim limiter.Limiter
|
||||
if len(lims) > 0 {
|
||||
lim = newLimiterGroup(lims...)
|
||||
}
|
||||
l.limits[key] = lim
|
||||
|
||||
if lim != nil && l.options.logger != nil {
|
||||
l.options.logger.Debugf("input limit for %s: %d", key, lim.Limit())
|
||||
}
|
||||
|
||||
return lim
|
||||
}
|
||||
|
||||
func (l *rateLimiter) periodReload(ctx context.Context) error {
|
||||
period := l.options.period
|
||||
if period < time.Second {
|
||||
period = time.Second
|
||||
}
|
||||
ticker := time.NewTicker(period)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := l.reload(ctx); err != nil {
|
||||
l.options.logger.Warnf("reload: %v", err)
|
||||
// return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *rateLimiter) reload(ctx context.Context) error {
|
||||
v, err := l.load(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lines := append(l.options.limits, v...)
|
||||
|
||||
ipLimits := make(map[string]RateLimitGenerator)
|
||||
cidrLimits := cidranger.NewPCTrieRanger()
|
||||
|
||||
for _, s := range lines {
|
||||
key, limit := l.parseLimit(s)
|
||||
if key == "" || limit <= 0 {
|
||||
continue
|
||||
}
|
||||
switch key {
|
||||
case GlobalLimitKey:
|
||||
ipLimits[key] = NewRateLimitSingleGenerator(limit)
|
||||
case IPLimitKey:
|
||||
ipLimits[key] = NewRateLimitGenerator(limit)
|
||||
default:
|
||||
if ip := net.ParseIP(key); ip != nil {
|
||||
ipLimits[key] = NewRateLimitGenerator(limit)
|
||||
break
|
||||
}
|
||||
if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil {
|
||||
cidrLimits.Insert(&cidrLimitEntry{
|
||||
ipNet: *ipNet,
|
||||
limit: NewRateLimitGenerator(limit),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.ipLimits = ipLimits
|
||||
l.cidrLimits = cidrLimits
|
||||
l.limits = make(map[string]limiter.Limiter)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *rateLimiter) load(ctx context.Context) (patterns []string, err error) {
|
||||
if l.options.fileLoader != nil {
|
||||
if lister, ok := l.options.fileLoader.(loader.Lister); ok {
|
||||
list, er := lister.List(ctx)
|
||||
if er != nil {
|
||||
l.options.logger.Warnf("file loader: %v", er)
|
||||
}
|
||||
for _, s := range list {
|
||||
if line := l.parseLine(s); line != "" {
|
||||
patterns = append(patterns, line)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
r, er := l.options.fileLoader.Load(ctx)
|
||||
if er != nil {
|
||||
l.options.logger.Warnf("file loader: %v", er)
|
||||
}
|
||||
if v, _ := l.parsePatterns(r); v != nil {
|
||||
patterns = append(patterns, v...)
|
||||
}
|
||||
}
|
||||
}
|
||||
if l.options.redisLoader != nil {
|
||||
if lister, ok := l.options.redisLoader.(loader.Lister); ok {
|
||||
list, er := lister.List(ctx)
|
||||
if er != nil {
|
||||
l.options.logger.Warnf("redis loader: %v", er)
|
||||
}
|
||||
patterns = append(patterns, list...)
|
||||
} else {
|
||||
r, er := l.options.redisLoader.Load(ctx)
|
||||
if er != nil {
|
||||
l.options.logger.Warnf("redis loader: %v", er)
|
||||
}
|
||||
if v, _ := l.parsePatterns(r); v != nil {
|
||||
patterns = append(patterns, v...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
l.options.logger.Debugf("load items %d", len(patterns))
|
||||
return
|
||||
}
|
||||
|
||||
func (l *rateLimiter) parsePatterns(r io.Reader) (patterns []string, err error) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
if line := l.parseLine(scanner.Text()); line != "" {
|
||||
patterns = append(patterns, line)
|
||||
}
|
||||
}
|
||||
|
||||
err = scanner.Err()
|
||||
return
|
||||
}
|
||||
|
||||
func (l *rateLimiter) parseLine(s string) string {
|
||||
if n := strings.IndexByte(s, '#'); n >= 0 {
|
||||
s = s[:n]
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func (l *rateLimiter) parseLimit(s string) (key string, limit float64) {
|
||||
s = strings.Replace(s, "\t", " ", -1)
|
||||
s = strings.TrimSpace(s)
|
||||
var ss []string
|
||||
for _, v := range strings.Split(s, " ") {
|
||||
if v != "" {
|
||||
ss = append(ss, v)
|
||||
}
|
||||
}
|
||||
if len(ss) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
key = ss[0]
|
||||
limit, _ = strconv.ParseFloat(ss[1], 64)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *rateLimiter) Close() error {
|
||||
l.cancelFunc()
|
||||
if l.options.fileLoader != nil {
|
||||
l.options.fileLoader.Close()
|
||||
}
|
||||
if l.options.redisLoader != nil {
|
||||
l.options.redisLoader.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type cidrLimitEntry struct {
|
||||
ipNet net.IPNet
|
||||
limit RateLimitGenerator
|
||||
}
|
||||
|
||||
func (p *cidrLimitEntry) Network() net.IPNet {
|
||||
return p.ipNet
|
||||
}
|
@ -95,7 +95,7 @@ type trafficLimiter struct {
|
||||
cidrLimits cidranger.Ranger
|
||||
inLimits map[string]limiter.Limiter
|
||||
outLimits map[string]limiter.Limiter
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
cancelFunc context.CancelFunc
|
||||
options options
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ var (
|
||||
// serverConn is a server side Conn with metrics supported.
|
||||
type serverConn struct {
|
||||
net.Conn
|
||||
rbuf bytes.Buffer
|
||||
raddr string
|
||||
rlimiter limiter.TrafficLimiter
|
||||
rbuf bytes.Buffer
|
||||
raddr string
|
||||
limiter limiter.TrafficLimiter
|
||||
}
|
||||
|
||||
func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn {
|
||||
@ -31,19 +31,19 @@ func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn {
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(c.RemoteAddr().String())
|
||||
return &serverConn{
|
||||
Conn: c,
|
||||
rlimiter: rlimiter,
|
||||
raddr: host,
|
||||
Conn: c,
|
||||
limiter: rlimiter,
|
||||
raddr: host,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
if c.rlimiter == nil ||
|
||||
c.rlimiter.In(c.raddr) == nil {
|
||||
if c.limiter == nil ||
|
||||
c.limiter.In(c.raddr) == nil {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
limiter := c.rlimiter.In(c.raddr)
|
||||
limiter := c.limiter.In(c.raddr)
|
||||
|
||||
if c.rbuf.Len() > 0 {
|
||||
burst := len(b)
|
||||
@ -70,12 +70,12 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
if c.rlimiter == nil ||
|
||||
c.rlimiter.Out(c.raddr) == nil {
|
||||
if c.limiter == nil ||
|
||||
c.limiter.Out(c.raddr) == nil {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
limiter := c.rlimiter.Out(c.raddr)
|
||||
limiter := c.limiter.Out(c.raddr)
|
||||
nn := 0
|
||||
for len(b) > 0 {
|
||||
nn, err = c.Conn.Write(b[:limiter.Wait(context.Background(), len(b))])
|
||||
|
Reference in New Issue
Block a user