diff --git a/api/config.go b/api/config.go index 2cce93e..affa7b3 100644 --- a/api/config.go +++ b/api/config.go @@ -62,6 +62,9 @@ type saveConfigRequest struct { // output format, one of yaml|json, default is yaml. // in: query Format string `form:"format" json:"format"` + // file path, default is gost.yaml|gost.json in current working directory. + // in: query + Path string `form:"path" json:"path"` } // successful operation. @@ -92,6 +95,10 @@ func saveConfig(ctx *gin.Context) { req.Format = "yaml" } + if req.Path != "" { + file = req.Path + } + f, err := os.Create(file) if err != nil { writeError(ctx, &Error{ diff --git a/api/config_router.go b/api/config_router.go new file mode 100644 index 0000000..6d0c624 --- /dev/null +++ b/api/config_router.go @@ -0,0 +1,169 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/go-gost/x/config" + parser "github.com/go-gost/x/config/parsing/router" + "github.com/go-gost/x/registry" +) + +// swagger:parameters createRouterRequest +type createRouterRequest struct { + // in: body + Data config.RouterConfig `json:"data"` +} + +// successful operation. +// swagger:response createRouterResponse +type createRouterResponse struct { + Data Response +} + +func createRouter(ctx *gin.Context) { + // swagger:route POST /config/routers Router createRouterRequest + // + // Create a new router, the name of the router must be unique in router list. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: createRouterResponse + + var req createRouterRequest + ctx.ShouldBindJSON(&req.Data) + + if req.Data.Name == "" { + writeError(ctx, ErrInvalid) + return + } + + v := parser.ParseRouter(&req.Data) + + if err := registry.RouterRegistry().Register(req.Data.Name, v); err != nil { + writeError(ctx, ErrDup) + return + } + + config.OnUpdate(func(c *config.Config) error { + c.Routers = append(c.Routers, &req.Data) + return nil + }) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters updateRouterRequest +type updateRouterRequest struct { + // in: path + // required: true + Router string `uri:"router" json:"router"` + // in: body + Data config.RouterConfig `json:"data"` +} + +// successful operation. +// swagger:response updateRouterResponse +type updateRouterResponse struct { + Data Response +} + +func updateRouter(ctx *gin.Context) { + // swagger:route PUT /config/routers/{router} Router updateRouterRequest + // + // Update router by name, the router must already exist. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: updateRouterResponse + + var req updateRouterRequest + ctx.ShouldBindUri(&req) + ctx.ShouldBindJSON(&req.Data) + + if !registry.RouterRegistry().IsRegistered(req.Router) { + writeError(ctx, ErrNotFound) + return + } + + req.Data.Name = req.Router + + v := parser.ParseRouter(&req.Data) + + registry.RouterRegistry().Unregister(req.Router) + + if err := registry.RouterRegistry().Register(req.Router, v); err != nil { + writeError(ctx, ErrDup) + return + } + + config.OnUpdate(func(c *config.Config) error { + for i := range c.Routers { + if c.Routers[i].Name == req.Router { + c.Routers[i] = &req.Data + break + } + } + return nil + }) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters deleteRouterRequest +type deleteRouterRequest struct { + // in: path + // required: true + Router string `uri:"router" json:"router"` +} + +// successful operation. +// swagger:response deleteRouterResponse +type deleteRouterResponse struct { + Data Response +} + +func deleteRouter(ctx *gin.Context) { + // swagger:route DELETE /config/routers/{router} Router deleteRouterRequest + // + // Delete router by name. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: deleteRouterResponse + + var req deleteRouterRequest + ctx.ShouldBindUri(&req) + + if !registry.RouterRegistry().IsRegistered(req.Router) { + writeError(ctx, ErrNotFound) + return + } + registry.RouterRegistry().Unregister(req.Router) + + config.OnUpdate(func(c *config.Config) error { + routeres := c.Routers + c.Routers = nil + for _, s := range routeres { + if s.Name == req.Router { + continue + } + c.Routers = append(c.Routers, s) + } + return nil + }) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} diff --git a/api/service.go b/api/service.go index 6802179..089bb10 100644 --- a/api/service.go +++ b/api/service.go @@ -138,6 +138,10 @@ func registerConfig(config *gin.RouterGroup) { config.PUT("/ingresses/:ingress", updateIngress) config.DELETE("/ingresses/:ingress", deleteIngress) + config.POST("/routers", createRouter) + config.PUT("/routers/:router", updateRouter) + config.DELETE("/routers/:router", deleteRouter) + config.POST("/limiters", createLimiter) config.PUT("/limiters/:limiter", updateLimiter) config.DELETE("/limiters/:limiter", deleteLimiter) diff --git a/api/swagger.yaml b/api/swagger.yaml index b2392f1..dd00c58 100644 --- a/api/swagger.yaml +++ b/api/swagger.yaml @@ -34,6 +34,8 @@ definitions: name: type: string x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' redis: $ref: '#/definitions/RedisLoader' reload: @@ -71,6 +73,8 @@ definitions: name: type: string x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' redis: $ref: '#/definitions/RedisLoader' reload: @@ -91,6 +95,8 @@ definitions: name: type: string x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' redis: $ref: '#/definitions/RedisLoader' reload: @@ -107,9 +113,6 @@ definitions: ChainConfig: properties: hops: - description: |- - REMOVED since beta.6 - Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` items: $ref: '#/definitions/HopConfig' type: array @@ -204,6 +207,16 @@ definitions: $ref: '#/definitions/LimiterConfig' type: array x-go-name: RLimiters + routers: + items: + $ref: '#/definitions/RouterConfig' + type: array + x-go-name: Routers + sds: + items: + $ref: '#/definitions/SDConfig' + type: array + x-go-name: SDs services: items: $ref: '#/definitions/ServiceConfig' @@ -273,6 +286,8 @@ definitions: addr: type: string x-go-name: Addr + auth: + $ref: '#/definitions/AuthConfig' bypass: type: string x-go-name: Bypass @@ -284,12 +299,22 @@ definitions: host: type: string x-go-name: Host + http: + $ref: '#/definitions/HTTPNodeConfig' name: type: string x-go-name: Name + network: + type: string + x-go-name: Network + path: + type: string + x-go-name: Path protocol: type: string x-go-name: Protocol + tls: + $ref: '#/definitions/TLSNodeConfig' type: object x-go-package: github.com/go-gost/x/config ForwarderConfig: @@ -304,12 +329,6 @@ definitions: x-go-name: Nodes selector: $ref: '#/definitions/SelectorConfig' - targets: - description: DEPRECATED by nodes since beta.4 - items: - type: string - type: array - x-go-name: Targets type: object x-go-package: github.com/go-gost/x/config HTTPLoader: @@ -321,6 +340,27 @@ definitions: x-go-name: URL type: object x-go-package: github.com/go-gost/x/config + HTTPNodeConfig: + properties: + header: + additionalProperties: + type: string + type: object + x-go-name: Header + host: + type: string + x-go-name: Host + type: object + x-go-package: github.com/go-gost/x/config + HTTPRecorder: + properties: + timeout: + $ref: '#/definitions/Duration' + url: + type: string + x-go-name: URL + type: object + x-go-package: github.com/go-gost/x/config HandlerConfig: properties: auth: @@ -338,9 +378,9 @@ definitions: x-go-name: Chain chainGroup: $ref: '#/definitions/ChainGroupConfig' - ingress: + limiter: type: string - x-go-name: Ingress + x-go-name: Limiter metadata: additionalProperties: {} type: object @@ -366,9 +406,13 @@ definitions: type: string type: array x-go-name: Bypasses + file: + $ref: '#/definitions/FileLoader' hosts: type: string x-go-name: Hosts + http: + $ref: '#/definitions/HTTPLoader' interface: type: string x-go-name: Interface @@ -380,6 +424,12 @@ definitions: $ref: '#/definitions/NodeConfig' type: array x-go-name: Nodes + plugin: + $ref: '#/definitions/PluginConfig' + redis: + $ref: '#/definitions/RedisLoader' + reload: + $ref: '#/definitions/Duration' resolver: type: string x-go-name: Resolver @@ -418,6 +468,8 @@ definitions: name: type: string x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' redis: $ref: '#/definitions/RedisLoader' reload: @@ -433,6 +485,8 @@ definitions: name: type: string x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' redis: $ref: '#/definitions/RedisLoader' reload: @@ -468,6 +522,8 @@ definitions: name: type: string x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' redis: $ref: '#/definitions/RedisLoader' reload: @@ -564,6 +620,11 @@ definitions: addr: type: string x-go-name: Addr + auth: + $ref: '#/definitions/AuthConfig' + auther: + type: string + x-go-name: Auther path: type: string x-go-name: Path @@ -574,6 +635,9 @@ definitions: addr: type: string x-go-name: Addr + async: + type: boolean + x-go-name: Async chain: type: string x-go-name: Chain @@ -583,6 +647,9 @@ definitions: hostname: type: string x-go-name: Hostname + only: + type: string + x-go-name: Only prefer: type: string x-go-name: Prefer @@ -597,6 +664,8 @@ definitions: addr: type: string x-go-name: Addr + auth: + $ref: '#/definitions/AuthConfig' bypass: type: string x-go-name: Bypass @@ -615,6 +684,8 @@ definitions: hosts: type: string x-go-name: Hosts + http: + $ref: '#/definitions/HTTPNodeConfig' interface: type: string x-go-name: Interface @@ -625,6 +696,12 @@ definitions: name: type: string x-go-name: Name + network: + type: string + x-go-name: Network + path: + type: string + x-go-name: Path protocol: type: string x-go-name: Protocol @@ -633,6 +710,25 @@ definitions: x-go-name: Resolver sockopts: $ref: '#/definitions/SockOptsConfig' + tls: + $ref: '#/definitions/TLSNodeConfig' + type: object + x-go-package: github.com/go-gost/x/config + PluginConfig: + properties: + addr: + type: string + x-go-name: Addr + timeout: + $ref: '#/definitions/Duration' + tls: + $ref: '#/definitions/TLSConfig' + token: + type: string + x-go-name: Token + type: + type: string + x-go-name: Type type: object x-go-package: github.com/go-gost/x/config ProfilingConfig: @@ -646,15 +742,24 @@ definitions: properties: file: $ref: '#/definitions/FileRecorder' + http: + $ref: '#/definitions/HTTPRecorder' name: type: string x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' redis: $ref: '#/definitions/RedisRecorder' + tcp: + $ref: '#/definitions/TCPRecorder' type: object x-go-package: github.com/go-gost/x/config RecorderObject: properties: + Metadata: + additionalProperties: {} + type: object name: type: string x-go-name: Name @@ -713,6 +818,8 @@ definitions: $ref: '#/definitions/NameserverConfig' type: array x-go-name: Nameservers + plugin: + $ref: '#/definitions/PluginConfig' type: object x-go-package: github.com/go-gost/x/config Response: @@ -726,6 +833,47 @@ definitions: x-go-name: Msg type: object x-go-package: github.com/go-gost/x/api + RouterConfig: + properties: + file: + $ref: '#/definitions/FileLoader' + http: + $ref: '#/definitions/HTTPLoader' + name: + type: string + x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' + redis: + $ref: '#/definitions/RedisLoader' + reload: + $ref: '#/definitions/Duration' + routes: + items: + $ref: '#/definitions/RouterRouteConfig' + type: array + x-go-name: Routes + type: object + x-go-package: github.com/go-gost/x/config + RouterRouteConfig: + properties: + gateway: + type: string + x-go-name: Gateway + net: + type: string + x-go-name: Net + type: object + x-go-package: github.com/go-gost/x/config + SDConfig: + properties: + name: + type: string + x-go-name: Name + plugin: + $ref: '#/definitions/PluginConfig' + type: object + x-go-package: github.com/go-gost/x/config SelectorConfig: properties: failTimeout: @@ -809,6 +957,15 @@ definitions: x-go-name: Mark type: object x-go-package: github.com/go-gost/x/config + TCPRecorder: + properties: + addr: + type: string + x-go-name: Addr + timeout: + $ref: '#/definitions/Duration' + type: object + x-go-package: github.com/go-gost/x/config TLSConfig: properties: caFile: @@ -823,6 +980,8 @@ definitions: keyFile: type: string x-go-name: KeyFile + options: + $ref: '#/definitions/TLSOptions' organization: type: string x-go-name: Organization @@ -836,6 +995,33 @@ definitions: $ref: '#/definitions/Duration' type: object x-go-package: github.com/go-gost/x/config + TLSNodeConfig: + properties: + options: + $ref: '#/definitions/TLSOptions' + secure: + type: boolean + x-go-name: Secure + serverName: + type: string + x-go-name: ServerName + type: object + x-go-package: github.com/go-gost/x/config + TLSOptions: + properties: + cipherSuites: + items: + type: string + type: array + x-go-name: CipherSuites + maxVersion: + type: string + x-go-name: MaxVersion + minVersion: + type: string + x-go-name: MinVersion + type: object + x-go-package: github.com/go-gost/x/config info: title: Documentation of Web API. version: 1.0.0 @@ -866,6 +1052,11 @@ paths: name: format type: string x-go-name: Format + - description: file path, default is gost.yaml|gost.json in current working directory. + in: query + name: path + type: string + x-go-name: Path responses: "200": $ref: '#/responses/saveConfigResponse' @@ -1513,6 +1704,64 @@ paths: summary: Update rate limiter by name, the limiter must already exist. tags: - Limiter + /config/routers: + post: + operationId: createRouterRequest + parameters: + - in: body + name: data + schema: + $ref: '#/definitions/RouterConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/createRouterResponse' + security: + - basicAuth: + - '[]' + summary: Create a new router, the name of the router must be unique in router list. + tags: + - Router + /config/routers/{router}: + delete: + operationId: deleteRouterRequest + parameters: + - in: path + name: router + required: true + type: string + x-go-name: Router + responses: + "200": + $ref: '#/responses/deleteRouterResponse' + security: + - basicAuth: + - '[]' + summary: Delete router by name. + tags: + - Router + put: + operationId: updateRouterRequest + parameters: + - in: path + name: router + required: true + type: string + x-go-name: Router + - in: body + name: data + schema: + $ref: '#/definitions/RouterConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/updateRouterResponse' + security: + - basicAuth: + - '[]' + summary: Update router by name, the router must already exist. + tags: + - Router /config/services: post: operationId: createServiceRequest @@ -1640,6 +1889,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + createRouterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' createServiceResponse: description: successful operation. headers: @@ -1712,6 +1967,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + deleteRouterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' deleteServiceResponse: description: successful operation. headers: @@ -1796,6 +2057,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + updateRouterResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' updateServiceResponse: description: successful operation. headers: diff --git a/auth/plugin.go b/auth/plugin.go index 72fe797..8ada348 100644 --- a/auth/plugin.go +++ b/auth/plugin.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/auth/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -58,7 +58,7 @@ func (p *grpcPlugin) Authenticate(ctx context.Context, user, password string, op &proto.AuthenticateRequest{ Username: user, Password: password, - Client: string(auth_util.ClientAddrFromContext(ctx)), + Client: string(ctxvalue.ClientAddrFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -118,7 +118,7 @@ func (p *httpPlugin) Authenticate(ctx context.Context, user, password string, op rb := httpPluginRequest{ Username: user, Password: password, - Client: string(auth_util.ClientAddrFromContext(ctx)), + Client: string(ctxvalue.ClientAddrFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/bypass/plugin.go b/bypass/plugin.go index 4e20703..b18919c 100644 --- a/bypass/plugin.go +++ b/bypass/plugin.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/bypass" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/bypass/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -61,7 +61,7 @@ func (p *grpcPlugin) Contains(ctx context.Context, network, addr string, opts .. &proto.BypassRequest{ Network: network, Addr: addr, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), Host: options.Host, Path: options.Path, }) @@ -129,7 +129,7 @@ func (p *httpPlugin) Contains(ctx context.Context, network, addr string, opts .. rb := httpPluginRequest{ Network: network, Addr: addr, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), Host: options.Host, Path: options.Path, } diff --git a/config/config.go b/config/config.go index a4bf055..30d663c 100644 --- a/config/config.go +++ b/config/config.go @@ -79,6 +79,11 @@ type LogRotationConfig struct { Compress bool `yaml:"compress,omitempty" json:"compress,omitempty"` } +type LoggerConfig struct { + Name string `json:"name"` + Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` +} + type ProfilingConfig struct { Addr string `json:"addr"` } @@ -244,6 +249,21 @@ type SDConfig struct { Plugin *PluginConfig `yaml:",omitempty" json:"plugin,omitempty"` } +type RouterRouteConfig struct { + Net string `json:"net"` + Gateway string `json:"gateway"` +} + +type RouterConfig struct { + Name string `json:"name"` + Routes []*RouterRouteConfig `yaml:",omitempty" json:"routes,omitempty"` + Reload time.Duration `yaml:",omitempty" json:"reload,omitempty"` + File *FileLoader `yaml:",omitempty" json:"file,omitempty"` + Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` + HTTP *HTTPLoader `yaml:"http,omitempty" json:"http,omitempty"` + Plugin *PluginConfig `yaml:",omitempty" json:"plugin,omitempty"` +} + type RecorderConfig struct { Name string `json:"name"` File *FileRecorder `yaml:",omitempty" json:"file,omitempty"` @@ -289,6 +309,7 @@ type LimiterConfig struct { File *FileLoader `yaml:",omitempty" json:"file,omitempty"` Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` HTTP *HTTPLoader `yaml:"http,omitempty" json:"http,omitempty"` + Plugin *PluginConfig `yaml:",omitempty" json:"plugin,omitempty"` } type ListenerConfig struct { @@ -311,7 +332,7 @@ type HandlerConfig struct { Authers []string `yaml:",omitempty" json:"authers,omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` - Ingress string `yaml:",omitempty" json:"ingress,omitempty"` + Limiter string `yaml:",omitempty" json:"limiter,omitempty"` Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } @@ -380,6 +401,7 @@ type ServiceConfig struct { Limiter string `yaml:",omitempty" json:"limiter,omitempty"` CLimiter string `yaml:"climiter,omitempty" json:"climiter,omitempty"` RLimiter string `yaml:"rlimiter,omitempty" json:"rlimiter,omitempty"` + Logger string `yaml:",omitempty" json:"logger,omitempty"` Recorders []*RecorderObject `yaml:",omitempty" json:"recorders,omitempty"` Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"` Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"` @@ -446,11 +468,13 @@ type Config struct { Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"` Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` Ingresses []*IngressConfig `yaml:",omitempty" json:"ingresses,omitempty"` + Routers []*RouterConfig `yaml:",omitempty" json:"routers,omitempty"` SDs []*SDConfig `yaml:"sds,omitempty" json:"sds,omitempty"` Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"` Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"` CLimiters []*LimiterConfig `yaml:"climiters,omitempty" json:"climiters,omitempty"` RLimiters []*LimiterConfig `yaml:"rlimiters,omitempty" json:"rlimiters,omitempty"` + Loggers []*LoggerConfig `yaml:",omitempty" json:"loggers,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` diff --git a/config/parsing/ingress/parse.go b/config/parsing/ingress/parse.go index 5c95303..f7dafc3 100644 --- a/config/parsing/ingress/parse.go +++ b/config/parsing/ingress/parse.go @@ -41,13 +41,13 @@ func ParseIngress(cfg *config.IngressConfig) ingress.Ingress { } } - var rules []xingress.Rule + var rules []*ingress.Rule for _, rule := range cfg.Rules { if rule.Hostname == "" || rule.Endpoint == "" { continue } - rules = append(rules, xingress.Rule{ + rules = append(rules, &ingress.Rule{ Hostname: rule.Hostname, Endpoint: rule.Endpoint, }) diff --git a/config/parsing/limiter/parse.go b/config/parsing/limiter/parse.go index b26b5fd..6004290 100644 --- a/config/parsing/limiter/parse.go +++ b/config/parsing/limiter/parse.go @@ -1,12 +1,16 @@ package limiter import ( + "crypto/tls" + "strings" + "github.com/go-gost/core/limiter/conn" "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/x/config" "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" xconn "github.com/go-gost/x/limiter/conn" xrate "github.com/go-gost/x/limiter/rate" xtraffic "github.com/go-gost/x/limiter/traffic" @@ -17,6 +21,30 @@ func ParseTrafficLimiter(cfg *config.LimiterConfig) (lim traffic.TrafficLimiter) return nil } + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xtraffic.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xtraffic.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + var opts []xtraffic.Option if cfg.File != nil && cfg.File.Path != "" { diff --git a/config/parsing/logger/parse.go b/config/parsing/logger/parse.go new file mode 100644 index 0000000..4f207b8 --- /dev/null +++ b/config/parsing/logger/parse.go @@ -0,0 +1,55 @@ +package logger + +import ( + "io" + "os" + "path/filepath" + + "github.com/go-gost/core/logger" + "github.com/go-gost/x/config" + xlogger "github.com/go-gost/x/logger" + "gopkg.in/natefinch/lumberjack.v2" +) + +func ParseLogger(cfg *config.LoggerConfig) logger.Logger { + if cfg == nil || cfg.Log == nil { + return nil + } + opts := []xlogger.Option{ + xlogger.NameOption(cfg.Name), + xlogger.FormatOption(logger.LogFormat(cfg.Log.Format)), + xlogger.LevelOption(logger.LogLevel(cfg.Log.Level)), + } + + var out io.Writer = os.Stderr + switch cfg.Log.Output { + case "none", "null": + return xlogger.Nop() + case "stdout": + out = os.Stdout + case "stderr", "": + out = os.Stderr + default: + if cfg.Log.Rotation != nil { + out = &lumberjack.Logger{ + Filename: cfg.Log.Output, + MaxSize: cfg.Log.Rotation.MaxSize, + MaxAge: cfg.Log.Rotation.MaxAge, + MaxBackups: cfg.Log.Rotation.MaxBackups, + LocalTime: cfg.Log.Rotation.LocalTime, + Compress: cfg.Log.Rotation.Compress, + } + } else { + os.MkdirAll(filepath.Dir(cfg.Log.Output), 0755) + f, err := os.OpenFile(cfg.Log.Output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + logger.Default().Warn(err) + } else { + out = f + } + } + } + opts = append(opts, xlogger.OutputOption(out)) + + return xlogger.NewLogger(opts...) +} diff --git a/config/parsing/router/parse.go b/config/parsing/router/parse.go new file mode 100644 index 0000000..f3a563d --- /dev/null +++ b/config/parsing/router/parse.go @@ -0,0 +1,104 @@ +package router + +import ( + "crypto/tls" + "net" + "strings" + + "github.com/go-gost/core/logger" + "github.com/go-gost/core/router" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" + xrouter "github.com/go-gost/x/router" +) + +func ParseRouter(cfg *config.RouterConfig) router.Router { + if cfg == nil { + return nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xrouter.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xrouter.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + var routes []*router.Route + for _, route := range cfg.Routes { + _, ipNet, _ := net.ParseCIDR(route.Net) + if ipNet == nil { + continue + } + gw := net.ParseIP(route.Gateway) + if gw == nil { + continue + } + + routes = append(routes, &router.Route{ + Net: ipNet, + Gateway: gw, + }) + } + opts := []xrouter.Option{ + xrouter.RoutesOption(routes), + xrouter.ReloadPeriodOption(cfg.Reload), + xrouter.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "router", + "router": cfg.Name, + })), + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xrouter.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // rediss list + opts = append(opts, xrouter.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + case "set": // redis set + opts = append(opts, xrouter.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis hash + opts = append(opts, xrouter.RedisLoaderOption(loader.RedisHashLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xrouter.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + return xrouter.NewRouter(opts...) +} diff --git a/config/parsing/service/parse.go b/config/parsing/service/parse.go index 8edf2e5..5deb7eb 100644 --- a/config/parsing/service/parse.go +++ b/config/parsing/service/parse.go @@ -41,7 +41,12 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { Type: "auto", } } - serviceLogger := logger.Default().WithFields(map[string]any{ + + log := registry.LoggerRegistry().Get(cfg.Logger) + if log == nil { + log = logger.Default() + } + serviceLogger := log.WithFields(map[string]any{ "kind": "service", "service": cfg.Name, "listener": cfg.Listener.Type, @@ -210,6 +215,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { handler.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)), handler.TLSConfigOption(tlsConfig), handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)), + handler.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Handler.Limiter)), handler.LoggerOption(handlerLogger), handler.ServiceOption(cfg.Name), ) diff --git a/connector/tunnel/connector.go b/connector/tunnel/connector.go index a427dfd..8a1c6dc 100644 --- a/connector/tunnel/connector.go +++ b/connector/tunnel/connector.go @@ -9,7 +9,7 @@ import ( "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" "github.com/go-gost/relay" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/registry" ) @@ -73,7 +73,7 @@ func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, a } srcAddr := conn.LocalAddr().String() - if v := auth_util.ClientAddrFromContext(ctx); v != "" { + if v := ctxvalue.ClientAddrFromContext(ctx); v != "" { srcAddr = string(v) } diff --git a/go.mod b/go.mod index a055ff2..267a5e5 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,10 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20231113123850-a916f0401649 + github.com/go-gost/core v0.0.0-20231119081403-abc73f2ca2b7 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 - github.com/go-gost/plugin v0.0.0-20231109123346-0ae4157b9d25 + github.com/go-gost/plugin v0.0.0-20231119084331-d49a1cb23b3b github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 @@ -44,6 +44,7 @@ require ( golang.zx2c4.com/wireguard v0.0.0-20220703234212-c31a7b1ab478 google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 8278c02..7d8908f 100644 --- a/go.sum +++ b/go.sum @@ -93,16 +93,14 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20231109123312-8e4fc06cf1b7 h1:sDsPtmP51qf8zN/RbZZj/3vNLCoH0sdvpIRwV6TfzvY= -github.com/go-gost/core v0.0.0-20231109123312-8e4fc06cf1b7/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= -github.com/go-gost/core v0.0.0-20231113123850-a916f0401649 h1:14iGAk7cqc+aDWtsuY6CWpP0lvC54pA5Izjeh5FdQNs= -github.com/go-gost/core v0.0.0-20231113123850-a916f0401649/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20231119081403-abc73f2ca2b7 h1:fxVUlZANqPApygO7lT8bYySyajiCFA62bDiNorral1w= +github.com/go-gost/core v0.0.0-20231119081403-abc73f2ca2b7/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/plugin v0.0.0-20231109123346-0ae4157b9d25 h1:sOarC0xAJij4VtEhkJRng5okZW23KlXprxhb5XFZ+pw= -github.com/go-gost/plugin v0.0.0-20231109123346-0ae4157b9d25/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= +github.com/go-gost/plugin v0.0.0-20231119084331-d49a1cb23b3b h1:ZmnYutflq+KOZK+Px5RDckorDSxTYlkT4aQbjTC8/C4= +github.com/go-gost/plugin v0.0.0-20231119084331-d49a1cb23b3b/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= @@ -726,6 +724,8 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8 gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index d90ca24..e3b743a 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -21,9 +21,9 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/x/config" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" xnet "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/forward" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/registry" @@ -119,8 +119,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand host = net.JoinHostPort(host, "0") } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - var target *chain.Node if host != "" { target = &chain.Node{ @@ -223,10 +221,9 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot "src": addr.String(), }) remoteAddr = addr + ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(remoteAddr.String())) } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(remoteAddr.String())) - target := &chain.Node{ Addr: req.Host, } @@ -259,7 +256,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr) return resp.Write(rw) } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(id)) } if httpSettings := target.Options().HTTP; httpSettings != nil { if httpSettings.Host != "" { @@ -292,8 +289,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot InsecureSkipVerify: !tlsSettings.Secure, } tls_util.SetTLSOptions(cfg, &config.TLSOptions{ - MinVersion: tlsSettings.Options.MinVersion, - MaxVersion: tlsSettings.Options.MaxVersion, + MinVersion: tlsSettings.Options.MinVersion, + MaxVersion: tlsSettings.Options.MaxVersion, CipherSuites: tlsSettings.Options.CipherSuites, }) cc = tls.Client(cc, cfg) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 2c91672..ea92caa 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -22,10 +22,10 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/x/config" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" - auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/forward" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/registry" @@ -117,8 +117,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - if md, ok := conn.(mdata.Metadatable); ok { if v := mdutil.GetString(md.Metadata(), "host"); v != "" { host = v @@ -224,10 +222,9 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot "src": addr.String(), }) remoteAddr = addr + ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(remoteAddr.String())) } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(remoteAddr.String())) - target := &chain.Node{ Addr: req.Host, } @@ -260,7 +257,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr) return resp.Write(rw) } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(id)) } if httpSettings := target.Options().HTTP; httpSettings != nil { if httpSettings.Host != "" { diff --git a/handler/http/handler.go b/handler/http/handler.go index fd6249a..e6534eb 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -19,11 +19,12 @@ import ( "github.com/asaskevich/govalidator" "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" "github.com/go-gost/x/registry" ) @@ -89,8 +90,6 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler } defer req.Body.Close() - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - return h.handleRequest(ctx, conn, req, log) } @@ -148,11 +147,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt resp.Header = http.Header{} } - id, ok := h.authenticate(ctx, conn, req, resp, log) + clientID, ok := h.authenticate(ctx, conn, req, resp, log) if !ok { return nil } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, addr) { resp.StatusCode = http.StatusForbidden @@ -186,7 +185,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } cc, err := h.router.Dial(ctx, network, addr) @@ -222,9 +221,16 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(addr), + traffic.ClientOption(clientID), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + start := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(start), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 902f1a2..8dce6aa 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -20,12 +20,13 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" netpkg "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" "github.com/go-gost/x/registry" ) @@ -89,8 +90,6 @@ func (h *http2Handler) Handle(ctx context.Context, conn net.Conn, opts ...handle return err } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - md := v.Metadata() return h.roundTrip(ctx, md.Get("w").(http.ResponseWriter), @@ -149,11 +148,11 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req Body: io.NopCloser(bytes.NewReader([]byte{})), } - id, ok := h.authenticate(ctx, w, req, resp, log) + clientID, ok := h.authenticate(ctx, w, req, resp, log) if !ok { return nil } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", addr) { w.WriteHeader(http.StatusForbidden) @@ -167,7 +166,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } cc, err := h.router.Dial(ctx, "tcp", addr) @@ -205,9 +204,15 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req return nil } + rw := wrapper.WrapReadWriter(h.options.Limiter, xio.NewReadWriter(req.Body, flushWriter{w}), req.RemoteAddr, + traffic.NetworkOption("tcp"), + traffic.AddrOption(addr), + traffic.ClientOption(clientID), + traffic.SrcOption(req.RemoteAddr), + ) start := time.Now() log.Infof("%s <-> %s", req.RemoteAddr, addr) - netpkg.Transport(xio.NewReadWriter(req.Body, flushWriter{w}), cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(start), }).Infof("%s >-< %s", req.RemoteAddr, addr) diff --git a/handler/http3/handler.go b/handler/http3/handler.go index a2c92c3..767657c 100644 --- a/handler/http3/handler.go +++ b/handler/http3/handler.go @@ -14,7 +14,7 @@ import ( "github.com/go-gost/core/hop" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - sx "github.com/go-gost/x/internal/util/selector" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/registry" ) @@ -114,7 +114,7 @@ func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } var target *chain.Node diff --git a/handler/relay/connect.go b/handler/relay/connect.go index c67bd89..6af1c6b 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -8,11 +8,13 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" xnet "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" serial "github.com/go-gost/x/internal/util/serial" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (err error) { @@ -51,7 +53,7 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: address}) } var cc io.ReadWriteCloser @@ -103,9 +105,16 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network } } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(address), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), address) - xnet.Transport(conn, cc) + xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), address) diff --git a/handler/relay/forward.go b/handler/relay/forward.go index 62b1a24..04a9d33 100644 --- a/handler/relay/forward.go +++ b/handler/relay/forward.go @@ -7,9 +7,12 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string, log logger.Logger) error { @@ -84,9 +87,16 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network conn = rc } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(target.Addr), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) diff --git a/handler/relay/handler.go b/handler/relay/handler.go index 5047dad..09ab0ab 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -13,7 +13,7 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/core/service" "github.com/go-gost/relay" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/registry" ) @@ -83,8 +83,6 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - if !h.checkRateLimit(conn.RemoteAddr()) { return ErrRateLimit } @@ -136,13 +134,13 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle } if h.options.Auther != nil { - id, ok := h.options.Auther.Authenticate(ctx, user, pass) + clientID, ok := h.options.Auther.Authenticate(ctx, user, pass) if !ok { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) return ErrUnauthorized } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) } network := networkID.String() diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 3434604..ab7acb8 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -21,9 +21,9 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" dissector "github.com/go-gost/tls-dissector" + ctxvalue "github.com/go-gost/x/internal/ctx" xio "github.com/go-gost/x/internal/io" netpkg "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -123,7 +123,7 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: host}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: host}) } cc, err := h.router.Dial(ctx, "tcp", host) @@ -191,7 +191,7 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: host}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: host}) } cc, err := h.router.Dial(ctx, "tcp", host) diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index 19b7adf..db04955 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -8,12 +8,13 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks4" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" "github.com/go-gost/x/registry" ) @@ -82,8 +83,6 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - req, err := gosocks4.ReadRequest(conn) if err != nil { log.Error(err) @@ -100,7 +99,7 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl log.Trace(resp) return resp.Write(conn) } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(id)) } switch req.Cmd { @@ -132,7 +131,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr}) } cc, err := h.router.Dial(ctx, "tcp", addr) @@ -152,9 +151,16 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g return err } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption("tcp"), + traffic.AddrOption(addr), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), addr) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), addr) diff --git a/handler/socks/v5/connect.go b/handler/socks/v5/connect.go index ad4e0e3..dcccbfd 100644 --- a/handler/socks/v5/connect.go +++ b/handler/socks/v5/connect.go @@ -6,10 +6,12 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { @@ -28,7 +30,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: address}) } cc, err := h.router.Dial(ctx, network, address) @@ -48,9 +50,16 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ return err } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, conn.RemoteAddr().String(), + traffic.NetworkOption(network), + traffic.AddrOption(address), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), address) - netpkg.Transport(conn, cc) + netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), address) diff --git a/handler/socks/v5/handler.go b/handler/socks/v5/handler.go index 690ad24..2432cce 100644 --- a/handler/socks/v5/handler.go +++ b/handler/socks/v5/handler.go @@ -10,7 +10,7 @@ import ( "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks5" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/util/socks" "github.com/go-gost/x/registry" ) @@ -95,7 +95,9 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl } log.Trace(req) - ctx = auth_util.ContextWithID(ctx, auth_util.ID(sc.ID())) + if clientID := sc.ID(); clientID != "" { + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) + } conn = sc conn.SetReadDeadline(time.Time{}) diff --git a/handler/socks/v5/selector.go b/handler/socks/v5/selector.go index f4b4e86..f49adaf 100644 --- a/handler/socks/v5/selector.go +++ b/handler/socks/v5/selector.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" - auth_util "github.com/go-gost/x/internal/util/auth" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/util/socks" ) @@ -70,7 +70,7 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (string, net.Co var id string if s.Authenticator != nil { var ok bool - ctx := auth_util.ContextWithClientAddr(context.Background(), auth_util.ClientAddr(conn.RemoteAddr().String())) + ctx := ctxvalue.ContextWithClientAddr(context.Background(), ctxvalue.ClientAddr(conn.RemoteAddr().String())) id, ok = s.Authenticator.Authenticate(ctx, req.Username, req.Password) if !ok { resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) diff --git a/handler/ss/handler.go b/handler/ss/handler.go index 3014e3e..56f52e5 100644 --- a/handler/ss/handler.go +++ b/handler/ss/handler.go @@ -10,8 +10,8 @@ import ( "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks5" + ctxvalue "github.com/go-gost/x/internal/ctx" netpkg "github.com/go-gost/x/internal/net" - sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/internal/util/ss" "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" @@ -108,7 +108,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H switch h.md.hash { case "host": - ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr.String()}) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: addr.String()}) } cc, err := h.router.Dial(ctx, "tcp", addr.String()) diff --git a/handler/tun/handler.go b/handler/tun/handler.go index 87a14b3..445190e 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -9,9 +9,10 @@ import ( "time" "github.com/go-gost/core/chain" - "github.com/go-gost/core/hop" "github.com/go-gost/core/handler" + "github.com/go-gost/core/hop" md "github.com/go-gost/core/metadata" + "github.com/go-gost/core/router" tun_util "github.com/go-gost/x/internal/util/tun" "github.com/go-gost/x/registry" "github.com/songgao/water/waterutil" @@ -108,15 +109,18 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. return h.handleServer(ctx, conn, config, log) } -func (h *tunHandler) findRouteFor(dst net.IP, routes ...tun_util.Route) net.Addr { +func (h *tunHandler) findRouteFor(ctx context.Context, dst net.IP, router router.Router) net.Addr { if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok { return v.(net.Addr) } - for _, route := range routes { - if route.Net.Contains(dst) && route.Gateway != nil { - if v, ok := h.routes.Load(ipToTunRouteKey(route.Gateway)); ok { - return v.(net.Addr) - } + + if router == nil { + return nil + } + + if route := router.GetRoute(ctx, dst); route != nil && route.Gateway != nil { + if v, ok := h.routes.Load(ipToTunRouteKey(route.Gateway)); ok { + return v.(net.Addr) } } return nil diff --git a/handler/tun/server.go b/handler/tun/server.go index 3c2784c..64c9b8e 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -82,7 +82,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con return nil } - addr := h.findRouteFor(dst, config.Routes...) + addr := h.findRouteFor(ctx, dst, config.Router) if addr == nil { log.Debugf("no route for %s -> %s, packet discarded", src, dst) return nil @@ -203,7 +203,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con return nil } - if addr := h.findRouteFor(dst, config.Routes...); addr != nil { + if addr := h.findRouteFor(ctx, dst, config.Router); addr != nil { log.Debugf("find route: %s -> %s", dst, addr) _, err := conn.WriteTo(b[:n], addr) diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index 1099e07..a7ddf89 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "net" + "github.com/go-gost/core/ingress" "github.com/go-gost/core/logger" "github.com/go-gost/core/sd" "github.com/go-gost/relay" @@ -56,7 +57,10 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, h.pool.Add(tunnelID, NewConnector(connectorID, tunnelID, h.id, session, h.md.sd), h.md.tunnelTTL) if h.md.ingress != nil { - h.md.ingress.Set(ctx, addr, tunnelID.String()) + h.md.ingress.SetRule(ctx, &ingress.Rule{ + Hostname: addr, + Endpoint: tunnelID.String(), + }) } if h.md.sd != nil { err := h.md.sd.Register(ctx, &sd.Service{ diff --git a/handler/tunnel/connect.go b/handler/tunnel/connect.go index 4b5f020..51c329b 100644 --- a/handler/tunnel/connect.go +++ b/handler/tunnel/connect.go @@ -6,9 +6,12 @@ import ( "net" "time" + "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/logger" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" xnet "github.com/go-gost/x/internal/net" + "github.com/go-gost/x/limiter/traffic/wrapper" ) func (h *tunnelHandler) handleConnect(ctx context.Context, req *relay.Request, conn net.Conn, network, srcAddr string, dstAddr string, tunnelID relay.TunnelID, log logger.Logger) error { @@ -41,7 +44,9 @@ func (h *tunnelHandler) handleConnect(ctx context.Context, req *relay.Request, c if !h.md.directTunnel { var tid relay.TunnelID if ingress := h.md.ingress; ingress != nil && host != "" { - tid = parseTunnelID(ingress.Get(ctx, host)) + if rule := ingress.GetRule(ctx, host); rule != nil { + tid = parseTunnelID(rule.Endpoint) + } } if !tid.Equal(tunnelID) { resp.Status = relay.StatusHostUnreachable @@ -95,9 +100,16 @@ func (h *tunnelHandler) handleConnect(ctx context.Context, req *relay.Request, c req.WriteTo(cc) } + rw := wrapper.WrapReadWriter(h.options.Limiter, conn, tunnelID.String(), + traffic.NetworkOption(network), + traffic.AddrOption(dstAddr), + traffic.ClientOption(string(ctxvalue.ClientIDFromContext(ctx))), + traffic.SrcOption(conn.RemoteAddr().String()), + ) + t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - xnet.Transport(conn, cc) + xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) diff --git a/handler/tunnel/entrypoint.go b/handler/tunnel/entrypoint.go index d454058..a8f3126 100644 --- a/handler/tunnel/entrypoint.go +++ b/handler/tunnel/entrypoint.go @@ -85,7 +85,9 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { var tunnelID relay.TunnelID if ep.ingress != nil { - tunnelID = parseTunnelID(ep.ingress.Get(ctx, req.Host)) + if rule := ep.ingress.GetRule(ctx, req.Host); rule != nil { + tunnelID = parseTunnelID(rule.Endpoint) + } } if tunnelID.IsZero() { err := fmt.Errorf("no route to host %s", req.Host) diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go index e715264..af385e4 100644 --- a/handler/tunnel/handler.go +++ b/handler/tunnel/handler.go @@ -15,8 +15,8 @@ import ( "github.com/go-gost/core/recorder" "github.com/go-gost/core/service" "github.com/go-gost/relay" + ctxvalue "github.com/go-gost/x/internal/ctx" xnet "github.com/go-gost/x/internal/net" - auth_util "github.com/go-gost/x/internal/util/auth" xrecorder "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" xservice "github.com/go-gost/x/service" @@ -169,8 +169,6 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) - if !h.checkRateLimit(conn.RemoteAddr()) { return ErrRateLimit } @@ -238,13 +236,13 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl } if h.options.Auther != nil { - id, ok := h.options.Auther.Authenticate(ctx, user, pass) + clientID, ok := h.options.Auther.Authenticate(ctx, user, pass) if !ok { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) return ErrUnauthorized } - ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) } switch req.Cmd & relay.CmdMask { diff --git a/handler/tunnel/metadata.go b/handler/tunnel/metadata.go index 62ca919..14ee9be 100644 --- a/handler/tunnel/metadata.go +++ b/handler/tunnel/metadata.go @@ -45,13 +45,13 @@ func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress")) if h.md.ingress == nil { - var rules []xingress.Rule + var rules []*ingress.Rule for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") { ss := strings.SplitN(s, ":", 2) if len(ss) != 2 { continue } - rules = append(rules, xingress.Rule{ + rules = append(rules, &ingress.Rule{ Hostname: ss[0], Endpoint: ss[1], }) @@ -60,7 +60,8 @@ func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.ingress = xingress.NewIngress( xingress.RulesOption(rules), xingress.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "ingress", + "kind": "ingress", + "ingress": "@internal", })), ) } diff --git a/hop/plugin.go b/hop/plugin.go index 612d0e1..519d5f5 100644 --- a/hop/plugin.go +++ b/hop/plugin.go @@ -13,8 +13,8 @@ import ( "github.com/go-gost/plugin/hop/proto" "github.com/go-gost/x/config" node_parser "github.com/go-gost/x/config/parsing/node" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -68,7 +68,8 @@ func (p *grpcPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai Addr: options.Addr, Host: options.Host, Path: options.Path, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), + Src: string(ctxvalue.ClientAddrFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -106,6 +107,7 @@ type httpPluginRequest struct { Host string `json:"host"` Path string `json:"path"` Client string `json:"client"` + Src string `json:"src"` } type httpPluginResponse struct { @@ -154,7 +156,8 @@ func (p *httpPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai Addr: options.Addr, Host: options.Host, Path: options.Path, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), + Src: string(ctxvalue.ClientAddrFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/hosts/plugin.go b/hosts/plugin.go index 07f15d6..1ccd6b7 100644 --- a/hosts/plugin.go +++ b/hosts/plugin.go @@ -11,8 +11,8 @@ import ( "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/hosts/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -58,7 +58,7 @@ func (p *grpcPlugin) Lookup(ctx context.Context, network, host string, opts ...h &proto.LookupRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -126,7 +126,7 @@ func (p *httpPlugin) Lookup(ctx context.Context, network, host string, opts ...h rb := httpPluginRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/ingress/ingress.go b/ingress/ingress.go index bc2edf7..6e50350 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -14,13 +14,8 @@ import ( "github.com/go-gost/x/internal/loader" ) -type Rule struct { - Hostname string - Endpoint string -} - type options struct { - rules []Rule + rules []*ingress.Rule fileLoader loader.Loader redisLoader loader.Loader httpLoader loader.Loader @@ -30,7 +25,7 @@ type options struct { type Option func(opts *options) -func RulesOption(rules []Rule) Option { +func RulesOption(rules []*ingress.Rule) Option { return func(opts *options) { opts.rules = rules } @@ -67,7 +62,7 @@ func LoggerOption(logger logger.Logger) Option { } type localIngress struct { - rules map[string]Rule + rules map[string]*ingress.Rule cancelFunc context.CancelFunc options options mu sync.RWMutex @@ -119,9 +114,9 @@ func (ing *localIngress) periodReload(ctx context.Context) error { } func (ing *localIngress) reload(ctx context.Context) error { - rules := make(map[string]Rule) + rules := make(map[string]*ingress.Rule) - fn := func(rule Rule) { + fn := func(rule *ingress.Rule) { if rule.Hostname == "" || rule.Endpoint == "" { return } @@ -152,7 +147,7 @@ func (ing *localIngress) reload(ctx context.Context) error { return nil } -func (ing *localIngress) load(ctx context.Context) (rules []Rule, err error) { +func (ing *localIngress) load(ctx context.Context) (rules []*ingress.Rule, err error) { if ing.options.fileLoader != nil { if lister, ok := ing.options.fileLoader.(loader.Lister); ok { list, er := lister.List(ctx) @@ -203,7 +198,7 @@ func (ing *localIngress) load(ctx context.Context) (rules []Rule, err error) { return } -func (ing *localIngress) parseRules(r io.Reader) (rules []Rule, err error) { +func (ing *localIngress) parseRules(r io.Reader) (rules []*ingress.Rule, err error) { if r == nil { return } @@ -219,9 +214,9 @@ func (ing *localIngress) parseRules(r io.Reader) (rules []Rule, err error) { return } -func (ing *localIngress) Get(ctx context.Context, host string, opts ...ingress.GetOption) string { +func (ing *localIngress) GetRule(ctx context.Context, host string, opts ...ingress.Option) *ingress.Rule { if host == "" || ing == nil { - return "" + return nil } // try to strip the port @@ -229,22 +224,18 @@ func (ing *localIngress) Get(ctx context.Context, host string, opts ...ingress.G host = v } - if ing == nil || len(ing.rules) == 0 { - return "" - } - ing.options.logger.Debugf("ingress: lookup %s", host) ep := ing.lookup(host) - if ep == "" { + if ep == nil { ep = ing.lookup("." + host) } - if ep == "" { + if ep == nil { s := host for { if index := strings.IndexByte(s, '.'); index > 0 { ep = ing.lookup(s[index:]) s = s[index+1:] - if ep == "" { + if ep == nil { continue } } @@ -252,29 +243,29 @@ func (ing *localIngress) Get(ctx context.Context, host string, opts ...ingress.G } } - if ep != "" { + if ep != nil { ing.options.logger.Debugf("ingress: %s -> %s", host, ep) } return ep } -func (ing *localIngress) Set(ctx context.Context, host, endpoint string, opts ...ingress.SetOption) bool { +func (ing *localIngress) SetRule(ctx context.Context, rule *ingress.Rule, opts ...ingress.Option) bool { return false } -func (ing *localIngress) lookup(host string) string { - if ing == nil || len(ing.rules) == 0 { - return "" +func (ing *localIngress) lookup(host string) *ingress.Rule { + if ing == nil { + return nil } ing.mu.RLock() defer ing.mu.RUnlock() - return ing.rules[host].Endpoint + return ing.rules[host] } -func (ing *localIngress) parseLine(s string) (rule Rule) { +func (ing *localIngress) parseLine(s string) (rule *ingress.Rule) { line := strings.Replace(s, "\t", " ", -1) line = strings.TrimSpace(line) if n := strings.IndexByte(line, '#'); n >= 0 { @@ -290,7 +281,7 @@ func (ing *localIngress) parseLine(s string) (rule Rule) { return // invalid lines are ignored } - return Rule{ + return &ingress.Rule{ Hostname: sp[0], Endpoint: sp[1], } diff --git a/ingress/plugin.go b/ingress/plugin.go index fbbd997..c09a524 100644 --- a/ingress/plugin.go +++ b/ingress/plugin.go @@ -46,30 +46,36 @@ func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) ingress.Ingr return p } -func (p *grpcPlugin) Get(ctx context.Context, host string, opts ...ingress.GetOption) string { +func (p *grpcPlugin) GetRule(ctx context.Context, host string, opts ...ingress.Option) *ingress.Rule { if p.client == nil { - return "" + return nil } - r, err := p.client.Get(ctx, - &proto.GetRequest{ + r, err := p.client.GetRule(ctx, + &proto.GetRuleRequest{ Host: host, }) if err != nil { p.log.Error(err) - return "" + return nil + } + if r.Endpoint == "" { + return nil + } + return &ingress.Rule{ + Hostname: host, + Endpoint: r.Endpoint, } - return r.GetEndpoint() } -func (p *grpcPlugin) Set(ctx context.Context, host, endpoint string, opts ...ingress.SetOption) bool { - if p.client == nil { +func (p *grpcPlugin) SetRule(ctx context.Context, rule *ingress.Rule, opts ...ingress.Option) bool { + if p.client == nil || rule == nil { return false } - r, _ := p.client.Set(ctx, &proto.SetRequest{ - Host: host, - Endpoint: endpoint, + r, _ := p.client.SetRule(ctx, &proto.SetRuleRequest{ + Host: rule.Hostname, + Endpoint: rule.Endpoint, }) if r == nil { return false @@ -85,20 +91,20 @@ func (p *grpcPlugin) Close() error { return nil } -type httpPluginGetRequest struct { +type httpPluginGetRuleRequest struct { Host string `json:"host"` } -type httpPluginGetResponse struct { +type httpPluginGetRuleResponse struct { Endpoint string `json:"endpoint"` } -type httpPluginSetRequest struct { +type httpPluginSetRuleRequest struct { Host string `json:"host"` Endpoint string `json:"endpoint"` } -type httpPluginSetResponse struct { +type httpPluginSetRuleResponse struct { OK bool `json:"ok"` } @@ -127,14 +133,14 @@ func NewHTTPPlugin(name string, url string, opts ...plugin.Option) ingress.Ingre } } -func (p *httpPlugin) Get(ctx context.Context, host string, opts ...ingress.GetOption) (endpoint string) { +func (p *httpPlugin) GetRule(ctx context.Context, host string, opts ...ingress.Option) *ingress.Rule { if p.client == nil { - return + return nil } req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.url, nil) if err != nil { - return + return nil } if p.header != nil { req.Header = p.header.Clone() @@ -147,29 +153,35 @@ func (p *httpPlugin) Get(ctx context.Context, host string, opts ...ingress.GetOp resp, err := p.client.Do(req) if err != nil { - return + return nil } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return + return nil } - res := httpPluginGetResponse{} + res := httpPluginGetRuleResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { - return + return nil + } + if res.Endpoint == "" { + return nil + } + return &ingress.Rule{ + Hostname: host, + Endpoint: res.Endpoint, } - return res.Endpoint } -func (p *httpPlugin) Set(ctx context.Context, host, endpoint string, opts ...ingress.SetOption) bool { - if p.client == nil { +func (p *httpPlugin) SetRule(ctx context.Context, rule *ingress.Rule, opts ...ingress.Option) bool { + if p.client == nil || rule == nil { return false } - rb := httpPluginSetRequest{ - Host: host, - Endpoint: endpoint, + rb := httpPluginSetRuleRequest{ + Host: rule.Hostname, + Endpoint: rule.Endpoint, } v, err := json.Marshal(&rb) if err != nil { @@ -195,7 +207,7 @@ func (p *httpPlugin) Set(ctx context.Context, host, endpoint string, opts ...ing return false } - res := httpPluginSetResponse{ } + res := httpPluginSetRuleResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return false } diff --git a/internal/ctx/value.go b/internal/ctx/value.go new file mode 100644 index 0000000..46b3dc2 --- /dev/null +++ b/internal/ctx/value.go @@ -0,0 +1,76 @@ +package ctx + +import "context" + +// clientAddrKey saves the client address. +type clientAddrKey struct{} + +type ClientAddr string + +var ( + keyClientAddr clientAddrKey +) + +func ContextWithClientAddr(ctx context.Context, addr ClientAddr) context.Context { + return context.WithValue(ctx, keyClientAddr, addr) +} + +func ClientAddrFromContext(ctx context.Context) ClientAddr { + v, _ := ctx.Value(keyClientAddr).(ClientAddr) + return v +} + +// sidKey saves the session ID. +type sidKey struct{} +type Sid string + +var ( + keySid sidKey +) + +func ContextWithSid(ctx context.Context, sid Sid) context.Context { + return context.WithValue(ctx, keySid, sid) +} + +func SidFromContext(ctx context.Context) Sid { + v, _ := ctx.Value(keySid).(Sid) + return v +} + +// hashKey saves the hash source for Selector. +type hashKey struct{} + +type Hash struct { + Source string +} + +var ( + clientHashKey = &hashKey{} +) + +func ContextWithHash(ctx context.Context, hash *Hash) context.Context { + return context.WithValue(ctx, clientHashKey, hash) +} + +func HashFromContext(ctx context.Context) *Hash { + if v, _ := ctx.Value(clientHashKey).(*Hash); v != nil { + return v + } + return nil +} + +type clientIDKey struct{} +type ClientID string + +var ( + keyClientID = &clientIDKey{} +) + +func ContextWithClientID(ctx context.Context, clientID ClientID) context.Context { + return context.WithValue(ctx, keyClientID, clientID) +} + +func ClientIDFromContext(ctx context.Context) ClientID { + v, _ := ctx.Value(keyClientID).(ClientID) + return v +} diff --git a/internal/util/auth/key.go b/internal/util/auth/key.go deleted file mode 100644 index a83c277..0000000 --- a/internal/util/auth/key.go +++ /dev/null @@ -1,34 +0,0 @@ -package auth - -import ( - "context" -) - -type idKey struct{} -type ID string - -type addrKey struct{} -type ClientAddr string - -var ( - clientIDKey = &idKey{} - clientAddrKey = &addrKey{} -) - -func ContextWithID(ctx context.Context, id ID) context.Context { - return context.WithValue(ctx, clientIDKey, id) -} - -func IDFromContext(ctx context.Context) ID { - v, _ := ctx.Value(clientIDKey).(ID) - return v -} - -func ContextWithClientAddr(ctx context.Context, addr ClientAddr) context.Context { - return context.WithValue(ctx, clientAddrKey, addr) -} - -func ClientAddrFromContext(ctx context.Context) ClientAddr { - v, _ := ctx.Value(clientAddrKey).(ClientAddr) - return v -} diff --git a/internal/util/selector/key.go b/internal/util/selector/key.go deleted file mode 100644 index 292eac3..0000000 --- a/internal/util/selector/key.go +++ /dev/null @@ -1,26 +0,0 @@ -package selector - -import ( - "context" -) - -type hashKey struct{} - -type Hash struct { - Source string -} - -var ( - clientHashKey = &hashKey{} -) - -func ContextWithHash(ctx context.Context, hash *Hash) context.Context { - return context.WithValue(ctx, clientHashKey, hash) -} - -func HashFromContext(ctx context.Context) *Hash { - if v, _ := ctx.Value(clientHashKey).(*Hash); v != nil { - return v - } - return nil -} diff --git a/internal/util/tun/config.go b/internal/util/tun/config.go index 899deb5..1253d58 100644 --- a/internal/util/tun/config.go +++ b/internal/util/tun/config.go @@ -1,12 +1,10 @@ package tun -import "net" +import ( + "net" -// Route is an IP routing entry -type Route struct { - Net net.IPNet - Gateway net.IP -} + "github.com/go-gost/core/router" +) type Config struct { Name string @@ -15,5 +13,5 @@ type Config struct { Peer string MTU int Gateway net.IP - Routes []Route + Router router.Router } diff --git a/limiter/traffic/plugin.go b/limiter/traffic/plugin.go new file mode 100644 index 0000000..a967566 --- /dev/null +++ b/limiter/traffic/plugin.go @@ -0,0 +1,235 @@ +package traffic + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + + "github.com/go-gost/core/limiter/traffic" + "github.com/go-gost/core/logger" + "github.com/go-gost/plugin/limiter/traffic/proto" + "github.com/go-gost/x/internal/plugin" + "google.golang.org/grpc" +) + +type grpcPlugin struct { + conn grpc.ClientConnInterface + client proto.LimiterClient + log logger.Logger +} + +// NewGRPCPlugin creates a traffic limiter plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) traffic.TrafficLimiter { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + + p := &grpcPlugin{ + conn: conn, + log: log, + } + if conn != nil { + p.client = proto.NewLimiterClient(conn) + } + return p +} + +func (p *grpcPlugin) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + r, err := p.client.Limit(ctx, + &proto.LimitRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + }) + if err != nil { + p.log.Error(err) + return nil + } + + return NewLimiter(int(r.In)) +} + +func (p *grpcPlugin) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + r, err := p.client.Limit(ctx, + &proto.LimitRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + }) + if err != nil { + p.log.Error(err) + return nil + } + + return NewLimiter(int(r.Out)) +} + +func (p *grpcPlugin) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() + } + return nil +} + +type httpPluginRequest struct { + Network string `json:"network"` + Addr string `json:"addr"` + Client string `json:"client"` + Src string `json:"src"` +} + +type httpPluginResponse struct { + In int64 `json:"in"` + Out int64 `json:"out"` +} + +type httpPlugin struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPlugin creates a traffic limiter plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) traffic.TrafficLimiter { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPlugin{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": name, + }), + } +} + +func (p *httpPlugin) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + rb := httpPluginRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + } + v, err := json.Marshal(&rb) + if err != nil { + return nil + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return nil + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + res := httpPluginResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil + } + return NewLimiter(int(res.In)) +} + +func (p *httpPlugin) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { + if p.client == nil { + return nil + } + + var options traffic.Options + for _, opt := range opts { + opt(&options) + } + + rb := httpPluginRequest{ + Network: options.Network, + Addr: options.Addr, + Client: options.Client, + Src: options.Src, + } + v, err := json.Marshal(&rb) + if err != nil { + return nil + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return nil + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + res := httpPluginResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil + } + return NewLimiter(int(res.Out)) +} diff --git a/limiter/traffic/traffic.go b/limiter/traffic/traffic.go index 12b16bb..4349daa 100644 --- a/limiter/traffic/traffic.go +++ b/limiter/traffic/traffic.go @@ -121,7 +121,7 @@ func NewTrafficLimiter(opts ...Option) limiter.TrafficLimiter { // In obtains a traffic input limiter based on key. // The key should be client connection address. -func (l *trafficLimiter) In(key string) limiter.Limiter { +func (l *trafficLimiter) In(ctx context.Context, key string, opts ...limiter.Option) limiter.Limiter { var lims []limiter.Limiter // service level limiter @@ -185,7 +185,7 @@ func (l *trafficLimiter) In(key string) limiter.Limiter { // Out obtains a traffic output limiter based on key. // The key should be client connection address. -func (l *trafficLimiter) Out(key string) limiter.Limiter { +func (l *trafficLimiter) Out(ctx context.Context, key string, opts ...limiter.Option) limiter.Limiter { var lims []limiter.Limiter // service level limiter diff --git a/limiter/traffic/wrapper/conn.go b/limiter/traffic/wrapper/conn.go index 3671041..67a84b2 100644 --- a/limiter/traffic/wrapper/conn.go +++ b/limiter/traffic/wrapper/conn.go @@ -26,43 +26,54 @@ type serverConn struct { rbuf bytes.Buffer limiter limiter.TrafficLimiter limiterIn limiter.Limiter - expIn int64 limiterOut limiter.Limiter + expIn int64 expOut int64 + opts []limiter.Option } -func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn { - if limiter == nil { +func WrapConn(tlimiter limiter.TrafficLimiter, c net.Conn) net.Conn { + if tlimiter == nil { return c } + return &serverConn{ Conn: c, - limiter: limiter, + limiter: tlimiter, + opts: []limiter.Option{ + limiter.NetworkOption(c.LocalAddr().Network()), + limiter.SrcOption(c.RemoteAddr().String()), + limiter.AddrOption(c.LocalAddr().String()), + }, } } -func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter { +func (c *serverConn) getInLimiter() limiter.Limiter { now := time.Now().UnixNano() // cache the limiter for 60s if c.limiter != nil && time.Duration(now-c.expIn) > 60*time.Second { - c.limiterIn = c.limiter.In(addr.String()) + if lim := c.limiter.In(context.Background(), c.RemoteAddr().String()); lim != nil { + c.limiterIn = lim + } c.expIn = now } return c.limiterIn } -func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter { +func (c *serverConn) getOutLimiter() limiter.Limiter { now := time.Now().UnixNano() // cache the limiter for 60s if c.limiter != nil && time.Duration(now-c.expOut) > 60*time.Second { - c.limiterOut = c.limiter.Out(addr.String()) + if lim := c.limiter.Out(context.Background(), c.RemoteAddr().String()); lim != nil { + c.limiterOut = lim + } c.expOut = now } return c.limiterOut } func (c *serverConn) Read(b []byte) (n int, err error) { - limiter := c.getInLimiter(c.RemoteAddr()) + limiter := c.getInLimiter() if limiter == nil { return c.Conn.Read(b) } @@ -92,7 +103,7 @@ func (c *serverConn) Read(b []byte) (n int, err error) { } func (c *serverConn) Write(b []byte) (n int, err error) { - limiter := c.getOutLimiter(c.RemoteAddr()) + limiter := c.getOutLimiter() if limiter == nil { return c.Conn.Write(b) } @@ -163,7 +174,7 @@ func (c *packetConn) getInLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.In(addr.String()) + lim = c.limiter.In(context.Background(), addr.String()) c.inLimits.Set(addr.String(), lim, 0) return lim @@ -187,7 +198,7 @@ func (c *packetConn) getOutLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.Out(addr.String()) + lim = c.limiter.Out(context.Background(), addr.String()) c.outLimits.Set(addr.String(), lim, 0) return lim @@ -266,7 +277,7 @@ func (c *udpConn) getInLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.In(addr.String()) + lim = c.limiter.In(context.Background(), addr.String()) c.inLimits.Set(addr.String(), lim, 0) return lim @@ -290,7 +301,7 @@ func (c *udpConn) getOutLimiter(addr net.Addr) limiter.Limiter { return lim } - lim = c.limiter.Out(addr.String()) + lim = c.limiter.Out(context.Background(), addr.String()) c.outLimits.Set(addr.String(), lim, 0) return lim diff --git a/limiter/traffic/wrapper/io.go b/limiter/traffic/wrapper/io.go new file mode 100644 index 0000000..901a943 --- /dev/null +++ b/limiter/traffic/wrapper/io.go @@ -0,0 +1,108 @@ +package wrapper + +import ( + "bytes" + "context" + "io" + "time" + + limiter "github.com/go-gost/core/limiter/traffic" +) + +// readWriter is an io.ReadWriter with traffic limiter supported. +type readWriter struct { + io.ReadWriter + rbuf bytes.Buffer + limiter limiter.TrafficLimiter + limiterIn limiter.Limiter + limiterOut limiter.Limiter + expIn int64 + expOut int64 + opts []limiter.Option + key string +} + +func WrapReadWriter(limiter limiter.TrafficLimiter, rw io.ReadWriter, key string, opts ...limiter.Option) io.ReadWriter { + if limiter == nil { + return rw + } + + return &readWriter{ + ReadWriter: rw, + limiter: limiter, + opts: opts, + } +} + +func (p *readWriter) getInLimiter() limiter.Limiter { + now := time.Now().UnixNano() + // cache the limiter for 60s + if p.limiter != nil && time.Duration(now-p.expIn) > 60*time.Second { + if lim := p.limiter.In(context.Background(), p.key, p.opts...); lim != nil { + p.limiterIn = lim + } + p.expIn = now + } + return p.limiterIn +} + +func (p *readWriter) getOutLimiter() limiter.Limiter { + now := time.Now().UnixNano() + // cache the limiter for 60s + if p.limiter != nil && time.Duration(now-p.expOut) > 60*time.Second { + if lim := p.limiter.Out(context.Background(), p.key, p.opts...); lim != nil { + p.limiterOut = lim + } + p.expOut = now + } + return p.limiterOut +} + +func (p *readWriter) Read(b []byte) (n int, err error) { + limiter := p.getInLimiter() + if limiter == nil { + return p.ReadWriter.Read(b) + } + + if p.rbuf.Len() > 0 { + burst := len(b) + if p.rbuf.Len() < burst { + burst = p.rbuf.Len() + } + lim := limiter.Wait(context.Background(), burst) + return p.rbuf.Read(b[:lim]) + } + + nn, err := p.ReadWriter.Read(b) + if err != nil { + return nn, err + } + + n = limiter.Wait(context.Background(), nn) + if n < nn { + if _, err = p.rbuf.Write(b[n:nn]); err != nil { + return 0, err + } + } + + return +} + +func (p *readWriter) Write(b []byte) (n int, err error) { + limiter := p.getOutLimiter() + if limiter == nil { + return p.ReadWriter.Write(b) + } + + nn := 0 + for len(b) > 0 { + nn, err = p.ReadWriter.Write(b[:limiter.Wait(context.Background(), len(b))]) + n += nn + if err != nil { + return + } + b = b[nn:] + } + + return +} diff --git a/listener/http2/listener.go b/listener/http2/listener.go index abee6fb..a021660 100644 --- a/listener/http2/listener.go +++ b/listener/http2/listener.go @@ -118,10 +118,10 @@ func (l *http2Listener) Close() (err error) { case <-l.errChan: default: err = l.server.Close() - l.errChan <- err + l.errChan <- http.ErrServerClosed close(l.errChan) } - return nil + return } func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { diff --git a/listener/tun/listener.go b/listener/tun/listener.go index d0f7fd1..2b2cce8 100644 --- a/listener/tun/listener.go +++ b/listener/tun/listener.go @@ -8,6 +8,7 @@ import ( "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" + "github.com/go-gost/core/router" xnet "github.com/go-gost/x/internal/net" limiter "github.com/go-gost/x/limiter/traffic/wrapper" mdx "github.com/go-gost/x/metadata" @@ -26,6 +27,7 @@ type tunListener struct { logger logger.Logger md metadata options listener.Options + routes []*router.Route } func NewListener(opts ...listener.Option) listener.Listener { diff --git a/listener/tun/metadata.go b/listener/tun/metadata.go index 8f3f7d1..6fec254 100644 --- a/listener/tun/metadata.go +++ b/listener/tun/metadata.go @@ -4,9 +4,13 @@ import ( "net" "strings" + "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/core/router" tun_util "github.com/go-gost/x/internal/util/tun" + "github.com/go-gost/x/registry" + xrouter "github.com/go-gost/x/router" ) const ( @@ -36,9 +40,10 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { } config := &tun_util.Config{ - Name: mdutil.GetString(md, name), - Peer: mdutil.GetString(md, peer), - MTU: mdutil.GetInt(md, mtu), + Name: mdutil.GetString(md, name), + Peer: mdutil.GetString(md, peer), + MTU: mdutil.GetInt(md, mtu), + Router: registry.RouterRegistry().Get(mdutil.GetString(md, "router")), } if config.MTU <= 0 { config.MTU = defaultMTU @@ -62,35 +67,48 @@ func (l *tunListener) parseMetadata(md mdata.Metadata) (err error) { } for _, s := range strings.Split(mdutil.GetString(md, route), ",") { - var route tun_util.Route _, ipNet, _ := net.ParseCIDR(strings.TrimSpace(s)) if ipNet == nil { continue } - route.Net = *ipNet - route.Gateway = config.Gateway - config.Routes = append(config.Routes, route) + l.routes = append(l.routes, &router.Route{ + Net: ipNet, + Gateway: config.Gateway, + }) } for _, s := range mdutil.GetStrings(md, routes) { ss := strings.SplitN(s, " ", 2) if len(ss) == 2 { - var route tun_util.Route + var route router.Route _, ipNet, _ := net.ParseCIDR(strings.TrimSpace(ss[0])) if ipNet == nil { continue } - route.Net = *ipNet - route.Gateway = net.ParseIP(ss[1]) - if route.Gateway == nil { - route.Gateway = config.Gateway + route.Net = ipNet + gw := net.ParseIP(ss[1]) + if gw == nil { + gw = config.Gateway } - config.Routes = append(config.Routes, route) + l.routes = append(l.routes, &router.Route{ + Net: ipNet, + Gateway: gw, + }) } } + if config.Router == nil && len(l.routes) > 0 { + config.Router = xrouter.NewRouter( + xrouter.RoutesOption(l.routes), + xrouter.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "router", + "router": "@internal", + })), + ) + } + l.md.config = config return diff --git a/listener/tun/tun_darwin.go b/listener/tun/tun_darwin.go index 9cb4eba..d517898 100644 --- a/listener/tun/tun_darwin.go +++ b/listener/tun/tun_darwin.go @@ -6,8 +6,6 @@ import ( "net" "os/exec" "strings" - - tun_util "github.com/go-gost/x/internal/util/tun" ) const ( @@ -38,15 +36,15 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net. ip = l.md.config.Net[0].IP } - if err = l.addRoutes(name, l.md.config.Routes...); err != nil { + if err = l.addRoutes(name); err != nil { return } return } -func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error { - for _, route := range routes { +func (l *tunListener) addRoutes(ifName string) error { + for _, route := range l.routes { cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName) l.logger.Debug(cmd) args := strings.Split(cmd, " ") diff --git a/listener/tun/tun_linux.go b/listener/tun/tun_linux.go index be66f2c..8f6b47b 100644 --- a/listener/tun/tun_linux.go +++ b/listener/tun/tun_linux.go @@ -6,8 +6,6 @@ import ( "net" "github.com/vishvananda/netlink" - - tun_util "github.com/go-gost/x/internal/util/tun" ) func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.IP, err error) { @@ -42,17 +40,17 @@ func (l *tunListener) createTun() (dev io.ReadWriteCloser, name string, ip net.I return } - if err = l.addRoutes(ifce, l.md.config.Routes...); err != nil { + if err = l.addRoutes(ifce); err != nil { return } return } -func (l *tunListener) addRoutes(ifce *net.Interface, routes ...tun_util.Route) error { - for _, route := range routes { +func (l *tunListener) addRoutes(ifce *net.Interface) error { + for _, route := range l.routes { r := netlink.Route{ - Dst: &route.Net, + Dst: route.Net, Gw: route.Gateway, } if r.Gw == nil { diff --git a/listener/tun/tun_unix.go b/listener/tun/tun_unix.go index 191607b..9f9872e 100644 --- a/listener/tun/tun_unix.go +++ b/listener/tun/tun_unix.go @@ -8,8 +8,6 @@ import ( "net" "os/exec" "strings" - - tun_util "github.com/go-gost/x/internal/util/tun" ) const ( @@ -38,15 +36,15 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net. ip = l.md.config.Net[0].IP } - if err = l.addRoutes(name, l.md.config.Routes...); err != nil { + if err = l.addRoutes(name); err != nil { return } return } -func (l *tunListener) addRoutes(ifName string, routes ...tun_util.Route) error { - for _, route := range routes { +func (l *tunListener) addRoutes(ifName string) error { + for _, route := range l.routes { cmd := fmt.Sprintf("route add -net %s -interface %s", route.Net.String(), ifName) l.logger.Debug(cmd) args := strings.Split(cmd, " ") diff --git a/listener/tun/tun_windows.go b/listener/tun/tun_windows.go index 971dc2c..452b948 100644 --- a/listener/tun/tun_windows.go +++ b/listener/tun/tun_windows.go @@ -6,8 +6,6 @@ import ( "net" "os/exec" "strings" - - tun_util "github.com/go-gost/x/internal/util/tun" ) const ( @@ -38,15 +36,15 @@ func (l *tunListener) createTun() (ifce io.ReadWriteCloser, name string, ip net. ip = ipNet.IP } - if err = l.addRoutes(name, l.md.config.Gateway, l.md.config.Routes...); err != nil { + if err = l.addRoutes(name, l.md.config.Gateway); err != nil { return } return } -func (l *tunListener) addRoutes(ifName string, gw net.IP, routes ...tun_util.Route) error { - for _, route := range routes { +func (l *tunListener) addRoutes(ifName string, gw net.IP) error { + for _, route := range l.routes { l.deleteRoute(ifName, route.Net.String()) cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=%s store=active", diff --git a/logger/logger.go b/logger/logger.go index 7fab886..effc465 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -10,28 +10,35 @@ import ( "github.com/sirupsen/logrus" ) -type LoggerOptions struct { +type Options struct { + Name string Output io.Writer Format logger.LogFormat Level logger.LogLevel } -type LoggerOption func(opts *LoggerOptions) +type Option func(opts *Options) -func OutputLoggerOption(out io.Writer) LoggerOption { - return func(opts *LoggerOptions) { +func NameOption(name string) Option { + return func(opts *Options) { + opts.Name = name + } +} + +func OutputOption(out io.Writer) Option { + return func(opts *Options) { opts.Output = out } } -func FormatLoggerOption(format logger.LogFormat) LoggerOption { - return func(opts *LoggerOptions) { +func FormatOption(format logger.LogFormat) Option { + return func(opts *Options) { opts.Format = format } } -func LevelLoggerOption(level logger.LogLevel) LoggerOption { - return func(opts *LoggerOptions) { +func LevelOption(level logger.LogLevel) Option { + return func(opts *Options) { opts.Level = level } } @@ -40,8 +47,8 @@ type logrusLogger struct { logger *logrus.Entry } -func NewLogger(opts ...LoggerOption) logger.Logger { - var options LoggerOptions +func NewLogger(opts ...Option) logger.Logger { + var options Options for _, opt := range opts { opt(&options) } diff --git a/registry/ingress.go b/registry/ingress.go index 19c91e7..837773b 100644 --- a/registry/ingress.go +++ b/registry/ingress.go @@ -30,19 +30,19 @@ type ingressWrapper struct { r *ingressRegistry } -func (w *ingressWrapper) Get(ctx context.Context, host string, opts ...ingress.GetOption) string { +func (w *ingressWrapper) GetRule(ctx context.Context, host string, opts ...ingress.Option) *ingress.Rule { v := w.r.get(w.name) if v == nil { - return "" + return nil } - return v.Get(ctx, host, opts...) + return v.GetRule(ctx, host, opts...) } -func (w *ingressWrapper) Set(ctx context.Context, host, endpoint string, opts ...ingress.SetOption) bool { +func (w *ingressWrapper) SetRule(ctx context.Context, rule *ingress.Rule, opts ...ingress.Option) bool { v := w.r.get(w.name) if v == nil { return false } - return v.Set(ctx, host, endpoint, opts...) + return v.SetRule(ctx, rule, opts...) } diff --git a/registry/limiter.go b/registry/limiter.go index 4ade180..fedc718 100644 --- a/registry/limiter.go +++ b/registry/limiter.go @@ -1,6 +1,8 @@ package registry import ( + "context" + "github.com/go-gost/core/limiter/conn" "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" @@ -30,20 +32,20 @@ type trafficLimiterWrapper struct { r *trafficLimiterRegistry } -func (w *trafficLimiterWrapper) In(key string) traffic.Limiter { +func (w *trafficLimiterWrapper) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { v := w.r.get(w.name) if v == nil { return nil } - return v.In(key) + return v.In(ctx, key, opts...) } -func (w *trafficLimiterWrapper) Out(key string) traffic.Limiter { +func (w *trafficLimiterWrapper) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter { v := w.r.get(w.name) if v == nil { return nil } - return v.Out(key) + return v.Out(ctx, key, opts...) } type connLimiterRegistry struct { diff --git a/registry/logger.go b/registry/logger.go new file mode 100644 index 0000000..8bf17fa --- /dev/null +++ b/registry/logger.go @@ -0,0 +1,13 @@ +package registry + +import ( + "github.com/go-gost/core/logger" +) + +type loggerRegistry struct { + registry[logger.Logger] +} + +func (r *loggerRegistry) Register(name string, v logger.Logger) error { + return r.registry.Register(name, v) +} diff --git a/registry/registry.go b/registry/registry.go index 9a3731d..933386c 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -15,9 +15,11 @@ import ( "github.com/go-gost/core/limiter/conn" "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" + "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" reg "github.com/go-gost/core/registry" "github.com/go-gost/core/resolver" + "github.com/go-gost/core/router" "github.com/go-gost/core/sd" "github.com/go-gost/core/service" ) @@ -46,7 +48,10 @@ var ( rateLimiterReg reg.Registry[rate.RateLimiter] = new(rateLimiterRegistry) ingressReg reg.Registry[ingress.Ingress] = new(ingressRegistry) + routerReg reg.Registry[router.Router] = new(routerRegistry) sdReg reg.Registry[sd.SD] = new(sdRegistry) + + loggerReg reg.Registry[logger.Logger] = new(loggerRegistry) ) type registry[T any] struct { @@ -166,6 +171,14 @@ func IngressRegistry() reg.Registry[ingress.Ingress] { return ingressReg } +func RouterRegistry() reg.Registry[router.Router] { + return routerReg +} + func SDRegistry() reg.Registry[sd.SD] { return sdReg } + +func LoggerRegistry() reg.Registry[logger.Logger] { + return loggerReg +} diff --git a/registry/router.go b/registry/router.go new file mode 100644 index 0000000..e8b4633 --- /dev/null +++ b/registry/router.go @@ -0,0 +1,40 @@ +package registry + +import ( + "context" + "net" + + "github.com/go-gost/core/router" +) + +type routerRegistry struct { + registry[router.Router] +} + +func (r *routerRegistry) Register(name string, v router.Router) error { + return r.registry.Register(name, v) +} + +func (r *routerRegistry) Get(name string) router.Router { + if name != "" { + return &routerWrapper{name: name, r: r} + } + return nil +} + +func (r *routerRegistry) get(name string) router.Router { + return r.registry.Get(name) +} + +type routerWrapper struct { + name string + r *routerRegistry +} + +func (w *routerWrapper) GetRoute(ctx context.Context, dst net.IP, opts ...router.Option) *router.Route { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.GetRoute(ctx, dst, opts...) +} diff --git a/resolver/plugin.go b/resolver/plugin.go index 1c5d47f..d523a45 100644 --- a/resolver/plugin.go +++ b/resolver/plugin.go @@ -13,8 +13,8 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/core/resolver" "github.com/go-gost/plugin/resolver/proto" + ctxvalue "github.com/go-gost/x/internal/ctx" "github.com/go-gost/x/internal/plugin" - auth_util "github.com/go-gost/x/internal/util/auth" "google.golang.org/grpc" ) @@ -60,7 +60,7 @@ func (p *grpcPlugin) Resolve(ctx context.Context, network, host string, opts ... &proto.ResolveRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), }) if err != nil { p.log.Error(err) @@ -127,7 +127,7 @@ func (p *httpPlugin) Resolve(ctx context.Context, network, host string, opts ... rb := httpPluginRequest{ Network: network, Host: host, - Client: string(auth_util.IDFromContext(ctx)), + Client: string(ctxvalue.ClientIDFromContext(ctx)), } v, err := json.Marshal(&rb) if err != nil { diff --git a/router/plugin.go b/router/plugin.go new file mode 100644 index 0000000..23c1268 --- /dev/null +++ b/router/plugin.go @@ -0,0 +1,141 @@ +package router + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + + "github.com/go-gost/core/logger" + "github.com/go-gost/core/router" + "github.com/go-gost/plugin/router/proto" + "github.com/go-gost/x/internal/plugin" + "google.golang.org/grpc" +) + +type grpcPlugin struct { + conn grpc.ClientConnInterface + client proto.RouterClient + log logger.Logger +} + +// NewGRPCPlugin creates an Router plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) router.Router { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "router", + "router": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + + p := &grpcPlugin{ + conn: conn, + log: log, + } + if conn != nil { + p.client = proto.NewRouterClient(conn) + } + return p +} + +func (p *grpcPlugin) GetRoute(ctx context.Context, dst net.IP, opts ...router.Option) *router.Route { + if p.client == nil { + return nil + } + + r, err := p.client.GetRoute(ctx, + &proto.GetRouteRequest{ + Dst: dst.String(), + }) + if err != nil { + p.log.Error(err) + return nil + } + + return ParseRoute(r.Net, r.Gateway) +} + +func (p *grpcPlugin) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() + } + return nil +} + +type httpPluginGetRouteRequest struct { + Dst string `json:"dst"` +} + +type httpPluginGetRouteResponse struct { + Net string `json:"net"` + Gateway string `json:"gateway"` +} + +type httpPlugin struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPlugin creates an Router plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) router.Router { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPlugin{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "router", + "router": name, + }), + } +} + +func (p *httpPlugin) GetRoute(ctx context.Context, dst net.IP, opts ...router.Option) *router.Route { + if p.client == nil { + return nil + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.url, nil) + if err != nil { + return nil + } + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + + q := req.URL.Query() + q.Set("dst", dst.String()) + req.URL.RawQuery = q.Encode() + + resp, err := p.client.Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + res := httpPluginGetRouteResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil + } + + return ParseRoute(res.Net, res.Gateway) +} diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..50dc125 --- /dev/null +++ b/router/router.go @@ -0,0 +1,261 @@ +package router + +import ( + "bufio" + "context" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/go-gost/core/logger" + "github.com/go-gost/core/router" + "github.com/go-gost/x/internal/loader" +) + +type options struct { + routes []*router.Route + fileLoader loader.Loader + redisLoader loader.Loader + httpLoader loader.Loader + period time.Duration + logger logger.Logger +} + +type Option func(opts *options) + +func RoutesOption(routes []*router.Route) Option { + return func(opts *options) { + opts.routes = routes + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + +func HTTPLoaderOption(httpLoader loader.Loader) Option { + return func(opts *options) { + opts.httpLoader = httpLoader + } +} + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type localRouter struct { + routes []*router.Route + cancelFunc context.CancelFunc + options options + mu sync.RWMutex +} + +// NewRouter creates and initializes a new Router. +func NewRouter(opts ...Option) router.Router { + var options options + for _, opt := range opts { + opt(&options) + } + + ctx, cancel := context.WithCancel(context.TODO()) + + r := &localRouter{ + cancelFunc: cancel, + options: options, + } + + if err := r.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) + } + if r.options.period > 0 { + go r.periodReload(ctx) + } + + return r +} + +func (p *localRouter) periodReload(ctx context.Context) error { + period := p.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := p.reload(ctx); err != nil { + p.options.logger.Warnf("reload: %v", err) + // return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (r *localRouter) reload(ctx context.Context) error { + routes := r.options.routes + + v, err := r.load(ctx) + if err != nil { + return err + } + routes = append(routes, v...) + + r.mu.Lock() + defer r.mu.Unlock() + + r.routes = routes + + return nil +} + +func (p *localRouter) load(ctx context.Context) (routes []*router.Route, err error) { + if p.options.fileLoader != nil { + if lister, ok := p.options.fileLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + for _, s := range list { + routes = append(routes, p.parseLine(s)) + } + } else { + fr, er := p.options.fileLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + if v, _ := p.parseRoutes(fr); v != nil { + routes = append(routes, v...) + } + } + } + if p.options.redisLoader != nil { + if lister, ok := p.options.redisLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + for _, v := range list { + routes = append(routes, p.parseLine(v)) + } + } else { + r, er := p.options.redisLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + v, _ := p.parseRoutes(r) + routes = append(routes, v...) + } + } + if p.options.httpLoader != nil { + r, er := p.options.httpLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("http loader: %v", er) + } + v, _ := p.parseRoutes(r) + routes = append(routes, v...) + } + + p.options.logger.Debugf("load items %d", len(routes)) + return +} + +func (p *localRouter) parseRoutes(r io.Reader) (routes []*router.Route, err error) { + if r == nil { + return + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + if route := p.parseLine(scanner.Text()); route != nil { + routes = append(routes, route) + } + } + + err = scanner.Err() + return +} + +func (p *localRouter) GetRoute(ctx context.Context, dst net.IP, opts ...router.Option) *router.Route { + if dst == nil || p == nil { + return nil + } + + p.mu.RLock() + routes := p.routes + p.mu.RUnlock() + + for _, route := range routes { + if route.Net != nil && route.Net.Contains(dst) { + return route + } + } + return nil +} + +func (*localRouter) parseLine(s string) (route *router.Route) { + line := strings.Replace(s, "\t", " ", -1) + line = strings.TrimSpace(line) + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + var sp []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + sp = append(sp, s) + } + } + if len(sp) < 2 { + return // invalid lines are ignored + } + + return ParseRoute(sp[0], sp[1]) +} + +func (p *localRouter) Close() error { + p.cancelFunc() + if p.options.fileLoader != nil { + p.options.fileLoader.Close() + } + if p.options.redisLoader != nil { + p.options.redisLoader.Close() + } + return nil +} + +func ParseRoute(dst string, gateway string) *router.Route { + _, ipNet, _ := net.ParseCIDR(dst) + if ipNet == nil { + return nil + } + gw := net.ParseIP(gateway) + if gw == nil { + return nil + } + + return &router.Route{ + Net: ipNet, + Gateway: gw, + } +} diff --git a/selector/strategy.go b/selector/strategy.go index f8b9e11..4a93448 100644 --- a/selector/strategy.go +++ b/selector/strategy.go @@ -12,7 +12,7 @@ import ( "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/core/selector" - sx "github.com/go-gost/x/internal/util/selector" + ctxvalue "github.com/go-gost/x/internal/ctx" ) type roundRobinStrategy[T any] struct { @@ -102,7 +102,7 @@ func (s *hashStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { if len(vs) == 0 { return } - if h := sx.HashFromContext(ctx); h != nil { + if h := ctxvalue.HashFromContext(ctx); h != nil { value := uint64(crc32.ChecksumIEEE([]byte(h.Source))) logger.Default().Tracef("hash %s %d", h.Source, value) return vs[value%uint64(len(vs))] diff --git a/service/service.go b/service/service.go index 11a0571..38724a7 100644 --- a/service/service.go +++ b/service/service.go @@ -15,7 +15,7 @@ import ( "github.com/go-gost/core/metrics" "github.com/go-gost/core/recorder" "github.com/go-gost/core/service" - sx "github.com/go-gost/x/internal/util/selector" + ctxvalue "github.com/go-gost/x/internal/ctx" xmetrics "github.com/go-gost/x/metrics" "github.com/rs/xid" ) @@ -145,20 +145,26 @@ func (s *defaultService) Serve() error { } tempDelay = 0 - host := conn.RemoteAddr().String() - if h, _, _ := net.SplitHostPort(host); h != "" { - host = h + clientAddr := conn.RemoteAddr().String() + clientIP := clientAddr + if h, _, _ := net.SplitHostPort(clientAddr); h != "" { + clientIP = h } + + ctx := ctxvalue.ContextWithSid(context.Background(), ctxvalue.Sid(xid.New().String())) + ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(clientAddr)) + ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: clientIP}) + for _, rec := range s.options.recorders { if rec.Record == recorder.RecorderServiceClientAddress { - if err := rec.Recorder.Record(context.Background(), []byte(host)); err != nil { + if err := rec.Recorder.Record(ctx, []byte(clientIP)); err != nil { s.options.logger.Errorf("record %s: %v", rec.Record, err) } break } } if s.options.admission != nil && - !s.options.admission.Admit(context.Background(), conn.RemoteAddr().String()) { + !s.options.admission.Admit(ctx, conn.RemoteAddr().String()) { conn.Close() s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr()) continue @@ -166,12 +172,12 @@ func (s *defaultService) Serve() error { go func() { if v := xmetrics.GetCounter(xmetrics.MetricServiceRequestsCounter, - metrics.Labels{"service": s.name, "client": host}); v != nil { + metrics.Labels{"service": s.name, "client": clientIP}); v != nil { v.Inc() } if v := xmetrics.GetGauge(xmetrics.MetricServiceRequestsInFlightGauge, - metrics.Labels{"service": s.name, "client": host}); v != nil { + metrics.Labels{"service": s.name, "client": clientIP}); v != nil { v.Inc() defer v.Dec() } @@ -184,13 +190,10 @@ func (s *defaultService) Serve() error { }() } - ctx := sx.ContextWithHash(context.Background(), &sx.Hash{Source: host}) - ctx = ContextWithSid(ctx, xid.New().String()) - if err := s.handler.Handle(ctx, conn); err != nil { s.options.logger.Error(err) if v := xmetrics.GetCounter(xmetrics.MetricServiceHandlerErrorsCounter, - metrics.Labels{"service": s.name, "client": host}); v != nil { + metrics.Labels{"service": s.name, "client": clientIP}); v != nil { v.Inc() } } @@ -211,18 +214,3 @@ func (s *defaultService) execCmds(phase string, cmds []string) { } } } - -type sidKey struct{} - -var ( - ssid sidKey -) - -func ContextWithSid(ctx context.Context, sid string) context.Context { - return context.WithValue(ctx, ssid, sid) -} - -func SidFromContext(ctx context.Context) string { - v, _ := ctx.Value(ssid).(string) - return v -}