From 88cc6ff4d5ba6b258f4301f4af7ad3479b7b5e3f Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 18 Nov 2023 18:28:09 +0800 Subject: [PATCH] add traffic limiter for proxy handler --- auth/plugin.go | 6 +- bypass/plugin.go | 6 +- config/config.go | 3 +- config/parsing/limiter/parse.go | 28 ++++ config/parsing/service/parse.go | 1 + connector/tunnel/connector.go | 4 +- go.mod | 4 +- go.sum | 10 +- handler/forward/local/handler.go | 13 +- handler/forward/remote/handler.go | 9 +- handler/http/handler.go | 22 ++- handler/http2/handler.go | 21 ++- handler/http3/handler.go | 4 +- handler/relay/connect.go | 15 +- handler/relay/forward.go | 12 +- handler/relay/handler.go | 8 +- handler/sni/handler.go | 6 +- handler/socks/v4/handler.go | 20 ++- handler/socks/v5/connect.go | 15 +- handler/socks/v5/handler.go | 6 +- handler/socks/v5/selector.go | 4 +- handler/ss/handler.go | 4 +- handler/tunnel/connect.go | 12 +- handler/tunnel/handler.go | 8 +- hop/plugin.go | 9 +- hosts/plugin.go | 6 +- internal/ctx/value.go | 76 ++++++++++ internal/util/auth/key.go | 34 ----- internal/util/selector/key.go | 26 ---- limiter/traffic/plugin.go | 235 ++++++++++++++++++++++++++++++ limiter/traffic/traffic.go | 4 +- limiter/traffic/wrapper/conn.go | 27 ++-- limiter/traffic/wrapper/io.go | 109 ++++++++++++++ listener/http2/listener.go | 4 +- registry/limiter.go | 10 +- resolver/plugin.go | 6 +- selector/strategy.go | 4 +- service/service.go | 42 ++---- 38 files changed, 633 insertions(+), 200 deletions(-) create mode 100644 internal/ctx/value.go delete mode 100644 internal/util/auth/key.go delete mode 100644 internal/util/selector/key.go create mode 100644 limiter/traffic/plugin.go create mode 100644 limiter/traffic/wrapper/io.go diff --git a/auth/plugin.go b/auth/plugin.go index 72fe797..8ada348 100644 --- a/auth/plugin.go +++ b/auth/plugin.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/auth/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -58,7 +58,7 @@ func (p *grpcPlugin) Authenticate(ctx context.Context, user, password string, op &proto.AuthenticateRequest{ Username: user, Password: password, - Client: string(auth_util.ClientAddrFromContext(ctx)), + Client: string(ctxvalue.ClientAddrFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -118,7 +118,7 @@ func (p *httpPlugin) Authenticate(ctx context.Context, user, password string, op rb := httpPluginRequest{ Username: user, Password: password, - Client: string(auth_util.ClientAddrFromContext(ctx)), + Client: string(ctxvalue.ClientAddrFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/bypass/plugin.go b/bypass/plugin.go index 4e20703..b18919c 100644 --- a/bypass/plugin.go +++ b/bypass/plugin.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/bypass" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/bypass/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -61,7 +61,7 @@ func (p *grpcPlugin) Contains(ctx context.Context, network, addr string, opts .. &proto.BypassRequest{ Network: network, Addr: addr, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), Host: options.Host, Path: options.Path, }) @@ -129,7 +129,7 @@ func (p *httpPlugin) Contains(ctx context.Context, network, addr string, opts .. rb := httpPluginRequest{ Network: network, Addr: addr, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), Host: options.Host, Path: options.Path, } diff --git a/config/config.go b/config/config.go index a4bf055..38ab61b 100644 --- a/config/config.go +++ b/config/config.go @@ -289,6 +289,7 @@ type LimiterConfig struct { File *FileLoader `yaml:",omitempty" json:"file,omitempty"` Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` HTTP *HTTPLoader `yaml:"http,omitempty" json:"http,omitempty"` + Plugin *PluginConfig `yaml:",omitempty" json:"plugin,omitempty"` } type ListenerConfig struct { @@ -311,7 +312,7 @@ type HandlerConfig struct { Authers []string `yaml:",omitempty" json:"authers,omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` - Ingress string `yaml:",omitempty" json:"ingress,omitempty"` + Limiter string `yaml:",omitempty" json:"limiter,omitempty"` Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } diff --git a/config/parsing/limiter/parse.go b/config/parsing/limiter/parse.go index b26b5fd..6004290 100644 --- a/config/parsing/limiter/parse.go +++ b/config/parsing/limiter/parse.go @@ -1,12 +1,16 @@ package limiter import ( + "crypto/tls" + "strings" + "github.com/go-gost/core/limiter/conn" "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/x/config" "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" xconn "github.com/go-gost/x/limiter/conn" xrate "github.com/go-gost/x/limiter/rate" xtraffic "github.com/go-gost/x/limiter/traffic" @@ -17,6 +21,30 @@ func ParseTrafficLimiter(cfg *config.LimiterConfig) (lim traffic.TrafficLimiter) return nil } + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xtraffic.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xtraffic.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + var opts []xtraffic.Option if cfg.File != nil && cfg.File.Path != "" { diff --git a/config/parsing/service/parse.go b/config/parsing/service/parse.go index 8edf2e5..8ce036c 100644 --- a/config/parsing/service/parse.go +++ b/config/parsing/service/parse.go @@ -210,6 +210,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { handler.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)), handler.TLSConfigOption(tlsConfig), handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)), + handler.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Handler.Limiter)), handler.LoggerOption(handlerLogger), handler.ServiceOption(cfg.Name), ) diff --git a/connector/tunnel/connector.go b/connector/tunnel/connector.go index a427dfd..8a1c6dc 100644 --- a/connector/tunnel/connector.go +++ b/connector/tunnel/connector.go @@ -9,7 +9,7 @@ import ( "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" "github.com/go-gost/relay" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/registry" ) @@ -73,7 +73,7 @@ func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, a } srcAddr := conn.LocalAddr().String() - if v := auth_util.ClientAddrFromContext(ctx); v != "" { + if v := ctxvalue.ClientAddrFromContext(ctx); v != "" { srcAddr = string(v) } diff --git a/go.mod b/go.mod index 4b35ed6..0db2020 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,10 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20231113123850-a916f0401649 + github.com/go-gost/core v0.0.0-20231118102540-486f2cee616a github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 - github.com/go-gost/plugin v0.0.0-20231109123346-0ae4157b9d25 + github.com/go-gost/plugin v0.0.0-20231118102615-bfe81cbb44b6 github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 diff --git a/go.sum b/go.sum index c476d34..a46131e 100644 --- a/go.sum +++ b/go.sum @@ -93,16 +93,14 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20231109123312-8e4fc06cf1b7 h1:sDsPtmP51qf8zN/RbZZj/3vNLCoH0sdvpIRwV6TfzvY= -github.com/go-gost/core v0.0.0-20231109123312-8e4fc06cf1b7/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= -github.com/go-gost/core v0.0.0-20231113123850-a916f0401649 h1:14iGAk7cqc+aDWtsuY6CWpP0lvC54pA5Izjeh5FdQNs= -github.com/go-gost/core v0.0.0-20231113123850-a916f0401649/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20231118102540-486f2cee616a h1:bGpcollgZpuI0ct6FdJxZ2k7ipu4T6qrQbHc+ZbI29I= +github.com/go-gost/core v0.0.0-20231118102540-486f2cee616a/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/plugin v0.0.0-20231109123346-0ae4157b9d25 h1:sOarC0xAJij4VtEhkJRng5okZW23KlXprxhb5XFZ+pw= -github.com/go-gost/plugin v0.0.0-20231109123346-0ae4157b9d25/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= +github.com/go-gost/plugin v0.0.0-20231118102615-bfe81cbb44b6 h1:1zFWxk8mmTewdDzzZutO4nremhQ6N93PWdB3FrLfbaQ= +github.com/go-gost/plugin v0.0.0-20231118102615-bfe81cbb44b6/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index d90ca24..e3b743a 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -21,9 +21,9 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/x/config" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" xnet "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/forward" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/registry" @@ -119,8 +119,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand host = net.JoinHostPort(host, "0") } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - var target *chain.Node if host != "" { target = &chain.Node{ @@ -223,10 +221,9 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot "src": addr.String(), }) remoteAddr = addr + ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(remoteAddr.String())) } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(remoteAddr.String())) - target := &chain.Node{ Addr: req.Host, } @@ -259,7 +256,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr) return resp.Write(rw) } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(id)) } if httpSettings := target.Options().HTTP; httpSettings != nil { if httpSettings.Host != "" { @@ -292,8 +289,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot InsecureSkipVerify: !tlsSettings.Secure, } tls_util.SetTLSOptions(cfg, &config.TLSOptions{ - MinVersion: tlsSettings.Options.MinVersion, - MaxVersion: tlsSettings.Options.MaxVersion, + MinVersion: tlsSettings.Options.MinVersion, + MaxVersion: tlsSettings.Options.MaxVersion, CipherSuites: tlsSettings.Options.CipherSuites, }) cc = tls.Client(cc, cfg) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 2c91672..ea92caa 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -22,10 +22,10 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/x/config" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/forward" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/registry" @@ -117,8 +117,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - if md, ok := conn.(mdata.Metadatable); ok { if v := mdutil.GetString(md.Metadata(), "host"); v != "" { host = v @@ -224,10 +222,9 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot "src": addr.String(), }) remoteAddr = addr + ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(remoteAddr.String())) } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(remoteAddr.String())) - target := &chain.Node{ Addr: req.Host, } @@ -260,7 +257,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr) return resp.Write(rw) } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(id)) } if httpSettings := target.Options().HTTP; httpSettings != nil { if httpSettings.Host != "" { diff --git a/handler/http/handler.go b/handler/http/handler.go index fd6249a..e6534eb 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -19,11 +19,12 @@ import ( "github.com/asaskevich/govalidator" "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" "github.com/go-gost/x/registry" ) @@ -89,8 +90,6 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler } defer req.Body.Close() - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - return h.handleRequest(ctx, conn, req, log) } @@ -148,11 +147,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt resp.Header = http.Header{} } - id, ok := h.authenticate(ctx, conn, req, resp, log) + clientID, ok := h.authenticate(ctx, conn, req, resp, log) if !ok { return nil } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, addr) { resp.StatusCode = http.StatusForbidden @@ -186,7 +185,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } cc, err := h.router.Dial(ctx, network, addr) @@ -222,9 +221,16 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(addr), + traffic.ClientOption(clientID), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + start := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(start), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 902f1a2..8dce6aa 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -20,12 +20,13 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" netpkg "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" "github.com/go-gost/x/registry" ) @@ -89,8 +90,6 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle return err } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - md := v.Metadata() return h.roundTrip(ctx, md.Get("w").(http.ResponseWriter), @@ -149,11 +148,11 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req Body: io.NopCloser(bytes.NewReader([]byte{})), } - id, ok := h.authenticate(ctx, w, req, resp, log) + clientID, ok := h.authenticate(ctx, w, req, resp, log) if !ok { return nil } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", addr) { w.WriteHeader(http.StatusForbidden) @@ -167,7 +166,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } cc, err := h.router.Dial(ctx, "tcp", addr) @@ -205,9 +204,15 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req return nil } + rw := wrapper.WrapReadWriter(h.options.Limiter, xio.NewReadWriter(req.Body, flushWriter{w}), req.RemoteAddr, + traffic.NetworkOption("tcp"), + traffic.AddrOption(addr), + traffic.ClientOption(clientID), + traffic.SrcOption(req.RemoteAddr), + ) start := time.Now() log.Infof("%s <-> %s", req.RemoteAddr, addr) - netpkg.Transport(xio.NewReadWriter(req.Body, flushWriter{w}), cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(start), }).Infof("%s >-< %s", req.RemoteAddr, addr) diff --git a/handler/http3/handler.go b/handler/http3/handler.go index a2c92c3..767657c 100644 --- a/handler/http3/handler.go +++ b/handler/http3/handler.go @@ -14,7 +14,7 @@ import ( "github.com/go-gost/core/hop" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - sx "github.com/go-gost/x/internal/util/selector" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/registry" ) @@ -114,7 +114,7 @@ func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } var target *chain.Node diff --git a/handler/relay/connect.go b/handler/relay/connect.go index c67bd89..6af1c6b 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -8,11 +8,13 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" xnet "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" serial "github.com/go-gost/x/internal/util/serial" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (err error) { @@ -51,7 +53,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: address}) } var cc io.ReadWriteCloser @@ -103,9 +105,16 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network } } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(address), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), address) - xnet.Transport(conn, cc) + xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), address) diff --git a/handler/relay/forward.go b/handler/relay/forward.go index 62b1a24..04a9d33 100644 --- a/handler/relay/forward.go +++ b/handler/relay/forward.go @@ -7,9 +7,12 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) error { @@ -84,9 +87,16 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network conn = rc } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(target.Addr), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) diff --git a/handler/relay/handler.go b/handler/relay/handler.go index 5047dad..09ab0ab 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -13,7 +13,7 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/core/service" "github.com/go-gost/relay" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/registry" ) @@ -83,8 +83,6 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - if !h.checkRateLimit(conn.RemoteAddr()) { return ErrRateLimit } @@ -136,13 +134,13 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle } if h.options.Auther != nil { - id, ok := h.options.Auther.Authenticate(ctx, user, pass) + clientID, ok := h.options.Auther.Authenticate(ctx, user, pass) if !ok { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) return ErrUnauthorized } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) } network := networkID.String() diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 3434604..ab7acb8 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -21,9 +21,9 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" dissector "github.com/go-gost/tls-dissector" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" netpkg "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -123,7 +123,7 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: host}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: host}) } cc, err := h.router.Dial(ctx, "tcp", host) @@ -191,7 +191,7 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: host}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: host}) } cc, err := h.router.Dial(ctx, "tcp", host) diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index 19b7adf..db04955 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -8,12 +8,13 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks4" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" "github.com/go-gost/x/registry" ) @@ -82,8 +83,6 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - req, err := gosocks4.ReadRequest(conn) if err != nil { log.Error(err) @@ -100,7 +99,7 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl log.Trace(resp) return resp.Write(conn) } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(id)) } switch req.Cmd { @@ -132,7 +131,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } cc, err := h.router.Dial(ctx, "tcp", addr) @@ -152,9 +151,16 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g return err } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption("tcp"), + traffic.AddrOption(addr), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) diff --git a/handler/socks/v5/connect.go b/handler/socks/v5/connect.go index ad4e0e3..dcccbfd 100644 --- a/handler/socks/v5/connect.go +++ b/handler/socks/v5/connect.go @@ -6,10 +6,12 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { @@ -28,7 +30,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: address}) } cc, err := h.router.Dial(ctx, network, address) @@ -48,9 +50,16 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ return err } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(address), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), address) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), address) diff --git a/handler/socks/v5/handler.go b/handler/socks/v5/handler.go index 690ad24..2432cce 100644 --- a/handler/socks/v5/handler.go +++ b/handler/socks/v5/handler.go @@ -10,7 +10,7 @@ import ( "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks5" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/util/socks" "github.com/go-gost/x/registry" ) @@ -95,7 +95,9 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl } log.Trace(req) - ctx = auth_util.ContextWithID(ctx, auth_util.ID(sc.ID())) + if clientID := sc.ID(); clientID != "" { + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) + } conn = sc conn.SetReadDeadline(time.Time{}) diff --git a/handler/socks/v5/selector.go b/handler/socks/v5/selector.go index f4b4e86..f49adaf 100644 --- a/handler/socks/v5/selector.go +++ b/handler/socks/v5/selector.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/util/socks" ) @@ -70,7 +70,7 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (string, net.Co var id string if s.Authenticator != nil { var ok bool - ctx := auth_util.ContextWithClientAddr(context.Background(), auth_util.ClientAddr(conn.RemoteAddr().String())) + ctx := ctxvalue.ContextWithClientAddr(context.Background(), ctxvalue.ClientAddr(conn.RemoteAddr().String())) id, ok = s.Authenticator.Authenticate(ctx, req.Username, req.Password) if !ok { resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) diff --git a/handler/ss/handler.go b/handler/ss/handler.go index 3014e3e..56f52e5 100644 --- a/handler/ss/handler.go +++ b/handler/ss/handler.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks5" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/internal/util/ss" "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" @@ -108,7 +108,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr.String()}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr.String()}) } cc, err := h.router.Dial(ctx, "tcp", addr.String()) diff --git a/handler/tunnel/connect.go b/handler/tunnel/connect.go index 4b5f020..a84d9d2 100644 --- a/handler/tunnel/connect.go +++ b/handler/tunnel/connect.go @@ -6,9 +6,12 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" xnet "github.com/go-gost/x/internal/net" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *tunnelHandler) handleConnect(ctx context.Context, req *relay.Request, conn net.Conn, network, srcAddr string, dstAddr string, tunnelID relay.TunnelID, log logger.Logger) error { @@ -95,9 +98,16 @@ func (h *tunnelHandler) handleConnect(ctx context.Context, req *relay.Request, c req.WriteTo(cc) } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, tunnelID.String(), + traffic.NetworkOption(network), + traffic.AddrOption(dstAddr), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - xnet.Transport(conn, cc) + xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go index e715264..af385e4 100644 --- a/handler/tunnel/handler.go +++ b/handler/tunnel/handler.go @@ -15,8 +15,8 @@ import ( "github.com/go-gost/core/recorder" "github.com/go-gost/core/service" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" xnet "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" xrecorder "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" xservice "github.com/go-gost/x/service" @@ -169,8 +169,6 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - if !h.checkRateLimit(conn.RemoteAddr()) { return ErrRateLimit } @@ -238,13 +236,13 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl } if h.options.Auther != nil { - id, ok := h.options.Auther.Authenticate(ctx, user, pass) + clientID, ok := h.options.Auther.Authenticate(ctx, user, pass) if !ok { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) return ErrUnauthorized } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) } switch req.Cmd & relay.CmdMask { diff --git a/hop/plugin.go b/hop/plugin.go index 612d0e1..519d5f5 100644 --- a/hop/plugin.go +++ b/hop/plugin.go @@ -13,8 +13,8 @@ import ( "github.com/go-gost/plugin/hop/proto" "github.com/go-gost/x/config" node_parser "github.com/go-gost/x/config/parsing/node" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -68,7 +68,8 @@ func (p *grpcPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai Addr: options.Addr, Host: options.Host, Path: options.Path, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), + Src: string(ctxvalue.ClientAddrFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -106,6 +107,7 @@ type httpPluginRequest struct { Host string `json:"host"` Path string `json:"path"` Client string `json:"client"` + Src string `json:"src"` } type httpPluginResponse struct { @@ -154,7 +156,8 @@ func (p *httpPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai Addr: options.Addr, Host: options.Host, Path: options.Path, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), + Src: string(ctxvalue.ClientAddrFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/hosts/plugin.go b/hosts/plugin.go index 07f15d6..1ccd6b7 100644 --- a/hosts/plugin.go +++ b/hosts/plugin.go @@ -11,8 +11,8 @@ import ( "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/hosts/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -58,7 +58,7 @@ func (p *grpcPlugin) Lookup(ctx context.Context, network, host string, opts ...h &proto.LookupRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -126,7 +126,7 @@ func (p *httpPlugin) Lookup(ctx context.Context, network, host string, opts ...h rb := httpPluginRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/internal/ctx/value.go b/internal/ctx/value.go new file mode 100644 index 0000000..46b3dc2 --- /dev/null +++ b/internal/ctx/value.go @@ -0,0 +1,76 @@ +package ctx + +import "context" + +// clientAddrKey saves the client address. +type clientAddrKey struct{} + +type ClientAddr string + +var ( + keyClientAddr clientAddrKey +) + +func ContextWithClientAddr(ctx context.Context, addr ClientAddr) context.Context { + return context.WithValue(ctx, keyClientAddr, addr) +} + +func ClientAddrFromContext(ctx context.Context) ClientAddr { + v, _ := ctx.Value(keyClientAddr).(ClientAddr) + return v +} + +// sidKey saves the session ID. +type sidKey struct{} +type Sid string + +var ( + keySid sidKey +) + +func ContextWithSid(ctx context.Context, sid Sid) context.Context { + return context.WithValue(ctx, keySid, sid) +} + +func SidFromContext(ctx context.Context) Sid { + v, _ := ctx.Value(keySid).(Sid) + return v +} + +// hashKey saves the hash source for Selector. +type hashKey struct{} + +type Hash struct { + Source string +} + +var ( + clientHashKey = &hashKey{} +) + +func ContextWithHash(ctx context.Context, hash *Hash) context.Context { + return context.WithValue(ctx, clientHashKey, hash) +} + +func HashFromContext(ctx context.Context) *Hash { + if v, _ := ctx.Value(clientHashKey).(*Hash); v != nil { + return v + } + return nil +} + +type clientIDKey struct{} +type ClientID string + +var ( + keyClientID = &clientIDKey{} +) + +func ContextWithClientID(ctx context.Context, clientID ClientID) context.Context { + return context.WithValue(ctx, keyClientID, clientID) +} + +func ClientIDFromContext(ctx context.Context) ClientID { + v, _ := ctx.Value(keyClientID).(ClientID) + return v +} diff --git a/internal/util/auth/key.go b/internal/util/auth/key.go deleted file mode 100644 index a83c277..0000000 --- a/internal/util/auth/key.go +++ /dev/null @@ -1,34 +0,0 @@ -package auth - -import ( - "context" -) - -type idKey struct{} -type ID string - -type addrKey struct{} -type ClientAddr string - -var ( - clientIDKey = &idKey{} - clientAddrKey = &addrKey{} -) - -func ContextWithID(ctx context.Context, id ID) context.Context { - return context.WithValue(ctx, clientIDKey, id) -} - -func IDFromContext(ctx context.Context) ID { - v, _ := ctx.Value(clientIDKey).(ID) - return v -} - -func ContextWithClientAddr(ctx context.Context, addr ClientAddr) context.Context { - return context.WithValue(ctx, clientAddrKey, addr) -} - -func ClientAddrFromContext(ctx context.Context) ClientAddr { - v, _ := ctx.Value(clientAddrKey).(ClientAddr) - return v -} diff --git a/internal/util/selector/key.go b/internal/util/selector/key.go deleted file mode 100644 index 292eac3..0000000 --- a/internal/util/selector/key.go +++ /dev/null @@ -1,26 +0,0 @@ -package selector - -import ( - "context" -) - -type hashKey struct{} - -type Hash struct { - Source string -} - -var ( - clientHashKey = &hashKey{} -) - -func ContextWithHash(ctx context.Context, hash *Hash) context.Context { - return context.WithValue(ctx, clientHashKey, hash) -} - -func HashFromContext(ctx context.Context) *Hash { - if v, _ := ctx.Value(clientHashKey).(*Hash); v != nil { - return v - } - return nil -} diff --git a/limiter/traffic/plugin.go b/limiter/traffic/plugin.go new file mode 100644 index 0000000..a967566 --- /dev/null +++ b/limiter/traffic/plugin.go @@ -0,0 +1,235 @@ +package traffic + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + + "github.com/go-gost/core/limiter/traffic" + "github.com/go-gost/core/logger" + "github.com/go-gost/plugin/limiter/traffic/proto" + "github.com/go-gost/x/internal/plugin" + "google.golang.org/grpc" +) + +type grpcPlugin struct { + conn grpc.ClientConnInterface + client proto.LimiterClient + log logger.Logger +} + +// NewGRPCPlugin creates a traffic limiter plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) traffic.TrafficLimiter { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + + p := &grpcPlugin{ + conn: conn, + log: log, + } + if conn != nil { + p.client = proto.NewLimiterClient(conn) + } + return p +} + +func (p *grpcPlugin) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + r, err := p.client.Limit(ctx, + &proto.LimitRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + }) + if err != nil { + p.log.Error(err) + return nil + } + + return NewLimiter(int(r.In)) +} + +func (p *grpcPlugin) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + r, err := p.client.Limit(ctx, + &proto.LimitRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + }) + if err != nil { + p.log.Error(err) + return nil + } + + return NewLimiter(int(r.Out)) +} + +func (p *grpcPlugin) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() + } + return nil +} + +type httpPluginRequest struct { + Network string `json:"network"` + Addr string `json:"addr"` + Client string `json:"client"` + Src string `json:"src"` +} + +type httpPluginResponse struct { + In int64 `json:"in"` + Out int64 `json:"out"` +} + +type httpPlugin struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPlugin creates a traffic limiter plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) traffic.TrafficLimiter { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPlugin{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": name, + }), + } +} + +func (p *httpPlugin) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + rb := httpPluginRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + } + v, err := json.Marshal(&rb) + if err != nil { + return nil + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return nil + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + res := httpPluginResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil + } + return NewLimiter(int(res.In)) +} + +func (p *httpPlugin) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + rb := httpPluginRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + } + v, err := json.Marshal(&rb) + if err != nil { + return nil + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return nil + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + res := httpPluginResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil + } + return NewLimiter(int(res.Out)) +} diff --git a/limiter/traffic/traffic.go b/limiter/traffic/traffic.go index 12b16bb..4349daa 100644 --- a/limiter/traffic/traffic.go +++ b/limiter/traffic/traffic.go @@ -121,7 +121,7 @@ func NewTrafficLimiter(opts ...Option) limiter.TrafficLimiter { // In obtains a traffic input limiter based on key. // The key should be client connection address. -func (l *trafficLimiter) In(key string) limiter.Limiter { +func (l *trafficLimiter) In(ctx context.Context, key string, opts ...limiter.Option) limiter.Limiter { var lims []limiter.Limiter // service level limiter @@ -185,7 +185,7 @@ func (l *trafficLimiter) In(key string) limiter.Limiter { // Out obtains a traffic output limiter based on key. // The key should be client connection address. -func (l *trafficLimiter) Out(key string) limiter.Limiter { +func (l *trafficLimiter) Out(ctx context.Context, key string, opts ...limiter.Option) limiter.Limiter { var lims []limiter.Limiter // service level limiter diff --git a/limiter/traffic/wrapper/conn.go b/limiter/traffic/wrapper/conn.go index 3671041..407b0d8 100644 --- a/limiter/traffic/wrapper/conn.go +++ b/limiter/traffic/wrapper/conn.go @@ -26,8 +26,8 @@ type serverConn struct { rbuf bytes.Buffer limiter limiter.TrafficLimiter limiterIn limiter.Limiter - expIn int64 limiterOut limiter.Limiter + expIn int64 expOut int64 } @@ -35,34 +35,39 @@ func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn { if limiter == nil { return c } + return &serverConn{ Conn: c, limiter: limiter, } } -func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter { +func (c *serverConn) getInLimiter() limiter.Limiter { now := time.Now().UnixNano() // cache the limiter for 60s if c.limiter != nil && time.Duration(now-c.expIn) > 60*time.Second { - c.limiterIn = c.limiter.In(addr.String()) + if lim := c.limiter.In(context.Background(), c.RemoteAddr().String()); lim != nil { + c.limiterIn = lim + } c.expIn = now } return c.limiterIn } -func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter { +func (c *serverConn) getOutLimiter() limiter.Limiter { now := time.Now().UnixNano() // cache the limiter for 60s if c.limiter != nil && time.Duration(now-c.expOut) > 60*time.Second { - c.limiterOut = c.limiter.Out(addr.String()) + if lim := c.limiter.Out(context.Background(), c.RemoteAddr().String()); lim != nil { + c.limiterOut = lim + } c.expOut = now } return c.limiterOut } func (c *serverConn) Read(b []byte) (n int, err error) { - limiter := c.getInLimiter(c.RemoteAddr()) + limiter := c.getInLimiter() if limiter == nil { return c.Conn.Read(b) } @@ -92,7 +97,7 @@ func (c *serverConn) Read(b []byte) (n int, err error) { } func (c *serverConn) Write(b []byte) (n int, err error) { - limiter := c.getOutLimiter(c.RemoteAddr()) + limiter := c.getOutLimiter() if limiter == nil { return c.Conn.Write(b) } @@ -163,7 +168,7 @@ func (c *packetConn) getInLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.In(addr.String()) + lim = c.limiter.In(context.Background(), addr.String()) c.inLimits.Set(addr.String(), lim, 0) return lim @@ -187,7 +192,7 @@ func (c *packetConn) getOutLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.Out(addr.String()) + lim = c.limiter.Out(context.Background(), addr.String()) c.outLimits.Set(addr.String(), lim, 0) return lim @@ -266,7 +271,7 @@ func (c *udpConn) getInLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.In(addr.String()) + lim = c.limiter.In(context.Background(), addr.String()) c.inLimits.Set(addr.String(), lim, 0) return lim @@ -290,7 +295,7 @@ func (c *udpConn) getOutLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.Out(addr.String()) + lim = c.limiter.Out(context.Background(), addr.String()) c.outLimits.Set(addr.String(), lim, 0) return lim diff --git a/limiter/traffic/wrapper/io.go b/limiter/traffic/wrapper/io.go new file mode 100644 index 0000000..3ee0ee3 --- /dev/null +++ b/limiter/traffic/wrapper/io.go @@ -0,0 +1,109 @@ +package wrapper + +import ( + "bytes" + "context" + "io" + "time" + + "github.com/go-gost/core/limiter/traffic" + limiter "github.com/go-gost/core/limiter/traffic" +) + +// readWriter is an io.ReadWriter with traffic limiter supported. +type readWriter struct { + io.ReadWriter + rbuf bytes.Buffer + limiter limiter.TrafficLimiter + limiterIn limiter.Limiter + limiterOut limiter.Limiter + expIn int64 + expOut int64 + opts []traffic.Option + key string +} + +func WrapReadWriter(limiter limiter.TrafficLimiter, rw io.ReadWriter, key string, opts ...traffic.Option) io.ReadWriter { + if limiter == nil { + return rw + } + + return &readWriter{ + ReadWriter: rw, + limiter: limiter, + opts: opts, + } +} + +func (p *readWriter) getInLimiter() limiter.Limiter { + now := time.Now().UnixNano() + // cache the limiter for 60s + if p.limiter != nil && time.Duration(now-p.expIn) > 60*time.Second { + if lim := p.limiter.In(context.Background(), p.key, p.opts...); lim != nil { + p.limiterIn = lim + } + p.expIn = now + } + return p.limiterIn +} + +func (p *readWriter) getOutLimiter() limiter.Limiter { + now := time.Now().UnixNano() + // cache the limiter for 60s + if p.limiter != nil && time.Duration(now-p.expOut) > 60*time.Second { + if lim := p.limiter.Out(context.Background(), p.key, p.opts...); lim != nil { + p.limiterOut = lim + } + p.expOut = now + } + return p.limiterOut +} + +func (p *readWriter) Read(b []byte) (n int, err error) { + limiter := p.getInLimiter() + if limiter == nil { + return p.ReadWriter.Read(b) + } + + if p.rbuf.Len() > 0 { + burst := len(b) + if p.rbuf.Len() < burst { + burst = p.rbuf.Len() + } + lim := limiter.Wait(context.Background(), burst) + return p.rbuf.Read(b[:lim]) + } + + nn, err := p.ReadWriter.Read(b) + if err != nil { + return nn, err + } + + n = limiter.Wait(context.Background(), nn) + if n < nn { + if _, err = p.rbuf.Write(b[n:nn]); err != nil { + return 0, err + } + } + + return +} + +func (p *readWriter) Write(b []byte) (n int, err error) { + limiter := p.getOutLimiter() + if limiter == nil { + return p.ReadWriter.Write(b) + } + + nn := 0 + for len(b) > 0 { + nn, err = p.ReadWriter.Write(b[:limiter.Wait(context.Background(), len(b))]) + n += nn + if err != nil { + return + } + b = b[nn:] + } + + return +} diff --git a/listener/http2/listener.go b/listener/http2/listener.go index abee6fb..a021660 100644 --- a/listener/http2/listener.go +++ b/listener/http2/listener.go @@ -118,10 +118,10 @@ func (l *http2Listener) Close() (err error) { case <-l.errChan: default: err = l.server.Close() - l.errChan <- err + l.errChan <- http.ErrServerClosed close(l.errChan) } - return nil + return } func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { diff --git a/registry/limiter.go b/registry/limiter.go index 4ade180..fedc718 100644 --- a/registry/limiter.go +++ b/registry/limiter.go @@ -1,6 +1,8 @@ package registry import ( + "context" + "github.com/go-gost/core/limiter/conn" "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" @@ -30,20 +32,20 @@ type trafficLimiterWrapper struct { r *trafficLimiterRegistry } -func (w *trafficLimiterWrapper) In(key string) traffic.Limiter { +func (w *trafficLimiterWrapper) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { v := w.r.get(w.name) if v == nil { return nil } - return v.In(key) + return v.In(ctx, key, opts...) } -func (w *trafficLimiterWrapper) Out(key string) traffic.Limiter { +func (w *trafficLimiterWrapper) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { v := w.r.get(w.name) if v == nil { return nil } - return v.Out(key) + return v.Out(ctx, key, opts...) } type connLimiterRegistry struct { diff --git a/resolver/plugin.go b/resolver/plugin.go index 1c5d47f..d523a45 100644 --- a/resolver/plugin.go +++ b/resolver/plugin.go @@ -13,8 +13,8 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/core/resolver" "github.com/go-gost/plugin/resolver/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -60,7 +60,7 @@ func (p *grpcPlugin) Resolve(ctx context.Context, network, host string, opts ... &proto.ResolveRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -127,7 +127,7 @@ func (p *httpPlugin) Resolve(ctx context.Context, network, host string, opts ... rb := httpPluginRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/selector/strategy.go b/selector/strategy.go index f8b9e11..4a93448 100644 --- a/selector/strategy.go +++ b/selector/strategy.go @@ -12,7 +12,7 @@ import ( "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/core/selector" - sx "github.com/go-gost/x/internal/util/selector" + ctxvalue "github.com/go-gost/x/internal/ctx" ) type roundRobinStrategy[T any] struct { @@ -102,7 +102,7 @@ func (s *hashStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { if len(vs) == 0 { return } - if h := sx.HashFromContext(ctx); h != nil { + if h := ctxvalue.HashFromContext(ctx); h != nil { value := uint64(crc32.ChecksumIEEE([]byte(h.Source))) logger.Default().Tracef("hash %s %d", h.Source, value) return vs[value%uint64(len(vs))] diff --git a/service/service.go b/service/service.go index 11a0571..38724a7 100644 --- a/service/service.go +++ b/service/service.go @@ -15,7 +15,7 @@ import ( "github.com/go-gost/core/metrics" "github.com/go-gost/core/recorder" "github.com/go-gost/core/service" - sx "github.com/go-gost/x/internal/util/selector" + ctxvalue "github.com/go-gost/x/internal/ctx" xmetrics "github.com/go-gost/x/metrics" "github.com/rs/xid" ) @@ -145,20 +145,26 @@ func (s *defaultService) Serve() error { } tempDelay = 0 - host := conn.RemoteAddr().String() - if h, _, _ := net.SplitHostPort(host); h != "" { - host = h + clientAddr := conn.RemoteAddr().String() + clientIP := clientAddr + if h, _, _ := net.SplitHostPort(clientAddr); h != "" { + clientIP = h } + + ctx := ctxvalue.ContextWithSid(context.Background(), ctxvalue.Sid(xid.New().String())) + ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(clientAddr)) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: clientIP}) + for _, rec := range s.options.recorders { if rec.Record == recorder.RecorderServiceClientAddress { - if err := rec.Recorder.Record(context.Background(), []byte(host)); err != nil { + if err := rec.Recorder.Record(ctx, []byte(clientIP)); err != nil { s.options.logger.Errorf("record %s: %v", rec.Record, err) } break } } if s.options.admission != nil && - !s.options.admission.Admit(context.Background(), conn.RemoteAddr().String()) { + !s.options.admission.Admit(ctx, conn.RemoteAddr().String()) { conn.Close() s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr()) continue @@ -166,12 +172,12 @@ func (s *defaultService) Serve() error { go func() { if v := xmetrics.GetCounter(xmetrics.MetricServiceRequestsCounter, - metrics.Labels{"service": s.name, "client": host}); v != nil { + metrics.Labels{"service": s.name, "client": clientIP}); v != nil { v.Inc() } if v := xmetrics.GetGauge(xmetrics.MetricServiceRequestsInFlightGauge, - metrics.Labels{"service": s.name, "client": host}); v != nil { + metrics.Labels{"service": s.name, "client": clientIP}); v != nil { v.Inc() defer v.Dec() } @@ -184,13 +190,10 @@ func (s *defaultService) Serve() error { }() } - ctx := sx.ContextWithHash(context.Background(), &sx.Hash{Source: host}) - ctx = ContextWithSid(ctx, xid.New().String()) - if err := s.handler.Handle(ctx, conn); err != nil { s.options.logger.Error(err) if v := xmetrics.GetCounter(xmetrics.MetricServiceHandlerErrorsCounter, - metrics.Labels{"service": s.name, "client": host}); v != nil { + metrics.Labels{"service": s.name, "client": clientIP}); v != nil { v.Inc() } } @@ -211,18 +214,3 @@ func (s *defaultService) execCmds(phase string, cmds []string) { } } } - -type sidKey struct{} - -var ( - ssid sidKey -) - -func ContextWithSid(ctx context.Context, sid string) context.Context { - return context.WithValue(ctx, ssid, sid) -} - -func SidFromContext(ctx context.Context) string { - v, _ := ctx.Value(ssid).(string) - return v -}