add rate limiter
This commit is contained in:
parent
a04c8b45f3
commit
e23da0f319
@ -184,6 +184,22 @@ type RecorderObject struct {
|
||||
Record string `json:"record"`
|
||||
}
|
||||
|
||||
type LimiterConfig struct {
|
||||
Name string `json:"name"`
|
||||
RateLimit *RateLimitConfig `yaml:"rate" json:"rate"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Input string `yaml:",omitempty" json:"input,omitempty"`
|
||||
Output string `yaml:",omitempty" json:"output,omitempty"`
|
||||
Conn *LimitConfig `yaml:",omitempty" json:"conn,omitempty"`
|
||||
}
|
||||
|
||||
type LimitConfig struct {
|
||||
Input string `yaml:",omitempty" json:"input,omitempty"`
|
||||
Output string `yaml:",omitempty" json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type ListenerConfig struct {
|
||||
Type string `json:"type"`
|
||||
Chain string `yaml:",omitempty" json:"chain,omitempty"`
|
||||
@ -247,6 +263,7 @@ type ServiceConfig struct {
|
||||
Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"`
|
||||
Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"`
|
||||
Forwarder *ForwarderConfig `yaml:",omitempty" json:"forwarder,omitempty"`
|
||||
Limiter string `yaml:",omitempty" json:"limiter,omitempty"`
|
||||
Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
@ -297,6 +314,7 @@ type Config struct {
|
||||
Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"`
|
||||
Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"`
|
||||
Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"`
|
||||
Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"`
|
||||
TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
|
||||
Log *LogConfig `yaml:",omitempty" json:"log,omitempty"`
|
||||
Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"`
|
||||
|
@ -4,11 +4,13 @@ import (
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"github.com/alecthomas/units"
|
||||
"github.com/go-gost/core/admission"
|
||||
"github.com/go-gost/core/auth"
|
||||
"github.com/go-gost/core/bypass"
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/hosts"
|
||||
"github.com/go-gost/core/limiter"
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/core/recorder"
|
||||
"github.com/go-gost/core/resolver"
|
||||
@ -17,9 +19,10 @@ import (
|
||||
auth_impl "github.com/go-gost/x/auth"
|
||||
bypass_impl "github.com/go-gost/x/bypass"
|
||||
"github.com/go-gost/x/config"
|
||||
hosts_impl "github.com/go-gost/x/hosts"
|
||||
xhosts "github.com/go-gost/x/hosts"
|
||||
"github.com/go-gost/x/internal/loader"
|
||||
recorder_impl "github.com/go-gost/x/recorder"
|
||||
xlimiter "github.com/go-gost/x/limiter"
|
||||
xrecorder "github.com/go-gost/x/recorder"
|
||||
"github.com/go-gost/x/registry"
|
||||
resolver_impl "github.com/go-gost/x/resolver"
|
||||
xs "github.com/go-gost/x/selector"
|
||||
@ -224,7 +227,7 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper {
|
||||
return nil
|
||||
}
|
||||
|
||||
var mappings []hosts_impl.Mapping
|
||||
var mappings []xhosts.Mapping
|
||||
for _, mapping := range cfg.Mappings {
|
||||
if mapping.IP == "" || mapping.Hostname == "" {
|
||||
continue
|
||||
@ -234,33 +237,33 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
mappings = append(mappings, hosts_impl.Mapping{
|
||||
mappings = append(mappings, xhosts.Mapping{
|
||||
Hostname: mapping.Hostname,
|
||||
IP: ip,
|
||||
})
|
||||
}
|
||||
opts := []hosts_impl.Option{
|
||||
hosts_impl.MappingsOption(mappings),
|
||||
hosts_impl.ReloadPeriodOption(cfg.Reload),
|
||||
hosts_impl.LoggerOption(logger.Default().WithFields(map[string]any{
|
||||
opts := []xhosts.Option{
|
||||
xhosts.MappingsOption(mappings),
|
||||
xhosts.ReloadPeriodOption(cfg.Reload),
|
||||
xhosts.LoggerOption(logger.Default().WithFields(map[string]any{
|
||||
"kind": "hosts",
|
||||
"hosts": cfg.Name,
|
||||
})),
|
||||
}
|
||||
if cfg.File != nil && cfg.File.Path != "" {
|
||||
opts = append(opts, hosts_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path)))
|
||||
opts = append(opts, xhosts.FileLoaderOption(loader.FileLoader(cfg.File.Path)))
|
||||
}
|
||||
if cfg.Redis != nil && cfg.Redis.Addr != "" {
|
||||
switch cfg.Redis.Type {
|
||||
case "list": // redis list
|
||||
opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisListLoader(
|
||||
opts = append(opts, xhosts.RedisLoaderOption(loader.RedisListLoader(
|
||||
cfg.Redis.Addr,
|
||||
loader.DBRedisLoaderOption(cfg.Redis.DB),
|
||||
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
|
||||
loader.KeyRedisLoaderOption(cfg.Redis.Key),
|
||||
)))
|
||||
default: // redis set
|
||||
opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisSetLoader(
|
||||
opts = append(opts, xhosts.RedisLoaderOption(loader.RedisSetLoader(
|
||||
cfg.Redis.Addr,
|
||||
loader.DBRedisLoaderOption(cfg.Redis.DB),
|
||||
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
|
||||
@ -268,7 +271,7 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper {
|
||||
)))
|
||||
}
|
||||
}
|
||||
return hosts_impl.NewHostMapper(opts...)
|
||||
return xhosts.NewHostMapper(opts...)
|
||||
}
|
||||
|
||||
func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
|
||||
@ -277,8 +280,8 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
|
||||
}
|
||||
|
||||
if cfg.File != nil && cfg.File.Path != "" {
|
||||
return recorder_impl.FileRecorder(cfg.File.Path,
|
||||
recorder_impl.SepRecorderOption(cfg.File.Sep))
|
||||
return xrecorder.FileRecorder(cfg.File.Path,
|
||||
xrecorder.SepRecorderOption(cfg.File.Sep))
|
||||
}
|
||||
|
||||
if cfg.Redis != nil &&
|
||||
@ -286,16 +289,16 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) {
|
||||
cfg.Redis.Key != "" {
|
||||
switch cfg.Redis.Type {
|
||||
case "list": // redis list
|
||||
return recorder_impl.RedisListRecorder(cfg.Redis.Addr,
|
||||
recorder_impl.DBRedisRecorderOption(cfg.Redis.DB),
|
||||
recorder_impl.KeyRedisRecorderOption(cfg.Redis.Key),
|
||||
recorder_impl.PasswordRedisRecorderOption(cfg.Redis.Password),
|
||||
return xrecorder.RedisListRecorder(cfg.Redis.Addr,
|
||||
xrecorder.DBRedisRecorderOption(cfg.Redis.DB),
|
||||
xrecorder.KeyRedisRecorderOption(cfg.Redis.Key),
|
||||
xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password),
|
||||
)
|
||||
default: // redis set
|
||||
return recorder_impl.RedisSetRecorder(cfg.Redis.Addr,
|
||||
recorder_impl.DBRedisRecorderOption(cfg.Redis.DB),
|
||||
recorder_impl.KeyRedisRecorderOption(cfg.Redis.Key),
|
||||
recorder_impl.PasswordRedisRecorderOption(cfg.Redis.Password),
|
||||
return xrecorder.RedisSetRecorder(cfg.Redis.Addr,
|
||||
xrecorder.DBRedisRecorderOption(cfg.Redis.DB),
|
||||
xrecorder.KeyRedisRecorderOption(cfg.Redis.Key),
|
||||
xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password),
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -318,3 +321,35 @@ func defaultChainSelector() selector.Selector[chain.Chainer] {
|
||||
xs.BackupFilter[chain.Chainer](),
|
||||
)
|
||||
}
|
||||
|
||||
func ParseRateLimiter(cfg *config.LimiterConfig) (lim limiter.RateLimiter) {
|
||||
if cfg == nil || cfg.RateLimit == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var rlimiters []limiter.Limiter
|
||||
var wlimiters []limiter.Limiter
|
||||
if cfg.RateLimit.Conn != nil {
|
||||
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Conn.Input); v > 0 {
|
||||
rlimiters = append(rlimiters, xlimiter.Limiter(int(v)))
|
||||
}
|
||||
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Conn.Output); v > 0 {
|
||||
wlimiters = append(wlimiters, xlimiter.Limiter(int(v)))
|
||||
}
|
||||
}
|
||||
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Input); v > 0 {
|
||||
rlimiters = append(rlimiters, xlimiter.Limiter(int(v)))
|
||||
}
|
||||
if v, _ := units.ParseBase2Bytes(cfg.RateLimit.Output); v > 0 {
|
||||
wlimiters = append(wlimiters, xlimiter.Limiter(int(v)))
|
||||
}
|
||||
|
||||
var input, output limiter.Limiter
|
||||
if len(rlimiters) > 0 {
|
||||
input = xlimiter.MultiLimiter(rlimiters...)
|
||||
}
|
||||
if len(wlimiters) > 0 {
|
||||
output = xlimiter.MultiLimiter(wlimiters...)
|
||||
}
|
||||
return xlimiter.RateLimiter(input, output)
|
||||
}
|
||||
|
@ -75,6 +75,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
|
||||
listener.TLSConfigOption(tlsConfig),
|
||||
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
|
||||
listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
|
||||
listener.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.Limiter)),
|
||||
listener.LoggerOption(listenerLogger),
|
||||
listener.ServiceOption(cfg.Name),
|
||||
)
|
||||
|
83
limiter/rate.go
Normal file
83
limiter/rate.go
Normal file
@ -0,0 +1,83 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-gost/core/limiter"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type llimiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func Limiter(r int) limiter.Limiter {
|
||||
return &llimiter{
|
||||
limiter: rate.NewLimiter(rate.Limit(r), r),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *llimiter) Limit(b int) int {
|
||||
if l.limiter.Burst() < b {
|
||||
b = l.limiter.Burst()
|
||||
}
|
||||
l.limiter.WaitN(context.Background(), b)
|
||||
return b
|
||||
}
|
||||
|
||||
type Generator interface {
|
||||
Generate() limiter.Limiter
|
||||
}
|
||||
|
||||
type limiterGenerator struct {
|
||||
limit int
|
||||
}
|
||||
|
||||
func NewGenerator(r int) Generator {
|
||||
return &limiterGenerator{limit: r}
|
||||
}
|
||||
|
||||
// Generate creates a new Limiter.
|
||||
func (g *limiterGenerator) Generate() limiter.Limiter {
|
||||
return Limiter(g.limit)
|
||||
}
|
||||
|
||||
type multiLimiter struct {
|
||||
limiters []limiter.Limiter
|
||||
}
|
||||
|
||||
func MultiLimiter(limiters ...limiter.Limiter) limiter.Limiter {
|
||||
return &multiLimiter{
|
||||
limiters: limiters,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *multiLimiter) Limit(b int) int {
|
||||
for i := range l.limiters {
|
||||
b = l.limiters[i].Limit(b)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
input limiter.Limiter
|
||||
output limiter.Limiter
|
||||
}
|
||||
|
||||
func RateLimiter(input, output limiter.Limiter) limiter.RateLimiter {
|
||||
if input == nil || output == nil {
|
||||
return nil
|
||||
}
|
||||
return &rateLimiter{
|
||||
input: input,
|
||||
output: output,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *rateLimiter) Input() limiter.Limiter {
|
||||
return l.input
|
||||
}
|
||||
|
||||
func (l *rateLimiter) Output() limiter.Limiter {
|
||||
return l.output
|
||||
}
|
@ -3,6 +3,7 @@ package tcp
|
||||
import (
|
||||
"net"
|
||||
|
||||
limiter "github.com/go-gost/core/limiter/wrapper"
|
||||
"github.com/go-gost/core/listener"
|
||||
"github.com/go-gost/core/logger"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
@ -47,7 +48,8 @@ func (l *tcpListener) Init(md md.Metadata) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
l.ln = metrics.WrapListener(l.options.Service, ln)
|
||||
ln = metrics.WrapListener(l.options.Service, ln)
|
||||
l.ln = limiter.WrapListener(l.options.RateLimiter, ln)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/go-gost/core/common/net/udp"
|
||||
limiter "github.com/go-gost/core/limiter/wrapper"
|
||||
"github.com/go-gost/core/listener"
|
||||
"github.com/go-gost/core/logger"
|
||||
md "github.com/go-gost/core/metadata"
|
||||
@ -54,6 +55,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
|
||||
return
|
||||
}
|
||||
conn = metrics.WrapPacketConn(l.options.Service, conn)
|
||||
conn = limiter.WrapPacketConn(l.options.RateLimiter, conn)
|
||||
|
||||
l.ln = udp.NewListener(conn, &udp.ListenConfig{
|
||||
Backlog: l.md.backlog,
|
||||
|
48
registry/limiter.go
Normal file
48
registry/limiter.go
Normal file
@ -0,0 +1,48 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"github.com/go-gost/core/limiter"
|
||||
)
|
||||
|
||||
type rlimiterRegistry struct {
|
||||
registry
|
||||
}
|
||||
|
||||
func (r *rlimiterRegistry) Register(name string, v limiter.RateLimiter) error {
|
||||
return r.registry.Register(name, v)
|
||||
}
|
||||
|
||||
func (r *rlimiterRegistry) Get(name string) limiter.RateLimiter {
|
||||
if name != "" {
|
||||
return &rlimiterWrapper{name: name, r: r}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rlimiterRegistry) get(name string) limiter.RateLimiter {
|
||||
if v := r.registry.Get(name); v != nil {
|
||||
return v.(limiter.RateLimiter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type rlimiterWrapper struct {
|
||||
name string
|
||||
r *rlimiterRegistry
|
||||
}
|
||||
|
||||
func (w *rlimiterWrapper) Input() limiter.Limiter {
|
||||
v := w.r.get(w.name)
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return v.Input()
|
||||
}
|
||||
|
||||
func (w *rlimiterWrapper) Output() limiter.Limiter {
|
||||
v := w.r.get(w.name)
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return v.Output()
|
||||
}
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/go-gost/core/bypass"
|
||||
"github.com/go-gost/core/chain"
|
||||
"github.com/go-gost/core/hosts"
|
||||
"github.com/go-gost/core/limiter"
|
||||
"github.com/go-gost/core/recorder"
|
||||
"github.com/go-gost/core/resolver"
|
||||
"github.com/go-gost/core/service"
|
||||
@ -33,6 +34,7 @@ var (
|
||||
resolverReg Registry[resolver.Resolver] = &resolverRegistry{}
|
||||
hostsReg Registry[hosts.HostMapper] = &hostsRegistry{}
|
||||
recorderReg Registry[recorder.Recorder] = &recorderRegistry{}
|
||||
rlimiterReg Registry[limiter.RateLimiter] = &rlimiterRegistry{}
|
||||
)
|
||||
|
||||
type Registry[T any] interface {
|
||||
@ -126,3 +128,7 @@ func HostsRegistry() Registry[hosts.HostMapper] {
|
||||
func RecorderRegistry() Registry[recorder.Recorder] {
|
||||
return recorderReg
|
||||
}
|
||||
|
||||
func RateLimiterRegistry() Registry[limiter.RateLimiter] {
|
||||
return rlimiterReg
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user