update limiter

This commit is contained in:
ginuerzh
2022-09-14 20:00:35 +08:00
parent 91c12882f5
commit 01d7dc77c6
34 changed files with 1171 additions and 79 deletions

View File

@ -120,7 +120,7 @@ func updateConnLimiter(ctx *gin.Context) {
type deleteConnLimiterRequest struct {
// in: path
// required: true
Limiter string `uri:"Limiter" json:"Limiter"`
Limiter string `uri:"limiter" json:"limiter"`
}
// successful operation.

View File

@ -120,7 +120,7 @@ func updateLimiter(ctx *gin.Context) {
type deleteLimiterRequest struct {
// in: path
// required: true
Limiter string `uri:"Limiter" json:"Limiter"`
Limiter string `uri:"limiter" json:"limiter"`
}
// successful operation.

166
api/config_rate_limiter.go Normal file
View File

@ -0,0 +1,166 @@
package api
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-gost/x/config"
"github.com/go-gost/x/config/parsing"
"github.com/go-gost/x/registry"
)
// swagger:parameters createRateLimiterRequest
type createRateLimiterRequest struct {
// in: body
Data config.LimiterConfig `json:"data"`
}
// successful operation.
// swagger:response createRateLimiterResponse
type createRateLimiterResponse struct {
Data Response
}
func createRateLimiter(ctx *gin.Context) {
// swagger:route POST /config/rlimiters Limiter createRateLimiterRequest
//
// Create a new rate limiter, the name of limiter must be unique in limiter list.
//
// Security:
// basicAuth: []
//
// Responses:
// 200: createRateLimiterResponse
var req createRateLimiterRequest
ctx.ShouldBindJSON(&req.Data)
if req.Data.Name == "" {
writeError(ctx, ErrInvalid)
return
}
v := parsing.ParseRateLimiter(&req.Data)
if err := registry.RateLimiterRegistry().Register(req.Data.Name, v); err != nil {
writeError(ctx, ErrDup)
return
}
cfg := config.Global()
cfg.CLimiters = append(cfg.CLimiters, &req.Data)
config.SetGlobal(cfg)
ctx.JSON(http.StatusOK, Response{
Msg: "OK",
})
}
// swagger:parameters updateRateLimiterRequest
type updateRateLimiterRequest struct {
// in: path
// required: true
Limiter string `uri:"limiter" json:"limiter"`
// in: body
Data config.LimiterConfig `json:"data"`
}
// successful operation.
// swagger:response updateRateLimiterResponse
type updateRateLimiterResponse struct {
Data Response
}
func updateRateLimiter(ctx *gin.Context) {
// swagger:route PUT /config/rlimiters/{limiter} Limiter updateRateLimiterRequest
//
// Update rate limiter by name, the limiter must already exist.
//
// Security:
// basicAuth: []
//
// Responses:
// 200: updateRateLimiterResponse
var req updateRateLimiterRequest
ctx.ShouldBindUri(&req)
ctx.ShouldBindJSON(&req.Data)
if !registry.RateLimiterRegistry().IsRegistered(req.Limiter) {
writeError(ctx, ErrNotFound)
return
}
req.Data.Name = req.Limiter
v := parsing.ParseRateLimiter(&req.Data)
registry.RateLimiterRegistry().Unregister(req.Limiter)
if err := registry.RateLimiterRegistry().Register(req.Limiter, v); err != nil {
writeError(ctx, ErrDup)
return
}
cfg := config.Global()
for i := range cfg.Limiters {
if cfg.Limiters[i].Name == req.Limiter {
cfg.Limiters[i] = &req.Data
break
}
}
config.SetGlobal(cfg)
ctx.JSON(http.StatusOK, Response{
Msg: "OK",
})
}
// swagger:parameters deleteRateLimiterRequest
type deleteRateLimiterRequest struct {
// in: path
// required: true
Limiter string `uri:"limiter" json:"limiter"`
}
// successful operation.
// swagger:response deleteRateLimiterResponse
type deleteRateLimiterResponse struct {
Data Response
}
func deleteRateLimiter(ctx *gin.Context) {
// swagger:route DELETE /config/rlimiters/{limiter} Limiter deleteRateLimiterRequest
//
// Delete rate limiter by name.
//
// Security:
// basicAuth: []
//
// Responses:
// 200: deleteRateLimiterResponse
var req deleteRateLimiterRequest
ctx.ShouldBindUri(&req)
if !registry.RateLimiterRegistry().IsRegistered(req.Limiter) {
writeError(ctx, ErrNotFound)
return
}
registry.RateLimiterRegistry().Unregister(req.Limiter)
cfg := config.Global()
limiteres := cfg.Limiters
cfg.Limiters = nil
for _, s := range limiteres {
if s.Name == req.Limiter {
continue
}
cfg.Limiters = append(cfg.Limiters, s)
}
config.SetGlobal(cfg)
ctx.JSON(http.StatusOK, Response{
Msg: "OK",
})
}

View File

@ -137,4 +137,8 @@ func registerConfig(config *gin.RouterGroup) {
config.POST("/climiters", createConnLimiter)
config.PUT("/climiters/:limiter", updateConnLimiter)
config.DELETE("/climiters/:limiter", deleteConnLimiter)
config.POST("/rlimiters", createRateLimiter)
config.PUT("/rlimiters/:limiter", updateRateLimiter)
config.DELETE("/rlimiters/:limiter", deleteRateLimiter)
}

View File

@ -151,6 +151,11 @@ definitions:
$ref: '#/definitions/ChainConfig'
type: array
x-go-name: Chains
climiters:
items:
$ref: '#/definitions/LimiterConfig'
type: array
x-go-name: CLimiters
hosts:
items:
$ref: '#/definitions/HostsConfig'
@ -177,6 +182,11 @@ definitions:
$ref: '#/definitions/ResolverConfig'
type: array
x-go-name: Resolvers
rlimiters:
items:
$ref: '#/definitions/LimiterConfig'
type: array
x-go-name: RLimiters
services:
items:
$ref: '#/definitions/ServiceConfig'
@ -358,11 +368,20 @@ definitions:
x-go-package: github.com/go-gost/x/config
LimiterConfig:
properties:
file:
$ref: '#/definitions/FileLoader'
limits:
items:
type: string
type: array
x-go-name: Limits
name:
type: string
x-go-name: Name
rate:
$ref: '#/definitions/RateLimiterConfig'
redis:
$ref: '#/definitions/RedisLoader'
reload:
$ref: '#/definitions/Duration'
type: object
x-go-package: github.com/go-gost/x/config
ListenerConfig:
@ -483,21 +502,6 @@ definitions:
x-go-name: Addr
type: object
x-go-package: github.com/go-gost/x/config
RateLimiterConfig:
properties:
file:
$ref: '#/definitions/FileLoader'
limits:
items:
type: string
type: array
x-go-name: Limits
redis:
$ref: '#/definitions/RedisLoader'
reload:
$ref: '#/definitions/Duration'
type: object
x-go-package: github.com/go-gost/x/config
RecorderConfig:
properties:
file:
@ -616,6 +620,9 @@ definitions:
type: string
type: array
x-go-name: Bypasses
climiter:
type: string
x-go-name: CLimiter
forwarder:
$ref: '#/definitions/ForwarderConfig'
handler:
@ -624,6 +631,7 @@ definitions:
type: string
x-go-name: Hosts
interface:
description: DEPRECATED by metadata.interface since beta.5
type: string
x-go-name: Interface
limiter:
@ -646,6 +654,9 @@ definitions:
resolver:
type: string
x-go-name: Resolver
rlimiter:
type: string
x-go-name: RLimiter
sockopts:
$ref: '#/definitions/SockOptsConfig'
type: object
@ -956,6 +967,64 @@ paths:
summary: Update chain by name, the chain must already exist.
tags:
- Chain
/config/climiters:
post:
operationId: createConnLimiterRequest
parameters:
- in: body
name: data
schema:
$ref: '#/definitions/LimiterConfig'
x-go-name: Data
responses:
"200":
$ref: '#/responses/createConnLimiterResponse'
security:
- basicAuth:
- '[]'
summary: Create a new conn limiter, the name of limiter must be unique in limiter list.
tags:
- Limiter
/config/climiters/{limiter}:
delete:
operationId: deleteConnLimiterRequest
parameters:
- in: path
name: limiter
required: true
type: string
x-go-name: Limiter
responses:
"200":
$ref: '#/responses/deleteConnLimiterResponse'
security:
- basicAuth:
- '[]'
summary: Delete conn limiter by name.
tags:
- Limiter
put:
operationId: updateConnLimiterRequest
parameters:
- in: path
name: limiter
required: true
type: string
x-go-name: Limiter
- in: body
name: data
schema:
$ref: '#/definitions/LimiterConfig'
x-go-name: Data
responses:
"200":
$ref: '#/responses/updateConnLimiterResponse'
security:
- basicAuth:
- '[]'
summary: Update conn limiter by name, the limiter must already exist.
tags:
- Limiter
/config/hosts:
post:
operationId: createHostsRequest
@ -1037,9 +1106,10 @@ paths:
operationId: deleteLimiterRequest
parameters:
- in: path
name: Limiter
name: limiter
required: true
type: string
x-go-name: Limiter
responses:
"200":
$ref: '#/responses/deleteLimiterResponse'
@ -1129,6 +1199,64 @@ paths:
summary: Update resolver by name, the resolver must already exist.
tags:
- Resolver
/config/rlimiters:
post:
operationId: createRateLimiterRequest
parameters:
- in: body
name: data
schema:
$ref: '#/definitions/LimiterConfig'
x-go-name: Data
responses:
"200":
$ref: '#/responses/createRateLimiterResponse'
security:
- basicAuth:
- '[]'
summary: Create a new rate limiter, the name of limiter must be unique in limiter list.
tags:
- Limiter
/config/rlimiters/{limiter}:
delete:
operationId: deleteRateLimiterRequest
parameters:
- in: path
name: limiter
required: true
type: string
x-go-name: Limiter
responses:
"200":
$ref: '#/responses/deleteRateLimiterResponse'
security:
- basicAuth:
- '[]'
summary: Delete rate limiter by name.
tags:
- Limiter
put:
operationId: updateRateLimiterRequest
parameters:
- in: path
name: limiter
required: true
type: string
x-go-name: Limiter
- in: body
name: data
schema:
$ref: '#/definitions/LimiterConfig'
x-go-name: Data
responses:
"200":
$ref: '#/responses/updateRateLimiterResponse'
security:
- basicAuth:
- '[]'
summary: Update rate limiter by name, the limiter must already exist.
tags:
- Limiter
/config/services:
post:
operationId: createServiceRequest
@ -1214,6 +1342,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
createConnLimiterResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
createHostsResponse:
description: successful operation.
headers:
@ -1226,6 +1360,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
createRateLimiterResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
createResolverResponse:
description: successful operation.
headers:
@ -1262,6 +1402,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
deleteConnLimiterResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
deleteHostsResponse:
description: successful operation.
headers:
@ -1274,6 +1420,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
deleteRateLimiterResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
deleteResolverResponse:
description: successful operation.
headers:
@ -1322,6 +1474,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
updateConnLimiterResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
updateHostsResponse:
description: successful operation.
headers:
@ -1334,6 +1492,12 @@ responses:
Data: {}
schema:
$ref: '#/definitions/Response'
updateRateLimiterResponse:
description: successful operation.
headers:
Data: {}
schema:
$ref: '#/definitions/Response'
updateResolverResponse:
description: successful operation.
headers:

View File

@ -254,7 +254,8 @@ type ServiceConfig struct {
Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
Hosts string `yaml:",omitempty" json:"hosts,omitempty"`
Limiter string `yaml:",omitempty" json:"limiter,omitempty"`
CLimiter string `yaml:"climiter,omitempty" json:"limiter,omitempty"`
CLimiter string `yaml:"climiter,omitempty" json:"climiter,omitempty"`
RLimiter string `yaml:"rlimiter,omitempty" json:"rlimiter,omitempty"`
Recorders []*RecorderObject `yaml:",omitempty" json:"recorders,omitempty"`
Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"`
Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"`
@ -311,6 +312,7 @@ type Config struct {
Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"`
Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"`
CLimiters []*LimiterConfig `yaml:"climiters,omitempty" json:"climiters,omitempty"`
RLimiters []*LimiterConfig `yaml:"rlimiters,omitempty" json:"rlimiters,omitempty"`
TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
Log *LogConfig `yaml:",omitempty" json:"log,omitempty"`
Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"`

View File

@ -1,6 +1,7 @@
package parsing
import (
"fmt"
"time"
"github.com/go-gost/core/bypass"
@ -63,11 +64,16 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
nm = mdx.NewMetadata(v.Metadata)
}
cr := registry.ConnectorRegistry().Get(v.Connector.Type)(
connector.AuthOption(parseAuth(v.Connector.Auth)),
connector.TLSConfigOption(tlsConfig),
connector.LoggerOption(connectorLogger),
)
var cr connector.Connector
if rf := registry.ConnectorRegistry().Get(v.Connector.Type); rf != nil {
cr = rf(
connector.AuthOption(parseAuth(v.Connector.Auth)),
connector.TLSConfigOption(tlsConfig),
connector.LoggerOption(connectorLogger),
)
} else {
return nil, fmt.Errorf("unregistered connector: %s", v.Connector.Type)
}
if v.Connector.Metadata == nil {
v.Connector.Metadata = make(map[string]any)
@ -97,12 +103,18 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
if nm != nil {
ppv = mdutil.GetInt(nm, mdKeyProxyProtocol)
}
d := registry.DialerRegistry().Get(v.Dialer.Type)(
dialer.AuthOption(parseAuth(v.Dialer.Auth)),
dialer.TLSConfigOption(tlsConfig),
dialer.LoggerOption(dialerLogger),
dialer.ProxyProtocolOption(ppv),
)
var d dialer.Dialer
if rf := registry.DialerRegistry().Get(v.Dialer.Type); rf != nil {
d = rf(
dialer.AuthOption(parseAuth(v.Dialer.Auth)),
dialer.TLSConfigOption(tlsConfig),
dialer.LoggerOption(dialerLogger),
dialer.ProxyProtocolOption(ppv),
)
} else {
return nil, fmt.Errorf("unregistered dialer: %s", v.Dialer.Type)
}
if v.Dialer.Metadata == nil {
v.Dialer.Metadata = make(map[string]any)

View File

@ -10,6 +10,7 @@ import (
"github.com/go-gost/core/chain"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/limiter/conn"
"github.com/go-gost/core/limiter/rate"
"github.com/go-gost/core/limiter/traffic"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/recorder"
@ -22,6 +23,7 @@ import (
xhosts "github.com/go-gost/x/hosts"
"github.com/go-gost/x/internal/loader"
xconn "github.com/go-gost/x/limiter/conn"
xrate "github.com/go-gost/x/limiter/rate"
xtraffic "github.com/go-gost/x/limiter/traffic"
xrecorder "github.com/go-gost/x/recorder"
"github.com/go-gost/x/registry"
@ -408,3 +410,43 @@ func ParseConnLimiter(cfg *config.LimiterConfig) (lim conn.ConnLimiter) {
return xconn.NewConnLimiter(opts...)
}
func ParseRateLimiter(cfg *config.LimiterConfig) (lim rate.RateLimiter) {
if cfg == nil {
return nil
}
var opts []xrate.Option
if cfg.File != nil && cfg.File.Path != "" {
opts = append(opts, xrate.FileLoaderOption(loader.FileLoader(cfg.File.Path)))
}
if cfg.Redis != nil && cfg.Redis.Addr != "" {
switch cfg.Redis.Type {
case "list": // redis list
opts = append(opts, xrate.RedisLoaderOption(loader.RedisListLoader(
cfg.Redis.Addr,
loader.DBRedisLoaderOption(cfg.Redis.DB),
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
loader.KeyRedisLoaderOption(cfg.Redis.Key),
)))
default: // redis set
opts = append(opts, xrate.RedisLoaderOption(loader.RedisSetLoader(
cfg.Redis.Addr,
loader.DBRedisLoaderOption(cfg.Redis.DB),
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
loader.KeyRedisLoaderOption(cfg.Redis.Key),
)))
}
}
opts = append(opts,
xrate.LimitsOption(cfg.Limits...),
xrate.ReloadPeriodOption(cfg.Reload),
xrate.LoggerOption(logger.Default().WithFields(map[string]any{
"kind": "limiter",
"limiter": cfg.Name,
})),
)
return xrate.NewRateLimiter(opts...)
}

View File

@ -1,6 +1,7 @@
package parsing
import (
"fmt"
"strings"
"github.com/go-gost/core/admission"
@ -91,19 +92,24 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
}
}
ln := registry.ListenerRegistry().Get(cfg.Listener.Type)(
listener.AddrOption(cfg.Addr),
listener.AutherOption(auther),
listener.AuthOption(parseAuth(cfg.Listener.Auth)),
listener.TLSConfigOption(tlsConfig),
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)),
listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)),
listener.LoggerOption(listenerLogger),
listener.ServiceOption(cfg.Name),
listener.ProxyProtocolOption(ppv),
)
var ln listener.Listener
if rf := registry.ListenerRegistry().Get(cfg.Listener.Type); rf != nil {
ln = rf(
listener.AddrOption(cfg.Addr),
listener.AutherOption(auther),
listener.AuthOption(parseAuth(cfg.Listener.Auth)),
listener.TLSConfigOption(tlsConfig),
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)),
listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)),
listener.LoggerOption(listenerLogger),
listener.ServiceOption(cfg.Name),
listener.ProxyProtocolOption(ppv),
)
} else {
return nil, fmt.Errorf("unregistered listener: %s", cfg.Listener.Type)
}
if cfg.Listener.Metadata == nil {
cfg.Listener.Metadata = make(map[string]any)
@ -161,14 +167,20 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
WithRecorder(recorders...).
WithLogger(handlerLogger)
h := registry.HandlerRegistry().Get(cfg.Handler.Type)(
handler.RouterOption(router),
handler.AutherOption(auther),
handler.AuthOption(parseAuth(cfg.Handler.Auth)),
handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)),
handler.TLSConfigOption(tlsConfig),
handler.LoggerOption(handlerLogger),
)
var h handler.Handler
if rf := registry.HandlerRegistry().Get(cfg.Handler.Type); rf != nil {
h = rf(
handler.RouterOption(router),
handler.AutherOption(auther),
handler.AuthOption(parseAuth(cfg.Handler.Auth)),
handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)),
handler.TLSConfigOption(tlsConfig),
handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)),
handler.LoggerOption(handlerLogger),
)
} else {
return nil, fmt.Errorf("unregistered handler: %s", cfg.Handler.Type)
}
if forwarder, ok := h.(handler.Forwarder); ok {
forwarder.Forward(parseForwarder(cfg.Forwarder))

2
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/gin-contrib/cors v1.3.1
github.com/gin-gonic/gin v1.7.7
github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe
github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b
github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7

4
go.sum
View File

@ -98,8 +98,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe h1:zYcwKOe9ceGpwin84bH7J/DRZ4g9MhU+xOsTMxqOuNw=
github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b h1:fWUPYFp0W/6GEhL0wrURGPQN2AQHhf4IZKiALJJOJh8=
github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=

View File

@ -130,6 +130,10 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
b := bufpool.Get(h.md.bufferSize)
defer bufpool.Put(b)
@ -152,6 +156,18 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
return nil
}
func (h *dnsHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {

View File

@ -77,6 +77,10 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
target := h.group.Next(ctx)
if target == nil {
err := errors.New("target not available")
@ -119,3 +123,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil
}
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -71,6 +71,10 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
target := h.group.Next(ctx)
if target == nil {
err := errors.New("target not available")
@ -113,3 +117,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
return nil
}
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -75,6 +75,10 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
log.Error(err)
@ -337,3 +341,15 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
resp.Write(conn)
return
}
func (h *httpHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -75,6 +75,10 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
v, ok := conn.(md.Metadatable)
if !ok || v == nil {
err := errors.New("wrong connection type")
@ -345,3 +349,15 @@ func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response)
_, err := io.Copy(flushWriter{w}, resp.Body)
return err
}
func (h *http2Handler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -74,6 +74,10 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
var dstAddr net.Addr
if h.md.tproxy {
@ -269,3 +273,15 @@ func (h *redirectHandler) getServerName(ctx context.Context, r io.Reader) (host
return
}
func (h *redirectHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -63,6 +63,10 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
dstAddr := conn.LocalAddr()
log = log.WithFields(map[string]any{
@ -92,3 +96,15 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
return nil
}
func (h *redirectHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -75,6 +75,10 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
if h.md.readTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
}
@ -145,3 +149,15 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
}
return ErrUnknownCmd
}
func (h *relayHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -76,6 +76,10 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
var hdr [dissector.RecordHeaderLen]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
log.Error(err)
@ -251,3 +255,15 @@ func (h *sniHandler) decodeServerName(s string) (string, error) {
}
return string(v), nil
}
func (h *sniHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -72,6 +72,10 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
if h.md.readTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
}
@ -150,3 +154,15 @@ func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *goso
// TODO: bind
return ErrUnimplemented
}
func (h *socks4Handler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -78,6 +78,10 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
if h.md.readTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
}
@ -113,3 +117,15 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl
return err
}
}
func (h *socks5Handler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -76,6 +76,10 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
if h.cipher != nil {
conn = ss.ShadowConn(h.cipher.StreamConn(conn), nil)
}
@ -117,3 +121,15 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H
return nil
}
func (h *ssHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -77,6 +77,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
pc, ok := conn.(net.PacketConn)
if ok {
if h.cipher != nil {
@ -186,3 +190,15 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn, log logger.Logger) (er
return <-errc
}
func (h *ssuHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}

View File

@ -66,6 +66,10 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
"local": conn.LocalAddr().String(),
})
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
}
switch cc := conn.(type) {
case *sshd_util.DirectForwardConn:
return h.handleDirectForward(ctx, cc, log)
@ -217,6 +221,18 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti
return nil
}
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}
func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) {
host, portString, err := net.SplitHostPort(addr.String())
if err != nil {

View File

@ -102,7 +102,7 @@ type connLimiter struct {
ipLimits map[string]ConnLimitGenerator
cidrLimits cidranger.Ranger
limits map[string]limiter.Limiter
mu sync.RWMutex
mu sync.Mutex
cancelFunc context.CancelFunc
options options
}

View File

@ -20,7 +20,7 @@ func (l *llimiter) Limit() int {
}
func (l *llimiter) Allow(n int) bool {
if atomic.AddInt64(&l.current, int64(n)) >= int64(l.limit) {
if atomic.AddInt64(&l.current, int64(n)) > int64(l.limit) {
if n > 0 {
atomic.AddInt64(&l.current, -int64(n))
}

44
limiter/rate/generator.go Normal file
View File

@ -0,0 +1,44 @@
package rate
import (
"github.com/go-gost/core/limiter/rate"
limiter "github.com/go-gost/core/limiter/rate"
)
type RateLimitGenerator interface {
Limiter() limiter.Limiter
}
type rateLimitGenerator struct {
r float64
}
func NewRateLimitGenerator(r float64) RateLimitGenerator {
return &rateLimitGenerator{
r: r,
}
}
func (p *rateLimitGenerator) Limiter() limiter.Limiter {
if p == nil || p.r <= 0 {
return nil
}
return NewLimiter(p.r, int(p.r)+1)
}
type rateLimitSingleGenerator struct {
limiter rate.Limiter
}
func NewRateLimitSingleGenerator(r float64) RateLimitGenerator {
p := &rateLimitSingleGenerator{}
if r > 0 {
p.limiter = NewLimiter(r, int(r)+1)
}
return p
}
func (p *rateLimitSingleGenerator) Limiter() limiter.Limiter {
return p.limiter
}

26
limiter/rate/limiter.go Normal file
View File

@ -0,0 +1,26 @@
package rate
import (
"time"
limiter "github.com/go-gost/core/limiter/rate"
"golang.org/x/time/rate"
)
type rlimiter struct {
limiter *rate.Limiter
}
func NewLimiter(r float64, b int) limiter.Limiter {
return &rlimiter{
limiter: rate.NewLimiter(rate.Limit(r), b),
}
}
func (l *rlimiter) Allow(n int) bool {
return l.limiter.AllowN(time.Now(), n)
}
func (l *rlimiter) Limit() float64 {
return float64(l.limiter.Limit())
}

353
limiter/rate/rate.go Normal file
View File

@ -0,0 +1,353 @@
package rate
import (
"bufio"
"context"
"io"
"net"
"sort"
"strconv"
"strings"
"sync"
"time"
limiter "github.com/go-gost/core/limiter/rate"
"github.com/go-gost/core/logger"
"github.com/go-gost/x/internal/loader"
"github.com/yl2chen/cidranger"
)
const (
GlobalLimitKey = "$"
IPLimitKey = "$$"
)
type limiterGroup struct {
limiters []limiter.Limiter
}
func newLimiterGroup(limiters ...limiter.Limiter) *limiterGroup {
sort.Slice(limiters, func(i, j int) bool {
return limiters[i].Limit() < limiters[j].Limit()
})
return &limiterGroup{limiters: limiters}
}
func (l *limiterGroup) Allow(n int) (b bool) {
b = true
for i := range l.limiters {
if v := l.limiters[i].Allow(n); !v {
b = false
}
}
return
}
func (l *limiterGroup) Limit() float64 {
if len(l.limiters) == 0 {
return 0
}
return l.limiters[0].Limit()
}
type options struct {
limits []string
fileLoader loader.Loader
redisLoader loader.Loader
period time.Duration
logger logger.Logger
}
type Option func(opts *options)
func LimitsOption(limits ...string) Option {
return func(opts *options) {
opts.limits = limits
}
}
func ReloadPeriodOption(period time.Duration) Option {
return func(opts *options) {
opts.period = period
}
}
func FileLoaderOption(fileLoader loader.Loader) Option {
return func(opts *options) {
opts.fileLoader = fileLoader
}
}
func RedisLoaderOption(redisLoader loader.Loader) Option {
return func(opts *options) {
opts.redisLoader = redisLoader
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *options) {
opts.logger = logger
}
}
type rateLimiter struct {
ipLimits map[string]RateLimitGenerator
cidrLimits cidranger.Ranger
limits map[string]limiter.Limiter
mu sync.Mutex
cancelFunc context.CancelFunc
options options
}
func NewRateLimiter(opts ...Option) limiter.RateLimiter {
var options options
for _, opt := range opts {
opt(&options)
}
ctx, cancel := context.WithCancel(context.TODO())
lim := &rateLimiter{
ipLimits: make(map[string]RateLimitGenerator),
cidrLimits: cidranger.NewPCTrieRanger(),
limits: make(map[string]limiter.Limiter),
options: options,
cancelFunc: cancel,
}
if err := lim.reload(ctx); err != nil {
options.logger.Warnf("reload: %v", err)
}
if lim.options.period > 0 {
go lim.periodReload(ctx)
}
return lim
}
func (l *rateLimiter) Limiter(key string) limiter.Limiter {
l.mu.Lock()
defer l.mu.Unlock()
if lim, ok := l.limits[key]; ok {
return lim
}
var lims []limiter.Limiter
if ip := net.ParseIP(key); ip != nil {
found := false
if p := l.ipLimits[key]; p != nil {
if lim := p.Limiter(); lim != nil {
lims = append(lims, lim)
found = true
}
}
if !found {
if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 {
if v, _ := p[0].(*cidrLimitEntry); v != nil {
if lim := v.limit.Limiter(); lim != nil {
lims = append(lims, lim)
}
}
}
}
}
if len(lims) == 0 {
if p := l.ipLimits[IPLimitKey]; p != nil {
if lim := p.Limiter(); lim != nil {
lims = append(lims, lim)
}
}
}
if p := l.ipLimits[GlobalLimitKey]; p != nil {
if lim := p.Limiter(); lim != nil {
lims = append(lims, lim)
}
}
var lim limiter.Limiter
if len(lims) > 0 {
lim = newLimiterGroup(lims...)
}
l.limits[key] = lim
if lim != nil && l.options.logger != nil {
l.options.logger.Debugf("input limit for %s: %d", key, lim.Limit())
}
return lim
}
func (l *rateLimiter) periodReload(ctx context.Context) error {
period := l.options.period
if period < time.Second {
period = time.Second
}
ticker := time.NewTicker(period)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := l.reload(ctx); err != nil {
l.options.logger.Warnf("reload: %v", err)
// return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func (l *rateLimiter) reload(ctx context.Context) error {
v, err := l.load(ctx)
if err != nil {
return err
}
lines := append(l.options.limits, v...)
ipLimits := make(map[string]RateLimitGenerator)
cidrLimits := cidranger.NewPCTrieRanger()
for _, s := range lines {
key, limit := l.parseLimit(s)
if key == "" || limit <= 0 {
continue
}
switch key {
case GlobalLimitKey:
ipLimits[key] = NewRateLimitSingleGenerator(limit)
case IPLimitKey:
ipLimits[key] = NewRateLimitGenerator(limit)
default:
if ip := net.ParseIP(key); ip != nil {
ipLimits[key] = NewRateLimitGenerator(limit)
break
}
if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil {
cidrLimits.Insert(&cidrLimitEntry{
ipNet: *ipNet,
limit: NewRateLimitGenerator(limit),
})
}
}
}
l.mu.Lock()
defer l.mu.Unlock()
l.ipLimits = ipLimits
l.cidrLimits = cidrLimits
l.limits = make(map[string]limiter.Limiter)
return nil
}
func (l *rateLimiter) load(ctx context.Context) (patterns []string, err error) {
if l.options.fileLoader != nil {
if lister, ok := l.options.fileLoader.(loader.Lister); ok {
list, er := lister.List(ctx)
if er != nil {
l.options.logger.Warnf("file loader: %v", er)
}
for _, s := range list {
if line := l.parseLine(s); line != "" {
patterns = append(patterns, line)
}
}
} else {
r, er := l.options.fileLoader.Load(ctx)
if er != nil {
l.options.logger.Warnf("file loader: %v", er)
}
if v, _ := l.parsePatterns(r); v != nil {
patterns = append(patterns, v...)
}
}
}
if l.options.redisLoader != nil {
if lister, ok := l.options.redisLoader.(loader.Lister); ok {
list, er := lister.List(ctx)
if er != nil {
l.options.logger.Warnf("redis loader: %v", er)
}
patterns = append(patterns, list...)
} else {
r, er := l.options.redisLoader.Load(ctx)
if er != nil {
l.options.logger.Warnf("redis loader: %v", er)
}
if v, _ := l.parsePatterns(r); v != nil {
patterns = append(patterns, v...)
}
}
}
l.options.logger.Debugf("load items %d", len(patterns))
return
}
func (l *rateLimiter) parsePatterns(r io.Reader) (patterns []string, err error) {
if r == nil {
return
}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
if line := l.parseLine(scanner.Text()); line != "" {
patterns = append(patterns, line)
}
}
err = scanner.Err()
return
}
func (l *rateLimiter) parseLine(s string) string {
if n := strings.IndexByte(s, '#'); n >= 0 {
s = s[:n]
}
return strings.TrimSpace(s)
}
func (l *rateLimiter) parseLimit(s string) (key string, limit float64) {
s = strings.Replace(s, "\t", " ", -1)
s = strings.TrimSpace(s)
var ss []string
for _, v := range strings.Split(s, " ") {
if v != "" {
ss = append(ss, v)
}
}
if len(ss) < 2 {
return
}
key = ss[0]
limit, _ = strconv.ParseFloat(ss[1], 64)
return
}
func (l *rateLimiter) Close() error {
l.cancelFunc()
if l.options.fileLoader != nil {
l.options.fileLoader.Close()
}
if l.options.redisLoader != nil {
l.options.redisLoader.Close()
}
return nil
}
type cidrLimitEntry struct {
ipNet net.IPNet
limit RateLimitGenerator
}
func (p *cidrLimitEntry) Network() net.IPNet {
return p.ipNet
}

View File

@ -95,7 +95,7 @@ type trafficLimiter struct {
cidrLimits cidranger.Ranger
inLimits map[string]limiter.Limiter
outLimits map[string]limiter.Limiter
mu sync.RWMutex
mu sync.Mutex
cancelFunc context.CancelFunc
options options
}

View File

@ -20,9 +20,9 @@ var (
// serverConn is a server side Conn with metrics supported.
type serverConn struct {
net.Conn
rbuf bytes.Buffer
raddr string
rlimiter limiter.TrafficLimiter
rbuf bytes.Buffer
raddr string
limiter limiter.TrafficLimiter
}
func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn {
@ -31,19 +31,19 @@ func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn {
}
host, _, _ := net.SplitHostPort(c.RemoteAddr().String())
return &serverConn{
Conn: c,
rlimiter: rlimiter,
raddr: host,
Conn: c,
limiter: rlimiter,
raddr: host,
}
}
func (c *serverConn) Read(b []byte) (n int, err error) {
if c.rlimiter == nil ||
c.rlimiter.In(c.raddr) == nil {
if c.limiter == nil ||
c.limiter.In(c.raddr) == nil {
return c.Conn.Read(b)
}
limiter := c.rlimiter.In(c.raddr)
limiter := c.limiter.In(c.raddr)
if c.rbuf.Len() > 0 {
burst := len(b)
@ -70,12 +70,12 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
}
func (c *serverConn) Write(b []byte) (n int, err error) {
if c.rlimiter == nil ||
c.rlimiter.Out(c.raddr) == nil {
if c.limiter == nil ||
c.limiter.Out(c.raddr) == nil {
return c.Conn.Write(b)
}
limiter := c.rlimiter.Out(c.raddr)
limiter := c.limiter.Out(c.raddr)
nn := 0
for len(b) > 0 {
nn, err = c.Conn.Write(b[:limiter.Wait(context.Background(), len(b))])

View File

@ -2,6 +2,7 @@ package registry
import (
"github.com/go-gost/core/limiter/conn"
"github.com/go-gost/core/limiter/rate"
"github.com/go-gost/core/limiter/traffic"
)
@ -82,3 +83,38 @@ func (w *connLimiterWrapper) Limiter(key string) conn.Limiter {
}
return v.Limiter(key)
}
type rateLimiterRegistry struct {
registry
}
func (r *rateLimiterRegistry) Register(name string, v rate.RateLimiter) error {
return r.registry.Register(name, v)
}
func (r *rateLimiterRegistry) Get(name string) rate.RateLimiter {
if name != "" {
return &rateLimiterWrapper{name: name, r: r}
}
return nil
}
func (r *rateLimiterRegistry) get(name string) rate.RateLimiter {
if v := r.registry.Get(name); v != nil {
return v.(rate.RateLimiter)
}
return nil
}
type rateLimiterWrapper struct {
name string
r *rateLimiterRegistry
}
func (w *rateLimiterWrapper) Limiter(key string) rate.Limiter {
v := w.r.get(w.name)
if v == nil {
return nil
}
return v.Limiter(key)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/go-gost/core/chain"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/limiter/conn"
"github.com/go-gost/core/limiter/rate"
"github.com/go-gost/core/limiter/traffic"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/resolver"
@ -27,16 +28,18 @@ var (
dialerReg Registry[NewDialer] = &dialerRegistry{}
connectorReg Registry[NewConnector] = &connectorRegistry{}
serviceReg Registry[service.Service] = &serviceRegistry{}
chainReg Registry[chain.Chainer] = &chainRegistry{}
autherReg Registry[auth.Authenticator] = &autherRegistry{}
admissionReg Registry[admission.Admission] = &admissionRegistry{}
bypassReg Registry[bypass.Bypass] = &bypassRegistry{}
resolverReg Registry[resolver.Resolver] = &resolverRegistry{}
hostsReg Registry[hosts.HostMapper] = &hostsRegistry{}
recorderReg Registry[recorder.Recorder] = &recorderRegistry{}
serviceReg Registry[service.Service] = &serviceRegistry{}
chainReg Registry[chain.Chainer] = &chainRegistry{}
autherReg Registry[auth.Authenticator] = &autherRegistry{}
admissionReg Registry[admission.Admission] = &admissionRegistry{}
bypassReg Registry[bypass.Bypass] = &bypassRegistry{}
resolverReg Registry[resolver.Resolver] = &resolverRegistry{}
hostsReg Registry[hosts.HostMapper] = &hostsRegistry{}
recorderReg Registry[recorder.Recorder] = &recorderRegistry{}
trafficLimiterReg Registry[traffic.TrafficLimiter] = &trafficLimiterRegistry{}
connLimiterReg Registry[conn.ConnLimiter] = &connLimiterRegistry{}
rateLimiterReg Registry[rate.RateLimiter] = &rateLimiterRegistry{}
)
type Registry[T any] interface {
@ -138,3 +141,7 @@ func TrafficLimiterRegistry() Registry[traffic.TrafficLimiter] {
func ConnLimiterRegistry() Registry[conn.ConnLimiter] {
return connLimiterReg
}
func RateLimiterRegistry() Registry[rate.RateLimiter] {
return rateLimiterReg
}