From 3bc2524068b78f1fd8ac8eee32ead06826b7aa81 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 9 Apr 2022 23:30:28 +0800 Subject: [PATCH] refine bypass and admission pattern matchers --- admission/admission.go | 72 ++++++++--------- bypass/bypass.go | 94 ++++++++++++--------- config/parsing/parse.go | 4 +- internal/util/matcher/matcher.go | 135 ++++++++++++++++++++----------- 4 files changed, 178 insertions(+), 127 deletions(-) diff --git a/admission/admission.go b/admission/admission.go index a821a55..32cbb64 100644 --- a/admission/admission.go +++ b/admission/admission.go @@ -2,7 +2,6 @@ package admission import ( "net" - "strconv" admission_pkg "github.com/go-gost/core/admission" "github.com/go-gost/core/logger" @@ -22,60 +21,52 @@ func LoggerOption(logger logger.Logger) Option { } type admission struct { - matchers []matcher.Matcher - reversed bool - options options -} - -// NewAdmission creates and initializes a new Admission using matchers as its match rules. -// The rules will be reversed if the reversed is true. -func NewAdmission(reversed bool, matchers []matcher.Matcher, opts ...Option) admission_pkg.Admission { - options := options{} - for _, opt := range opts { - opt(&options) - } - return &admission{ - matchers: matchers, - reversed: reversed, - options: options, - } + ipMatcher matcher.Matcher + cidrMatcher matcher.Matcher + reversed bool + options options } // NewAdmissionPatterns creates and initializes a new Admission using matcher patterns as its match rules. // The rules will be reversed if the reverse is true. -func NewAdmissionPatterns(reversed bool, patterns []string, opts ...Option) admission_pkg.Admission { - var matchers []matcher.Matcher +func NewAdmission(reversed bool, patterns []string, opts ...Option) admission_pkg.Admission { + var options options + for _, opt := range opts { + opt(&options) + } + + var ips []net.IP + var inets []*net.IPNet for _, pattern := range patterns { - if m := matcher.NewMatcher(pattern); m != nil { - matchers = append(matchers, m) + if ip := net.ParseIP(pattern); ip != nil { + ips = append(ips, ip) + continue + } + if _, inet, err := net.ParseCIDR(pattern); err == nil { + inets = append(inets, inet) + continue } } - return NewAdmission(reversed, matchers, opts...) + return &admission{ + reversed: reversed, + options: options, + ipMatcher: matcher.IPMatcher(ips), + cidrMatcher: matcher.CIDRMatcher(inets), + } } func (p *admission) Admit(addr string) bool { - if addr == "" || p == nil || len(p.matchers) == 0 { + if addr == "" || p == nil { p.options.logger.Debugf("admission: %v is denied", addr) return false } // try to strip the port - if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { - if p, _ := strconv.Atoi(port); p > 0 { // port is valid - addr = host - } + if host, _, _ := net.SplitHostPort(addr); host != "" { + addr = host } - var matched bool - for _, matcher := range p.matchers { - if matcher == nil { - continue - } - if matcher.Match(addr) { - matched = true - break - } - } + matched := p.matched(addr) b := !p.reversed && matched || p.reversed && !matched @@ -84,3 +75,8 @@ func (p *admission) Admit(addr string) bool { } return b } + +func (p *admission) matched(addr string) bool { + return p.ipMatcher.Match(addr) || + p.cidrMatcher.Match(addr) +} diff --git a/bypass/bypass.go b/bypass/bypass.go index 9824b92..45a7450 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -2,7 +2,7 @@ package bypass import ( "net" - "strconv" + "strings" bypass_pkg "github.com/go-gost/core/bypass" "github.com/go-gost/core/logger" @@ -22,59 +22,63 @@ func LoggerOption(logger logger.Logger) Option { } type bypass struct { - matchers []matcher.Matcher - reversed bool - options options -} - -// NewBypass creates and initializes a new Bypass using matchers as its match rules. -// The rules will be reversed if the reversed is true. -func NewBypass(reversed bool, matchers []matcher.Matcher, opts ...Option) bypass_pkg.Bypass { - options := options{} - for _, opt := range opts { - opt(&options) - } - return &bypass{ - matchers: matchers, - reversed: reversed, - options: options, - } + ipMatcher matcher.Matcher + cidrMatcher matcher.Matcher + domainMatcher matcher.Matcher + wildcardMatcher matcher.Matcher + reversed bool + options options } // NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. // The rules will be reversed if the reverse is true. -func NewBypassPatterns(reversed bool, patterns []string, opts ...Option) bypass_pkg.Bypass { - var matchers []matcher.Matcher - for _, pattern := range patterns { - if m := matcher.NewMatcher(pattern); m != nil { - matchers = append(matchers, m) - } +func NewBypass(reversed bool, patterns []string, opts ...Option) bypass_pkg.Bypass { + var options options + for _, opt := range opts { + opt(&options) + } + + var ips []net.IP + var inets []*net.IPNet + var domains []string + var wildcards []string + for _, pattern := range patterns { + if ip := net.ParseIP(pattern); ip != nil { + ips = append(ips, ip) + continue + } + if _, inet, err := net.ParseCIDR(pattern); err == nil { + inets = append(inets, inet) + continue + } + if strings.ContainsAny(pattern, "*?") { + wildcards = append(wildcards, pattern) + continue + } + domains = append(domains, pattern) + + } + return &bypass{ + reversed: reversed, + options: options, + ipMatcher: matcher.IPMatcher(ips), + cidrMatcher: matcher.CIDRMatcher(inets), + domainMatcher: matcher.DomainMatcher(domains), + wildcardMatcher: matcher.WildcardMatcher(wildcards), } - return NewBypass(reversed, matchers, opts...) } func (bp *bypass) Contains(addr string) bool { - if addr == "" || bp == nil || len(bp.matchers) == 0 { + if addr == "" || bp == nil { return false } // try to strip the port - if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { - if p, _ := strconv.Atoi(port); p > 0 { // port is valid - addr = host - } + if host, _, _ := net.SplitHostPort(addr); host != "" { + addr = host } - var matched bool - for _, matcher := range bp.matchers { - if matcher == nil { - continue - } - if matcher.Match(addr) { - matched = true - break - } - } + matched := bp.matched(addr) b := !bp.reversed && matched || bp.reversed && !matched @@ -83,3 +87,13 @@ func (bp *bypass) Contains(addr string) bool { } return b } + +func (bp *bypass) matched(addr string) bool { + if ip := net.ParseIP(addr); ip != nil { + return bp.ipMatcher.Match(addr) || + bp.cidrMatcher.Match(addr) + } + + return bp.domainMatcher.Match(addr) || + bp.wildcardMatcher.Match(addr) +} diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 6fded76..139b681 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -88,7 +88,7 @@ func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { if cfg == nil { return nil } - return admission_impl.NewAdmissionPatterns( + return admission_impl.NewAdmission( cfg.Reverse, cfg.Matchers, admission_impl.LoggerOption(logger.Default().WithFields(map[string]any{ @@ -102,7 +102,7 @@ func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { if cfg == nil { return nil } - return bypass_impl.NewBypassPatterns( + return bypass_impl.NewBypass( cfg.Reverse, cfg.Matchers, bypass_impl.LoggerOption(logger.Default().WithFields(map[string]any{ diff --git a/internal/util/matcher/matcher.go b/internal/util/matcher/matcher.go index 81442d3..1e1b527 100644 --- a/internal/util/matcher/matcher.go +++ b/internal/util/matcher/matcher.go @@ -13,87 +13,128 @@ type Matcher interface { Match(v string) bool } -// NewMatcher creates a Matcher for the given pattern. -// The acutal Matcher depends on the pattern: -// IP Matcher if pattern is a valid IP address. -// CIDR Matcher if pattern is a valid CIDR address. -// Domain Matcher if both of the above are not. -func NewMatcher(pattern string) Matcher { - if pattern == "" { - return nil - } - if ip := net.ParseIP(pattern); ip != nil { - return IPMatcher(ip) - } - if _, inet, err := net.ParseCIDR(pattern); err == nil { - return CIDRMatcher(inet) - } - return DomainMatcher(pattern) -} - type ipMatcher struct { - ip net.IP + ips map[string]struct{} } -// IPMatcher creates a Matcher for a specific IP address. -func IPMatcher(ip net.IP) Matcher { - return &ipMatcher{ - ip: ip, +// IPMatcher creates a Matcher with a list of IP addresses. +func IPMatcher(ips []net.IP) Matcher { + matcher := &ipMatcher{ + ips: make(map[string]struct{}), } + for _, ip := range ips { + matcher.ips[ip.String()] = struct{}{} + } + return matcher } func (m *ipMatcher) Match(ip string) bool { - if m == nil { + if m == nil || len(m.ips) == 0 { return false } - return m.ip.Equal(net.ParseIP(ip)) + _, ok := m.ips[ip] + return ok } type cidrMatcher struct { - ipNet *net.IPNet + inets []*net.IPNet } -// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. -func CIDRMatcher(inet *net.IPNet) Matcher { +// CIDRMatcher creates a Matcher for a list of CIDR notation IP addresses. +func CIDRMatcher(inets []*net.IPNet) Matcher { return &cidrMatcher{ - ipNet: inet, + inets: inets, } } func (m *cidrMatcher) Match(ip string) bool { - if m == nil || m.ipNet == nil { + if m == nil || len(m.inets) == 0 { return false } - return m.ipNet.Contains(net.ParseIP(ip)) + for _, inet := range m.inets { + if inet.Contains(net.ParseIP(ip)) { + return true + } + } + return false } type domainMatcher struct { - pattern string - glob glob.Glob + domains map[string]struct{} } -// DomainMatcher creates a Matcher for a specific domain pattern, -// the pattern can be a plain domain such as 'example.com', -// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. -func DomainMatcher(pattern string) Matcher { - p := pattern - if strings.HasPrefix(pattern, ".") { - p = pattern[1:] // trim the prefix '.' - pattern = "*" + p +// DomainMatcher creates a Matcher for a list of domains, +// the domain should be a plain domain such as 'example.com', +// or a special pattern '.example.com' that matches 'example.com' +// and any subdomain 'abc.example.com', 'def.abc.example.com' etc. +func DomainMatcher(domains []string) Matcher { + matcher := &domainMatcher{ + domains: make(map[string]struct{}), } - return &domainMatcher{ - pattern: p, - glob: glob.MustCompile(pattern), + for _, domain := range domains { + matcher.domains[domain] = struct{}{} } + return matcher } func (m *domainMatcher) Match(domain string) bool { - if m == nil || m.glob == nil { + if m == nil || len(m.domains) == 0 { return false } - if domain == m.pattern { + if _, ok := m.domains[domain]; ok { return true } - return m.glob.Match(domain) + + if _, ok := m.domains["."+domain]; ok { + return true + } + + for { + if index := strings.IndexByte(domain, '.'); index > 0 { + if _, ok := m.domains[domain[index:]]; ok { + return true + } + domain = domain[index+1:] + continue + } + break + } + return false +} + +type wildcardMatcherPattern struct { + pattern string + glob glob.Glob +} +type wildcardMatcher struct { + patterns []wildcardMatcherPattern +} + +// WildcardMatcher creates a Matcher for a specific wildcard domain pattern, +// the pattern should be a wildcard such as '*.exmaple.com'. +func WildcardMatcher(patterns []string) Matcher { + matcher := &wildcardMatcher{} + for _, pattern := range patterns { + matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{ + pattern: pattern, + glob: glob.MustCompile(pattern), + }) + } + + return matcher +} + +func (m *wildcardMatcher) Match(domain string) bool { + if m == nil || len(m.patterns) == 0 { + return false + } + + for _, pattern := range m.patterns { + if pattern.glob.Match(domain) { + return true + } + } + + return false }