diff --git a/config/config.go b/config/config.go index 1b1323b..79fc72c 100644 --- a/config/config.go +++ b/config/config.go @@ -184,6 +184,22 @@ type RecorderObject struct { Record string `json:"record"` } +type LimiterConfig struct { + Name string `json:"name"` + RateLimit *RateLimitConfig `yaml:"rate" json:"rate"` +} + +type RateLimitConfig struct { + Input string `yaml:",omitempty" json:"input,omitempty"` + Output string `yaml:",omitempty" json:"output,omitempty"` + Conn *LimitConfig `yaml:",omitempty" json:"conn,omitempty"` +} + +type LimitConfig struct { + Input string `yaml:",omitempty" json:"input,omitempty"` + Output string `yaml:",omitempty" json:"output,omitempty"` +} + type ListenerConfig struct { Type string `json:"type"` Chain string `yaml:",omitempty" json:"chain,omitempty"` @@ -247,6 +263,7 @@ type ServiceConfig struct { Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"` Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"` Forwarder *ForwarderConfig `yaml:",omitempty" json:"forwarder,omitempty"` + Limiter string `yaml:",omitempty" json:"limiter,omitempty"` Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } @@ -297,6 +314,7 @@ type Config struct { Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"` Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"` + Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` diff --git a/config/parsing/parse.go b/config/parsing/parse.go index e255b0e..084b69a 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -4,11 +4,13 @@ import ( "net" "net/url" + "github.com/alecthomas/units" "github.com/go-gost/core/admission" "github.com/go-gost/core/auth" "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" "github.com/go-gost/core/hosts" + "github.com/go-gost/core/limiter" "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" "github.com/go-gost/core/resolver" @@ -17,9 +19,10 @@ import ( auth_impl "github.com/go-gost/x/auth" bypass_impl "github.com/go-gost/x/bypass" "github.com/go-gost/x/config" - hosts_impl "github.com/go-gost/x/hosts" + xhosts "github.com/go-gost/x/hosts" "github.com/go-gost/x/internal/loader" - recorder_impl "github.com/go-gost/x/recorder" + xlimiter "github.com/go-gost/x/limiter" + xrecorder "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" resolver_impl "github.com/go-gost/x/resolver" xs "github.com/go-gost/x/selector" @@ -224,7 +227,7 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { return nil } - var mappings []hosts_impl.Mapping + var mappings []xhosts.Mapping for _, mapping := range cfg.Mappings { if mapping.IP == "" || mapping.Hostname == "" { continue @@ -234,33 +237,33 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { if ip == nil { continue } - mappings = append(mappings, hosts_impl.Mapping{ + mappings = append(mappings, xhosts.Mapping{ Hostname: mapping.Hostname, IP: ip, }) } - opts := []hosts_impl.Option{ - hosts_impl.MappingsOption(mappings), - hosts_impl.ReloadPeriodOption(cfg.Reload), - hosts_impl.LoggerOption(logger.Default().WithFields(map[string]any{ + opts := []xhosts.Option{ + xhosts.MappingsOption(mappings), + xhosts.ReloadPeriodOption(cfg.Reload), + xhosts.LoggerOption(logger.Default().WithFields(map[string]any{ "kind": "hosts", "hosts": cfg.Name, })), } if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, hosts_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + opts = append(opts, xhosts.FileLoaderOption(loader.FileLoader(cfg.File.Path))) } if cfg.Redis != nil && cfg.Redis.Addr != "" { switch cfg.Redis.Type { case "list": // redis list - opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisListLoader( + opts = append(opts, xhosts.RedisLoaderOption(loader.RedisListLoader( cfg.Redis.Addr, loader.DBRedisLoaderOption(cfg.Redis.DB), loader.PasswordRedisLoaderOption(cfg.Redis.Password), loader.KeyRedisLoaderOption(cfg.Redis.Key), ))) default: // redis set - opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisSetLoader( + opts = append(opts, xhosts.RedisLoaderOption(loader.RedisSetLoader( cfg.Redis.Addr, loader.DBRedisLoaderOption(cfg.Redis.DB), loader.PasswordRedisLoaderOption(cfg.Redis.Password), @@ -268,7 +271,7 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { ))) } } - return hosts_impl.NewHostMapper(opts...) + return xhosts.NewHostMapper(opts...) } func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { @@ -277,8 +280,8 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { } if cfg.File != nil && cfg.File.Path != "" { - return recorder_impl.FileRecorder(cfg.File.Path, - recorder_impl.SepRecorderOption(cfg.File.Sep)) + return xrecorder.FileRecorder(cfg.File.Path, + xrecorder.SepRecorderOption(cfg.File.Sep)) } if cfg.Redis != nil && @@ -286,16 +289,16 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { cfg.Redis.Key != "" { switch cfg.Redis.Type { case "list": // redis list - return recorder_impl.RedisListRecorder(cfg.Redis.Addr, - recorder_impl.DBRedisRecorderOption(cfg.Redis.DB), - recorder_impl.KeyRedisRecorderOption(cfg.Redis.Key), - recorder_impl.PasswordRedisRecorderOption(cfg.Redis.Password), + return xrecorder.RedisListRecorder(cfg.Redis.Addr, + xrecorder.DBRedisRecorderOption(cfg.Redis.DB), + xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), + xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), ) default: // redis set - return recorder_impl.RedisSetRecorder(cfg.Redis.Addr, - recorder_impl.DBRedisRecorderOption(cfg.Redis.DB), - recorder_impl.KeyRedisRecorderOption(cfg.Redis.Key), - recorder_impl.PasswordRedisRecorderOption(cfg.Redis.Password), + return xrecorder.RedisSetRecorder(cfg.Redis.Addr, + xrecorder.DBRedisRecorderOption(cfg.Redis.DB), + xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), + xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), ) } } @@ -318,3 +321,35 @@ func defaultChainSelector() selector.Selector[chain.Chainer] { xs.BackupFilter[chain.Chainer](), ) } + +func ParseRateLimiter(cfg *config.LimiterConfig) (lim limiter.RateLimiter) { + if cfg == nil || cfg.RateLimit == nil { + return nil + } + + var rlimiters []limiter.Limiter + var wlimiters []limiter.Limiter + if cfg.RateLimit.Conn != nil { + if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Conn.Input); v > 0 { + rlimiters = append(rlimiters, xlimiter.Limiter(int(v))) + } + if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Conn.Output); v > 0 { + wlimiters = append(wlimiters, xlimiter.Limiter(int(v))) + } + } + if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Input); v > 0 { + rlimiters = append(rlimiters, xlimiter.Limiter(int(v))) + } + if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Output); v > 0 { + wlimiters = append(wlimiters, xlimiter.Limiter(int(v))) + } + + var input, output limiter.Limiter + if len(rlimiters) > 0 { + input = xlimiter.MultiLimiter(rlimiters...) + } + if len(wlimiters) > 0 { + output = xlimiter.MultiLimiter(wlimiters...) + } + return xlimiter.RateLimiter(input, output) +} diff --git a/config/parsing/service.go b/config/parsing/service.go index 7c63d44..c8d27d3 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -75,6 +75,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { listener.TLSConfigOption(tlsConfig), listener.AdmissionOption(admission.AdmissionGroup(admissions...)), listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)), + listener.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.Limiter)), listener.LoggerOption(listenerLogger), listener.ServiceOption(cfg.Name), ) diff --git a/limiter/rate.go b/limiter/rate.go new file mode 100644 index 0000000..120710b --- /dev/null +++ b/limiter/rate.go @@ -0,0 +1,83 @@ +package limiter + +import ( + "context" + + "github.com/go-gost/core/limiter" + "golang.org/x/time/rate" +) + +type llimiter struct { + limiter *rate.Limiter +} + +func Limiter(r int) limiter.Limiter { + return &llimiter{ + limiter: rate.NewLimiter(rate.Limit(r), r), + } +} + +func (l *llimiter) Limit(b int) int { + if l.limiter.Burst() < b { + b = l.limiter.Burst() + } + l.limiter.WaitN(context.Background(), b) + return b +} + +type Generator interface { + Generate() limiter.Limiter +} + +type limiterGenerator struct { + limit int +} + +func NewGenerator(r int) Generator { + return &limiterGenerator{limit: r} +} + +// Generate creates a new Limiter. +func (g *limiterGenerator) Generate() limiter.Limiter { + return Limiter(g.limit) +} + +type multiLimiter struct { + limiters []limiter.Limiter +} + +func MultiLimiter(limiters ...limiter.Limiter) limiter.Limiter { + return &multiLimiter{ + limiters: limiters, + } +} + +func (l *multiLimiter) Limit(b int) int { + for i := range l.limiters { + b = l.limiters[i].Limit(b) + } + return b +} + +type rateLimiter struct { + input limiter.Limiter + output limiter.Limiter +} + +func RateLimiter(input, output limiter.Limiter) limiter.RateLimiter { + if input == nil || output == nil { + return nil + } + return &rateLimiter{ + input: input, + output: output, + } +} + +func (l *rateLimiter) Input() limiter.Limiter { + return l.input +} + +func (l *rateLimiter) Output() limiter.Limiter { + return l.output +} diff --git a/listener/tcp/listener.go b/listener/tcp/listener.go index 079170d..366b6a7 100644 --- a/listener/tcp/listener.go +++ b/listener/tcp/listener.go @@ -3,6 +3,7 @@ package tcp import ( "net" + limiter "github.com/go-gost/core/limiter/wrapper" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" @@ -47,7 +48,8 @@ func (l *tcpListener) Init(md md.Metadata) (err error) { return } - l.ln = metrics.WrapListener(l.options.Service, ln) + ln = metrics.WrapListener(l.options.Service, ln) + l.ln = limiter.WrapListener(l.options.RateLimiter, ln) return } diff --git a/listener/udp/listener.go b/listener/udp/listener.go index 1b37912..d5983f2 100644 --- a/listener/udp/listener.go +++ b/listener/udp/listener.go @@ -4,6 +4,7 @@ import ( "net" "github.com/go-gost/core/common/net/udp" + limiter "github.com/go-gost/core/limiter/wrapper" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" @@ -54,6 +55,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) { return } conn = metrics.WrapPacketConn(l.options.Service, conn) + conn = limiter.WrapPacketConn(l.options.RateLimiter, conn) l.ln = udp.NewListener(conn, &udp.ListenConfig{ Backlog: l.md.backlog, diff --git a/registry/limiter.go b/registry/limiter.go new file mode 100644 index 0000000..69a48d7 --- /dev/null +++ b/registry/limiter.go @@ -0,0 +1,48 @@ +package registry + +import ( + "github.com/go-gost/core/limiter" +) + +type rlimiterRegistry struct { + registry +} + +func (r *rlimiterRegistry) Register(name string, v limiter.RateLimiter) error { + return r.registry.Register(name, v) +} + +func (r *rlimiterRegistry) Get(name string) limiter.RateLimiter { + if name != "" { + return &rlimiterWrapper{name: name, r: r} + } + return nil +} + +func (r *rlimiterRegistry) get(name string) limiter.RateLimiter { + if v := r.registry.Get(name); v != nil { + return v.(limiter.RateLimiter) + } + return nil +} + +type rlimiterWrapper struct { + name string + r *rlimiterRegistry +} + +func (w *rlimiterWrapper) Input() limiter.Limiter { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Input() +} + +func (w *rlimiterWrapper) Output() limiter.Limiter { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Output() +} diff --git a/registry/registry.go b/registry/registry.go index f9aa5b4..ac163f0 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" "github.com/go-gost/core/hosts" + "github.com/go-gost/core/limiter" "github.com/go-gost/core/recorder" "github.com/go-gost/core/resolver" "github.com/go-gost/core/service" @@ -33,6 +34,7 @@ var ( resolverReg Registry[resolver.Resolver] = &resolverRegistry{} hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} recorderReg Registry[recorder.Recorder] = &recorderRegistry{} + rlimiterReg Registry[limiter.RateLimiter] = &rlimiterRegistry{} ) type Registry[T any] interface { @@ -126,3 +128,7 @@ func HostsRegistry() Registry[hosts.HostMapper] { func RecorderRegistry() Registry[recorder.Recorder] { return recorderReg } + +func RateLimiterRegistry() Registry[limiter.RateLimiter] { + return rlimiterReg +}