From 91c12882f597f72bac85aa87eb43efaf8fe303e0 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 14 Sep 2022 00:15:19 +0800 Subject: [PATCH] add conn limiter --- api/config_conn_limiter.go | 166 ++++++++++ api/config_limiter.go | 16 +- api/service.go | 3 + config/config.go | 13 +- config/parsing/parse.go | 92 ++++-- config/parsing/service.go | 3 +- connector/socks/v5/connector.go | 17 +- go.mod | 2 +- go.sum | 4 +- limiter/conn/conn.go | 360 ++++++++++++++++++++++ limiter/conn/generator.go | 42 +++ limiter/conn/limiter.go | 30 ++ limiter/conn/wrapper/conn.go | 43 +++ limiter/conn/wrapper/listener.go | 40 +++ limiter/generator.go | 61 ---- limiter/traffic/generator.go | 61 ++++ limiter/{ => traffic}/limiter.go | 4 +- limiter/{rate.go => traffic/traffic.go} | 44 +-- limiter/{ => traffic}/wrapper/conn.go | 52 ++-- limiter/{ => traffic}/wrapper/listener.go | 12 +- listener/dns/listener.go | 4 +- listener/ftcp/listener.go | 4 +- listener/grpc/listener.go | 6 +- listener/http2/h2/listener.go | 6 +- listener/http2/listener.go | 6 +- listener/http3/listener.go | 4 +- listener/icmp/listener.go | 4 +- listener/kcp/listener.go | 4 +- listener/mtls/listener.go | 6 +- listener/mws/listener.go | 6 +- listener/obfs/http/listener.go | 6 +- listener/obfs/tls/listener.go | 6 +- listener/pht/listener.go | 4 +- listener/quic/listener.go | 4 +- listener/redirect/tcp/listener.go | 6 +- listener/redirect/udp/listener.go | 4 +- listener/rtcp/listener.go | 6 +- listener/rudp/listener.go | 4 +- listener/ssh/listener.go | 6 +- listener/sshd/listener.go | 6 +- listener/tap/listener.go | 4 +- listener/tcp/listener.go | 6 +- listener/tls/listener.go | 6 +- listener/tun/listener.go | 4 +- listener/udp/listener.go | 4 +- listener/ws/listener.go | 6 +- registry/limiter.go | 58 +++- registry/registry.go | 30 +- 48 files changed, 1041 insertions(+), 244 deletions(-) create mode 100644 api/config_conn_limiter.go create mode 100644 limiter/conn/conn.go create mode 100644 limiter/conn/generator.go create mode 100644 limiter/conn/limiter.go create mode 100644 limiter/conn/wrapper/conn.go create mode 100644 limiter/conn/wrapper/listener.go delete mode 100644 limiter/generator.go create mode 100644 limiter/traffic/generator.go rename limiter/{ => traffic}/limiter.go (86%) rename limiter/{rate.go => traffic/traffic.go} (86%) rename limiter/{ => traffic}/wrapper/conn.go (84%) rename limiter/{ => traffic}/wrapper/listener.go (51%) diff --git a/api/config_conn_limiter.go b/api/config_conn_limiter.go new file mode 100644 index 0000000..bbfd931 --- /dev/null +++ b/api/config_conn_limiter.go @@ -0,0 +1,166 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/go-gost/x/config" + "github.com/go-gost/x/config/parsing" + "github.com/go-gost/x/registry" +) + +// swagger:parameters createConnLimiterRequest +type createConnLimiterRequest struct { + // in: body + Data config.LimiterConfig `json:"data"` +} + +// successful operation. +// swagger:response createConnLimiterResponse +type createConnLimiterResponse struct { + Data Response +} + +func createConnLimiter(ctx *gin.Context) { + // swagger:route POST /config/climiters Limiter createConnLimiterRequest + // + // Create a new conn limiter, the name of limiter must be unique in limiter list. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: createConnLimiterResponse + + var req createConnLimiterRequest + ctx.ShouldBindJSON(&req.Data) + + if req.Data.Name == "" { + writeError(ctx, ErrInvalid) + return + } + + v := parsing.ParseConnLimiter(&req.Data) + + if err := registry.ConnLimiterRegistry().Register(req.Data.Name, v); err != nil { + writeError(ctx, ErrDup) + return + } + + cfg := config.Global() + cfg.CLimiters = append(cfg.CLimiters, &req.Data) + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters updateConnLimiterRequest +type updateConnLimiterRequest struct { + // in: path + // required: true + Limiter string `uri:"limiter" json:"limiter"` + // in: body + Data config.LimiterConfig `json:"data"` +} + +// successful operation. +// swagger:response updateConnLimiterResponse +type updateConnLimiterResponse struct { + Data Response +} + +func updateConnLimiter(ctx *gin.Context) { + // swagger:route PUT /config/climiters/{limiter} Limiter updateConnLimiterRequest + // + // Update conn limiter by name, the limiter must already exist. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: updateConnLimiterResponse + + var req updateConnLimiterRequest + ctx.ShouldBindUri(&req) + ctx.ShouldBindJSON(&req.Data) + + if !registry.ConnLimiterRegistry().IsRegistered(req.Limiter) { + writeError(ctx, ErrNotFound) + return + } + + req.Data.Name = req.Limiter + + v := parsing.ParseConnLimiter(&req.Data) + + registry.ConnLimiterRegistry().Unregister(req.Limiter) + + if err := registry.ConnLimiterRegistry().Register(req.Limiter, v); err != nil { + writeError(ctx, ErrDup) + return + } + + cfg := config.Global() + for i := range cfg.Limiters { + if cfg.Limiters[i].Name == req.Limiter { + cfg.Limiters[i] = &req.Data + break + } + } + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters deleteConnLimiterRequest +type deleteConnLimiterRequest struct { + // in: path + // required: true + Limiter string `uri:"Limiter" json:"Limiter"` +} + +// successful operation. +// swagger:response deleteConnLimiterResponse +type deleteConnLimiterResponse struct { + Data Response +} + +func deleteConnLimiter(ctx *gin.Context) { + // swagger:route DELETE /config/climiters/{limiter} Limiter deleteConnLimiterRequest + // + // Delete conn limiter by name. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: deleteConnLimiterResponse + + var req deleteConnLimiterRequest + ctx.ShouldBindUri(&req) + + if !registry.ConnLimiterRegistry().IsRegistered(req.Limiter) { + writeError(ctx, ErrNotFound) + return + } + registry.ConnLimiterRegistry().Unregister(req.Limiter) + + cfg := config.Global() + limiteres := cfg.Limiters + cfg.Limiters = nil + for _, s := range limiteres { + if s.Name == req.Limiter { + continue + } + cfg.Limiters = append(cfg.Limiters, s) + } + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} diff --git a/api/config_limiter.go b/api/config_limiter.go index c3abfd5..3c94e02 100644 --- a/api/config_limiter.go +++ b/api/config_limiter.go @@ -40,9 +40,9 @@ func createLimiter(ctx *gin.Context) { return } - v := parsing.ParseRateLimiter(&req.Data) + v := parsing.ParseTrafficLimiter(&req.Data) - if err := registry.RateLimiterRegistry().Register(req.Data.Name, v); err != nil { + if err := registry.TrafficLimiterRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) return } @@ -86,18 +86,18 @@ func updateLimiter(ctx *gin.Context) { ctx.ShouldBindUri(&req) ctx.ShouldBindJSON(&req.Data) - if !registry.RateLimiterRegistry().IsRegistered(req.Limiter) { + if !registry.TrafficLimiterRegistry().IsRegistered(req.Limiter) { writeError(ctx, ErrNotFound) return } req.Data.Name = req.Limiter - v := parsing.ParseRateLimiter(&req.Data) + v := parsing.ParseTrafficLimiter(&req.Data) - registry.RateLimiterRegistry().Unregister(req.Limiter) + registry.TrafficLimiterRegistry().Unregister(req.Limiter) - if err := registry.RateLimiterRegistry().Register(req.Limiter, v); err != nil { + if err := registry.TrafficLimiterRegistry().Register(req.Limiter, v); err != nil { writeError(ctx, ErrDup) return } @@ -143,11 +143,11 @@ func deleteLimiter(ctx *gin.Context) { var req deleteLimiterRequest ctx.ShouldBindUri(&req) - if !registry.RateLimiterRegistry().IsRegistered(req.Limiter) { + if !registry.TrafficLimiterRegistry().IsRegistered(req.Limiter) { writeError(ctx, ErrNotFound) return } - registry.RateLimiterRegistry().Unregister(req.Limiter) + registry.TrafficLimiterRegistry().Unregister(req.Limiter) cfg := config.Global() limiteres := cfg.Limiters diff --git a/api/service.go b/api/service.go index bf41ba0..d25bb8a 100644 --- a/api/service.go +++ b/api/service.go @@ -134,4 +134,7 @@ func registerConfig(config *gin.RouterGroup) { config.PUT("/limiters/:limiter", updateLimiter) config.DELETE("/limiters/:limiter", deleteLimiter) + config.POST("/climiters", createConnLimiter) + config.PUT("/climiters/:limiter", updateConnLimiter) + config.DELETE("/climiters/:limiter", deleteConnLimiter) } diff --git a/config/config.go b/config/config.go index c4fbee8..da594f2 100644 --- a/config/config.go +++ b/config/config.go @@ -185,22 +185,13 @@ type RecorderObject struct { } type LimiterConfig struct { - Name string `json:"name"` - Rate *RateLimiterConfig `yaml:"rate" json:"rate"` -} - -type RateLimiterConfig struct { + Name string `json:"name"` Limits []string `yaml:",omitempty" json:"limits,omitempty"` Reload time.Duration `yaml:",omitempty" json:"reload,omitempty"` File *FileLoader `yaml:",omitempty" json:"file,omitempty"` Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` } -type LimitConfig struct { - In string `yaml:",omitempty" json:"in,omitempty"` - Out string `yaml:",omitempty" json:"out,omitempty"` -} - type ListenerConfig struct { Type string `json:"type"` Chain string `yaml:",omitempty" json:"chain,omitempty"` @@ -263,6 +254,7 @@ type ServiceConfig struct { Resolver string `yaml:",omitempty" json:"resolver,omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"` Limiter string `yaml:",omitempty" json:"limiter,omitempty"` + CLimiter string `yaml:"climiter,omitempty" json:"limiter,omitempty"` Recorders []*RecorderObject `yaml:",omitempty" json:"recorders,omitempty"` Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"` Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"` @@ -318,6 +310,7 @@ type Config struct { Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"` Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"` + CLimiters []*LimiterConfig `yaml:"climiters,omitempty" json:"climiters,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` diff --git a/config/parsing/parse.go b/config/parsing/parse.go index e050f35..9343d7c 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -9,7 +9,8 @@ import ( "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" "github.com/go-gost/core/hosts" - "github.com/go-gost/core/limiter" + "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" "github.com/go-gost/core/resolver" @@ -20,7 +21,8 @@ import ( "github.com/go-gost/x/config" xhosts "github.com/go-gost/x/hosts" "github.com/go-gost/x/internal/loader" - xlimiter "github.com/go-gost/x/limiter" + xconn "github.com/go-gost/x/limiter/conn" + xtraffic "github.com/go-gost/x/limiter/traffic" xrecorder "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" resolver_impl "github.com/go-gost/x/resolver" @@ -327,42 +329,82 @@ func defaultChainSelector() selector.Selector[chain.Chainer] { ) } -func ParseRateLimiter(cfg *config.LimiterConfig) (lim limiter.RateLimiter) { - if cfg == nil || cfg.Rate == nil { +func ParseTrafficLimiter(cfg *config.LimiterConfig) (lim traffic.TrafficLimiter) { + if cfg == nil { return nil } - var opts []xlimiter.Option + var opts []xtraffic.Option - if cfg.Rate.File != nil && cfg.Rate.File.Path != "" { - opts = append(opts, xlimiter.FileLoaderOption(loader.FileLoader(cfg.Rate.File.Path))) + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xtraffic.FileLoaderOption(loader.FileLoader(cfg.File.Path))) } - if cfg.Rate.Redis != nil && cfg.Rate.Redis.Addr != "" { - switch cfg.Rate.Redis.Type { + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { case "list": // redis list - opts = append(opts, xlimiter.RedisLoaderOption(loader.RedisListLoader( - cfg.Rate.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Rate.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Rate.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Rate.Redis.Key), + opts = append(opts, xtraffic.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), ))) default: // redis set - opts = append(opts, xlimiter.RedisLoaderOption(loader.RedisSetLoader( - cfg.Rate.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Rate.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Rate.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Rate.Redis.Key), + opts = append(opts, xtraffic.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), ))) } } opts = append(opts, - xlimiter.LimitsOption(cfg.Rate.Limits...), - xlimiter.ReloadPeriodOption(cfg.Rate.Reload), - xlimiter.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "limiter", - "hosts": cfg.Name, + xtraffic.LimitsOption(cfg.Limits...), + xtraffic.ReloadPeriodOption(cfg.Reload), + xtraffic.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": cfg.Name, })), ) - return xlimiter.NewRateLimiter(opts...) + return xtraffic.NewTrafficLimiter(opts...) +} + +func ParseConnLimiter(cfg *config.LimiterConfig) (lim conn.ConnLimiter) { + if cfg == nil { + return nil + } + + var opts []xconn.Option + + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xconn.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // redis list + opts = append(opts, xconn.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis set + opts = append(opts, xconn.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + opts = append(opts, + xconn.LimitsOption(cfg.Limits...), + xconn.ReloadPeriodOption(cfg.Reload), + xconn.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": cfg.Name, + })), + ) + + return xconn.NewConnLimiter(opts...) } diff --git a/config/parsing/service.go b/config/parsing/service.go index 087fd38..a1baba1 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -98,7 +98,8 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { listener.TLSConfigOption(tlsConfig), listener.AdmissionOption(admission.AdmissionGroup(admissions...)), listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)), - listener.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.Limiter)), + listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)), + listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)), listener.LoggerOption(listenerLogger), listener.ServiceOption(cfg.Name), listener.ProxyProtocolOption(ppv), diff --git a/connector/socks/v5/connector.go b/connector/socks/v5/connector.go index 846601c..9914553 100644 --- a/connector/socks/v5/connector.go +++ b/connector/socks/v5/connector.go @@ -100,9 +100,14 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a defer conn.SetDeadline(time.Time{}) } + var cOpts connector.ConnectOptions + for _, opt := range opts { + opt(&cOpts) + } + switch network { case "udp", "udp4", "udp6": - return c.connectUDP(ctx, conn, network, address, log) + return c.connectUDP(ctx, conn, network, address, log, &cOpts) case "tcp", "tcp4", "tcp6": if _, ok := conn.(net.PacketConn); ok { err := fmt.Errorf("tcp over udp is unsupported") @@ -144,7 +149,7 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a return conn, nil } -func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (net.Conn, error) { +func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network, address string, log logger.Logger, opts *connector.ConnectOptions) (net.Conn, error) { addr, err := net.ResolveUDPAddr(network, address) if err != nil { log.Error(err) @@ -152,7 +157,7 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network } if c.md.relay == "udp" { - return c.relayUDP(ctx, conn, addr, log) + return c.relayUDP(ctx, conn, addr, log, opts) } req := gosocks5.NewRequest(socks.CmdUDPTun, nil) @@ -176,7 +181,7 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network return socks.UDPTunClientConn(conn, addr), nil } -func (c *socks5Connector) relayUDP(ctx context.Context, conn net.Conn, addr net.Addr, log logger.Logger) (net.Conn, error) { +func (c *socks5Connector) relayUDP(ctx context.Context, conn net.Conn, addr net.Addr, log logger.Logger, opts *connector.ConnectOptions) (net.Conn, error) { req := gosocks5.NewRequest(gosocks5.CmdUdp, nil) log.Trace(req) if err := req.Write(conn); err != nil { @@ -191,11 +196,13 @@ func (c *socks5Connector) relayUDP(ctx context.Context, conn net.Conn, addr net. } log.Trace(reply) + log.Debugf("bind on: %v", reply.Addr) + if reply.Rep != gosocks5.Succeeded { return nil, errors.New("get socks5 UDP tunnel failure") } - cc, err := (&net.Dialer{}).DialContext(ctx, "udp", reply.Addr.String()) + cc, err := opts.NetDialer.Dial(ctx, "udp", reply.Addr.String()) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 5592363..357c65c 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.7.7 - github.com/go-gost/core v0.0.0-20220908143917-e7a104651a75 + github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 diff --git a/go.sum b/go.sum index 9887f4d..015f8a4 100644 --- a/go.sum +++ b/go.sum @@ -98,8 +98,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20220908143917-e7a104651a75 h1:8DoQErtmgR9pRajWTswswLgaqOprJtkz/iC+2oOe24g= -github.com/go-gost/core v0.0.0-20220908143917-e7a104651a75/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= +github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe h1:zYcwKOe9ceGpwin84bH7J/DRZ4g9MhU+xOsTMxqOuNw= +github.com/go-gost/core v0.0.0-20220913161420-45b7ac2021fe/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= diff --git a/limiter/conn/conn.go b/limiter/conn/conn.go new file mode 100644 index 0000000..eaf8619 --- /dev/null +++ b/limiter/conn/conn.go @@ -0,0 +1,360 @@ +package conn + +import ( + "bufio" + "context" + "io" + "net" + "sort" + "strconv" + "strings" + "sync" + "time" + + limiter "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/internal/loader" + "github.com/yl2chen/cidranger" +) + +const ( + GlobalLimitKey = "$" + IPLimitKey = "$$" +) + +type limiterGroup struct { + limiters []limiter.Limiter +} + +func newLimiterGroup(limiters ...limiter.Limiter) *limiterGroup { + sort.Slice(limiters, func(i, j int) bool { + return limiters[i].Limit() < limiters[j].Limit() + }) + return &limiterGroup{limiters: limiters} +} + +func (l *limiterGroup) Allow(n int) (b bool) { + var i int + + for i = range l.limiters { + if b = l.limiters[i].Allow(n); !b { + break + } + } + if !b && i > 0 && n > 0 { + for i := range l.limiters[:i] { + l.limiters[i].Allow(-n) + } + } + + return +} + +func (l *limiterGroup) Limit() int { + if len(l.limiters) == 0 { + return 0 + } + + return l.limiters[0].Limit() +} + +type options struct { + limits []string + fileLoader loader.Loader + redisLoader loader.Loader + period time.Duration + logger logger.Logger +} + +type Option func(opts *options) + +func LimitsOption(limits ...string) Option { + return func(opts *options) { + opts.limits = limits + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type connLimiter struct { + ipLimits map[string]ConnLimitGenerator + cidrLimits cidranger.Ranger + limits map[string]limiter.Limiter + mu sync.RWMutex + cancelFunc context.CancelFunc + options options +} + +func NewConnLimiter(opts ...Option) limiter.ConnLimiter { + var options options + for _, opt := range opts { + opt(&options) + } + + ctx, cancel := context.WithCancel(context.TODO()) + lim := &connLimiter{ + ipLimits: make(map[string]ConnLimitGenerator), + cidrLimits: cidranger.NewPCTrieRanger(), + limits: make(map[string]limiter.Limiter), + options: options, + cancelFunc: cancel, + } + + if err := lim.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) + } + if lim.options.period > 0 { + go lim.periodReload(ctx) + } + return lim +} + +func (l *connLimiter) Limiter(key string) limiter.Limiter { + l.mu.Lock() + defer l.mu.Unlock() + + if lim, ok := l.limits[key]; ok { + return lim + } + + var lims []limiter.Limiter + + if ip := net.ParseIP(key); ip != nil { + found := false + if p := l.ipLimits[key]; p != nil { + if lim := p.Limiter(); lim != nil { + lims = append(lims, lim) + found = true + } + } + if !found { + if p, _ := l.cidrLimits.ContainingNetworks(ip); len(p) > 0 { + if v, _ := p[0].(*cidrLimitEntry); v != nil { + if lim := v.limit.Limiter(); lim != nil { + lims = append(lims, lim) + } + } + } + } + } + + if len(lims) == 0 { + if p := l.ipLimits[IPLimitKey]; p != nil { + if lim := p.Limiter(); lim != nil { + lims = append(lims, lim) + } + } + } + + if p := l.ipLimits[GlobalLimitKey]; p != nil { + if lim := p.Limiter(); lim != nil { + lims = append(lims, lim) + } + } + + var lim limiter.Limiter + if len(lims) > 0 { + lim = newLimiterGroup(lims...) + } + l.limits[key] = lim + + if lim != nil && l.options.logger != nil { + l.options.logger.Debugf("conn limit for %s: %d", key, lim.Limit()) + } + + return lim +} + +func (l *connLimiter) periodReload(ctx context.Context) error { + period := l.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := l.reload(ctx); err != nil { + l.options.logger.Warnf("reload: %v", err) + // return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (l *connLimiter) reload(ctx context.Context) error { + v, err := l.load(ctx) + if err != nil { + return err + } + + lines := append(l.options.limits, v...) + + ipLimits := make(map[string]ConnLimitGenerator) + cidrLimits := cidranger.NewPCTrieRanger() + + for _, s := range lines { + key, limit := l.parseLimit(s) + if key == "" || limit <= 0 { + continue + } + switch key { + case GlobalLimitKey: + ipLimits[key] = NewConnLimitSingleGenerator(limit) + case IPLimitKey: + ipLimits[key] = NewConnLimitGenerator(limit) + default: + if ip := net.ParseIP(key); ip != nil { + ipLimits[key] = NewConnLimitGenerator(limit) + break + } + if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil { + cidrLimits.Insert(&cidrLimitEntry{ + ipNet: *ipNet, + limit: NewConnLimitGenerator(limit), + }) + } + } + } + + l.mu.Lock() + defer l.mu.Unlock() + + l.ipLimits = ipLimits + l.cidrLimits = cidrLimits + l.limits = make(map[string]limiter.Limiter) + + return nil +} + +func (l *connLimiter) load(ctx context.Context) (patterns []string, err error) { + if l.options.fileLoader != nil { + if lister, ok := l.options.fileLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + l.options.logger.Warnf("file loader: %v", er) + } + for _, s := range list { + if line := l.parseLine(s); line != "" { + patterns = append(patterns, line) + } + } + } else { + r, er := l.options.fileLoader.Load(ctx) + if er != nil { + l.options.logger.Warnf("file loader: %v", er) + } + if v, _ := l.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } + } + } + if l.options.redisLoader != nil { + if lister, ok := l.options.redisLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + l.options.logger.Warnf("redis loader: %v", er) + } + patterns = append(patterns, list...) + } else { + r, er := l.options.redisLoader.Load(ctx) + if er != nil { + l.options.logger.Warnf("redis loader: %v", er) + } + if v, _ := l.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } + } + } + + l.options.logger.Debugf("load items %d", len(patterns)) + return +} + +func (l *connLimiter) parsePatterns(r io.Reader) (patterns []string, err error) { + if r == nil { + return + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + if line := l.parseLine(scanner.Text()); line != "" { + patterns = append(patterns, line) + } + } + + err = scanner.Err() + return +} + +func (l *connLimiter) parseLine(s string) string { + if n := strings.IndexByte(s, '#'); n >= 0 { + s = s[:n] + } + return strings.TrimSpace(s) +} + +func (l *connLimiter) parseLimit(s string) (key string, limit int) { + s = strings.Replace(s, "\t", " ", -1) + s = strings.TrimSpace(s) + var ss []string + for _, v := range strings.Split(s, " ") { + if v != "" { + ss = append(ss, v) + } + } + if len(ss) < 2 { + return + } + + key = ss[0] + limit, _ = strconv.Atoi(ss[1]) + + return +} + +func (l *connLimiter) Close() error { + l.cancelFunc() + if l.options.fileLoader != nil { + l.options.fileLoader.Close() + } + if l.options.redisLoader != nil { + l.options.redisLoader.Close() + } + return nil +} + +type cidrLimitEntry struct { + ipNet net.IPNet + limit ConnLimitGenerator +} + +func (p *cidrLimitEntry) Network() net.IPNet { + return p.ipNet +} diff --git a/limiter/conn/generator.go b/limiter/conn/generator.go new file mode 100644 index 0000000..07dd4e1 --- /dev/null +++ b/limiter/conn/generator.go @@ -0,0 +1,42 @@ +package conn + +import ( + limiter "github.com/go-gost/core/limiter/conn" +) + +type ConnLimitGenerator interface { + Limiter() limiter.Limiter +} + +type connLimitGenerator struct { + n int +} + +func NewConnLimitGenerator(n int) ConnLimitGenerator { + return &connLimitGenerator{ + n: n, + } +} + +func (p *connLimitGenerator) Limiter() limiter.Limiter { + if p == nil || p.n <= 0 { + return nil + } + return NewLimiter(p.n) +} + +type connLimitSingleGenerator struct { + limiter limiter.Limiter +} + +func NewConnLimitSingleGenerator(n int) ConnLimitGenerator { + p := &connLimitSingleGenerator{} + if n > 0 { + p.limiter = NewLimiter(n) + } + return p +} + +func (p *connLimitSingleGenerator) Limiter() limiter.Limiter { + return p.limiter +} diff --git a/limiter/conn/limiter.go b/limiter/conn/limiter.go new file mode 100644 index 0000000..0367555 --- /dev/null +++ b/limiter/conn/limiter.go @@ -0,0 +1,30 @@ +package conn + +import ( + "sync/atomic" + + limiter "github.com/go-gost/core/limiter/conn" +) + +type llimiter struct { + limit int + current int64 +} + +func NewLimiter(n int) limiter.Limiter { + return &llimiter{limit: n} +} + +func (l *llimiter) Limit() int { + return l.limit +} + +func (l *llimiter) Allow(n int) bool { + if atomic.AddInt64(&l.current, int64(n)) >= int64(l.limit) { + if n > 0 { + atomic.AddInt64(&l.current, -int64(n)) + } + return false + } + return true +} diff --git a/limiter/conn/wrapper/conn.go b/limiter/conn/wrapper/conn.go new file mode 100644 index 0000000..c519a80 --- /dev/null +++ b/limiter/conn/wrapper/conn.go @@ -0,0 +1,43 @@ +package wrapper + +import ( + "errors" + "net" + "syscall" + + limiter "github.com/go-gost/core/limiter/conn" +) + +var ( + errUnsupport = errors.New("unsupported operation") +) + +// serverConn is a server side Conn with metrics supported. +type serverConn struct { + net.Conn + limiter limiter.Limiter +} + +func WrapConn(limiter limiter.Limiter, c net.Conn) net.Conn { + if limiter == nil { + return c + } + return &serverConn{ + Conn: c, + limiter: limiter, + } +} + +func (c *serverConn) SyscallConn() (rc syscall.RawConn, err error) { + if sc, ok := c.Conn.(syscall.Conn); ok { + rc, err = sc.SyscallConn() + return + } + err = errUnsupport + return +} + +func (c *serverConn) Close() error { + c.limiter.Allow(-1) + return c.Conn.Close() +} diff --git a/limiter/conn/wrapper/listener.go b/limiter/conn/wrapper/listener.go new file mode 100644 index 0000000..d77d84a --- /dev/null +++ b/limiter/conn/wrapper/listener.go @@ -0,0 +1,40 @@ +package wrapper + +import ( + "net" + + limiter "github.com/go-gost/core/limiter/conn" +) + +type listener struct { + net.Listener + limiter limiter.ConnLimiter +} + +func WrapListener(limiter limiter.ConnLimiter, ln net.Listener) net.Listener { + if limiter == nil { + return ln + } + + return &listener{ + limiter: limiter, + Listener: ln, + } +} + +func (ln *listener) Accept() (net.Conn, error) { + c, err := ln.Listener.Accept() + if err != nil { + return nil, err + } + + host, _, _ := net.SplitHostPort(c.RemoteAddr().String()) + if lim := ln.limiter.Limiter(host); lim != nil { + if lim.Allow(1) { + return WrapConn(lim, c), nil + } + c.Close() + } + + return c, nil +} diff --git a/limiter/generator.go b/limiter/generator.go deleted file mode 100644 index ad2bf3c..0000000 --- a/limiter/generator.go +++ /dev/null @@ -1,61 +0,0 @@ -package limiter - -import ( - "github.com/go-gost/core/limiter" -) - -type RateLimitGenerator interface { - In() limiter.Limiter - Out() limiter.Limiter -} - -type rateLimitGenerator struct { - in int - out int -} - -func NewRateLimitGenerator(in, out int) RateLimitGenerator { - return &rateLimitGenerator{ - in: in, - out: out, - } -} - -func (p *rateLimitGenerator) In() limiter.Limiter { - if p == nil || p.in <= 0 { - return nil - } - return NewLimiter(p.in) -} - -func (p *rateLimitGenerator) Out() limiter.Limiter { - if p == nil || p.out <= 0 { - return nil - } - return NewLimiter(p.out) -} - -type rateLimitSingleGenerator struct { - in limiter.Limiter - out limiter.Limiter -} - -func NewRateLimitSingleGenerator(in, out int) RateLimitGenerator { - p := &rateLimitSingleGenerator{} - if in > 0 { - p.in = NewLimiter(in) - } - if out > 0 { - p.out = NewLimiter(out) - } - - return p -} - -func (p *rateLimitSingleGenerator) In() limiter.Limiter { - return p.in -} - -func (p *rateLimitSingleGenerator) Out() limiter.Limiter { - return p.out -} diff --git a/limiter/traffic/generator.go b/limiter/traffic/generator.go new file mode 100644 index 0000000..4eaff07 --- /dev/null +++ b/limiter/traffic/generator.go @@ -0,0 +1,61 @@ +package traffic + +import ( + limiter "github.com/go-gost/core/limiter/traffic" +) + +type TrafficLimitGenerator interface { + In() limiter.Limiter + Out() limiter.Limiter +} + +type trafficLimitGenerator struct { + in int + out int +} + +func NewTrafficLimitGenerator(in, out int) TrafficLimitGenerator { + return &trafficLimitGenerator{ + in: in, + out: out, + } +} + +func (p *trafficLimitGenerator) In() limiter.Limiter { + if p == nil || p.in <= 0 { + return nil + } + return NewLimiter(p.in) +} + +func (p *trafficLimitGenerator) Out() limiter.Limiter { + if p == nil || p.out <= 0 { + return nil + } + return NewLimiter(p.out) +} + +type trafficLimitSingleGenerator struct { + in limiter.Limiter + out limiter.Limiter +} + +func NewTrafficLimitSingleGenerator(in, out int) TrafficLimitGenerator { + p := &trafficLimitSingleGenerator{} + if in > 0 { + p.in = NewLimiter(in) + } + if out > 0 { + p.out = NewLimiter(out) + } + + return p +} + +func (p *trafficLimitSingleGenerator) In() limiter.Limiter { + return p.in +} + +func (p *trafficLimitSingleGenerator) Out() limiter.Limiter { + return p.out +} diff --git a/limiter/limiter.go b/limiter/traffic/limiter.go similarity index 86% rename from limiter/limiter.go rename to limiter/traffic/limiter.go index ea18409..7cbcdd8 100644 --- a/limiter/limiter.go +++ b/limiter/traffic/limiter.go @@ -1,9 +1,9 @@ -package limiter +package traffic import ( "context" - "github.com/go-gost/core/limiter" + limiter "github.com/go-gost/core/limiter/traffic" "golang.org/x/time/rate" ) diff --git a/limiter/rate.go b/limiter/traffic/traffic.go similarity index 86% rename from limiter/rate.go rename to limiter/traffic/traffic.go index dfbc71f..22de113 100644 --- a/limiter/rate.go +++ b/limiter/traffic/traffic.go @@ -1,4 +1,4 @@ -package limiter +package traffic import ( "bufio" @@ -11,7 +11,7 @@ import ( "time" "github.com/alecthomas/units" - "github.com/go-gost/core/limiter" + limiter "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" "github.com/yl2chen/cidranger" @@ -90,8 +90,8 @@ func LoggerOption(logger logger.Logger) Option { } } -type rateLimiter struct { - ipLimits map[string]RateLimitGenerator +type trafficLimiter struct { + ipLimits map[string]TrafficLimitGenerator cidrLimits cidranger.Ranger inLimits map[string]limiter.Limiter outLimits map[string]limiter.Limiter @@ -100,15 +100,15 @@ type rateLimiter struct { options options } -func NewRateLimiter(opts ...Option) limiter.RateLimiter { +func NewTrafficLimiter(opts ...Option) limiter.TrafficLimiter { var options options for _, opt := range opts { opt(&options) } ctx, cancel := context.WithCancel(context.TODO()) - lim := &rateLimiter{ - ipLimits: make(map[string]RateLimitGenerator), + lim := &trafficLimiter{ + ipLimits: make(map[string]TrafficLimitGenerator), cidrLimits: cidranger.NewPCTrieRanger(), inLimits: make(map[string]limiter.Limiter), outLimits: make(map[string]limiter.Limiter), @@ -125,7 +125,7 @@ func NewRateLimiter(opts ...Option) limiter.RateLimiter { return lim } -func (l *rateLimiter) In(key string) limiter.Limiter { +func (l *trafficLimiter) In(key string) limiter.Limiter { l.mu.Lock() defer l.mu.Unlock() @@ -178,7 +178,7 @@ func (l *rateLimiter) In(key string) limiter.Limiter { return lim } -func (l *rateLimiter) Out(key string) limiter.Limiter { +func (l *trafficLimiter) Out(key string) limiter.Limiter { l.mu.Lock() defer l.mu.Unlock() @@ -231,7 +231,7 @@ func (l *rateLimiter) Out(key string) limiter.Limiter { return lim } -func (l *rateLimiter) periodReload(ctx context.Context) error { +func (l *trafficLimiter) periodReload(ctx context.Context) error { period := l.options.period if period < time.Second { period = time.Second @@ -252,7 +252,7 @@ func (l *rateLimiter) periodReload(ctx context.Context) error { } } -func (l *rateLimiter) reload(ctx context.Context) error { +func (l *trafficLimiter) reload(ctx context.Context) error { v, err := l.load(ctx) if err != nil { return err @@ -260,7 +260,7 @@ func (l *rateLimiter) reload(ctx context.Context) error { lines := append(l.options.limits, v...) - ipLimits := make(map[string]RateLimitGenerator) + ipLimits := make(map[string]TrafficLimitGenerator) cidrLimits := cidranger.NewPCTrieRanger() for _, s := range lines { @@ -270,18 +270,18 @@ func (l *rateLimiter) reload(ctx context.Context) error { } switch key { case GlobalLimitKey: - ipLimits[key] = NewRateLimitSingleGenerator(in, out) + ipLimits[key] = NewTrafficLimitSingleGenerator(in, out) case ConnLimitKey: - ipLimits[key] = NewRateLimitGenerator(in, out) + ipLimits[key] = NewTrafficLimitGenerator(in, out) default: if ip := net.ParseIP(key); ip != nil { - ipLimits[key] = NewRateLimitGenerator(in, out) + ipLimits[key] = NewTrafficLimitGenerator(in, out) break } if _, ipNet, _ := net.ParseCIDR(key); ipNet != nil { cidrLimits.Insert(&cidrLimitEntry{ ipNet: *ipNet, - limit: NewRateLimitGenerator(in, out), + limit: NewTrafficLimitGenerator(in, out), }) } } @@ -298,7 +298,7 @@ func (l *rateLimiter) reload(ctx context.Context) error { return nil } -func (l *rateLimiter) load(ctx context.Context) (patterns []string, err error) { +func (l *trafficLimiter) load(ctx context.Context) (patterns []string, err error) { if l.options.fileLoader != nil { if lister, ok := l.options.fileLoader.(loader.Lister); ok { list, er := lister.List(ctx) @@ -342,7 +342,7 @@ func (l *rateLimiter) load(ctx context.Context) (patterns []string, err error) { return } -func (l *rateLimiter) parsePatterns(r io.Reader) (patterns []string, err error) { +func (l *trafficLimiter) parsePatterns(r io.Reader) (patterns []string, err error) { if r == nil { return } @@ -358,14 +358,14 @@ func (l *rateLimiter) parsePatterns(r io.Reader) (patterns []string, err error) return } -func (l *rateLimiter) parseLine(s string) string { +func (l *trafficLimiter) parseLine(s string) string { if n := strings.IndexByte(s, '#'); n >= 0 { s = s[:n] } return strings.TrimSpace(s) } -func (l *rateLimiter) parseLimit(s string) (key string, in, out int) { +func (l *trafficLimiter) parseLimit(s string) (key string, in, out int) { s = strings.Replace(s, "\t", " ", -1) s = strings.TrimSpace(s) var ss []string @@ -391,7 +391,7 @@ func (l *rateLimiter) parseLimit(s string) (key string, in, out int) { return } -func (l *rateLimiter) Close() error { +func (l *trafficLimiter) Close() error { l.cancelFunc() if l.options.fileLoader != nil { l.options.fileLoader.Close() @@ -404,7 +404,7 @@ func (l *rateLimiter) Close() error { type cidrLimitEntry struct { ipNet net.IPNet - limit RateLimitGenerator + limit TrafficLimitGenerator } func (p *cidrLimitEntry) Network() net.IPNet { diff --git a/limiter/wrapper/conn.go b/limiter/traffic/wrapper/conn.go similarity index 84% rename from limiter/wrapper/conn.go rename to limiter/traffic/wrapper/conn.go index 119c1b3..7dbaeae 100644 --- a/limiter/wrapper/conn.go +++ b/limiter/traffic/wrapper/conn.go @@ -8,7 +8,7 @@ import ( "net" "syscall" - "github.com/go-gost/core/limiter" + limiter "github.com/go-gost/core/limiter/traffic" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/udp" ) @@ -22,10 +22,10 @@ type serverConn struct { net.Conn rbuf bytes.Buffer raddr string - rlimiter limiter.RateLimiter + rlimiter limiter.TrafficLimiter } -func WrapConn(rlimiter limiter.RateLimiter, c net.Conn) net.Conn { +func WrapConn(rlimiter limiter.TrafficLimiter, c net.Conn) net.Conn { if rlimiter == nil { return c } @@ -100,16 +100,16 @@ func (c *serverConn) SyscallConn() (rc syscall.RawConn, err error) { type packetConn struct { net.PacketConn - rlimiter limiter.RateLimiter + limiter limiter.TrafficLimiter } -func WrapPacketConn(rlimiter limiter.RateLimiter, pc net.PacketConn) net.PacketConn { - if rlimiter == nil { +func WrapPacketConn(limiter limiter.TrafficLimiter, pc net.PacketConn) net.PacketConn { + if limiter == nil { return pc } return &packetConn{ PacketConn: pc, - rlimiter: rlimiter, + limiter: limiter, } } @@ -122,11 +122,11 @@ func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { host, _, _ := net.SplitHostPort(addr.String()) - if c.rlimiter == nil || c.rlimiter.In(host) == nil { + if c.limiter == nil || c.limiter.In(host) == nil { return } - limiter := c.rlimiter.In(host) + limiter := c.limiter.In(host) // discard when exceed the limit size. if limiter.Wait(context.Background(), n) < n { continue @@ -137,10 +137,10 @@ func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.rlimiter != nil { + if c.limiter != nil { host, _, _ := net.SplitHostPort(addr.String()) // discard when exceed the limit size. - if limiter := c.rlimiter.Out(host); limiter != nil && + if limiter := c.limiter.Out(host); limiter != nil && limiter.Wait(context.Background(), len(p)) < len(p) { n = len(p) return @@ -152,13 +152,13 @@ func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { type udpConn struct { net.PacketConn - rlimiter limiter.RateLimiter + limiter limiter.TrafficLimiter } -func WrapUDPConn(rlimiter limiter.RateLimiter, pc net.PacketConn) udp.Conn { +func WrapUDPConn(limiter limiter.TrafficLimiter, pc net.PacketConn) udp.Conn { return &udpConn{ PacketConn: pc, - rlimiter: rlimiter, + limiter: limiter, } } @@ -200,10 +200,10 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } host, _, _ := net.SplitHostPort(addr.String()) - if c.rlimiter == nil || c.rlimiter.In(host) == nil { + if c.limiter == nil || c.limiter.In(host) == nil { return } - limiter := c.rlimiter.In(host) + limiter := c.limiter.In(host) // discard when exceed the limit size. if limiter.Wait(context.Background(), n) < n { continue @@ -222,10 +222,10 @@ func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { host, _, _ := net.SplitHostPort(addr.String()) - if c.rlimiter == nil || c.rlimiter.In(host) == nil { + if c.limiter == nil || c.limiter.In(host) == nil { return } - limiter := c.rlimiter.In(host) + limiter := c.limiter.In(host) // discard when exceed the limit size. if limiter.Wait(context.Background(), n) < n { continue @@ -247,10 +247,10 @@ func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAd host, _, _ := net.SplitHostPort(addr.String()) - if c.rlimiter == nil || c.rlimiter.In(host) == nil { + if c.limiter == nil || c.limiter.In(host) == nil { return } - limiter := c.rlimiter.In(host) + limiter := c.limiter.In(host) // discard when exceed the limit size. if limiter.Wait(context.Background(), n) < n { continue @@ -272,10 +272,10 @@ func (c *udpConn) Write(b []byte) (n int, err error) { } func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.rlimiter != nil { + if c.limiter != nil { host, _, _ := net.SplitHostPort(addr.String()) // discard when exceed the limit size. - if limiter := c.rlimiter.Out(host); limiter != nil && + if limiter := c.limiter.Out(host); limiter != nil && limiter.Wait(context.Background(), len(p)) < len(p) { n = len(p) return @@ -287,10 +287,10 @@ func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { - if c.rlimiter != nil { + if c.limiter != nil { host, _, _ := net.SplitHostPort(addr.String()) // discard when exceed the limit size. - if limiter := c.rlimiter.Out(host); limiter != nil && + if limiter := c.limiter.Out(host); limiter != nil && limiter.Wait(context.Background(), len(b)) < len(b) { n = len(b) return @@ -306,10 +306,10 @@ func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { } func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { - if c.rlimiter != nil { + if c.limiter != nil { host, _, _ := net.SplitHostPort(addr.String()) // discard when exceed the limit size. - if limiter := c.rlimiter.Out(host); limiter != nil && + if limiter := c.limiter.Out(host); limiter != nil && limiter.Wait(context.Background(), len(b)) < len(b) { n = len(b) return diff --git a/limiter/wrapper/listener.go b/limiter/traffic/wrapper/listener.go similarity index 51% rename from limiter/wrapper/listener.go rename to limiter/traffic/wrapper/listener.go index 42a2c8a..0e4ec5b 100644 --- a/limiter/wrapper/listener.go +++ b/limiter/traffic/wrapper/listener.go @@ -3,21 +3,21 @@ package wrapper import ( "net" - "github.com/go-gost/core/limiter" + limiter "github.com/go-gost/core/limiter/traffic" ) type listener struct { net.Listener - rlimiter limiter.RateLimiter + limiter limiter.TrafficLimiter } -func WrapListener(rlimiter limiter.RateLimiter, ln net.Listener) net.Listener { - if rlimiter == nil { +func WrapListener(limiter limiter.TrafficLimiter, ln net.Listener) net.Listener { + if limiter == nil { return ln } return &listener{ - rlimiter: rlimiter, + limiter: limiter, Listener: ln, } } @@ -28,5 +28,5 @@ func (ln *listener) Accept() (net.Conn, error) { return nil, err } - return WrapConn(ln.rlimiter, c), nil + return WrapConn(ln.limiter, c), nil } diff --git a/listener/dns/listener.go b/listener/dns/listener.go index ea27ec4..53d8d3e 100644 --- a/listener/dns/listener.go +++ b/listener/dns/listener.go @@ -10,7 +10,7 @@ import ( "strings" admission "github.com/go-gost/x/admission/wrapper" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" @@ -118,7 +118,7 @@ func (l *dnsListener) Accept() (conn net.Conn, err error) { case conn = <-l.cqueue: conn = metrics.WrapConn(l.options.Service, conn) conn = admission.WrapConn(l.options.Admission, conn) - conn = limiter.WrapConn(l.options.RateLimiter, conn) + conn = limiter.WrapConn(l.options.TrafficLimiter, conn) case err, ok = <-l.errChan: if !ok { err = listener.ErrClosed diff --git a/listener/ftcp/listener.go b/listener/ftcp/listener.go index 3dfabf6..4bfa88c 100644 --- a/listener/ftcp/listener.go +++ b/listener/ftcp/listener.go @@ -9,7 +9,7 @@ import ( md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/xtaci/tcpraw" @@ -53,7 +53,7 @@ func (l *ftcpListener) Init(md md.Metadata) (err error) { } conn = metrics.WrapPacketConn(l.options.Service, conn) conn = admission.WrapPacketConn(l.options.Admission, conn) - conn = limiter.WrapPacketConn(l.options.RateLimiter, conn) + conn = limiter.WrapPacketConn(l.options.TrafficLimiter, conn) l.ln = udp.NewListener( conn, diff --git a/listener/grpc/listener.go b/listener/grpc/listener.go index 43ffe60..b69af97 100644 --- a/listener/grpc/listener.go +++ b/listener/grpc/listener.go @@ -11,7 +11,8 @@ import ( xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" pb "github.com/go-gost/x/internal/util/grpc/proto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "google.golang.org/grpc" @@ -58,7 +59,8 @@ func (l *grpcListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) var opts []grpc.ServerOption diff --git a/listener/http2/h2/listener.go b/listener/http2/h2/listener.go index cfceae9..8e2ddaa 100644 --- a/listener/http2/h2/listener.go +++ b/listener/http2/h2/listener.go @@ -14,7 +14,8 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "golang.org/x/net/http2" @@ -80,7 +81,8 @@ func (l *h2Listener) Init(md md.Metadata) (err error) { l.addr = ln.Addr() ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) if l.h2c { diff --git a/listener/http2/listener.go b/listener/http2/listener.go index b6ee6d8..16e043c 100644 --- a/listener/http2/listener.go +++ b/listener/http2/listener.go @@ -12,7 +12,8 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" mdx "github.com/go-gost/x/metadata" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" @@ -69,7 +70,8 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { l.addr = ln.Addr() ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = tls.NewListener( diff --git a/listener/http3/listener.go b/listener/http3/listener.go index 1f8e41f..97cf7f1 100644 --- a/listener/http3/listener.go +++ b/listener/http3/listener.go @@ -9,7 +9,7 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" pht_util "github.com/go-gost/x/internal/util/pht" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" @@ -79,7 +79,7 @@ func (l *http3Listener) Accept() (conn net.Conn, err error) { conn = metrics.WrapConn(l.options.Service, conn) conn = admission.WrapConn(l.options.Admission, conn) - conn = limiter.WrapConn(l.options.RateLimiter, conn) + conn = limiter.WrapConn(l.options.TrafficLimiter, conn) return conn, nil } diff --git a/listener/icmp/listener.go b/listener/icmp/listener.go index 23dad6d..1e9b0cb 100644 --- a/listener/icmp/listener.go +++ b/listener/icmp/listener.go @@ -9,7 +9,7 @@ import ( md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" icmp_pkg "github.com/go-gost/x/internal/util/icmp" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" @@ -58,7 +58,7 @@ func (l *icmpListener) Init(md md.Metadata) (err error) { conn = icmp_pkg.ServerConn(conn) conn = metrics.WrapPacketConn(l.options.Service, conn) conn = admission.WrapPacketConn(l.options.Admission, conn) - conn = limiter.WrapPacketConn(l.options.RateLimiter, conn) + conn = limiter.WrapPacketConn(l.options.TrafficLimiter, conn) config := &quic.Config{ KeepAlivePeriod: l.md.keepAlivePeriod, diff --git a/listener/kcp/listener.go b/listener/kcp/listener.go index 972179c..2638548 100644 --- a/listener/kcp/listener.go +++ b/listener/kcp/listener.go @@ -10,7 +10,7 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" kcp_util "github.com/go-gost/x/internal/util/kcp" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/xtaci/kcp-go/v5" @@ -76,7 +76,7 @@ func (l *kcpListener) Init(md md.Metadata) (err error) { conn = metrics.WrapUDPConn(l.options.Service, conn) conn = admission.WrapUDPConn(l.options.Admission, conn) - conn = limiter.WrapUDPConn(l.options.RateLimiter, conn) + conn = limiter.WrapUDPConn(l.options.TrafficLimiter, conn) ln, err := kcp.ServeConn( kcp_util.BlockCrypt(config.Key, config.Crypt, kcp_util.DefaultSalt), diff --git a/listener/mtls/listener.go b/listener/mtls/listener.go index 6cc728a..9d3d21e 100644 --- a/listener/mtls/listener.go +++ b/listener/mtls/listener.go @@ -11,7 +11,8 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/xtaci/smux" @@ -57,7 +58,8 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) { ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = tls.NewListener(ln, l.options.TLSConfig) diff --git a/listener/mws/listener.go b/listener/mws/listener.go index 8365c70..40a4008 100644 --- a/listener/mws/listener.go +++ b/listener/mws/listener.go @@ -14,7 +14,8 @@ import ( xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" ws_util "github.com/go-gost/x/internal/util/ws" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/gorilla/websocket" @@ -99,7 +100,8 @@ func (l *mwsListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) if l.tlsEnabled { diff --git a/listener/obfs/http/listener.go b/listener/obfs/http/listener.go index 4a4e190..9772049 100644 --- a/listener/obfs/http/listener.go +++ b/listener/obfs/http/listener.go @@ -8,7 +8,8 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" xnet "github.com/go-gost/x/internal/net" @@ -53,7 +54,8 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln diff --git a/listener/obfs/tls/listener.go b/listener/obfs/tls/listener.go index c7a96d7..1cba1e9 100644 --- a/listener/obfs/tls/listener.go +++ b/listener/obfs/tls/listener.go @@ -10,7 +10,8 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -52,7 +53,8 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln diff --git a/listener/pht/listener.go b/listener/pht/listener.go index 96339a2..06ae383 100644 --- a/listener/pht/listener.go +++ b/listener/pht/listener.go @@ -11,7 +11,7 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" pht_util "github.com/go-gost/x/internal/util/pht" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -92,7 +92,7 @@ func (l *phtListener) Accept() (conn net.Conn, err error) { } conn = metrics.WrapConn(l.options.Service, conn) conn = admission.WrapConn(l.options.Admission, conn) - conn = limiter.WrapConn(l.options.RateLimiter, conn) + conn = limiter.WrapConn(l.options.TrafficLimiter, conn) return } diff --git a/listener/quic/listener.go b/listener/quic/listener.go index 6ec3a4e..fd0997f 100644 --- a/listener/quic/listener.go +++ b/listener/quic/listener.go @@ -10,7 +10,7 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" quic_util "github.com/go-gost/x/internal/util/quic" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/lucas-clemente/quic-go" @@ -102,7 +102,7 @@ func (l *quicListener) Accept() (conn net.Conn, err error) { case conn = <-l.cqueue: conn = metrics.WrapConn(l.options.Service, conn) conn = admission.WrapConn(l.options.Admission, conn) - conn = limiter.WrapConn(l.options.RateLimiter, conn) + conn = limiter.WrapConn(l.options.TrafficLimiter, conn) case err, ok = <-l.errChan: if !ok { err = listener.ErrClosed diff --git a/listener/redirect/tcp/listener.go b/listener/redirect/tcp/listener.go index 1e98801..9b0523f 100644 --- a/listener/redirect/tcp/listener.go +++ b/listener/redirect/tcp/listener.go @@ -11,7 +11,8 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -60,7 +61,8 @@ func (l *redirectListener) Init(md md.Metadata) (err error) { ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.ln = ln return diff --git a/listener/redirect/udp/listener.go b/listener/redirect/udp/listener.go index abc1cf9..98f29bc 100644 --- a/listener/redirect/udp/listener.go +++ b/listener/redirect/udp/listener.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -55,7 +55,7 @@ func (l *redirectListener) Accept() (conn net.Conn, err error) { } conn = metrics.WrapConn(l.options.Service, conn) conn = admission.WrapConn(l.options.Admission, conn) - conn = limiter.WrapConn(l.options.RateLimiter, conn) + conn = limiter.WrapConn(l.options.TrafficLimiter, conn) return } diff --git a/listener/rtcp/listener.go b/listener/rtcp/listener.go index d41cd7c..e614f7b 100644 --- a/listener/rtcp/listener.go +++ b/listener/rtcp/listener.go @@ -10,7 +10,8 @@ import ( md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -79,7 +80,8 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) { } l.ln = metrics.WrapListener(l.options.Service, l.ln) l.ln = admission.WrapListener(l.options.Admission, l.ln) - l.ln = limiter.WrapListener(l.options.RateLimiter, l.ln) + l.ln = limiter.WrapListener(l.options.TrafficLimiter, l.ln) + l.ln = climiter.WrapListener(l.options.ConnLimiter, l.ln) } conn, err = l.ln.Accept() if err != nil { diff --git a/listener/rudp/listener.go b/listener/rudp/listener.go index 981dbcc..143b90a 100644 --- a/listener/rudp/listener.go +++ b/listener/rudp/listener.go @@ -10,7 +10,7 @@ import ( md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -92,7 +92,7 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) { if pc, ok := conn.(net.PacketConn); ok { uc := metrics.WrapUDPConn(l.options.Service, pc) uc = admission.WrapUDPConn(l.options.Admission, uc) - conn = limiter.WrapUDPConn(l.options.RateLimiter, uc) + conn = limiter.WrapUDPConn(l.options.TrafficLimiter, uc) } return diff --git a/listener/ssh/listener.go b/listener/ssh/listener.go index d82b373..9367baa 100644 --- a/listener/ssh/listener.go +++ b/listener/ssh/listener.go @@ -12,7 +12,8 @@ import ( xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" ssh_util "github.com/go-gost/x/internal/util/ssh" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "golang.org/x/crypto/ssh" @@ -59,7 +60,8 @@ func (l *sshListener) Init(md md.Metadata) (err error) { ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln diff --git a/listener/sshd/listener.go b/listener/sshd/listener.go index 2fd902e..c532f76 100644 --- a/listener/sshd/listener.go +++ b/listener/sshd/listener.go @@ -15,7 +15,8 @@ import ( "github.com/go-gost/x/internal/net/proxyproto" ssh_util "github.com/go-gost/x/internal/util/ssh" sshd_util "github.com/go-gost/x/internal/util/sshd" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "golang.org/x/crypto/ssh" @@ -68,7 +69,8 @@ func (l *sshdListener) Init(md md.Metadata) (err error) { ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln diff --git a/listener/tap/listener.go b/listener/tap/listener.go index fe1ed1c..311162b 100644 --- a/listener/tap/listener.go +++ b/listener/tap/listener.go @@ -7,7 +7,7 @@ import ( "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" xnet "github.com/go-gost/x/internal/net" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" mdx "github.com/go-gost/x/metadata" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" @@ -80,7 +80,7 @@ func (l *tapListener) Init(md mdata.Metadata) (err error) { raddr: &net.IPAddr{IP: ip}, } c = metrics.WrapConn(l.options.Service, c) - c = limiter.WrapConn(l.options.RateLimiter, c) + c = limiter.WrapConn(l.options.TrafficLimiter, c) c = withMetadata(mdx.NewMetadata(map[string]any{ "config": l.md.config, }), c) diff --git a/listener/tcp/listener.go b/listener/tcp/listener.go index cc3341d..638064b 100644 --- a/listener/tcp/listener.go +++ b/listener/tcp/listener.go @@ -10,7 +10,8 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -55,7 +56,8 @@ func (l *tcpListener) Init(md md.Metadata) (err error) { ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.ln = ln diff --git a/listener/tls/listener.go b/listener/tls/listener.go index 1492c37..259cc1c 100644 --- a/listener/tls/listener.go +++ b/listener/tls/listener.go @@ -11,7 +11,8 @@ import ( admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -53,7 +54,8 @@ func (l *tlsListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.ln = tls.NewListener(ln, l.options.TLSConfig) diff --git a/listener/tun/listener.go b/listener/tun/listener.go index 3b2c862..d0f7fd1 100644 --- a/listener/tun/listener.go +++ b/listener/tun/listener.go @@ -9,7 +9,7 @@ import ( "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" xnet "github.com/go-gost/x/internal/net" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" mdx "github.com/go-gost/x/metadata" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" @@ -89,7 +89,7 @@ func (l *tunListener) listenLoop() { cancel: cancel, } c = metrics.WrapConn(l.options.Service, c) - c = limiter.WrapConn(l.options.RateLimiter, c) + c = limiter.WrapConn(l.options.TrafficLimiter, c) c = withMetadata(mdx.NewMetadata(map[string]any{ "config": l.md.config, }), c) diff --git a/listener/udp/listener.go b/listener/udp/listener.go index 323f02f..f428f3c 100644 --- a/listener/udp/listener.go +++ b/listener/udp/listener.go @@ -9,7 +9,7 @@ import ( md "github.com/go-gost/core/metadata" admission "github.com/go-gost/x/admission/wrapper" xnet "github.com/go-gost/x/internal/net" - limiter "github.com/go-gost/x/limiter/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" ) @@ -57,7 +57,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) { } conn = metrics.WrapPacketConn(l.options.Service, conn) conn = admission.WrapPacketConn(l.options.Admission, conn) - conn = limiter.WrapPacketConn(l.options.RateLimiter, conn) + conn = limiter.WrapPacketConn(l.options.TrafficLimiter, conn) l.ln = udp.NewListener(conn, &udp.ListenConfig{ Backlog: l.md.backlog, diff --git a/listener/ws/listener.go b/listener/ws/listener.go index 2288e4b..5db2269 100644 --- a/listener/ws/listener.go +++ b/listener/ws/listener.go @@ -14,7 +14,8 @@ import ( xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" ws_util "github.com/go-gost/x/internal/util/ws" - limiter "github.com/go-gost/x/limiter/wrapper" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" "github.com/go-gost/x/registry" "github.com/gorilla/websocket" @@ -94,7 +95,8 @@ func (l *wsListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) ln = admission.WrapListener(l.options.Admission, ln) - ln = limiter.WrapListener(l.options.RateLimiter, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) if l.tlsEnabled { diff --git a/registry/limiter.go b/registry/limiter.go index 2ebadf3..8b81b81 100644 --- a/registry/limiter.go +++ b/registry/limiter.go @@ -1,37 +1,38 @@ package registry import ( - "github.com/go-gost/core/limiter" + "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/limiter/traffic" ) -type rlimiterRegistry struct { +type trafficLimiterRegistry struct { registry } -func (r *rlimiterRegistry) Register(name string, v limiter.RateLimiter) error { +func (r *trafficLimiterRegistry) Register(name string, v traffic.TrafficLimiter) error { return r.registry.Register(name, v) } -func (r *rlimiterRegistry) Get(name string) limiter.RateLimiter { +func (r *trafficLimiterRegistry) Get(name string) traffic.TrafficLimiter { if name != "" { - return &rlimiterWrapper{name: name, r: r} + return &trafficLimiterWrapper{name: name, r: r} } return nil } -func (r *rlimiterRegistry) get(name string) limiter.RateLimiter { +func (r *trafficLimiterRegistry) get(name string) traffic.TrafficLimiter { if v := r.registry.Get(name); v != nil { - return v.(limiter.RateLimiter) + return v.(traffic.TrafficLimiter) } return nil } -type rlimiterWrapper struct { +type trafficLimiterWrapper struct { name string - r *rlimiterRegistry + r *trafficLimiterRegistry } -func (w *rlimiterWrapper) In(key string) limiter.Limiter { +func (w *trafficLimiterWrapper) In(key string) traffic.Limiter { v := w.r.get(w.name) if v == nil { return nil @@ -39,10 +40,45 @@ func (w *rlimiterWrapper) In(key string) limiter.Limiter { return v.In(key) } -func (w *rlimiterWrapper) Out(key string) limiter.Limiter { +func (w *trafficLimiterWrapper) Out(key string) traffic.Limiter { v := w.r.get(w.name) if v == nil { return nil } return v.Out(key) } + +type connLimiterRegistry struct { + registry +} + +func (r *connLimiterRegistry) Register(name string, v conn.ConnLimiter) error { + return r.registry.Register(name, v) +} + +func (r *connLimiterRegistry) Get(name string) conn.ConnLimiter { + if name != "" { + return &connLimiterWrapper{name: name, r: r} + } + return nil +} + +func (r *connLimiterRegistry) get(name string) conn.ConnLimiter { + if v := r.registry.Get(name); v != nil { + return v.(conn.ConnLimiter) + } + return nil +} + +type connLimiterWrapper struct { + name string + r *connLimiterRegistry +} + +func (w *connLimiterWrapper) Limiter(key string) conn.Limiter { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Limiter(key) +} diff --git a/registry/registry.go b/registry/registry.go index ac163f0..0e101e8 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -10,7 +10,8 @@ import ( "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" "github.com/go-gost/core/hosts" - "github.com/go-gost/core/limiter" + "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/recorder" "github.com/go-gost/core/resolver" "github.com/go-gost/core/service" @@ -26,15 +27,16 @@ var ( dialerReg Registry[NewDialer] = &dialerRegistry{} connectorReg Registry[NewConnector] = &connectorRegistry{} - serviceReg Registry[service.Service] = &serviceRegistry{} - chainReg Registry[chain.Chainer] = &chainRegistry{} - autherReg Registry[auth.Authenticator] = &autherRegistry{} - admissionReg Registry[admission.Admission] = &admissionRegistry{} - bypassReg Registry[bypass.Bypass] = &bypassRegistry{} - resolverReg Registry[resolver.Resolver] = &resolverRegistry{} - hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} - recorderReg Registry[recorder.Recorder] = &recorderRegistry{} - rlimiterReg Registry[limiter.RateLimiter] = &rlimiterRegistry{} + serviceReg Registry[service.Service] = &serviceRegistry{} + chainReg Registry[chain.Chainer] = &chainRegistry{} + autherReg Registry[auth.Authenticator] = &autherRegistry{} + admissionReg Registry[admission.Admission] = &admissionRegistry{} + bypassReg Registry[bypass.Bypass] = &bypassRegistry{} + resolverReg Registry[resolver.Resolver] = &resolverRegistry{} + hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} + recorderReg Registry[recorder.Recorder] = &recorderRegistry{} + trafficLimiterReg Registry[traffic.TrafficLimiter] = &trafficLimiterRegistry{} + connLimiterReg Registry[conn.ConnLimiter] = &connLimiterRegistry{} ) type Registry[T any] interface { @@ -129,6 +131,10 @@ func RecorderRegistry() Registry[recorder.Recorder] { return recorderReg } -func RateLimiterRegistry() Registry[limiter.RateLimiter] { - return rlimiterReg +func TrafficLimiterRegistry() Registry[traffic.TrafficLimiter] { + return trafficLimiterReg +} + +func ConnLimiterRegistry() Registry[conn.ConnLimiter] { + return connLimiterReg }