261 lines
5.4 KiB
Go
261 lines
5.4 KiB
Go
package matcher
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/gobwas/glob"
|
|
"github.com/yl2chen/cidranger"
|
|
)
|
|
|
|
// Matcher is a generic pattern matcher,
|
|
// it gives the match result of the given pattern for specific v.
|
|
type Matcher interface {
|
|
Match(v string) bool
|
|
}
|
|
|
|
type ipMatcher struct {
|
|
ips map[string]struct{}
|
|
}
|
|
|
|
// IPMatcher creates a Matcher with a list of IP addresses.
|
|
func IPMatcher(ips []net.IP) Matcher {
|
|
matcher := &ipMatcher{
|
|
ips: make(map[string]struct{}),
|
|
}
|
|
for _, ip := range ips {
|
|
matcher.ips[ip.String()] = struct{}{}
|
|
}
|
|
return matcher
|
|
}
|
|
|
|
func (m *ipMatcher) Match(ip string) bool {
|
|
if m == nil || len(m.ips) == 0 {
|
|
return false
|
|
}
|
|
_, ok := m.ips[ip]
|
|
return ok
|
|
}
|
|
|
|
type addrMatcher struct {
|
|
addrs map[string]*PortRange
|
|
}
|
|
|
|
// AddrMatcher creates a Matcher with a list of HOST:PORT addresses.
|
|
// the host can be an IP (e.g. 192.168.1.1) address, a plain domain such as 'example.com',
|
|
// or a special pattern '.example.com' that matches 'example.com'
|
|
// and any subdomain 'abc.example.com', 'def.abc.example.com' etc.
|
|
// The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535).
|
|
func AddrMatcher(addrs []string) Matcher {
|
|
matcher := &addrMatcher{
|
|
addrs: make(map[string]*PortRange),
|
|
}
|
|
for _, addr := range addrs {
|
|
host, port, _ := net.SplitHostPort(addr)
|
|
if host == "" {
|
|
matcher.addrs[addr] = nil
|
|
continue
|
|
}
|
|
pr, _ := parsePortRange(port)
|
|
matcher.addrs[host] = pr
|
|
}
|
|
return matcher
|
|
}
|
|
|
|
func (m *addrMatcher) Match(addr string) bool {
|
|
if m == nil || len(m.addrs) == 0 {
|
|
return false
|
|
}
|
|
host, sp, _ := net.SplitHostPort(addr)
|
|
if host == "" {
|
|
host = addr
|
|
}
|
|
port, _ := strconv.Atoi(sp)
|
|
|
|
if pr, ok := m.addrs[host]; ok {
|
|
if pr == nil || pr.contains(port) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
if pr, ok := m.addrs["."+host]; ok {
|
|
if pr == nil || pr.contains(port) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
for {
|
|
if index := strings.IndexByte(host, '.'); index > 0 {
|
|
if pr, ok := m.addrs[host[index:]]; ok {
|
|
if pr == nil || pr.contains(port) {
|
|
return true
|
|
}
|
|
}
|
|
host = host[index+1:]
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
type cidrMatcher struct {
|
|
ranger cidranger.Ranger
|
|
}
|
|
|
|
// CIDRMatcher creates a Matcher for a list of CIDR notation IP addresses.
|
|
func CIDRMatcher(inets []*net.IPNet) Matcher {
|
|
ranger := cidranger.NewPCTrieRanger()
|
|
for _, inet := range inets {
|
|
ranger.Insert(cidranger.NewBasicRangerEntry(*inet))
|
|
}
|
|
return &cidrMatcher{ranger: ranger}
|
|
}
|
|
|
|
func (m *cidrMatcher) Match(ip string) bool {
|
|
if m == nil || m.ranger == nil {
|
|
return false
|
|
}
|
|
if netIP := net.ParseIP(ip); netIP != nil {
|
|
b, _ := m.ranger.Contains(netIP)
|
|
return b
|
|
}
|
|
return false
|
|
}
|
|
|
|
type domainMatcher struct {
|
|
domains map[string]struct{}
|
|
}
|
|
|
|
// DomainMatcher creates a Matcher for a list of domains,
|
|
// the domain should be a plain domain such as 'example.com',
|
|
// or a special pattern '.example.com' that matches 'example.com'
|
|
// and any subdomain 'abc.example.com', 'def.abc.example.com' etc.
|
|
func DomainMatcher(domains []string) Matcher {
|
|
matcher := &domainMatcher{
|
|
domains: make(map[string]struct{}),
|
|
}
|
|
for _, domain := range domains {
|
|
matcher.domains[domain] = struct{}{}
|
|
}
|
|
return matcher
|
|
}
|
|
|
|
func (m *domainMatcher) Match(domain string) bool {
|
|
if m == nil || len(m.domains) == 0 {
|
|
return false
|
|
}
|
|
|
|
if _, ok := m.domains[domain]; ok {
|
|
return true
|
|
}
|
|
|
|
if _, ok := m.domains["."+domain]; ok {
|
|
return true
|
|
}
|
|
|
|
for {
|
|
if index := strings.IndexByte(domain, '.'); index > 0 {
|
|
if _, ok := m.domains[domain[index:]]; ok {
|
|
return true
|
|
}
|
|
domain = domain[index+1:]
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
return false
|
|
}
|
|
|
|
type wildcardMatcherPattern struct {
|
|
glob glob.Glob
|
|
pr *PortRange
|
|
}
|
|
type wildcardMatcher struct {
|
|
patterns []wildcardMatcherPattern
|
|
}
|
|
|
|
// WildcardMatcher creates a Matcher for a specific wildcard domain pattern,
|
|
// the pattern can be a wildcard such as '*.exmaple.com', '*.example.com:80', or '*.example.com:0-65535'
|
|
func WildcardMatcher(patterns []string) Matcher {
|
|
matcher := &wildcardMatcher{}
|
|
for _, pattern := range patterns {
|
|
host, port, _ := net.SplitHostPort(pattern)
|
|
if host == "" {
|
|
host = pattern
|
|
}
|
|
pr, _ := parsePortRange(port)
|
|
matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{
|
|
glob: glob.MustCompile(host),
|
|
pr: pr,
|
|
})
|
|
}
|
|
|
|
return matcher
|
|
}
|
|
|
|
func (m *wildcardMatcher) Match(addr string) bool {
|
|
if m == nil || len(m.patterns) == 0 {
|
|
return false
|
|
}
|
|
|
|
host, sp, _ := net.SplitHostPort(addr)
|
|
if host == "" {
|
|
host = addr
|
|
}
|
|
port, _ := strconv.Atoi(sp)
|
|
for _, pattern := range m.patterns {
|
|
if pattern.glob.Match(addr) {
|
|
if pattern.pr == nil || pattern.pr.contains(port) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
type PortRange struct {
|
|
Min int
|
|
Max int
|
|
}
|
|
|
|
// ParsePortRange parses the s to a PortRange.
|
|
// The s can be a single port number and will be converted to port range port-port.
|
|
func parsePortRange(s string) (*PortRange, error) {
|
|
minmax := strings.Split(s, "-")
|
|
switch len(minmax) {
|
|
case 1:
|
|
port, err := strconv.Atoi(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if port < 0 || port > 65535 {
|
|
return nil, fmt.Errorf("invalid port: %s", s)
|
|
}
|
|
return &PortRange{Min: port, Max: port}, nil
|
|
|
|
case 2:
|
|
min, err := strconv.Atoi(minmax[0])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
max, err := strconv.Atoi(minmax[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &PortRange{Min: min, Max: max}, nil
|
|
|
|
default:
|
|
return nil, fmt.Errorf("invalid range: %s", s)
|
|
}
|
|
}
|
|
|
|
func (pr *PortRange) contains(port int) bool {
|
|
return port >= pr.Min && port <= pr.Max
|
|
}
|