add rate limiter

This commit is contained in:
ginuerzh 2022-04-21 21:37:30 +08:00
parent a04c8b45f3
commit e23da0f319
8 changed files with 218 additions and 23 deletions

View File

@ -184,6 +184,22 @@ type RecorderObject struct {
Record string `json:"record"`
}
type LimiterConfig struct {
Name string `json:"name"`
RateLimit *RateLimitConfig `yaml:"rate" json:"rate"`
}
type RateLimitConfig struct {
Input string `yaml:",omitempty" json:"input,omitempty"`
Output string `yaml:",omitempty" json:"output,omitempty"`
Conn *LimitConfig `yaml:",omitempty" json:"conn,omitempty"`
}
type LimitConfig struct {
Input string `yaml:",omitempty" json:"input,omitempty"`
Output string `yaml:",omitempty" json:"output,omitempty"`
}
type ListenerConfig struct {
Type string `json:"type"`
Chain string `yaml:",omitempty" json:"chain,omitempty"`
@ -247,6 +263,7 @@ type ServiceConfig struct {
Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"`
Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"`
Forwarder *ForwarderConfig `yaml:",omitempty" json:"forwarder,omitempty"`
Limiter string `yaml:",omitempty" json:"limiter,omitempty"`
Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"`
}
@ -297,6 +314,7 @@ type Config struct {
Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"`
Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"`
Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"`
Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"`
TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
Log *LogConfig `yaml:",omitempty" json:"log,omitempty"`
Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"`

View File

@ -4,11 +4,13 @@ import (
"net"
"net/url"
"github.com/alecthomas/units"
"github.com/go-gost/core/admission"
"github.com/go-gost/core/auth"
"github.com/go-gost/core/bypass"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/limiter"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/resolver"
@ -17,9 +19,10 @@ import (
auth_impl "github.com/go-gost/x/auth"
bypass_impl "github.com/go-gost/x/bypass"
"github.com/go-gost/x/config"
hosts_impl "github.com/go-gost/x/hosts"
xhosts "github.com/go-gost/x/hosts"
"github.com/go-gost/x/internal/loader"
recorder_impl "github.com/go-gost/x/recorder"
xlimiter "github.com/go-gost/x/limiter"
xrecorder "github.com/go-gost/x/recorder"
"github.com/go-gost/x/registry"
resolver_impl "github.com/go-gost/x/resolver"
xs "github.com/go-gost/x/selector"
@ -224,7 +227,7 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper {
return nil
}
var mappings []hosts_impl.Mapping
var mappings []xhosts.Mapping
for _, mapping := range cfg.Mappings {
if mapping.IP == "" || mapping.Hostname == "" {
continue
@ -234,33 +237,33 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper {
if ip == nil {
continue
}
mappings = append(mappings, hosts_impl.Mapping{
mappings = append(mappings, xhosts.Mapping{
Hostname: mapping.Hostname,
IP: ip,
})
}
opts := []hosts_impl.Option{
hosts_impl.MappingsOption(mappings),
hosts_impl.ReloadPeriodOption(cfg.Reload),
hosts_impl.LoggerOption(logger.Default().WithFields(map[string]any{
opts := []xhosts.Option{
xhosts.MappingsOption(mappings),
xhosts.ReloadPeriodOption(cfg.Reload),
xhosts.LoggerOption(logger.Default().WithFields(map[string]any{
"kind": "hosts",
"hosts": cfg.Name,
})),
}
if cfg.File != nil && cfg.File.Path != "" {
opts = append(opts, hosts_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path)))
opts = append(opts, xhosts.FileLoaderOption(loader.FileLoader(cfg.File.Path)))
}
if cfg.Redis != nil && cfg.Redis.Addr != "" {
switch cfg.Redis.Type {
case "list": // redis list
opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisListLoader(
opts = append(opts, xhosts.RedisLoaderOption(loader.RedisListLoader(
cfg.Redis.Addr,
loader.DBRedisLoaderOption(cfg.Redis.DB),
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
loader.KeyRedisLoaderOption(cfg.Redis.Key),
)))
default: // redis set
opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisSetLoader(
opts = append(opts, xhosts.RedisLoaderOption(loader.RedisSetLoader(
cfg.Redis.Addr,
loader.DBRedisLoaderOption(cfg.Redis.DB),
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
@ -268,7 +271,7 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper {
)))
}
}
return hosts_impl.NewHostMapper(opts...)
return xhosts.NewHostMapper(opts...)
}
func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
@ -277,8 +280,8 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
}
if cfg.File != nil && cfg.File.Path != "" {
return recorder_impl.FileRecorder(cfg.File.Path,
recorder_impl.SepRecorderOption(cfg.File.Sep))
return xrecorder.FileRecorder(cfg.File.Path,
xrecorder.SepRecorderOption(cfg.File.Sep))
}
if cfg.Redis != nil &&
@ -286,16 +289,16 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
cfg.Redis.Key != "" {
switch cfg.Redis.Type {
case "list": // redis list
return recorder_impl.RedisListRecorder(cfg.Redis.Addr,
recorder_impl.DBRedisRecorderOption(cfg.Redis.DB),
recorder_impl.KeyRedisRecorderOption(cfg.Redis.Key),
recorder_impl.PasswordRedisRecorderOption(cfg.Redis.Password),
return xrecorder.RedisListRecorder(cfg.Redis.Addr,
xrecorder.DBRedisRecorderOption(cfg.Redis.DB),
xrecorder.KeyRedisRecorderOption(cfg.Redis.Key),
xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password),
)
default: // redis set
return recorder_impl.RedisSetRecorder(cfg.Redis.Addr,
recorder_impl.DBRedisRecorderOption(cfg.Redis.DB),
recorder_impl.KeyRedisRecorderOption(cfg.Redis.Key),
recorder_impl.PasswordRedisRecorderOption(cfg.Redis.Password),
return xrecorder.RedisSetRecorder(cfg.Redis.Addr,
xrecorder.DBRedisRecorderOption(cfg.Redis.DB),
xrecorder.KeyRedisRecorderOption(cfg.Redis.Key),
xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password),
)
}
}
@ -318,3 +321,35 @@ func defaultChainSelector() selector.Selector[chain.Chainer] {
xs.BackupFilter[chain.Chainer](),
)
}
func ParseRateLimiter(cfg *config.LimiterConfig) (lim limiter.RateLimiter) {
if cfg == nil || cfg.RateLimit == nil {
return nil
}
var rlimiters []limiter.Limiter
var wlimiters []limiter.Limiter
if cfg.RateLimit.Conn != nil {
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Conn.Input); v > 0 {
rlimiters = append(rlimiters, xlimiter.Limiter(int(v)))
}
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Conn.Output); v > 0 {
wlimiters = append(wlimiters, xlimiter.Limiter(int(v)))
}
}
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Input); v > 0 {
rlimiters = append(rlimiters, xlimiter.Limiter(int(v)))
}
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Output); v > 0 {
wlimiters = append(wlimiters, xlimiter.Limiter(int(v)))
}
var input, output limiter.Limiter
if len(rlimiters) > 0 {
input = xlimiter.MultiLimiter(rlimiters...)
}
if len(wlimiters) > 0 {
output = xlimiter.MultiLimiter(wlimiters...)
}
return xlimiter.RateLimiter(input, output)
}

View File

@ -75,6 +75,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
listener.TLSConfigOption(tlsConfig),
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
listener.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.Limiter)),
listener.LoggerOption(listenerLogger),
listener.ServiceOption(cfg.Name),
)

83
limiter/rate.go Normal file
View File

@ -0,0 +1,83 @@
package limiter
import (
"context"
"github.com/go-gost/core/limiter"
"golang.org/x/time/rate"
)
type llimiter struct {
limiter *rate.Limiter
}
func Limiter(r int) limiter.Limiter {
return &llimiter{
limiter: rate.NewLimiter(rate.Limit(r), r),
}
}
func (l *llimiter) Limit(b int) int {
if l.limiter.Burst() < b {
b = l.limiter.Burst()
}
l.limiter.WaitN(context.Background(), b)
return b
}
type Generator interface {
Generate() limiter.Limiter
}
type limiterGenerator struct {
limit int
}
func NewGenerator(r int) Generator {
return &limiterGenerator{limit: r}
}
// Generate creates a new Limiter.
func (g *limiterGenerator) Generate() limiter.Limiter {
return Limiter(g.limit)
}
type multiLimiter struct {
limiters []limiter.Limiter
}
func MultiLimiter(limiters ...limiter.Limiter) limiter.Limiter {
return &multiLimiter{
limiters: limiters,
}
}
func (l *multiLimiter) Limit(b int) int {
for i := range l.limiters {
b = l.limiters[i].Limit(b)
}
return b
}
type rateLimiter struct {
input limiter.Limiter
output limiter.Limiter
}
func RateLimiter(input, output limiter.Limiter) limiter.RateLimiter {
if input == nil || output == nil {
return nil
}
return &rateLimiter{
input: input,
output: output,
}
}
func (l *rateLimiter) Input() limiter.Limiter {
return l.input
}
func (l *rateLimiter) Output() limiter.Limiter {
return l.output
}

View File

@ -3,6 +3,7 @@ package tcp
import (
"net"
limiter "github.com/go-gost/core/limiter/wrapper"
"github.com/go-gost/core/listener"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
@ -47,7 +48,8 @@ func (l *tcpListener) Init(md md.Metadata) (err error) {
return
}
l.ln = metrics.WrapListener(l.options.Service, ln)
ln = metrics.WrapListener(l.options.Service, ln)
l.ln = limiter.WrapListener(l.options.RateLimiter, ln)
return
}

View File

@ -4,6 +4,7 @@ import (
"net"
"github.com/go-gost/core/common/net/udp"
limiter "github.com/go-gost/core/limiter/wrapper"
"github.com/go-gost/core/listener"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
@ -54,6 +55,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
return
}
conn = metrics.WrapPacketConn(l.options.Service, conn)
conn = limiter.WrapPacketConn(l.options.RateLimiter, conn)
l.ln = udp.NewListener(conn, &udp.ListenConfig{
Backlog: l.md.backlog,

48
registry/limiter.go Normal file
View File

@ -0,0 +1,48 @@
package registry
import (
"github.com/go-gost/core/limiter"
)
type rlimiterRegistry struct {
registry
}
func (r *rlimiterRegistry) Register(name string, v limiter.RateLimiter) error {
return r.registry.Register(name, v)
}
func (r *rlimiterRegistry) Get(name string) limiter.RateLimiter {
if name != "" {
return &rlimiterWrapper{name: name, r: r}
}
return nil
}
func (r *rlimiterRegistry) get(name string) limiter.RateLimiter {
if v := r.registry.Get(name); v != nil {
return v.(limiter.RateLimiter)
}
return nil
}
type rlimiterWrapper struct {
name string
r *rlimiterRegistry
}
func (w *rlimiterWrapper) Input() limiter.Limiter {
v := w.r.get(w.name)
if v == nil {
return nil
}
return v.Input()
}
func (w *rlimiterWrapper) Output() limiter.Limiter {
v := w.r.get(w.name)
if v == nil {
return nil
}
return v.Output()
}

View File

@ -10,6 +10,7 @@ import (
"github.com/go-gost/core/bypass"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/limiter"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/resolver"
"github.com/go-gost/core/service"
@ -33,6 +34,7 @@ var (
resolverReg Registry[resolver.Resolver] = &resolverRegistry{}
hostsReg Registry[hosts.HostMapper] = &hostsRegistry{}
recorderReg Registry[recorder.Recorder] = &recorderRegistry{}
rlimiterReg Registry[limiter.RateLimiter] = &rlimiterRegistry{}
)
type Registry[T any] interface {
@ -126,3 +128,7 @@ func HostsRegistry() Registry[hosts.HostMapper] {
func RecorderRegistry() Registry[recorder.Recorder] {
return recorderReg
}
func RateLimiterRegistry() Registry[limiter.RateLimiter] {
return rlimiterReg
}