add range port support for forwarder node
This commit is contained in:
@ -1,11 +1,11 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
xnet "github.com/go-gost/x/internal/net"
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/yl2chen/cidranger"
|
||||
)
|
||||
@ -40,7 +40,7 @@ func (m *ipMatcher) Match(ip string) bool {
|
||||
}
|
||||
|
||||
type addrMatcher struct {
|
||||
addrs map[string]*PortRange
|
||||
addrs map[string]*xnet.PortRange
|
||||
}
|
||||
|
||||
// AddrMatcher creates a Matcher with a list of HOST:PORT addresses.
|
||||
@ -50,7 +50,7 @@ type addrMatcher struct {
|
||||
// The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535).
|
||||
func AddrMatcher(addrs []string) Matcher {
|
||||
matcher := &addrMatcher{
|
||||
addrs: make(map[string]*PortRange),
|
||||
addrs: make(map[string]*xnet.PortRange),
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
host, port, _ := net.SplitHostPort(addr)
|
||||
@ -58,7 +58,10 @@ func AddrMatcher(addrs []string) Matcher {
|
||||
matcher.addrs[addr] = nil
|
||||
continue
|
||||
}
|
||||
pr, _ := parsePortRange(port)
|
||||
pr := &xnet.PortRange{}
|
||||
if err := pr.Parse(port); err != nil {
|
||||
pr = nil
|
||||
}
|
||||
matcher.addrs[host] = pr
|
||||
}
|
||||
return matcher
|
||||
@ -75,13 +78,13 @@ func (m *addrMatcher) Match(addr string) bool {
|
||||
port, _ := strconv.Atoi(sp)
|
||||
|
||||
if pr, ok := m.addrs[host]; ok {
|
||||
if pr == nil || pr.contains(port) {
|
||||
if pr == nil || pr.Contains(port) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if pr, ok := m.addrs["."+host]; ok {
|
||||
if pr == nil || pr.contains(port) {
|
||||
if pr == nil || pr.Contains(port) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -89,7 +92,7 @@ func (m *addrMatcher) Match(addr string) bool {
|
||||
for {
|
||||
if index := strings.IndexByte(host, '.'); index > 0 {
|
||||
if pr, ok := m.addrs[host[index:]]; ok {
|
||||
if pr == nil || pr.contains(port) {
|
||||
if pr == nil || pr.Contains(port) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -172,7 +175,7 @@ func (m *domainMatcher) Match(domain string) bool {
|
||||
|
||||
type wildcardMatcherPattern struct {
|
||||
glob glob.Glob
|
||||
pr *PortRange
|
||||
pr *xnet.PortRange
|
||||
}
|
||||
type wildcardMatcher struct {
|
||||
patterns []wildcardMatcherPattern
|
||||
@ -187,7 +190,11 @@ func WildcardMatcher(patterns []string) Matcher {
|
||||
if host == "" {
|
||||
host = pattern
|
||||
}
|
||||
pr, _ := parsePortRange(port)
|
||||
pr := &xnet.PortRange{}
|
||||
if err := pr.Parse(port); err != nil {
|
||||
pr = nil
|
||||
}
|
||||
|
||||
matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{
|
||||
glob: glob.MustCompile(host),
|
||||
pr: pr,
|
||||
@ -208,8 +215,8 @@ func (m *wildcardMatcher) Match(addr string) bool {
|
||||
}
|
||||
port, _ := strconv.Atoi(sp)
|
||||
for _, pattern := range m.patterns {
|
||||
if pattern.glob.Match(addr) {
|
||||
if pattern.pr == nil || pattern.pr.contains(port) {
|
||||
if pattern.glob.Match(host) {
|
||||
if pattern.pr == nil || pattern.pr.Contains(port) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -217,44 +224,3 @@ func (m *wildcardMatcher) Match(addr string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type PortRange struct {
|
||||
Min int
|
||||
Max int
|
||||
}
|
||||
|
||||
// ParsePortRange parses the s to a PortRange.
|
||||
// The s can be a single port number and will be converted to port range port-port.
|
||||
func parsePortRange(s string) (*PortRange, error) {
|
||||
minmax := strings.Split(s, "-")
|
||||
switch len(minmax) {
|
||||
case 1:
|
||||
port, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port < 0 || port > 65535 {
|
||||
return nil, fmt.Errorf("invalid port: %s", s)
|
||||
}
|
||||
return &PortRange{Min: port, Max: port}, nil
|
||||
|
||||
case 2:
|
||||
min, err := strconv.Atoi(minmax[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
max, err := strconv.Atoi(minmax[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PortRange{Min: min, Max: max}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid range: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PortRange) contains(port int) bool {
|
||||
return port >= pr.Min && port <= pr.Max
|
||||
}
|
||||
|
72
internal/net/addr.go
Normal file
72
internal/net/addr.go
Normal file
@ -0,0 +1,72 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AddrPortRange is the network address with port range supported.
|
||||
// e.g. 192.168.1.1:0-65535
|
||||
type AddrPortRange string
|
||||
|
||||
func (p AddrPortRange) Addrs() (addrs []string) {
|
||||
h, sp, err := net.SplitHostPort(string(p))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
pr := PortRange{}
|
||||
pr.Parse(sp)
|
||||
|
||||
for i := pr.Min; i <= pr.Max; i++ {
|
||||
addrs = append(addrs, net.JoinHostPort(h, strconv.Itoa(i)))
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
// Port range is a range of port list.
|
||||
type PortRange struct {
|
||||
Min int
|
||||
Max int
|
||||
}
|
||||
|
||||
// Parse parses the s to PortRange.
|
||||
// The s can be a single port number and will be converted to port range port-port.
|
||||
func (pr *PortRange) Parse(s string) error {
|
||||
minmax := strings.Split(s, "-")
|
||||
switch len(minmax) {
|
||||
case 1:
|
||||
port, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if port < 0 || port > 65535 {
|
||||
return fmt.Errorf("invalid port: %s", s)
|
||||
}
|
||||
|
||||
pr.Min, pr.Max = port, port
|
||||
return nil
|
||||
|
||||
case 2:
|
||||
min, err := strconv.Atoi(minmax[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
max, err := strconv.Atoi(minmax[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pr.Min, pr.Max = min, max
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid range: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (pr *PortRange) Contains(port int) bool {
|
||||
return port >= pr.Min && port <= pr.Max
|
||||
}
|
@ -39,7 +39,7 @@ func (r *Relay) SetBufferSize(n int) {
|
||||
r.bufferSize = n
|
||||
}
|
||||
|
||||
func (r *Relay) Run() (err error) {
|
||||
func (r *Relay) Run(ctx context.Context) (err error) {
|
||||
bufSize := r.bufferSize
|
||||
if bufSize <= 0 {
|
||||
bufSize = 4096
|
||||
@ -58,7 +58,7 @@ func (r *Relay) Run() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) {
|
||||
if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("bypass: ", raddr)
|
||||
}
|
||||
@ -96,7 +96,7 @@ func (r *Relay) Run() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) {
|
||||
if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("bypass: ", raddr)
|
||||
}
|
||||
|
Reference in New Issue
Block a user