diff --git a/admission/admission.go b/admission/admission.go index 48b9f42..6be5de3 100644 --- a/admission/admission.go +++ b/admission/admission.go @@ -113,8 +113,13 @@ func (p *localAdmission) Admit(ctx context.Context, addr string, opts ...admissi matched := p.matched(addr) - return !p.options.whitelist && !matched || + b := !p.options.whitelist && !matched || p.options.whitelist && matched + + if !b { + p.options.logger.Debugf("%s is denied", addr) + } + return b } func (p *localAdmission) periodReload(ctx context.Context) error { @@ -156,6 +161,10 @@ func (p *localAdmission) reload(ctx context.Context) error { inets = append(inets, inet) continue } + if ipAddr, _ := net.ResolveIPAddr("ip", pattern); ipAddr != nil { + p.options.logger.Debugf("resolve IP: %s -> %s", pattern, ipAddr) + ips = append(ips, ipAddr.IP) + } } p.mu.Lock() diff --git a/bypass/bypass.go b/bypass/bypass.go index 52cce57..ecebdb5 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -70,9 +70,8 @@ func LoggerOption(logger logger.Logger) Option { } type localBypass struct { - ipMatcher matcher.Matcher cidrMatcher matcher.Matcher - domainMatcher matcher.Matcher + addrMatcher matcher.Matcher wildcardMatcher matcher.Matcher cancelFunc context.CancelFunc options options @@ -132,15 +131,10 @@ func (bp *localBypass) reload(ctx context.Context) error { } patterns := append(bp.options.matchers, v...) - var ips []net.IP + var addrs []string var inets []*net.IPNet - var domains []string var wildcards []string for _, pattern := range patterns { - if ip := net.ParseIP(pattern); ip != nil { - ips = append(ips, ip) - continue - } if _, inet, err := net.ParseCIDR(pattern); err == nil { inets = append(inets, inet) continue @@ -149,15 +143,14 @@ func (bp *localBypass) reload(ctx context.Context) error { wildcards = append(wildcards, pattern) continue } - domains = append(domains, pattern) + addrs = append(addrs, pattern) } bp.mu.Lock() defer bp.mu.Unlock() - bp.ipMatcher = matcher.IPMatcher(ips) bp.cidrMatcher = matcher.CIDRMatcher(inets) - bp.domainMatcher = matcher.DomainMatcher(domains) + bp.addrMatcher = matcher.AddrMatcher(addrs) bp.wildcardMatcher = matcher.WildcardMatcher(wildcards) return nil @@ -237,11 +230,6 @@ func (bp *localBypass) Contains(ctx context.Context, network, addr string, opts return false } - // try to strip the port - if host, _, _ := net.SplitHostPort(addr); host != "" { - addr = host - } - matched := bp.matched(addr) b := !bp.options.whitelist && matched || @@ -263,13 +251,20 @@ func (bp *localBypass) matched(addr string) bool { bp.mu.RLock() defer bp.mu.RUnlock() - if ip := net.ParseIP(addr); ip != nil { - return bp.ipMatcher.Match(addr) || - bp.cidrMatcher.Match(addr) + if bp.addrMatcher.Match(addr) { + return true } - return bp.domainMatcher.Match(addr) || - bp.wildcardMatcher.Match(addr) + host, _, _ := net.SplitHostPort(addr) + if host == "" { + host = addr + } + + if ip := net.ParseIP(host); ip != nil { + return bp.cidrMatcher.Match(host) + } + + return bp.wildcardMatcher.Match(addr) } func (bp *localBypass) Close() error { diff --git a/internal/matcher/matcher.go b/internal/matcher/matcher.go index f8f3d64..9ab7d17 100644 --- a/internal/matcher/matcher.go +++ b/internal/matcher/matcher.go @@ -1,7 +1,9 @@ package matcher import ( + "fmt" "net" + "strconv" "strings" "github.com/gobwas/glob" @@ -37,6 +39,69 @@ func (m *ipMatcher) Match(ip string) bool { return ok } +type addrMatcher struct { + addrs map[string]*PortRange +} + +// AddrMatcher creates a Matcher with a list of HOST:PORT addresses. +// the host can be an IP (e.g. 192.168.1.1) address, a plain domain such as 'example.com', +// or a special pattern '.example.com' that matches 'example.com' +// and any subdomain 'abc.example.com', 'def.abc.example.com' etc. +// The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535). +func AddrMatcher(addrs []string) Matcher { + matcher := &addrMatcher{ + addrs: make(map[string]*PortRange), + } + for _, addr := range addrs { + host, port, _ := net.SplitHostPort(addr) + if host == "" { + matcher.addrs[addr] = nil + continue + } + pr, _ := parsePortRange(port) + matcher.addrs[host] = pr + } + return matcher +} + +func (m *addrMatcher) Match(addr string) bool { + if m == nil || len(m.addrs) == 0 { + return false + } + host, sp, _ := net.SplitHostPort(addr) + if host == "" { + host = addr + } + port, _ := strconv.Atoi(sp) + + if pr, ok := m.addrs[host]; ok { + if pr == nil || pr.contains(port) { + return true + } + } + + if pr, ok := m.addrs["."+host]; ok { + if pr == nil || pr.contains(port) { + return true + } + } + + for { + if index := strings.IndexByte(host, '.'); index > 0 { + if pr, ok := m.addrs[host[index:]]; ok { + if pr == nil || pr.contains(port) { + return true + } + } + host = host[index+1:] + continue + } + break + } + + return false +} + type cidrMatcher struct { ranger cidranger.Ranger } @@ -106,37 +171,90 @@ func (m *domainMatcher) Match(domain string) bool { } type wildcardMatcherPattern struct { - pattern string - glob glob.Glob + glob glob.Glob + pr *PortRange } type wildcardMatcher struct { patterns []wildcardMatcherPattern } // WildcardMatcher creates a Matcher for a specific wildcard domain pattern, -// the pattern should be a wildcard such as '*.exmaple.com'. +// the pattern can be a wildcard such as '*.exmaple.com', '*.example.com:80', or '*.example.com:0-65535' func WildcardMatcher(patterns []string) Matcher { matcher := &wildcardMatcher{} for _, pattern := range patterns { + host, port, _ := net.SplitHostPort(pattern) + if host == "" { + host = pattern + } + pr, _ := parsePortRange(port) matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{ - pattern: pattern, - glob: glob.MustCompile(pattern), + glob: glob.MustCompile(host), + pr: pr, }) } return matcher } -func (m *wildcardMatcher) Match(domain string) bool { +func (m *wildcardMatcher) Match(addr string) bool { if m == nil || len(m.patterns) == 0 { return false } + host, sp, _ := net.SplitHostPort(addr) + if host == "" { + host = addr + } + port, _ := strconv.Atoi(sp) for _, pattern := range m.patterns { - if pattern.glob.Match(domain) { - return true + if pattern.glob.Match(addr) { + if pattern.pr == nil || pattern.pr.contains(port) { + return true + } } } return false } + +type PortRange struct { + Min int + Max int +} + +// ParsePortRange parses the s to a PortRange. +// The s can be a single port number and will be converted to port range port-port. +func parsePortRange(s string) (*PortRange, error) { + minmax := strings.Split(s, "-") + switch len(minmax) { + case 1: + port, err := strconv.Atoi(s) + if err != nil { + return nil, err + } + if port < 0 || port > 65535 { + return nil, fmt.Errorf("invalid port: %s", s) + } + return &PortRange{Min: port, Max: port}, nil + + case 2: + min, err := strconv.Atoi(minmax[0]) + if err != nil { + return nil, err + } + max, err := strconv.Atoi(minmax[1]) + if err != nil { + return nil, err + } + + return &PortRange{Min: min, Max: max}, nil + + default: + return nil, fmt.Errorf("invalid range: %s", s) + } +} + +func (pr *PortRange) contains(port int) bool { + return port >= pr.Min && port <= pr.Max +}