diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index f8b80bf..b3af87d 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -173,6 +173,24 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { service.Handler.Retries = v md.Del("retries") } + if v := metadata.GetString(md, "admission"); v != "" { + admCfg := &config.AdmissionConfig{ + Name: fmt.Sprintf("admission-%d", len(cfg.Admissions)), + } + if v[0] == '~' { + admCfg.Reverse = true + v = v[1:] + } + for _, s := range strings.Split(v, ",") { + if s == "" { + continue + } + admCfg.Matchers = append(admCfg.Matchers, s) + } + service.Admission = admCfg.Name + cfg.Admissions = append(cfg.Admissions, admCfg) + md.Del("admission") + } if v := metadata.GetString(md, "bypass"); v != "" { bypassCfg := &config.BypassConfig{ Name: fmt.Sprintf("bypass-%d", len(cfg.Bypasses)), diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 0b4ca5c..2ccb1ef 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -12,7 +12,7 @@ import ( "github.com/go-gost/gost/pkg/service" ) -func buildService(cfg *config.Config) (services []service.Servicer) { +func buildService(cfg *config.Config) (services []service.Service) { if cfg == nil { return } @@ -25,6 +25,14 @@ func buildService(cfg *config.Config) (services []service.Servicer) { } } + for _, admissionCfg := range cfg.Admissions { + if adm := parsing.ParseAdmission(admissionCfg); adm != nil { + if err := registry.Admission().Register(admissionCfg.Name, adm); err != nil { + log.Fatal(err) + } + } + } + for _, bypassCfg := range cfg.Bypasses { if bp := parsing.ParseBypass(bypassCfg); bp != nil { if err := registry.Bypass().Register(bypassCfg.Name, bp); err != nil { diff --git a/pkg/admission/admission.go b/pkg/admission/admission.go new file mode 100644 index 0000000..b5ee368 --- /dev/null +++ b/pkg/admission/admission.go @@ -0,0 +1,85 @@ +package admission + +import ( + "net" + "strconv" + + "github.com/go-gost/gost/pkg/common/matcher" + "github.com/go-gost/gost/pkg/logger" +) + +type Admission interface { + Admit(addr string) bool +} + +type options struct { + logger logger.Logger +} + +type Option func(opts *options) + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type admission struct { + matchers []matcher.Matcher + reversed bool + options options +} + +// NewAdmission creates and initializes a new Admission using matchers as its match rules. +// The rules will be reversed if the reversed is true. +func NewAdmission(reversed bool, matchers []matcher.Matcher, opts ...Option) Admission { + options := options{} + for _, opt := range opts { + opt(&options) + } + return &admission{ + matchers: matchers, + reversed: reversed, + options: options, + } +} + +// NewAdmissionPatterns creates and initializes a new Admission using matcher patterns as its match rules. +// The rules will be reversed if the reverse is true. +func NewAdmissionPatterns(reversed bool, patterns []string, opts ...Option) Admission { + var matchers []matcher.Matcher + for _, pattern := range patterns { + if m := matcher.NewMatcher(pattern); m != nil { + matchers = append(matchers, m) + } + } + return NewAdmission(reversed, matchers, opts...) +} + +func (p *admission) Admit(addr string) bool { + if addr == "" || p == nil || len(p.matchers) == 0 { + return false + } + + // try to strip the port + if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { + if p, _ := strconv.Atoi(port); p > 0 { // port is valid + addr = host + } + } + + var matched bool + for _, matcher := range p.matchers { + if matcher == nil { + continue + } + if matcher.Match(addr) { + matched = true + break + } + } + + b := !p.reversed && matched || + p.reversed && !matched + return b +} diff --git a/pkg/api/config_admission.go b/pkg/api/config_admission.go new file mode 100644 index 0000000..48a3ab5 --- /dev/null +++ b/pkg/api/config_admission.go @@ -0,0 +1,166 @@ +package api + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/go-gost/gost/pkg/config" + "github.com/go-gost/gost/pkg/config/parsing" + "github.com/go-gost/gost/pkg/registry" +) + +// swagger:parameters createAdmissionRequest +type createAdmissionRequest struct { + // in: body + Data config.AdmissionConfig `json:"data"` +} + +// successful operation. +// swagger:response createAdmissionResponse +type createAdmissionResponse struct { + Data Response +} + +func createAdmission(ctx *gin.Context) { + // swagger:route POST /config/admissions ConfigManagement createAdmissionRequest + // + // Create a new admission, the name of admission must be unique in admission list. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: createAdmissionResponse + + var req createAdmissionRequest + ctx.ShouldBindJSON(&req.Data) + + if req.Data.Name == "" { + writeError(ctx, ErrInvalid) + return + } + + v := parsing.ParseAdmission(&req.Data) + + if err := registry.Admission().Register(req.Data.Name, v); err != nil { + writeError(ctx, ErrDup) + return + } + + cfg := config.Global() + cfg.Admissions = append(cfg.Admissions, &req.Data) + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters updateAdmissionRequest +type updateAdmissionRequest struct { + // in: path + // required: true + Admission string `uri:"admission" json:"admission"` + // in: body + Data config.AdmissionConfig `json:"data"` +} + +// successful operation. +// swagger:response updateAdmissionResponse +type updateAdmissionResponse struct { + Data Response +} + +func updateAdmission(ctx *gin.Context) { + // swagger:route PUT /config/admissions/{admission} ConfigManagement updateAdmissionRequest + // + // Update admission by name, the admission must already exist. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: updateAdmissionResponse + + var req updateAdmissionRequest + ctx.ShouldBindUri(&req) + ctx.ShouldBindJSON(&req.Data) + + if !registry.Admission().IsRegistered(req.Admission) { + writeError(ctx, ErrNotFound) + return + } + + req.Data.Name = req.Admission + + v := parsing.ParseAdmission(&req.Data) + + registry.Admission().Unregister(req.Admission) + + if err := registry.Admission().Register(req.Admission, v); err != nil { + writeError(ctx, ErrDup) + return + } + + cfg := config.Global() + for i := range cfg.Admissions { + if cfg.Admissions[i].Name == req.Admission { + cfg.Admissions[i] = &req.Data + break + } + } + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} + +// swagger:parameters deleteAdmissionRequest +type deleteAdmissionRequest struct { + // in: path + // required: true + Admission string `uri:"admission" json:"admission"` +} + +// successful operation. +// swagger:response deleteAdmissionResponse +type deleteAdmissionResponse struct { + Data Response +} + +func deleteAdmission(ctx *gin.Context) { + // swagger:route DELETE /config/admissions/{admission} ConfigManagement deleteAdmissionRequest + // + // Delete admission by name. + // + // Security: + // basicAuth: [] + // + // Responses: + // 200: deleteAdmissionResponse + + var req deleteAdmissionRequest + ctx.ShouldBindUri(&req) + + if !registry.Admission().IsRegistered(req.Admission) { + writeError(ctx, ErrNotFound) + return + } + registry.Admission().Unregister(req.Admission) + + cfg := config.Global() + admissiones := cfg.Admissions + cfg.Admissions = nil + for _, s := range admissiones { + if s.Name == req.Admission { + continue + } + cfg.Admissions = append(cfg.Admissions, s) + } + config.SetGlobal(cfg) + + ctx.JSON(http.StatusOK, Response{ + Msg: "OK", + }) +} diff --git a/pkg/api/config_auther.go b/pkg/api/config_auther.go index 4a4bcb6..840a854 100644 --- a/pkg/api/config_auther.go +++ b/pkg/api/config_auther.go @@ -141,8 +141,7 @@ func deleteAuther(ctx *gin.Context) { var req deleteAutherRequest ctx.ShouldBindUri(&req) - svc := registry.Auther().Get(req.Auther) - if svc == nil { + if !registry.Auther().IsRegistered(req.Auther) { writeError(ctx, ErrNotFound) return } diff --git a/pkg/api/config_bypass.go b/pkg/api/config_bypass.go index 6114dfc..321ead0 100644 --- a/pkg/api/config_bypass.go +++ b/pkg/api/config_bypass.go @@ -143,8 +143,7 @@ func deleteBypass(ctx *gin.Context) { var req deleteBypassRequest ctx.ShouldBindUri(&req) - svc := registry.Bypass().Get(req.Bypass) - if svc == nil { + if !registry.Bypass().IsRegistered(req.Bypass) { writeError(ctx, ErrNotFound) return } diff --git a/pkg/api/config_chain.go b/pkg/api/config_chain.go index 540f86c..7d6393a 100644 --- a/pkg/api/config_chain.go +++ b/pkg/api/config_chain.go @@ -152,8 +152,7 @@ func deleteChain(ctx *gin.Context) { var req deleteChainRequest ctx.ShouldBindUri(&req) - svc := registry.Chain().Get(req.Chain) - if svc == nil { + if !registry.Chain().IsRegistered(req.Chain) { writeError(ctx, ErrNotFound) return } diff --git a/pkg/api/config_hosts.go b/pkg/api/config_hosts.go index c647a1a..6d75021 100644 --- a/pkg/api/config_hosts.go +++ b/pkg/api/config_hosts.go @@ -143,8 +143,7 @@ func deleteHosts(ctx *gin.Context) { var req deleteHostsRequest ctx.ShouldBindUri(&req) - svc := registry.Hosts().Get(req.Hosts) - if svc == nil { + if !registry.Hosts().IsRegistered(req.Hosts) { writeError(ctx, ErrNotFound) return } diff --git a/pkg/api/config_resolver.go b/pkg/api/config_resolver.go index 2cb90f6..e8a82da 100644 --- a/pkg/api/config_resolver.go +++ b/pkg/api/config_resolver.go @@ -151,8 +151,7 @@ func deleteResolver(ctx *gin.Context) { var req deleteResolverRequest ctx.ShouldBindUri(&req) - svc := registry.Resolver().Get(req.Resolver) - if svc == nil { + if !registry.Resolver().IsRegistered(req.Resolver) { writeError(ctx, ErrNotFound) return } diff --git a/pkg/api/error.go b/pkg/api/error.go index 2d852b6..8fdc3f9 100644 --- a/pkg/api/error.go +++ b/pkg/api/error.go @@ -8,10 +8,10 @@ import ( ) var ( - ErrInvalid = &Error{statusCode: http.StatusBadRequest, Code: 40001, Msg: "instance invalid"} - ErrDup = &Error{statusCode: http.StatusBadRequest, Code: 40002, Msg: "instance duplicated"} - ErrCreate = &Error{statusCode: http.StatusConflict, Code: 40003, Msg: "instance creation failed"} - ErrNotFound = &Error{statusCode: http.StatusBadRequest, Code: 40004, Msg: "instance not found"} + ErrInvalid = &Error{statusCode: http.StatusBadRequest, Code: 40001, Msg: "object invalid"} + ErrDup = &Error{statusCode: http.StatusBadRequest, Code: 40002, Msg: "object duplicated"} + ErrCreate = &Error{statusCode: http.StatusConflict, Code: 40003, Msg: "object creation failed"} + ErrNotFound = &Error{statusCode: http.StatusBadRequest, Code: 40004, Msg: "object not found"} ErrSave = &Error{statusCode: http.StatusInternalServerError, Code: 40005, Msg: "save config failed"} ) diff --git a/pkg/api/server.go b/pkg/api/server.go index 1048d67..5d310f4 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -113,6 +113,10 @@ func registerConfig(config *gin.RouterGroup) { config.PUT("/authers/:auther", updateAuther) config.DELETE("/authers/:auther", deleteAuther) + config.POST("/admissions", createAdmission) + config.PUT("/admissions/:admission", updateAdmission) + config.DELETE("/admissions/:admission", deleteAdmission) + config.POST("/bypasses", createBypass) config.PUT("/bypasses/:bypass", updateBypass) config.DELETE("/bypasses/:bypass", deleteBypass) diff --git a/pkg/api/swagger.yaml b/pkg/api/swagger.yaml index 649f268..adc6603 100644 --- a/pkg/api/swagger.yaml +++ b/pkg/api/swagger.yaml @@ -20,6 +20,25 @@ definitions: x-go-name: PathPrefix type: object x-go-package: github.com/go-gost/gost/pkg/config + AdmissionConfig: + properties: + matchers: + items: + type: string + type: array + x-go-name: Matchers + name: + type: string + x-go-name: Name + reverse: + type: boolean + x-go-name: Reverse + type: + description: inline, file, etc. + type: string + x-go-name: Type + type: object + x-go-package: github.com/go-gost/gost/pkg/config AuthConfig: properties: password: @@ -81,6 +100,11 @@ definitions: x-go-package: github.com/go-gost/gost/pkg/config Config: properties: + admissions: + items: + $ref: '#/definitions/AdmissionConfig' + type: array + x-go-name: Admissions api: $ref: '#/definitions/APIConfig' authers: @@ -403,6 +427,9 @@ definitions: addr: type: string x-go-name: Addr + admission: + type: string + x-go-name: Admission bypass: type: string x-go-name: Bypass @@ -481,6 +508,65 @@ paths: summary: Save current config to file (gost.yaml or gost.json). tags: - ConfigManagement + /config/admissions: + post: + operationId: createAdmissionRequest + parameters: + - in: body + name: data + schema: + $ref: '#/definitions/AdmissionConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/createAdmissionResponse' + security: + - basicAuth: + - '[]' + summary: Create a new admission, the name of admission must be unique in admission + list. + tags: + - ConfigManagement + /config/admissions/{admission}: + delete: + operationId: deleteAdmissionRequest + parameters: + - in: path + name: admission + required: true + type: string + x-go-name: Admission + responses: + "200": + $ref: '#/responses/deleteAdmissionResponse' + security: + - basicAuth: + - '[]' + summary: Delete admission by name. + tags: + - ConfigManagement + put: + operationId: updateAdmissionRequest + parameters: + - in: path + name: admission + required: true + type: string + x-go-name: Admission + - in: body + name: data + schema: + $ref: '#/definitions/AdmissionConfig' + x-go-name: Data + responses: + "200": + $ref: '#/responses/updateAdmissionResponse' + security: + - basicAuth: + - '[]' + summary: Update admission by name, the admission must already exist. + tags: + - ConfigManagement /config/authers: post: operationId: createAutherRequest @@ -835,6 +921,12 @@ paths: produces: - application/json responses: + createAdmissionResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' createAutherResponse: description: successful operation. headers: @@ -871,6 +963,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + deleteAdmissionResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' deleteAutherResponse: description: successful operation. headers: @@ -919,6 +1017,12 @@ responses: Data: {} schema: $ref: '#/definitions/Response' + updateAdmissionResponse: + description: successful operation. + headers: + Data: {} + schema: + $ref: '#/definitions/Response' updateAutherResponse: description: successful operation. headers: diff --git a/pkg/bypass/bypass.go b/pkg/bypass/bypass.go index 142b714..4bc8551 100644 --- a/pkg/bypass/bypass.go +++ b/pkg/bypass/bypass.go @@ -3,131 +3,39 @@ package bypass import ( "net" "strconv" - "strings" + "github.com/go-gost/gost/pkg/common/matcher" "github.com/go-gost/gost/pkg/logger" - glob "github.com/gobwas/glob" ) -// Matcher is a generic pattern matcher, -// it gives the match result of the given pattern for specific v. -type Matcher interface { - Match(v string) bool -} - -// NewMatcher creates a Matcher for the given pattern. -// The acutal Matcher depends on the pattern: -// IP Matcher if pattern is a valid IP address. -// CIDR Matcher if pattern is a valid CIDR address. -// Domain Matcher if both of the above are not. -func NewMatcher(pattern string) Matcher { - if pattern == "" { - return nil - } - if ip := net.ParseIP(pattern); ip != nil { - return IPMatcher(ip) - } - if _, inet, err := net.ParseCIDR(pattern); err == nil { - return CIDRMatcher(inet) - } - return DomainMatcher(pattern) -} - -type ipMatcher struct { - ip net.IP -} - -// IPMatcher creates a Matcher for a specific IP address. -func IPMatcher(ip net.IP) Matcher { - return &ipMatcher{ - ip: ip, - } -} - -func (m *ipMatcher) Match(ip string) bool { - if m == nil { - return false - } - return m.ip.Equal(net.ParseIP(ip)) -} - -type cidrMatcher struct { - ipNet *net.IPNet -} - -// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. -func CIDRMatcher(inet *net.IPNet) Matcher { - return &cidrMatcher{ - ipNet: inet, - } -} - -func (m *cidrMatcher) Match(ip string) bool { - if m == nil || m.ipNet == nil { - return false - } - return m.ipNet.Contains(net.ParseIP(ip)) -} - -type domainMatcher struct { - pattern string - glob glob.Glob -} - -// DomainMatcher creates a Matcher for a specific domain pattern, -// the pattern can be a plain domain such as 'example.com', -// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. -func DomainMatcher(pattern string) Matcher { - p := pattern - if strings.HasPrefix(pattern, ".") { - p = pattern[1:] // trim the prefix '.' - pattern = "*" + p - } - return &domainMatcher{ - pattern: p, - glob: glob.MustCompile(pattern), - } -} - -func (m *domainMatcher) Match(domain string) bool { - if m == nil || m.glob == nil { - return false - } - - if domain == m.pattern { - return true - } - return m.glob.Match(domain) -} - // Bypass is a filter of address (IP or domain). type Bypass interface { // Contains reports whether the bypass includes addr. Contains(addr string) bool } -type bypassOptions struct { +type options struct { logger logger.Logger } -type BypassOption func(opts *bypassOptions) +type Option func(opts *options) -func LoggerBypassOption(logger logger.Logger) BypassOption { - return func(opts *bypassOptions) { +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { opts.logger = logger } } type bypass struct { - matchers []Matcher + matchers []matcher.Matcher reversed bool - options bypassOptions + options options } // NewBypass creates and initializes a new Bypass using matchers as its match rules. // The rules will be reversed if the reversed is true. -func NewBypass(reversed bool, matchers []Matcher, opts ...BypassOption) Bypass { - options := bypassOptions{} +func NewBypass(reversed bool, matchers []matcher.Matcher, opts ...Option) Bypass { + options := options{} for _, opt := range opts { opt(&options) } @@ -140,10 +48,10 @@ func NewBypass(reversed bool, matchers []Matcher, opts ...BypassOption) Bypass { // NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. // The rules will be reversed if the reverse is true. -func NewBypassPatterns(reversed bool, patterns []string, opts ...BypassOption) Bypass { - var matchers []Matcher +func NewBypassPatterns(reversed bool, patterns []string, opts ...Option) Bypass { + var matchers []matcher.Matcher for _, pattern := range patterns { - if m := NewMatcher(pattern); m != nil { + if m := matcher.NewMatcher(pattern); m != nil { matchers = append(matchers, m) } } diff --git a/pkg/common/matcher/matcher.go b/pkg/common/matcher/matcher.go new file mode 100644 index 0000000..81442d3 --- /dev/null +++ b/pkg/common/matcher/matcher.go @@ -0,0 +1,99 @@ +package matcher + +import ( + "net" + "strings" + + "github.com/gobwas/glob" +) + +// Matcher is a generic pattern matcher, +// it gives the match result of the given pattern for specific v. +type Matcher interface { + Match(v string) bool +} + +// NewMatcher creates a Matcher for the given pattern. +// The acutal Matcher depends on the pattern: +// IP Matcher if pattern is a valid IP address. +// CIDR Matcher if pattern is a valid CIDR address. +// Domain Matcher if both of the above are not. +func NewMatcher(pattern string) Matcher { + if pattern == "" { + return nil + } + if ip := net.ParseIP(pattern); ip != nil { + return IPMatcher(ip) + } + if _, inet, err := net.ParseCIDR(pattern); err == nil { + return CIDRMatcher(inet) + } + return DomainMatcher(pattern) +} + +type ipMatcher struct { + ip net.IP +} + +// IPMatcher creates a Matcher for a specific IP address. +func IPMatcher(ip net.IP) Matcher { + return &ipMatcher{ + ip: ip, + } +} + +func (m *ipMatcher) Match(ip string) bool { + if m == nil { + return false + } + return m.ip.Equal(net.ParseIP(ip)) +} + +type cidrMatcher struct { + ipNet *net.IPNet +} + +// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. +func CIDRMatcher(inet *net.IPNet) Matcher { + return &cidrMatcher{ + ipNet: inet, + } +} + +func (m *cidrMatcher) Match(ip string) bool { + if m == nil || m.ipNet == nil { + return false + } + return m.ipNet.Contains(net.ParseIP(ip)) +} + +type domainMatcher struct { + pattern string + glob glob.Glob +} + +// DomainMatcher creates a Matcher for a specific domain pattern, +// the pattern can be a plain domain such as 'example.com', +// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. +func DomainMatcher(pattern string) Matcher { + p := pattern + if strings.HasPrefix(pattern, ".") { + p = pattern[1:] // trim the prefix '.' + pattern = "*" + p + } + return &domainMatcher{ + pattern: p, + glob: glob.MustCompile(pattern), + } +} + +func (m *domainMatcher) Match(domain string) bool { + if m == nil || m.glob == nil { + return false + } + + if domain == m.pattern { + return true + } + return m.glob.Match(domain) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index bb83855..a44d58a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -94,6 +94,14 @@ type SelectorConfig struct { FailTimeout time.Duration `yaml:"failTimeout" json:"failTimeout"` } +type AdmissionConfig struct { + Name string `json:"name"` + // inline, file, etc. + Type string `yaml:",omitempty" json:"type,omitempty"` + Reverse bool `yaml:",omitempty" json:"reverse,omitempty"` + Matchers []string `json:"matchers"` +} + type BypassConfig struct { Name string `json:"name"` // inline, file, etc. @@ -173,6 +181,7 @@ type ConnectorConfig struct { type ServiceConfig struct { Name string `json:"name"` Addr string `yaml:",omitempty" json:"addr,omitempty"` + Admission string `yaml:",omitempty" json:"admission,omitempty"` Bypass string `yaml:",omitempty" json:"bypass,omitempty"` Resolver string `yaml:",omitempty" json:"resolver,omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"` @@ -207,17 +216,18 @@ type NodeConfig struct { } type Config struct { - Services []*ServiceConfig `json:"services"` - Chains []*ChainConfig `yaml:",omitempty" json:"chains,omitempty"` - Authers []*AutherConfig `yaml:",omitempty" json:"authers,omitempty"` - Bypasses []*BypassConfig `yaml:",omitempty" json:"bypasses,omitempty"` - Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"` - Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` - TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` - Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` - Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` - API *APIConfig `yaml:",omitempty" json:"api,omitempty"` - Metrics *MetricsConfig `yaml:",omitempty" json:"metrics,omitempty"` + Services []*ServiceConfig `json:"services"` + Chains []*ChainConfig `yaml:",omitempty" json:"chains,omitempty"` + Authers []*AutherConfig `yaml:",omitempty" json:"authers,omitempty"` + Admissions []*AdmissionConfig `yaml:",omitempty" json:"admissions,omitempty"` + Bypasses []*BypassConfig `yaml:",omitempty" json:"bypasses,omitempty"` + Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"` + Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` + Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` + API *APIConfig `yaml:",omitempty" json:"api,omitempty"` + Metrics *MetricsConfig `yaml:",omitempty" json:"metrics,omitempty"` } func (c *Config) Load() error { diff --git a/pkg/config/parsing/parse.go b/pkg/config/parsing/parse.go index 1235a31..1626701 100644 --- a/pkg/config/parsing/parse.go +++ b/pkg/config/parsing/parse.go @@ -4,6 +4,7 @@ import ( "net" "net/url" + "github.com/go-gost/gost/pkg/admission" "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" @@ -79,6 +80,20 @@ func parseSelector(cfg *config.SelectorConfig) chain.Selector { ) } +func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { + if cfg == nil { + return nil + } + return admission.NewAdmissionPatterns( + cfg.Reverse, + cfg.Matchers, + admission.LoggerOption(logger.Default().WithFields(map[string]interface{}{ + "kind": "admission", + "admission": cfg.Name, + })), + ) +} + func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { if cfg == nil { return nil @@ -86,7 +101,7 @@ func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { return bypass.NewBypassPatterns( cfg.Reverse, cfg.Matchers, - bypass.LoggerBypassOption(logger.Default().WithFields(map[string]interface{}{ + bypass.LoggerOption(logger.Default().WithFields(map[string]interface{}{ "kind": "bypass", "bypass": cfg.Name, })), diff --git a/pkg/config/parsing/service.go b/pkg/config/parsing/service.go index 5929512..bac1d29 100644 --- a/pkg/config/parsing/service.go +++ b/pkg/config/parsing/service.go @@ -14,7 +14,7 @@ import ( "github.com/go-gost/gost/pkg/service" ) -func ParseService(cfg *config.ServiceConfig) (service.Servicer, error) { +func ParseService(cfg *config.ServiceConfig) (service.Service, error) { if cfg.Listener == nil { cfg.Listener = &config.ListenerConfig{ Type: "tcp", @@ -112,10 +112,10 @@ func ParseService(cfg *config.ServiceConfig) (service.Servicer, error) { return nil, err } - s := (&service.Service{}). - WithListener(ln). - WithHandler(h). - WithLogger(serviceLogger) + s := service.NewService(ln, h, + service.AdmissionOption(registry.Admission().Get(cfg.Admission)), + service.LoggerOption(serviceLogger), + ) serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network()) return s, nil diff --git a/pkg/listener/dns/listener.go b/pkg/listener/dns/listener.go index 19c3cc2..479d3e5 100644 --- a/pkg/listener/dns/listener.go +++ b/pkg/listener/dns/listener.go @@ -180,6 +180,9 @@ func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/dns-message") raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if raddr == nil { + raddr = &net.TCPAddr{} + } if err := l.serve(&dohResponseWriter{raddr: raddr, ResponseWriter: w}, buf); err != nil { l.logger.Error(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/pkg/listener/grpc/server.go b/pkg/listener/grpc/server.go index c46a833..17ccf3e 100644 --- a/pkg/listener/grpc/server.go +++ b/pkg/listener/grpc/server.go @@ -8,6 +8,7 @@ import ( pb "github.com/go-gost/gost/pkg/common/util/grpc/proto" "github.com/go-gost/gost/pkg/logger" + "google.golang.org/grpc/peer" ) type server struct { @@ -24,6 +25,9 @@ func (s *server) Tunnel(srv pb.GostTunel_TunnelServer) error { remoteAddr: &net.TCPAddr{}, closed: make(chan struct{}), } + if p, ok := peer.FromContext(srv.Context()); ok { + c.remoteAddr = p.Addr + } select { case s.cqueue <- c: diff --git a/pkg/registry/admission.go b/pkg/registry/admission.go new file mode 100644 index 0000000..05bc0cd --- /dev/null +++ b/pkg/registry/admission.go @@ -0,0 +1,65 @@ +package registry + +import ( + "sync" + + "github.com/go-gost/gost/pkg/admission" +) + +var ( + admissionReg = &admissionRegistry{} +) + +func Admission() *admissionRegistry { + return admissionReg +} + +type admissionRegistry struct { + m sync.Map +} + +func (r *admissionRegistry) Register(name string, admission admission.Admission) error { + if name == "" || admission == nil { + return nil + } + if _, loaded := r.m.LoadOrStore(name, admission); loaded { + return ErrDup + } + + return nil +} + +func (r *admissionRegistry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *admissionRegistry) IsRegistered(name string) bool { + _, ok := r.m.Load(name) + return ok +} + +func (r *admissionRegistry) Get(name string) admission.Admission { + if name == "" { + return nil + } + return &admissionWrapper{name: name} +} + +func (r *admissionRegistry) get(name string) admission.Admission { + if v, ok := r.m.Load(name); ok { + return v.(admission.Admission) + } + return nil +} + +type admissionWrapper struct { + name string +} + +func (w *admissionWrapper) Admit(addr string) bool { + p := admissionReg.get(w.name) + if p == nil { + return false + } + return p.Admit(addr) +} diff --git a/pkg/registry/service.go b/pkg/registry/service.go index 2889725..cf5e667 100644 --- a/pkg/registry/service.go +++ b/pkg/registry/service.go @@ -18,7 +18,7 @@ type serviceRegistry struct { m sync.Map } -func (r *serviceRegistry) Register(name string, svc service.Servicer) error { +func (r *serviceRegistry) Register(name string, svc service.Service) error { if name == "" || svc == nil { return nil } @@ -38,7 +38,7 @@ func (r *serviceRegistry) IsRegistered(name string) bool { return ok } -func (r *serviceRegistry) Get(name string) service.Servicer { +func (r *serviceRegistry) Get(name string) service.Service { if name == "" { return nil } @@ -46,5 +46,5 @@ func (r *serviceRegistry) Get(name string) service.Servicer { if !ok { return nil } - return v.(service.Servicer) + return v.(service.Service) } diff --git a/pkg/service/service.go b/pkg/service/service.go index d9b7da9..c6b66a2 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -5,47 +5,64 @@ import ( "net" "time" + "github.com/go-gost/gost/pkg/admission" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" ) -type Servicer interface { +type options struct { + admission admission.Admission + logger logger.Logger +} + +type Option func(opts *options) + +func AdmissionOption(admission admission.Admission) Option { + return func(opts *options) { + opts.admission = admission + } +} + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type Service interface { Serve() error Addr() net.Addr Close() error } -type Service struct { +type service struct { listener listener.Listener handler handler.Handler - logger logger.Logger + options options } -func (s *Service) WithListener(ln listener.Listener) *Service { - s.listener = ln - return s +func NewService(ln listener.Listener, h handler.Handler, opts ...Option) Service { + var options options + for _, opt := range opts { + opt(&options) + } + return &service{ + listener: ln, + handler: h, + options: options, + } } -func (s *Service) WithHandler(h handler.Handler) *Service { - s.handler = h - return s -} - -func (s *Service) WithLogger(logger logger.Logger) *Service { - s.logger = logger - return s -} - -func (s *Service) Addr() net.Addr { +func (s *service) Addr() net.Addr { return s.listener.Addr() } -func (s *Service) Close() error { +func (s *service) Close() error { return s.listener.Close() } -func (s *Service) Serve() error { +func (s *service) Serve() error { var tempDelay time.Duration for { conn, e := s.listener.Accept() @@ -59,15 +76,22 @@ func (s *Service) Serve() error { if max := 5 * time.Second; tempDelay > max { tempDelay = max } - s.logger.Warnf("accept: %v, retrying in %v", e, tempDelay) + s.options.logger.Warnf("accept: %v, retrying in %v", e, tempDelay) time.Sleep(tempDelay) continue } - s.logger.Errorf("accept: %v", e) + s.options.logger.Errorf("accept: %v", e) return e } tempDelay = 0 + if s.options.admission != nil && + !s.options.admission.Admit(conn.RemoteAddr().String()) { + s.options.logger.Infof("admission: %s is denied", conn.RemoteAddr()) + conn.Close() + continue + } + go s.handler.Handle(context.Background(), conn) } }