add range port support for forwarder node

This commit is contained in:
ginuerzh 2023-11-14 19:41:57 +08:00
parent ca1f44d93c
commit d7b7ac6357
12 changed files with 133 additions and 73 deletions

View File

@ -23,6 +23,7 @@ import (
bypass_parser "github.com/go-gost/x/config/parsing/bypass" bypass_parser "github.com/go-gost/x/config/parsing/bypass"
hop_parser "github.com/go-gost/x/config/parsing/hop" hop_parser "github.com/go-gost/x/config/parsing/hop"
selector_parser "github.com/go-gost/x/config/parsing/selector" selector_parser "github.com/go-gost/x/config/parsing/selector"
xnet "github.com/go-gost/x/internal/net"
tls_util "github.com/go-gost/x/internal/util/tls" tls_util "github.com/go-gost/x/internal/util/tls"
"github.com/go-gost/x/metadata" "github.com/go-gost/x/metadata"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
@ -258,10 +259,18 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) {
} }
for _, node := range cfg.Nodes { for _, node := range cfg.Nodes {
if node != nil { if node != nil {
hc.Nodes = append(hc.Nodes, addrs := xnet.AddrPortRange(node.Addr).Addrs()
&config.NodeConfig{ if len(addrs) == 0 {
Name: node.Name, addrs = append(addrs, node.Addr)
Addr: node.Addr, }
for i, addr := range addrs {
name := node.Name
if i > 0 {
name = fmt.Sprintf("%s-%d", node.Name, i)
}
hc.Nodes = append(hc.Nodes, &config.NodeConfig{
Name: name,
Addr: addr,
Host: node.Host, Host: node.Host,
Network: node.Network, Network: node.Network,
Protocol: node.Protocol, Protocol: node.Protocol,
@ -271,8 +280,8 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) {
HTTP: node.HTTP, HTTP: node.HTTP,
TLS: node.TLS, TLS: node.TLS,
Auth: node.Auth, Auth: node.Auth,
}, })
) }
} }
} }
if len(hc.Nodes) > 0 { if len(hc.Nodes) > 0 {

View File

@ -71,7 +71,7 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, log logger.L
t := time.Now() t := time.Now()
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
relay.Run() relay.Run(ctx)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())

View File

@ -176,7 +176,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
t := time.Now() t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
r.Run() r.Run(ctx)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) }).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())

View File

@ -72,7 +72,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger
WithLogger(log) WithLogger(log)
r.SetBufferSize(h.md.udpBufferSize) r.SetBufferSize(h.md.udpBufferSize)
go r.Run() go r.Run(ctx)
t := time.Now() t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())

View File

@ -63,7 +63,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network
t := time.Now() t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
r.Run() r.Run(ctx)
log.WithFields(map[string]any{ log.WithFields(map[string]any{
"duration": time.Since(t), "duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) }).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())

View File

@ -130,7 +130,7 @@ func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, log logge
ipProtocol(waterutil.IPProtocol(header.NextHeader)), ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass) header.PayloadLen, header.TrafficClass)
} else { } else {
log.Warn("unknown packet, discarded") log.Warnf("unknown packet, discarded(%d)", n)
return nil return nil
} }

View File

@ -78,7 +78,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con
ipProtocol(waterutil.IPProtocol(header.NextHeader)), ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass) header.PayloadLen, header.TrafficClass)
} else { } else {
log.Warn("unknown packet, discarded") log.Warnf("unknown packet, discarded(%d)", n)
return nil return nil
} }
@ -199,7 +199,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con
ipProtocol(waterutil.IPProtocol(header.NextHeader)), ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass) header.PayloadLen, header.TrafficClass)
} else { } else {
log.Warn("unknown packet, discarded") log.Warnf("unknown packet, discarded(%d): % x", n, b[:n])
return nil return nil
} }

View File

@ -1,11 +1,11 @@
package matcher package matcher
import ( import (
"fmt"
"net" "net"
"strconv" "strconv"
"strings" "strings"
xnet "github.com/go-gost/x/internal/net"
"github.com/gobwas/glob" "github.com/gobwas/glob"
"github.com/yl2chen/cidranger" "github.com/yl2chen/cidranger"
) )
@ -40,7 +40,7 @@ func (m *ipMatcher) Match(ip string) bool {
} }
type addrMatcher struct { type addrMatcher struct {
addrs map[string]*PortRange addrs map[string]*xnet.PortRange
} }
// AddrMatcher creates a Matcher with a list of HOST:PORT addresses. // AddrMatcher creates a Matcher with a list of HOST:PORT addresses.
@ -50,7 +50,7 @@ type addrMatcher struct {
// The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535). // The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535).
func AddrMatcher(addrs []string) Matcher { func AddrMatcher(addrs []string) Matcher {
matcher := &addrMatcher{ matcher := &addrMatcher{
addrs: make(map[string]*PortRange), addrs: make(map[string]*xnet.PortRange),
} }
for _, addr := range addrs { for _, addr := range addrs {
host, port, _ := net.SplitHostPort(addr) host, port, _ := net.SplitHostPort(addr)
@ -58,7 +58,10 @@ func AddrMatcher(addrs []string) Matcher {
matcher.addrs[addr] = nil matcher.addrs[addr] = nil
continue continue
} }
pr, _ := parsePortRange(port) pr := &xnet.PortRange{}
if err := pr.Parse(port); err != nil {
pr = nil
}
matcher.addrs[host] = pr matcher.addrs[host] = pr
} }
return matcher return matcher
@ -75,13 +78,13 @@ func (m *addrMatcher) Match(addr string) bool {
port, _ := strconv.Atoi(sp) port, _ := strconv.Atoi(sp)
if pr, ok := m.addrs[host]; ok { if pr, ok := m.addrs[host]; ok {
if pr == nil || pr.contains(port) { if pr == nil || pr.Contains(port) {
return true return true
} }
} }
if pr, ok := m.addrs["."+host]; ok { if pr, ok := m.addrs["."+host]; ok {
if pr == nil || pr.contains(port) { if pr == nil || pr.Contains(port) {
return true return true
} }
} }
@ -89,7 +92,7 @@ func (m *addrMatcher) Match(addr string) bool {
for { for {
if index := strings.IndexByte(host, '.'); index > 0 { if index := strings.IndexByte(host, '.'); index > 0 {
if pr, ok := m.addrs[host[index:]]; ok { if pr, ok := m.addrs[host[index:]]; ok {
if pr == nil || pr.contains(port) { if pr == nil || pr.Contains(port) {
return true return true
} }
} }
@ -172,7 +175,7 @@ func (m *domainMatcher) Match(domain string) bool {
type wildcardMatcherPattern struct { type wildcardMatcherPattern struct {
glob glob.Glob glob glob.Glob
pr *PortRange pr *xnet.PortRange
} }
type wildcardMatcher struct { type wildcardMatcher struct {
patterns []wildcardMatcherPattern patterns []wildcardMatcherPattern
@ -187,7 +190,11 @@ func WildcardMatcher(patterns []string) Matcher {
if host == "" { if host == "" {
host = pattern host = pattern
} }
pr, _ := parsePortRange(port) pr := &xnet.PortRange{}
if err := pr.Parse(port); err != nil {
pr = nil
}
matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{ matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{
glob: glob.MustCompile(host), glob: glob.MustCompile(host),
pr: pr, pr: pr,
@ -208,8 +215,8 @@ func (m *wildcardMatcher) Match(addr string) bool {
} }
port, _ := strconv.Atoi(sp) port, _ := strconv.Atoi(sp)
for _, pattern := range m.patterns { for _, pattern := range m.patterns {
if pattern.glob.Match(addr) { if pattern.glob.Match(host) {
if pattern.pr == nil || pattern.pr.contains(port) { if pattern.pr == nil || pattern.pr.Contains(port) {
return true return true
} }
} }
@ -217,44 +224,3 @@ func (m *wildcardMatcher) Match(addr string) bool {
return false return false
} }
type PortRange struct {
Min int
Max int
}
// ParsePortRange parses the s to a PortRange.
// The s can be a single port number and will be converted to port range port-port.
func parsePortRange(s string) (*PortRange, error) {
minmax := strings.Split(s, "-")
switch len(minmax) {
case 1:
port, err := strconv.Atoi(s)
if err != nil {
return nil, err
}
if port < 0 || port > 65535 {
return nil, fmt.Errorf("invalid port: %s", s)
}
return &PortRange{Min: port, Max: port}, nil
case 2:
min, err := strconv.Atoi(minmax[0])
if err != nil {
return nil, err
}
max, err := strconv.Atoi(minmax[1])
if err != nil {
return nil, err
}
return &PortRange{Min: min, Max: max}, nil
default:
return nil, fmt.Errorf("invalid range: %s", s)
}
}
func (pr *PortRange) contains(port int) bool {
return port >= pr.Min && port <= pr.Max
}

72
internal/net/addr.go Normal file
View File

@ -0,0 +1,72 @@
package net
import (
"fmt"
"net"
"strconv"
"strings"
)
// AddrPortRange is the network address with port range supported.
// e.g. 192.168.1.1:0-65535
type AddrPortRange string
func (p AddrPortRange) Addrs() (addrs []string) {
h, sp, err := net.SplitHostPort(string(p))
if err != nil {
return nil
}
pr := PortRange{}
pr.Parse(sp)
for i := pr.Min; i <= pr.Max; i++ {
addrs = append(addrs, net.JoinHostPort(h, strconv.Itoa(i)))
}
return addrs
}
// Port range is a range of port list.
type PortRange struct {
Min int
Max int
}
// Parse parses the s to PortRange.
// The s can be a single port number and will be converted to port range port-port.
func (pr *PortRange) Parse(s string) error {
minmax := strings.Split(s, "-")
switch len(minmax) {
case 1:
port, err := strconv.Atoi(s)
if err != nil {
return err
}
if port < 0 || port > 65535 {
return fmt.Errorf("invalid port: %s", s)
}
pr.Min, pr.Max = port, port
return nil
case 2:
min, err := strconv.Atoi(minmax[0])
if err != nil {
return err
}
max, err := strconv.Atoi(minmax[1])
if err != nil {
return err
}
pr.Min, pr.Max = min, max
return nil
default:
return fmt.Errorf("invalid range: %s", s)
}
}
func (pr *PortRange) Contains(port int) bool {
return port >= pr.Min && port <= pr.Max
}

View File

@ -39,7 +39,7 @@ func (r *Relay) SetBufferSize(n int) {
r.bufferSize = n r.bufferSize = n
} }
func (r *Relay) Run() (err error) { func (r *Relay) Run(ctx context.Context) (err error) {
bufSize := r.bufferSize bufSize := r.bufferSize
if bufSize <= 0 { if bufSize <= 0 {
bufSize = 4096 bufSize = 4096
@ -58,7 +58,7 @@ func (r *Relay) Run() (err error) {
return err return err
} }
if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) { if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) {
if r.logger != nil { if r.logger != nil {
r.logger.Warn("bypass: ", raddr) r.logger.Warn("bypass: ", raddr)
} }
@ -96,7 +96,7 @@ func (r *Relay) Run() (err error) {
return err return err
} }
if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) { if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) {
if r.logger != nil { if r.logger != nil {
r.logger.Warn("bypass: ", raddr) r.logger.Warn("bypass: ", raddr)
} }

View File

@ -50,12 +50,13 @@ func (l *rudpListener) Init(md md.Metadata) (err error) {
if xnet.IsIPv4(l.options.Addr) { if xnet.IsIPv4(l.options.Addr) {
network = "udp4" network = "udp4"
} }
laddr, err := net.ResolveUDPAddr(network, l.options.Addr) if laddr, _ := net.ResolveUDPAddr(network, l.options.Addr); laddr != nil {
if err != nil { l.laddr = laddr
return }
if l.laddr == nil {
l.laddr = &bindAddr{addr: l.options.Addr}
} }
l.laddr = laddr
l.router = chain.NewRouter( l.router = chain.NewRouter(
chain.ChainRouterOption(l.options.Chain), chain.ChainRouterOption(l.options.Chain),
chain.LoggerRouterOption(l.logger), chain.LoggerRouterOption(l.logger),
@ -116,3 +117,15 @@ func (l *rudpListener) Close() error {
return nil return nil
} }
type bindAddr struct {
addr string
}
func (p *bindAddr) Network() string {
return "tcp"
}
func (p *bindAddr) String() string {
return p.addr
}

View File

@ -9,7 +9,7 @@ import (
const ( const (
defaultTTL = 5 * time.Second defaultTTL = 5 * time.Second
defaultReadBufferSize = 4096 defaultReadBufferSize = 1024
defaultReadQueueSize = 1024 defaultReadQueueSize = 1024
defaultBacklog = 128 defaultBacklog = 128
) )