From 01d7dc77c67a0268e3adb1511903c574cda1ba2e Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 14 Sep 2022 20:00:35 +0800 Subject: [PATCH] update limiter --- api/config_conn_limiter.go | 2 +- api/config_limiter.go | 2 +- api/config_rate_limiter.go | 166 ++++++++++++++ api/service.go | 4 + api/swagger.yaml | 200 +++++++++++++++-- config/config.go | 4 +- config/parsing/chain.go | 34 ++- config/parsing/parse.go | 42 ++++ config/parsing/service.go | 54 +++-- go.mod | 2 +- go.sum | 4 +- handler/dns/handler.go | 16 ++ handler/forward/local/handler.go | 16 ++ handler/forward/remote/handler.go | 16 ++ handler/http/handler.go | 16 ++ handler/http2/handler.go | 16 ++ handler/redirect/tcp/handler.go | 16 ++ handler/redirect/udp/handler.go | 16 ++ handler/relay/handler.go | 16 ++ handler/sni/handler.go | 16 ++ handler/socks/v4/handler.go | 16 ++ handler/socks/v5/handler.go | 16 ++ handler/ss/handler.go | 16 ++ handler/ss/udp/handler.go | 16 ++ handler/sshd/handler.go | 16 ++ limiter/conn/conn.go | 2 +- limiter/conn/limiter.go | 2 +- limiter/rate/generator.go | 44 ++++ limiter/rate/limiter.go | 26 +++ limiter/rate/rate.go | 353 ++++++++++++++++++++++++++++++ limiter/traffic/traffic.go | 2 +- limiter/traffic/wrapper/conn.go | 24 +- registry/limiter.go | 36 +++ registry/registry.go | 23 +- 34 files changed, 1171 insertions(+), 79 deletions(-) create mode 100644 api/config_rate_limiter.go create mode 100644 limiter/rate/generator.go create mode 100644 limiter/rate/limiter.go create mode 100644 limiter/rate/rate.go diff --git a/api/config_conn_limiter.go b/api/config_conn_limiter.go index bbfd931..943bd19 100644 --- a/api/config_conn_limiter.go +++ b/api/config_conn_limiter.go @@ -120,7 +120,7 @@ func updateConnLimiter(ctx *gin.Context) { type deleteConnLimiterRequest struct { // in: path // required: true - Limiter string `uri:"Limiter" json:"Limiter"` + Limiter string `uri:"limiter" json:"limiter"` } // successful operation. diff --git a/api/config_limiter.go b/api/config_limiter.go index 3c94e02..b7fed3a 100644 --- a/api/config_limiter.go +++ b/api/config_limiter.go @@ -120,7 +120,7 @@ func updateLimiter(ctx *gin.Context) { type deleteLimiterRequest struct { // in: path // required: true - Limiter string `uri:"Limiter" json:"Limiter"` + Limiter string `uri:"limiter" json:"limiter"` } // successful operation. diff --git a/api/config_rate_limiter.go b/api/config_rate_limiter.go new file mode 100644 index 0000000..93cbe5b --- /dev/null +++ b/api/config_rate_limiter.go @@ -0,0 +1,166 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/go-gost/x/config" + "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" +) + +// swagger:parameters createRateLimiterRequest +type createRateLimiterRequest struct { + // in: body + Data config.LimiterConfig `json:"data"` +} + +// successful operation. +// swagger:response createRateLimiterResponse +type createRateLimiterResponse struct { + Data Response +} + +func createRateLimiter(ctx *gin.Context) { + // swagger:route POST /config/rlimiters Limiter createRateLimiterRequest + // + // Create a new rate limiter, the name of limiter must be unique in limiter list. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: createRateLimiterResponse + + var req createRateLimiterRequest + ctx.ShouldBindJSON(&req.Data) + + if req.Data.Name == "" { + writeError(ctx, ErrInvalid) + return + } + + v := parsing.ParseRateLimiter(&req.Data) + + if err := registry.RateLimiterRegistry().Register(req.Data.Name, v); err != nil { + writeError(ctx, ErrDup) + return + } + + cfg := config.Global() + cfg.CLimiters = append(cfg.CLimiters, &req.Data) + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters updateRateLimiterRequest +type updateRateLimiterRequest struct { + // in: path + // required: true + Limiter string `uri:"limiter" json:"limiter"` + // in: body + Data config.LimiterConfig `json:"data"` +} + +// successful operation. +// swagger:response updateRateLimiterResponse +type updateRateLimiterResponse struct { + Data Response +} + +func updateRateLimiter(ctx *gin.Context) { + // swagger:route PUT /config/rlimiters/{limiter} Limiter updateRateLimiterRequest + // + // Update rate limiter by name, the limiter must already exist. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: updateRateLimiterResponse + + var req updateRateLimiterRequest + ctx.ShouldBindUri(&req) + ctx.ShouldBindJSON(&req.Data) + + if !registry.RateLimiterRegistry().IsRegistered(req.Limiter) { + writeError(ctx, ErrNotFound) + return + } + + req.Data.Name = req.Limiter + + v := parsing.ParseRateLimiter(&req.Data) + + registry.RateLimiterRegistry().Unregister(req.Limiter) + + if err := registry.RateLimiterRegistry().Register(req.Limiter, v); err != nil { + writeError(ctx, ErrDup) + return + } + + cfg := config.Global() + for i := range cfg.Limiters { + if cfg.Limiters[i].Name == req.Limiter { + cfg.Limiters[i] = &req.Data + break + } + } + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters deleteRateLimiterRequest +type deleteRateLimiterRequest struct { + // in: path + // required: true + Limiter string `uri:"limiter" json:"limiter"` +} + +// successful operation. +// swagger:response deleteRateLimiterResponse +type deleteRateLimiterResponse struct { + Data Response +} + +func deleteRateLimiter(ctx *gin.Context) { + // swagger:route DELETE /config/rlimiters/{limiter} Limiter deleteRateLimiterRequest + // + // Delete rate limiter by name. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: deleteRateLimiterResponse + + var req deleteRateLimiterRequest + ctx.ShouldBindUri(&req) + + if !registry.RateLimiterRegistry().IsRegistered(req.Limiter) { + writeError(ctx, ErrNotFound) + return + } + registry.RateLimiterRegistry().Unregister(req.Limiter) + + cfg := config.Global() + limiteres := cfg.Limiters + cfg.Limiters = nil + for _, s := range limiteres { + if s.Name == req.Limiter { + continue + } + cfg.Limiters = append(cfg.Limiters, s) + } + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} diff --git a/api/service.go b/api/service.go index d25bb8a..81ad192 100644 --- a/api/service.go +++ b/api/service.go @@ -137,4 +137,8 @@ func registerConfig(config *gin.RouterGroup) { config.POST("/climiters", createConnLimiter) config.PUT("/climiters/:limiter", updateConnLimiter) config.DELETE("/climiters/:limiter", deleteConnLimiter) + + config.POST("/rlimiters", createRateLimiter) + config.PUT("/rlimiters/:limiter", updateRateLimiter) + config.DELETE("/rlimiters/:limiter", deleteRateLimiter) } diff --git a/api/swagger.yaml b/api/swagger.yaml index 03968ef..dfcaab9 100644 --- a/api/swagger.yaml +++ b/api/swagger.yaml @@ -151,6 +151,11 @@ definitions: $ref: '#/definitions/ChainConfig' type: array x-go-name: Chains + climiters: + items: + $ref: '#/definitions/LimiterConfig' + type: array + x-go-name: CLimiters hosts: items: $ref: '#/definitions/HostsConfig' @@ -177,6 +182,11 @@ definitions: $ref: '#/definitions/ResolverConfig' type: array x-go-name: Resolvers + rlimiters: + items: + $ref: '#/definitions/LimiterConfig' + type: array + x-go-name: RLimiters services: items: $ref: '#/definitions/ServiceConfig' @@ -358,11 +368,20 @@ definitions: x-go-package: github.com/go-gost/x/config LimiterConfig: properties: + file: + $ref: '#/definitions/FileLoader' + limits: + items: + type: string + type: array + x-go-name: Limits name: type: string x-go-name: Name - rate: - $ref: '#/definitions/RateLimiterConfig' + redis: + $ref: '#/definitions/RedisLoader' + reload: + $ref: '#/definitions/Duration' type: object x-go-package: github.com/go-gost/x/config ListenerConfig: @@ -483,21 +502,6 @@ definitions: x-go-name: Addr type: object x-go-package: github.com/go-gost/x/config - RateLimiterConfig: - properties: - file: - $ref: '#/definitions/FileLoader' - limits: - items: - type: string - type: array - x-go-name: Limits - redis: - $ref: '#/definitions/RedisLoader' - reload: - $ref: '#/definitions/Duration' - type: object - x-go-package: github.com/go-gost/x/config RecorderConfig: properties: file: @@ -616,6 +620,9 @@ definitions: type: string type: array x-go-name: Bypasses + climiter: + type: string + x-go-name: CLimiter forwarder: $ref: '#/definitions/ForwarderConfig' handler: @@ -624,6 +631,7 @@ definitions: type: string x-go-name: Hosts interface: + description: DEPRECATED by metadata.interface since beta.5 type: string x-go-name: Interface limiter: @@ -646,6 +654,9 @@ definitions: resolver: type: string x-go-name: Resolver + rlimiter: + type: string + x-go-name: RLimiter sockopts: $ref: '#/definitions/SockOptsConfig' type: object @@ -956,6 +967,64 @@ paths: summary: Update chain by name, the chain must already exist. tags: - Chain + /config/climiters: + post: + operationId: createConnLimiterRequest + parameters: + - in: body + name: data + schema: + $ref: '#/definitions/LimiterConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/createConnLimiterResponse' + security: + - basicAuth: + - '[]' + summary: Create a new conn limiter, the name of limiter must be unique in limiter list. + tags: + - Limiter + /config/climiters/{limiter}: + delete: + operationId: deleteConnLimiterRequest + parameters: + - in: path + name: limiter + required: true + type: string + x-go-name: Limiter + responses: + "200": + $ref: '#/responses/deleteConnLimiterResponse' + security: + - basicAuth: + - '[]' + summary: Delete conn limiter by name. + tags: + - Limiter + put: + operationId: updateConnLimiterRequest + parameters: + - in: path + name: limiter + required: true + type: string + x-go-name: Limiter + - in: body + name: data + schema: + $ref: '#/definitions/LimiterConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/updateConnLimiterResponse' + security: + - basicAuth: + - '[]' + summary: Update conn limiter by name, the limiter must already exist. + tags: + - Limiter /config/hosts: post: operationId: createHostsRequest @@ -1037,9 +1106,10 @@ paths: operationId: deleteLimiterRequest parameters: - in: path - name: Limiter + name: limiter required: true type: string + x-go-name: Limiter responses: "200": $ref: '#/responses/deleteLimiterResponse' @@ -1129,6 +1199,64 @@ paths: summary: Update resolver by name, the resolver must already exist. tags: - Resolver + /config/rlimiters: + post: + operationId: createRateLimiterRequest + parameters: + - in: body + name: data + schema: + $ref: '#/definitions/LimiterConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/createRateLimiterResponse' + security: + - basicAuth: + - '[]' + summary: Create a new rate limiter, the name of limiter must be unique in limiter list. + tags: + - Limiter + /config/rlimiters/{limiter}: + delete: + operationId: deleteRateLimiterRequest + parameters: + - in: path + name: limiter + required: true + type: string + x-go-name: Limiter + responses: + "200": + $ref: '#/responses/deleteRateLimiterResponse' + security: + - basicAuth: + - '[]' + summary: Delete rate limiter by name. + tags: + - Limiter + put: + operationId: updateRateLimiterRequest + parameters: + - in: path + name: limiter + required: true + type: string + x-go-name: Limiter + - in: body + name: data + schema: + $ref: '#/definitions/LimiterConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/updateRateLimiterResponse' + security: + - basicAuth: + - '[]' + summary: Update rate limiter by name, the limiter must already exist. + tags: + - Limiter /config/services: post: operationId: createServiceRequest @@ -1214,6 +1342,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + createConnLimiterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' createHostsResponse: description: successful operation. headers: @@ -1226,6 +1360,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + createRateLimiterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' createResolverResponse: description: successful operation. headers: @@ -1262,6 +1402,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + deleteConnLimiterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' deleteHostsResponse: description: successful operation. headers: @@ -1274,6 +1420,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + deleteRateLimiterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' deleteResolverResponse: description: successful operation. headers: @@ -1322,6 +1474,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + updateConnLimiterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' updateHostsResponse: description: successful operation. headers: @@ -1334,6 +1492,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + updateRateLimiterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' updateResolverResponse: description: successful operation. headers: diff --git a/config/config.go b/config/config.go index da594f2..451d2b2 100644 --- a/config/config.go +++ b/config/config.go @@ -254,7 +254,8 @@ type ServiceConfig struct { Resolver string `yaml:",omitempty" json:"resolver,omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"` Limiter string `yaml:",omitempty" json:"limiter,omitempty"` - CLimiter string `yaml:"climiter,omitempty" json:"limiter,omitempty"` + CLimiter string `yaml:"climiter,omitempty" json:"climiter,omitempty"` + RLimiter string `yaml:"rlimiter,omitempty" json:"rlimiter,omitempty"` Recorders []*RecorderObject `yaml:",omitempty" json:"recorders,omitempty"` Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"` Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"` @@ -311,6 +312,7 @@ type Config struct { Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"` Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"` CLimiters []*LimiterConfig `yaml:"climiters,omitempty" json:"climiters,omitempty"` + RLimiters []*LimiterConfig `yaml:"rlimiters,omitempty" json:"rlimiters,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` diff --git a/config/parsing/chain.go b/config/parsing/chain.go index 9738834..240735f 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -1,6 +1,7 @@ package parsing import ( + "fmt" "time" "github.com/go-gost/core/bypass" @@ -63,11 +64,16 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { nm = mdx.NewMetadata(v.Metadata) } - cr := registry.ConnectorRegistry().Get(v.Connector.Type)( - connector.AuthOption(parseAuth(v.Connector.Auth)), - connector.TLSConfigOption(tlsConfig), - connector.LoggerOption(connectorLogger), - ) + var cr connector.Connector + if rf := registry.ConnectorRegistry().Get(v.Connector.Type); rf != nil { + cr = rf( + connector.AuthOption(parseAuth(v.Connector.Auth)), + connector.TLSConfigOption(tlsConfig), + connector.LoggerOption(connectorLogger), + ) + } else { + return nil, fmt.Errorf("unregistered connector: %s", v.Connector.Type) + } if v.Connector.Metadata == nil { v.Connector.Metadata = make(map[string]any) @@ -97,12 +103,18 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { if nm != nil { ppv = mdutil.GetInt(nm, mdKeyProxyProtocol) } - d := registry.DialerRegistry().Get(v.Dialer.Type)( - dialer.AuthOption(parseAuth(v.Dialer.Auth)), - dialer.TLSConfigOption(tlsConfig), - dialer.LoggerOption(dialerLogger), - dialer.ProxyProtocolOption(ppv), - ) + + var d dialer.Dialer + if rf := registry.DialerRegistry().Get(v.Dialer.Type); rf != nil { + d = rf( + dialer.AuthOption(parseAuth(v.Dialer.Auth)), + dialer.TLSConfigOption(tlsConfig), + dialer.LoggerOption(dialerLogger), + dialer.ProxyProtocolOption(ppv), + ) + } else { + return nil, fmt.Errorf("unregistered dialer: %s", v.Dialer.Type) + } if v.Dialer.Metadata == nil { v.Dialer.Metadata = make(map[string]any) diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 9343d7c..9026256 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/hosts" "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" @@ -22,6 +23,7 @@ import ( xhosts "github.com/go-gost/x/hosts" "github.com/go-gost/x/internal/loader" xconn "github.com/go-gost/x/limiter/conn" + xrate "github.com/go-gost/x/limiter/rate" xtraffic "github.com/go-gost/x/limiter/traffic" xrecorder "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" @@ -408,3 +410,43 @@ func ParseConnLimiter(cfg *config.LimiterConfig) (lim conn.ConnLimiter) { return xconn.NewConnLimiter(opts...) } + +func ParseRateLimiter(cfg *config.LimiterConfig) (lim rate.RateLimiter) { + if cfg == nil { + return nil + } + + var opts []xrate.Option + + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xrate.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // redis list + opts = append(opts, xrate.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis set + opts = append(opts, xrate.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + opts = append(opts, + xrate.LimitsOption(cfg.Limits...), + xrate.ReloadPeriodOption(cfg.Reload), + xrate.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": cfg.Name, + })), + ) + + return xrate.NewRateLimiter(opts...) +} diff --git a/config/parsing/service.go b/config/parsing/service.go index a1baba1..3eb7b4f 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -1,6 +1,7 @@ package parsing import ( + "fmt" "strings" "github.com/go-gost/core/admission" @@ -91,19 +92,24 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { } } - ln := registry.ListenerRegistry().Get(cfg.Listener.Type)( - listener.AddrOption(cfg.Addr), - listener.AutherOption(auther), - listener.AuthOption(parseAuth(cfg.Listener.Auth)), - listener.TLSConfigOption(tlsConfig), - listener.AdmissionOption(admission.AdmissionGroup(admissions...)), - listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)), - listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)), - listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)), - listener.LoggerOption(listenerLogger), - listener.ServiceOption(cfg.Name), - listener.ProxyProtocolOption(ppv), - ) + var ln listener.Listener + if rf := registry.ListenerRegistry().Get(cfg.Listener.Type); rf != nil { + ln = rf( + listener.AddrOption(cfg.Addr), + listener.AutherOption(auther), + listener.AuthOption(parseAuth(cfg.Listener.Auth)), + listener.TLSConfigOption(tlsConfig), + listener.AdmissionOption(admission.AdmissionGroup(admissions...)), + listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)), + listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)), + listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)), + listener.LoggerOption(listenerLogger), + listener.ServiceOption(cfg.Name), + listener.ProxyProtocolOption(ppv), + ) + } else { + return nil, fmt.Errorf("unregistered listener: %s", cfg.Listener.Type) + } if cfg.Listener.Metadata == nil { cfg.Listener.Metadata = make(map[string]any) @@ -161,14 +167,20 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { WithRecorder(recorders...). WithLogger(handlerLogger) - h := registry.HandlerRegistry().Get(cfg.Handler.Type)( - handler.RouterOption(router), - handler.AutherOption(auther), - handler.AuthOption(parseAuth(cfg.Handler.Auth)), - handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)), - handler.TLSConfigOption(tlsConfig), - handler.LoggerOption(handlerLogger), - ) + var h handler.Handler + if rf := registry.HandlerRegistry().Get(cfg.Handler.Type); rf != nil { + h = rf( + handler.RouterOption(router), + handler.AutherOption(auther), + handler.AuthOption(parseAuth(cfg.Handler.Auth)), + handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)), + handler.TLSConfigOption(tlsConfig), + handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)), + handler.LoggerOption(handlerLogger), + ) + } else { + return nil, fmt.Errorf("unregistered handler: %s", cfg.Handler.Type) + } if forwarder, ok := h.(handler.Forwarder); ok { forwarder.Forward(parseForwarder(cfg.Forwarder)) diff --git a/go.mod b/go.mod index 357c65c..81e06dd 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.7.7 - github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe + github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 diff --git a/go.sum b/go.sum index 015f8a4..1548ac6 100644 --- a/go.sum +++ b/go.sum @@ -98,8 +98,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe h1:zYcwKOe9ceGpwin84bH7J/DRZ4g9MhU+xOsTMxqOuNw= -github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= +github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b h1:fWUPYFp0W/6GEhL0wrURGPQN2AQHhf4IZKiALJJOJh8= +github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= diff --git a/handler/dns/handler.go b/handler/dns/handler.go index 96c76c4..d227e4c 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -130,6 +130,10 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + b := bufpool.Get(h.md.bufferSize) defer bufpool.Put(b) @@ -152,6 +156,18 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. return nil } +func (h *dnsHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} + func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) { mq := dns.Msg{} if err := mq.Unpack(msg); err != nil { diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index a93638f..1c90b55 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -77,6 +77,10 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + target := h.group.Next(ctx) if target == nil { err := errors.New("target not available") @@ -119,3 +123,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } + +func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index faecbb6..449b73a 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -71,6 +71,10 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + target := h.group.Next(ctx) if target == nil { err := errors.New("target not available") @@ -113,3 +117,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } + +func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/http/handler.go b/handler/http/handler.go index 2cda386..9241c82 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -75,6 +75,10 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { log.Error(err) @@ -337,3 +341,15 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. resp.Write(conn) return } + +func (h *httpHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 5dc974a..e6f356c 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -75,6 +75,10 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + v, ok := conn.(md.Metadatable) if !ok || v == nil { err := errors.New("wrong connection type") @@ -345,3 +349,15 @@ func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) _, err := io.Copy(flushWriter{w}, resp.Body) return err } + +func (h *http2Handler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/redirect/tcp/handler.go b/handler/redirect/tcp/handler.go index 04942fe..55caa62 100644 --- a/handler/redirect/tcp/handler.go +++ b/handler/redirect/tcp/handler.go @@ -74,6 +74,10 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + var dstAddr net.Addr if h.md.tproxy { @@ -269,3 +273,15 @@ func (h *redirectHandler) getServerName(ctx context.Context, r io.Reader) (host return } + +func (h *redirectHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/redirect/udp/handler.go b/handler/redirect/udp/handler.go index 94a7718..0d13f6a 100644 --- a/handler/redirect/udp/handler.go +++ b/handler/redirect/udp/handler.go @@ -63,6 +63,10 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + dstAddr := conn.LocalAddr() log = log.WithFields(map[string]any{ @@ -92,3 +96,15 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han return nil } + +func (h *redirectHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/relay/handler.go b/handler/relay/handler.go index d063f39..e0c7298 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -75,6 +75,10 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + if h.md.readTimeout > 0 { conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } @@ -145,3 +149,15 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle } return ErrUnknownCmd } + +func (h *relayHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 9542cab..01bd14d 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -76,6 +76,10 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + var hdr [dissector.RecordHeaderLen]byte if _, err := io.ReadFull(conn, hdr[:]); err != nil { log.Error(err) @@ -251,3 +255,15 @@ func (h *sniHandler) decodeServerName(s string) (string, error) { } return string(v), nil } + +func (h *sniHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index 7879ae5..e7e60b3 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -72,6 +72,10 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + if h.md.readTimeout > 0 { conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } @@ -150,3 +154,15 @@ func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *goso // TODO: bind return ErrUnimplemented } + +func (h *socks4Handler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/socks/v5/handler.go b/handler/socks/v5/handler.go index 138851c..65266d3 100644 --- a/handler/socks/v5/handler.go +++ b/handler/socks/v5/handler.go @@ -78,6 +78,10 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + if h.md.readTimeout > 0 { conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } @@ -113,3 +117,15 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl return err } } + +func (h *socks5Handler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/ss/handler.go b/handler/ss/handler.go index 3f69bee..2825c04 100644 --- a/handler/ss/handler.go +++ b/handler/ss/handler.go @@ -76,6 +76,10 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + if h.cipher != nil { conn = ss.ShadowConn(h.cipher.StreamConn(conn), nil) } @@ -117,3 +121,15 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H return nil } + +func (h *ssHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/ss/udp/handler.go b/handler/ss/udp/handler.go index 6f6f60e..24d451b 100644 --- a/handler/ss/udp/handler.go +++ b/handler/ss/udp/handler.go @@ -77,6 +77,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + pc, ok := conn.(net.PacketConn) if ok { if h.cipher != nil { @@ -186,3 +190,15 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (er return <-errc } + +func (h *ssuHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/sshd/handler.go b/handler/sshd/handler.go index 0f42468..813699c 100644 --- a/handler/sshd/handler.go +++ b/handler/sshd/handler.go @@ -66,6 +66,10 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand "local": conn.LocalAddr().String(), }) + if !h.checkRateLimit(conn.RemoteAddr()) { + return nil + } + switch cc := conn.(type) { case *sshd_util.DirectForwardConn: return h.handleDirectForward(ctx, cc, log) @@ -217,6 +221,18 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti return nil } +func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} + func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { host, portString, err := net.SplitHostPort(addr.String()) if err != nil { diff --git a/limiter/conn/conn.go b/limiter/conn/conn.go index eaf8619..961fc25 100644 --- a/limiter/conn/conn.go +++ b/limiter/conn/conn.go @@ -102,7 +102,7 @@ type connLimiter struct { ipLimits map[string]ConnLimitGenerator cidrLimits cidranger.Ranger limits map[string]limiter.Limiter - mu sync.RWMutex + mu sync.Mutex cancelFunc context.CancelFunc options options } diff --git a/limiter/conn/limiter.go b/limiter/conn/limiter.go index 0367555..9920a98 100644 --- a/limiter/conn/limiter.go +++ b/limiter/conn/limiter.go @@ -20,7 +20,7 @@ func (l *llimiter) Limit() int { } func (l *llimiter) Allow(n int) bool { - if atomic.AddInt64(&l.current, int64(n)) >= int64(l.limit) { + if atomic.AddInt64(&l.current, int64(n)) > int64(l.limit) { if n > 0 { atomic.AddInt64(&l.current, -int64(n)) } diff --git a/limiter/rate/generator.go b/limiter/rate/generator.go new file mode 100644 index 0000000..92c4614 --- /dev/null +++ b/limiter/rate/generator.go @@ -0,0 +1,44 @@ +package rate + +import ( + "github.com/go-gost/core/limiter/rate" + limiter "github.com/go-gost/core/limiter/rate" +) + +type RateLimitGenerator interface { + Limiter() limiter.Limiter +} + +type rateLimitGenerator struct { + r float64 +} + +func NewRateLimitGenerator(r float64) RateLimitGenerator { + return &rateLimitGenerator{ + r: r, + } +} + +func (p *rateLimitGenerator) Limiter() limiter.Limiter { + if p == nil || p.r <= 0 { + return nil + } + return NewLimiter(p.r, int(p.r)+1) +} + +type rateLimitSingleGenerator struct { + limiter rate.Limiter +} + +func NewRateLimitSingleGenerator(r float64) RateLimitGenerator { + p := &rateLimitSingleGenerator{} + if r > 0 { + p.limiter = NewLimiter(r, int(r)+1) + } + + return p +} + +func (p *rateLimitSingleGenerator) Limiter() limiter.Limiter { + return p.limiter +} diff --git a/limiter/rate/limiter.go b/limiter/rate/limiter.go new file mode 100644 index 0000000..8057185 --- /dev/null +++ b/limiter/rate/limiter.go @@ -0,0 +1,26 @@ +package rate + +import ( + "time" + + limiter "github.com/go-gost/core/limiter/rate" + "golang.org/x/time/rate" +) + +type rlimiter struct { + limiter *rate.Limiter +} + +func NewLimiter(r float64, b int) limiter.Limiter { + return &rlimiter{ + limiter: rate.NewLimiter(rate.Limit(r), b), + } +} + +func (l *rlimiter) Allow(n int) bool { + return l.limiter.AllowN(time.Now(), n) +} + +func (l *rlimiter) Limit() float64 { + return float64(l.limiter.Limit()) +} diff --git a/limiter/rate/rate.go b/limiter/rate/rate.go new file mode 100644 index 0000000..a52ecb8 --- /dev/null +++ b/limiter/rate/rate.go @@ -0,0 +1,353 @@ +package rate + +import ( + "bufio" + "context" + "io" + "net" + "sort" + "strconv" + "strings" + "sync" + "time" + + limiter "github.com/go-gost/core/limiter/rate" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/internal/loader" + "github.com/yl2chen/cidranger" +) + +const ( + GlobalLimitKey = "$" + IPLimitKey = "$$" +) + +type limiterGroup struct { + limiters []limiter.Limiter +} + +func newLimiterGroup(limiters ...limiter.Limiter) *limiterGroup { + sort.Slice(limiters, func(i, j int) bool { + return limiters[i].Limit() < limiters[j].Limit() + }) + return &limiterGroup{limiters: limiters} +} + +func (l *limiterGroup) Allow(n int) (b bool) { + b = true + for i := range l.limiters { + if v := l.limiters[i].Allow(n); !v { + b = false + } + } + return +} + +func (l *limiterGroup) Limit() float64 { + if len(l.limiters) == 0 { + return 0 + } + + return l.limiters[0].Limit() +} + +type options struct { + limits []string + fileLoader loader.Loader + redisLoader loader.Loader + period time.Duration + logger logger.Logger +} + +type Option func(opts *options) + +func LimitsOption(limits ...string) Option { + return func(opts *options) { + opts.limits = limits + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type rateLimiter struct { + ipLimits map[string]RateLimitGenerator + cidrLimits cidranger.Ranger + limits map[string]limiter.Limiter + mu sync.Mutex + cancelFunc context.CancelFunc + options options +} + +func NewRateLimiter(opts ...Option) limiter.RateLimiter { + var options options + for _, opt := range opts { + opt(&options) + } + + ctx, cancel := context.WithCancel(context.TODO()) + lim := &rateLimiter{ + ipLimits: make(map[string]RateLimitGenerator), + cidrLimits: cidranger.NewPCTrieRanger(), + limits: make(map[string]limiter.Limiter), + options: options, + cancelFunc: cancel, + } + + if err := lim.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) + } + if lim.options.period > 0 { + go lim.periodReload(ctx) + } + return lim +} + +func (l *rateLimiter) Limiter(key string) limiter.Limiter { + l.mu.Lock() + defer l.mu.Unlock() + + if lim, ok := l.limits[key]; ok { + return lim + } + + var lims []limiter.Limiter + + if ip := net.ParseIP(key); ip != nil { + found := false + if p := l.ipLimits[key]; p != nil { + if lim := p.Limiter(); lim != nil { + lims = append(lims, lim) + found = true + } + } + if !found { + if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 { + if v, _ := p[0].(*cidrLimitEntry); v != nil { + if lim := v.limit.Limiter(); lim != nil { + lims = append(lims, lim) + } + } + } + } + } + + if len(lims) == 0 { + if p := l.ipLimits[IPLimitKey]; p != nil { + if lim := p.Limiter(); lim != nil { + lims = append(lims, lim) + } + } + } + + if p := l.ipLimits[GlobalLimitKey]; p != nil { + if lim := p.Limiter(); lim != nil { + lims = append(lims, lim) + } + } + + var lim limiter.Limiter + if len(lims) > 0 { + lim = newLimiterGroup(lims...) + } + l.limits[key] = lim + + if lim != nil && l.options.logger != nil { + l.options.logger.Debugf("input limit for %s: %d", key, lim.Limit()) + } + + return lim +} + +func (l *rateLimiter) periodReload(ctx context.Context) error { + period := l.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := l.reload(ctx); err != nil { + l.options.logger.Warnf("reload: %v", err) + // return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (l *rateLimiter) reload(ctx context.Context) error { + v, err := l.load(ctx) + if err != nil { + return err + } + + lines := append(l.options.limits, v...) + + ipLimits := make(map[string]RateLimitGenerator) + cidrLimits := cidranger.NewPCTrieRanger() + + for _, s := range lines { + key, limit := l.parseLimit(s) + if key == "" || limit <= 0 { + continue + } + switch key { + case GlobalLimitKey: + ipLimits[key] = NewRateLimitSingleGenerator(limit) + case IPLimitKey: + ipLimits[key] = NewRateLimitGenerator(limit) + default: + if ip := net.ParseIP(key); ip != nil { + ipLimits[key] = NewRateLimitGenerator(limit) + break + } + if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil { + cidrLimits.Insert(&cidrLimitEntry{ + ipNet: *ipNet, + limit: NewRateLimitGenerator(limit), + }) + } + } + } + + l.mu.Lock() + defer l.mu.Unlock() + + l.ipLimits = ipLimits + l.cidrLimits = cidrLimits + l.limits = make(map[string]limiter.Limiter) + + return nil +} + +func (l *rateLimiter) load(ctx context.Context) (patterns []string, err error) { + if l.options.fileLoader != nil { + if lister, ok := l.options.fileLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + l.options.logger.Warnf("file loader: %v", er) + } + for _, s := range list { + if line := l.parseLine(s); line != "" { + patterns = append(patterns, line) + } + } + } else { + r, er := l.options.fileLoader.Load(ctx) + if er != nil { + l.options.logger.Warnf("file loader: %v", er) + } + if v, _ := l.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } + } + } + if l.options.redisLoader != nil { + if lister, ok := l.options.redisLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + l.options.logger.Warnf("redis loader: %v", er) + } + patterns = append(patterns, list...) + } else { + r, er := l.options.redisLoader.Load(ctx) + if er != nil { + l.options.logger.Warnf("redis loader: %v", er) + } + if v, _ := l.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } + } + } + + l.options.logger.Debugf("load items %d", len(patterns)) + return +} + +func (l *rateLimiter) parsePatterns(r io.Reader) (patterns []string, err error) { + if r == nil { + return + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + if line := l.parseLine(scanner.Text()); line != "" { + patterns = append(patterns, line) + } + } + + err = scanner.Err() + return +} + +func (l *rateLimiter) parseLine(s string) string { + if n := strings.IndexByte(s, '#'); n >= 0 { + s = s[:n] + } + return strings.TrimSpace(s) +} + +func (l *rateLimiter) parseLimit(s string) (key string, limit float64) { + s = strings.Replace(s, "\t", " ", -1) + s = strings.TrimSpace(s) + var ss []string + for _, v := range strings.Split(s, " ") { + if v != "" { + ss = append(ss, v) + } + } + if len(ss) < 2 { + return + } + + key = ss[0] + limit, _ = strconv.ParseFloat(ss[1], 64) + + return +} + +func (l *rateLimiter) Close() error { + l.cancelFunc() + if l.options.fileLoader != nil { + l.options.fileLoader.Close() + } + if l.options.redisLoader != nil { + l.options.redisLoader.Close() + } + return nil +} + +type cidrLimitEntry struct { + ipNet net.IPNet + limit RateLimitGenerator +} + +func (p *cidrLimitEntry) Network() net.IPNet { + return p.ipNet +} diff --git a/limiter/traffic/traffic.go b/limiter/traffic/traffic.go index 22de113..54260e2 100644 --- a/limiter/traffic/traffic.go +++ b/limiter/traffic/traffic.go @@ -95,7 +95,7 @@ type trafficLimiter struct { cidrLimits cidranger.Ranger inLimits map[string]limiter.Limiter outLimits map[string]limiter.Limiter - mu sync.RWMutex + mu sync.Mutex cancelFunc context.CancelFunc options options } diff --git a/limiter/traffic/wrapper/conn.go b/limiter/traffic/wrapper/conn.go index 7dbaeae..8b171b8 100644 --- a/limiter/traffic/wrapper/conn.go +++ b/limiter/traffic/wrapper/conn.go @@ -20,9 +20,9 @@ var ( // serverConn is a server side Conn with metrics supported. type serverConn struct { net.Conn - rbuf bytes.Buffer - raddr string - rlimiter limiter.TrafficLimiter + rbuf bytes.Buffer + raddr string + limiter limiter.TrafficLimiter } func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn { @@ -31,19 +31,19 @@ func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn { } host, _, _ := net.SplitHostPort(c.RemoteAddr().String()) return &serverConn{ - Conn: c, - rlimiter: rlimiter, - raddr: host, + Conn: c, + limiter: rlimiter, + raddr: host, } } func (c *serverConn) Read(b []byte) (n int, err error) { - if c.rlimiter == nil || - c.rlimiter.In(c.raddr) == nil { + if c.limiter == nil || + c.limiter.In(c.raddr) == nil { return c.Conn.Read(b) } - limiter := c.rlimiter.In(c.raddr) + limiter := c.limiter.In(c.raddr) if c.rbuf.Len() > 0 { burst := len(b) @@ -70,12 +70,12 @@ func (c *serverConn) Read(b []byte) (n int, err error) { } func (c *serverConn) Write(b []byte) (n int, err error) { - if c.rlimiter == nil || - c.rlimiter.Out(c.raddr) == nil { + if c.limiter == nil || + c.limiter.Out(c.raddr) == nil { return c.Conn.Write(b) } - limiter := c.rlimiter.Out(c.raddr) + limiter := c.limiter.Out(c.raddr) nn := 0 for len(b) > 0 { nn, err = c.Conn.Write(b[:limiter.Wait(context.Background(), len(b))]) diff --git a/registry/limiter.go b/registry/limiter.go index 8b81b81..8f38727 100644 --- a/registry/limiter.go +++ b/registry/limiter.go @@ -2,6 +2,7 @@ package registry import ( "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" ) @@ -82,3 +83,38 @@ func (w *connLimiterWrapper) Limiter(key string) conn.Limiter { } return v.Limiter(key) } + +type rateLimiterRegistry struct { + registry +} + +func (r *rateLimiterRegistry) Register(name string, v rate.RateLimiter) error { + return r.registry.Register(name, v) +} + +func (r *rateLimiterRegistry) Get(name string) rate.RateLimiter { + if name != "" { + return &rateLimiterWrapper{name: name, r: r} + } + return nil +} + +func (r *rateLimiterRegistry) get(name string) rate.RateLimiter { + if v := r.registry.Get(name); v != nil { + return v.(rate.RateLimiter) + } + return nil +} + +type rateLimiterWrapper struct { + name string + r *rateLimiterRegistry +} + +func (w *rateLimiterWrapper) Limiter(key string) rate.Limiter { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Limiter(key) +} diff --git a/registry/registry.go b/registry/registry.go index 0e101e8..9d2ef45 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -11,6 +11,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/hosts" "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/recorder" "github.com/go-gost/core/resolver" @@ -27,16 +28,18 @@ var ( dialerReg Registry[NewDialer] = &dialerRegistry{} connectorReg Registry[NewConnector] = &connectorRegistry{} - serviceReg Registry[service.Service] = &serviceRegistry{} - chainReg Registry[chain.Chainer] = &chainRegistry{} - autherReg Registry[auth.Authenticator] = &autherRegistry{} - admissionReg Registry[admission.Admission] = &admissionRegistry{} - bypassReg Registry[bypass.Bypass] = &bypassRegistry{} - resolverReg Registry[resolver.Resolver] = &resolverRegistry{} - hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} - recorderReg Registry[recorder.Recorder] = &recorderRegistry{} + serviceReg Registry[service.Service] = &serviceRegistry{} + chainReg Registry[chain.Chainer] = &chainRegistry{} + autherReg Registry[auth.Authenticator] = &autherRegistry{} + admissionReg Registry[admission.Admission] = &admissionRegistry{} + bypassReg Registry[bypass.Bypass] = &bypassRegistry{} + resolverReg Registry[resolver.Resolver] = &resolverRegistry{} + hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} + recorderReg Registry[recorder.Recorder] = &recorderRegistry{} + trafficLimiterReg Registry[traffic.TrafficLimiter] = &trafficLimiterRegistry{} connLimiterReg Registry[conn.ConnLimiter] = &connLimiterRegistry{} + rateLimiterReg Registry[rate.RateLimiter] = &rateLimiterRegistry{} ) type Registry[T any] interface { @@ -138,3 +141,7 @@ func TrafficLimiterRegistry() Registry[traffic.TrafficLimiter] { func ConnLimiterRegistry() Registry[conn.ConnLimiter] { return connLimiterReg } + +func RateLimiterRegistry() Registry[rate.RateLimiter] { + return rateLimiterReg +}