refine bypass and admission pattern matchers

This commit is contained in:
ginuerzh 2022-04-09 23:30:28 +08:00
parent 60b30598a2
commit 3bc2524068
4 changed files with 178 additions and 127 deletions

View File

@ -2,7 +2,6 @@ package admission
import ( import (
"net" "net"
"strconv"
admission_pkg "github.com/go-gost/core/admission" admission_pkg "github.com/go-gost/core/admission"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
@ -22,60 +21,52 @@ func LoggerOption(logger logger.Logger) Option {
} }
type admission struct { type admission struct {
matchers []matcher.Matcher ipMatcher matcher.Matcher
cidrMatcher matcher.Matcher
reversed bool reversed bool
options options options options
} }
// NewAdmission creates and initializes a new Admission using matchers as its match rules. // NewAdmissionPatterns creates and initializes a new Admission using matcher patterns as its match rules.
// The rules will be reversed if the reversed is true. // The rules will be reversed if the reverse is true.
func NewAdmission(reversed bool, matchers []matcher.Matcher, opts ...Option) admission_pkg.Admission { func NewAdmission(reversed bool, patterns []string, opts ...Option) admission_pkg.Admission {
options := options{} var options options
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)
} }
var ips []net.IP
var inets []*net.IPNet
for _, pattern := range patterns {
if ip := net.ParseIP(pattern); ip != nil {
ips = append(ips, ip)
continue
}
if _, inet, err := net.ParseCIDR(pattern); err == nil {
inets = append(inets, inet)
continue
}
}
return &admission{ return &admission{
matchers: matchers,
reversed: reversed, reversed: reversed,
options: options, options: options,
ipMatcher: matcher.IPMatcher(ips),
cidrMatcher: matcher.CIDRMatcher(inets),
} }
} }
// NewAdmissionPatterns creates and initializes a new Admission using matcher patterns as its match rules.
// The rules will be reversed if the reverse is true.
func NewAdmissionPatterns(reversed bool, patterns []string, opts ...Option) admission_pkg.Admission {
var matchers []matcher.Matcher
for _, pattern := range patterns {
if m := matcher.NewMatcher(pattern); m != nil {
matchers = append(matchers, m)
}
}
return NewAdmission(reversed, matchers, opts...)
}
func (p *admission) Admit(addr string) bool { func (p *admission) Admit(addr string) bool {
if addr == "" || p == nil || len(p.matchers) == 0 { if addr == "" || p == nil {
p.options.logger.Debugf("admission: %v is denied", addr) p.options.logger.Debugf("admission: %v is denied", addr)
return false return false
} }
// try to strip the port // try to strip the port
if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { if host, _, _ := net.SplitHostPort(addr); host != "" {
if p, _ := strconv.Atoi(port); p > 0 { // port is valid
addr = host addr = host
} }
}
var matched bool matched := p.matched(addr)
for _, matcher := range p.matchers {
if matcher == nil {
continue
}
if matcher.Match(addr) {
matched = true
break
}
}
b := !p.reversed && matched || b := !p.reversed && matched ||
p.reversed && !matched p.reversed && !matched
@ -84,3 +75,8 @@ func (p *admission) Admit(addr string) bool {
} }
return b return b
} }
func (p *admission) matched(addr string) bool {
return p.ipMatcher.Match(addr) ||
p.cidrMatcher.Match(addr)
}

View File

@ -2,7 +2,7 @@ package bypass
import ( import (
"net" "net"
"strconv" "strings"
bypass_pkg "github.com/go-gost/core/bypass" bypass_pkg "github.com/go-gost/core/bypass"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
@ -22,59 +22,63 @@ func LoggerOption(logger logger.Logger) Option {
} }
type bypass struct { type bypass struct {
matchers []matcher.Matcher ipMatcher matcher.Matcher
cidrMatcher matcher.Matcher
domainMatcher matcher.Matcher
wildcardMatcher matcher.Matcher
reversed bool reversed bool
options options options options
} }
// NewBypass creates and initializes a new Bypass using matchers as its match rules. // NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules.
// The rules will be reversed if the reversed is true. // The rules will be reversed if the reverse is true.
func NewBypass(reversed bool, matchers []matcher.Matcher, opts ...Option) bypass_pkg.Bypass { func NewBypass(reversed bool, patterns []string, opts ...Option) bypass_pkg.Bypass {
options := options{} var options options
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)
} }
var ips []net.IP
var inets []*net.IPNet
var domains []string
var wildcards []string
for _, pattern := range patterns {
if ip := net.ParseIP(pattern); ip != nil {
ips = append(ips, ip)
continue
}
if _, inet, err := net.ParseCIDR(pattern); err == nil {
inets = append(inets, inet)
continue
}
if strings.ContainsAny(pattern, "*?") {
wildcards = append(wildcards, pattern)
continue
}
domains = append(domains, pattern)
}
return &bypass{ return &bypass{
matchers: matchers,
reversed: reversed, reversed: reversed,
options: options, options: options,
ipMatcher: matcher.IPMatcher(ips),
cidrMatcher: matcher.CIDRMatcher(inets),
domainMatcher: matcher.DomainMatcher(domains),
wildcardMatcher: matcher.WildcardMatcher(wildcards),
} }
} }
// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules.
// The rules will be reversed if the reverse is true.
func NewBypassPatterns(reversed bool, patterns []string, opts ...Option) bypass_pkg.Bypass {
var matchers []matcher.Matcher
for _, pattern := range patterns {
if m := matcher.NewMatcher(pattern); m != nil {
matchers = append(matchers, m)
}
}
return NewBypass(reversed, matchers, opts...)
}
func (bp *bypass) Contains(addr string) bool { func (bp *bypass) Contains(addr string) bool {
if addr == "" || bp == nil || len(bp.matchers) == 0 { if addr == "" || bp == nil {
return false return false
} }
// try to strip the port // try to strip the port
if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { if host, _, _ := net.SplitHostPort(addr); host != "" {
if p, _ := strconv.Atoi(port); p > 0 { // port is valid
addr = host addr = host
} }
}
var matched bool matched := bp.matched(addr)
for _, matcher := range bp.matchers {
if matcher == nil {
continue
}
if matcher.Match(addr) {
matched = true
break
}
}
b := !bp.reversed && matched || b := !bp.reversed && matched ||
bp.reversed && !matched bp.reversed && !matched
@ -83,3 +87,13 @@ func (bp *bypass) Contains(addr string) bool {
} }
return b return b
} }
func (bp *bypass) matched(addr string) bool {
if ip := net.ParseIP(addr); ip != nil {
return bp.ipMatcher.Match(addr) ||
bp.cidrMatcher.Match(addr)
}
return bp.domainMatcher.Match(addr) ||
bp.wildcardMatcher.Match(addr)
}

View File

@ -88,7 +88,7 @@ func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission {
if cfg == nil { if cfg == nil {
return nil return nil
} }
return admission_impl.NewAdmissionPatterns( return admission_impl.NewAdmission(
cfg.Reverse, cfg.Reverse,
cfg.Matchers, cfg.Matchers,
admission_impl.LoggerOption(logger.Default().WithFields(map[string]any{ admission_impl.LoggerOption(logger.Default().WithFields(map[string]any{
@ -102,7 +102,7 @@ func ParseBypass(cfg *config.BypassConfig) bypass.Bypass {
if cfg == nil { if cfg == nil {
return nil return nil
} }
return bypass_impl.NewBypassPatterns( return bypass_impl.NewBypass(
cfg.Reverse, cfg.Reverse,
cfg.Matchers, cfg.Matchers,
bypass_impl.LoggerOption(logger.Default().WithFields(map[string]any{ bypass_impl.LoggerOption(logger.Default().WithFields(map[string]any{

View File

@ -13,87 +13,128 @@ type Matcher interface {
Match(v string) bool Match(v string) bool
} }
// NewMatcher creates a Matcher for the given pattern.
// The acutal Matcher depends on the pattern:
// IP Matcher if pattern is a valid IP address.
// CIDR Matcher if pattern is a valid CIDR address.
// Domain Matcher if both of the above are not.
func NewMatcher(pattern string) Matcher {
if pattern == "" {
return nil
}
if ip := net.ParseIP(pattern); ip != nil {
return IPMatcher(ip)
}
if _, inet, err := net.ParseCIDR(pattern); err == nil {
return CIDRMatcher(inet)
}
return DomainMatcher(pattern)
}
type ipMatcher struct { type ipMatcher struct {
ip net.IP ips map[string]struct{}
} }
// IPMatcher creates a Matcher for a specific IP address. // IPMatcher creates a Matcher with a list of IP addresses.
func IPMatcher(ip net.IP) Matcher { func IPMatcher(ips []net.IP) Matcher {
return &ipMatcher{ matcher := &ipMatcher{
ip: ip, ips: make(map[string]struct{}),
} }
for _, ip := range ips {
matcher.ips[ip.String()] = struct{}{}
}
return matcher
} }
func (m *ipMatcher) Match(ip string) bool { func (m *ipMatcher) Match(ip string) bool {
if m == nil { if m == nil || len(m.ips) == 0 {
return false return false
} }
return m.ip.Equal(net.ParseIP(ip)) _, ok := m.ips[ip]
return ok
} }
type cidrMatcher struct { type cidrMatcher struct {
ipNet *net.IPNet inets []*net.IPNet
} }
// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. // CIDRMatcher creates a Matcher for a list of CIDR notation IP addresses.
func CIDRMatcher(inet *net.IPNet) Matcher { func CIDRMatcher(inets []*net.IPNet) Matcher {
return &cidrMatcher{ return &cidrMatcher{
ipNet: inet, inets: inets,
} }
} }
func (m *cidrMatcher) Match(ip string) bool { func (m *cidrMatcher) Match(ip string) bool {
if m == nil || m.ipNet == nil { if m == nil || len(m.inets) == 0 {
return false return false
} }
return m.ipNet.Contains(net.ParseIP(ip)) for _, inet := range m.inets {
if inet.Contains(net.ParseIP(ip)) {
return true
}
}
return false
} }
type domainMatcher struct { type domainMatcher struct {
pattern string domains map[string]struct{}
glob glob.Glob
} }
// DomainMatcher creates a Matcher for a specific domain pattern, // DomainMatcher creates a Matcher for a list of domains,
// the pattern can be a plain domain such as 'example.com', // the domain should be a plain domain such as 'example.com',
// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. // or a special pattern '.example.com' that matches 'example.com'
func DomainMatcher(pattern string) Matcher { // and any subdomain 'abc.example.com', 'def.abc.example.com' etc.
p := pattern func DomainMatcher(domains []string) Matcher {
if strings.HasPrefix(pattern, ".") { matcher := &domainMatcher{
p = pattern[1:] // trim the prefix '.' domains: make(map[string]struct{}),
pattern = "*" + p
} }
return &domainMatcher{ for _, domain := range domains {
pattern: p, matcher.domains[domain] = struct{}{}
glob: glob.MustCompile(pattern),
} }
return matcher
} }
func (m *domainMatcher) Match(domain string) bool { func (m *domainMatcher) Match(domain string) bool {
if m == nil || m.glob == nil { if m == nil || len(m.domains) == 0 {
return false return false
} }
if domain == m.pattern { if _, ok := m.domains[domain]; ok {
return true return true
} }
return m.glob.Match(domain)
if _, ok := m.domains["."+domain]; ok {
return true
}
for {
if index := strings.IndexByte(domain, '.'); index > 0 {
if _, ok := m.domains[domain[index:]]; ok {
return true
}
domain = domain[index+1:]
continue
}
break
}
return false
}
type wildcardMatcherPattern struct {
pattern string
glob glob.Glob
}
type wildcardMatcher struct {
patterns []wildcardMatcherPattern
}
// WildcardMatcher creates a Matcher for a specific wildcard domain pattern,
// the pattern should be a wildcard such as '*.exmaple.com'.
func WildcardMatcher(patterns []string) Matcher {
matcher := &wildcardMatcher{}
for _, pattern := range patterns {
matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{
pattern: pattern,
glob: glob.MustCompile(pattern),
})
}
return matcher
}
func (m *wildcardMatcher) Match(domain string) bool {
if m == nil || len(m.patterns) == 0 {
return false
}
for _, pattern := range m.patterns {
if pattern.glob.Match(domain) {
return true
}
}
return false
} }