From 50d443049f3bb083c949c2fc2821d2bc037340d4 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 14 Sep 2022 19:53:21 +0800 Subject: [PATCH] add request rate limiter --- handler/option.go | 20 ++++++++++++++------ limiter/rate/limiter.go | 10 ++++++++++ metadata/util/util.go | 2 ++ 3 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 limiter/rate/limiter.go diff --git a/handler/option.go b/handler/option.go index ad86529..b75ca03 100644 --- a/handler/option.go +++ b/handler/option.go @@ -7,17 +7,19 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" + "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/logger" "github.com/go-gost/core/metadata" ) type Options struct { - Bypass bypass.Bypass - Router *chain.Router - Auth *url.Userinfo - Auther auth.Authenticator - TLSConfig *tls.Config - Logger logger.Logger + Bypass bypass.Bypass + Router *chain.Router + Auth *url.Userinfo + Auther auth.Authenticator + RateLimiter rate.RateLimiter + TLSConfig *tls.Config + Logger logger.Logger } type Option func(opts *Options) @@ -46,6 +48,12 @@ func AutherOption(auther auth.Authenticator) Option { } } +func RateLimiterOption(limiter rate.RateLimiter) Option { + return func(opts *Options) { + opts.RateLimiter = limiter + } +} + func TLSConfigOption(tlsConfig *tls.Config) Option { return func(opts *Options) { opts.TLSConfig = tlsConfig diff --git a/limiter/rate/limiter.go b/limiter/rate/limiter.go new file mode 100644 index 0000000..d92d7bb --- /dev/null +++ b/limiter/rate/limiter.go @@ -0,0 +1,10 @@ +package rate + +type Limiter interface { + Allow(n int) bool + Limit() float64 +} + +type RateLimiter interface { + Limiter(key string) Limiter +} diff --git a/metadata/util/util.go b/metadata/util/util.go index c45dd21..8772a44 100644 --- a/metadata/util/util.go +++ b/metadata/util/util.go @@ -49,6 +49,8 @@ func GetFloat(md metadata.Metadata, key string) (v float64) { } switch vv := md.Get(key).(type) { + case float64: + return vv case int: return float64(vv) case string: