update limiter

This commit is contained in:
ginuerzh
2022-09-14 20:00:35 +08:00
parent 91c12882f5
commit 01d7dc77c6
34 changed files with 1171 additions and 79 deletions

View File

@ -1,6 +1,7 @@
package parsing
import (
"fmt"
"time"
"github.com/go-gost/core/bypass"
@ -63,11 +64,16 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
nm = mdx.NewMetadata(v.Metadata)
}
cr := registry.ConnectorRegistry().Get(v.Connector.Type)(
connector.AuthOption(parseAuth(v.Connector.Auth)),
connector.TLSConfigOption(tlsConfig),
connector.LoggerOption(connectorLogger),
)
var cr connector.Connector
if rf := registry.ConnectorRegistry().Get(v.Connector.Type); rf != nil {
cr = rf(
connector.AuthOption(parseAuth(v.Connector.Auth)),
connector.TLSConfigOption(tlsConfig),
connector.LoggerOption(connectorLogger),
)
} else {
return nil, fmt.Errorf("unregistered connector: %s", v.Connector.Type)
}
if v.Connector.Metadata == nil {
v.Connector.Metadata = make(map[string]any)
@ -97,12 +103,18 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
if nm != nil {
ppv = mdutil.GetInt(nm, mdKeyProxyProtocol)
}
d := registry.DialerRegistry().Get(v.Dialer.Type)(
dialer.AuthOption(parseAuth(v.Dialer.Auth)),
dialer.TLSConfigOption(tlsConfig),
dialer.LoggerOption(dialerLogger),
dialer.ProxyProtocolOption(ppv),
)
var d dialer.Dialer
if rf := registry.DialerRegistry().Get(v.Dialer.Type); rf != nil {
d = rf(
dialer.AuthOption(parseAuth(v.Dialer.Auth)),
dialer.TLSConfigOption(tlsConfig),
dialer.LoggerOption(dialerLogger),
dialer.ProxyProtocolOption(ppv),
)
} else {
return nil, fmt.Errorf("unregistered dialer: %s", v.Dialer.Type)
}
if v.Dialer.Metadata == nil {
v.Dialer.Metadata = make(map[string]any)

View File

@ -10,6 +10,7 @@ import (
"github.com/go-gost/core/chain"
"github.com/go-gost/core/hosts"
"github.com/go-gost/core/limiter/conn"
"github.com/go-gost/core/limiter/rate"
"github.com/go-gost/core/limiter/traffic"
"github.com/go-gost/core/logger"
"github.com/go-gost/core/recorder"
@ -22,6 +23,7 @@ import (
xhosts "github.com/go-gost/x/hosts"
"github.com/go-gost/x/internal/loader"
xconn "github.com/go-gost/x/limiter/conn"
xrate "github.com/go-gost/x/limiter/rate"
xtraffic "github.com/go-gost/x/limiter/traffic"
xrecorder "github.com/go-gost/x/recorder"
"github.com/go-gost/x/registry"
@ -408,3 +410,43 @@ func ParseConnLimiter(cfg *config.LimiterConfig) (lim conn.ConnLimiter) {
return xconn.NewConnLimiter(opts...)
}
func ParseRateLimiter(cfg *config.LimiterConfig) (lim rate.RateLimiter) {
if cfg == nil {
return nil
}
var opts []xrate.Option
if cfg.File != nil && cfg.File.Path != "" {
opts = append(opts, xrate.FileLoaderOption(loader.FileLoader(cfg.File.Path)))
}
if cfg.Redis != nil && cfg.Redis.Addr != "" {
switch cfg.Redis.Type {
case "list": // redis list
opts = append(opts, xrate.RedisLoaderOption(loader.RedisListLoader(
cfg.Redis.Addr,
loader.DBRedisLoaderOption(cfg.Redis.DB),
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
loader.KeyRedisLoaderOption(cfg.Redis.Key),
)))
default: // redis set
opts = append(opts, xrate.RedisLoaderOption(loader.RedisSetLoader(
cfg.Redis.Addr,
loader.DBRedisLoaderOption(cfg.Redis.DB),
loader.PasswordRedisLoaderOption(cfg.Redis.Password),
loader.KeyRedisLoaderOption(cfg.Redis.Key),
)))
}
}
opts = append(opts,
xrate.LimitsOption(cfg.Limits...),
xrate.ReloadPeriodOption(cfg.Reload),
xrate.LoggerOption(logger.Default().WithFields(map[string]any{
"kind": "limiter",
"limiter": cfg.Name,
})),
)
return xrate.NewRateLimiter(opts...)
}

View File

@ -1,6 +1,7 @@
package parsing
import (
"fmt"
"strings"
"github.com/go-gost/core/admission"
@ -91,19 +92,24 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
}
}
ln := registry.ListenerRegistry().Get(cfg.Listener.Type)(
listener.AddrOption(cfg.Addr),
listener.AutherOption(auther),
listener.AuthOption(parseAuth(cfg.Listener.Auth)),
listener.TLSConfigOption(tlsConfig),
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)),
listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)),
listener.LoggerOption(listenerLogger),
listener.ServiceOption(cfg.Name),
listener.ProxyProtocolOption(ppv),
)
var ln listener.Listener
if rf := registry.ListenerRegistry().Get(cfg.Listener.Type); rf != nil {
ln = rf(
listener.AddrOption(cfg.Addr),
listener.AutherOption(auther),
listener.AuthOption(parseAuth(cfg.Listener.Auth)),
listener.TLSConfigOption(tlsConfig),
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)),
listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)),
listener.LoggerOption(listenerLogger),
listener.ServiceOption(cfg.Name),
listener.ProxyProtocolOption(ppv),
)
} else {
return nil, fmt.Errorf("unregistered listener: %s", cfg.Listener.Type)
}
if cfg.Listener.Metadata == nil {
cfg.Listener.Metadata = make(map[string]any)
@ -161,14 +167,20 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
WithRecorder(recorders...).
WithLogger(handlerLogger)
h := registry.HandlerRegistry().Get(cfg.Handler.Type)(
handler.RouterOption(router),
handler.AutherOption(auther),
handler.AuthOption(parseAuth(cfg.Handler.Auth)),
handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)),
handler.TLSConfigOption(tlsConfig),
handler.LoggerOption(handlerLogger),
)
var h handler.Handler
if rf := registry.HandlerRegistry().Get(cfg.Handler.Type); rf != nil {
h = rf(
handler.RouterOption(router),
handler.AutherOption(auther),
handler.AuthOption(parseAuth(cfg.Handler.Auth)),
handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)),
handler.TLSConfigOption(tlsConfig),
handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)),
handler.LoggerOption(handlerLogger),
)
} else {
return nil, fmt.Errorf("unregistered handler: %s", cfg.Handler.Type)
}
if forwarder, ok := h.(handler.Forwarder); ok {
forwarder.Forward(parseForwarder(cfg.Forwarder))