diff --git a/config/parsing/service/parse.go b/config/parsing/service/parse.go index 96f0645..8edf2e5 100644 --- a/config/parsing/service/parse.go +++ b/config/parsing/service/parse.go @@ -23,6 +23,7 @@ import ( bypass_parser "github.com/go-gost/x/config/parsing/bypass" hop_parser "github.com/go-gost/x/config/parsing/hop" selector_parser "github.com/go-gost/x/config/parsing/selector" + xnet "github.com/go-gost/x/internal/net" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" "github.com/go-gost/x/registry" @@ -258,10 +259,18 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) { } for _, node := range cfg.Nodes { if node != nil { - hc.Nodes = append(hc.Nodes, - &config.NodeConfig{ - Name: node.Name, - Addr: node.Addr, + addrs := xnet.AddrPortRange(node.Addr).Addrs() + if len(addrs) == 0 { + addrs = append(addrs, node.Addr) + } + for i, addr := range addrs { + name := node.Name + if i > 0 { + name = fmt.Sprintf("%s-%d", node.Name, i) + } + hc.Nodes = append(hc.Nodes, &config.NodeConfig{ + Name: name, + Addr: addr, Host: node.Host, Network: node.Network, Protocol: node.Protocol, @@ -271,8 +280,8 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) { HTTP: node.HTTP, TLS: node.TLS, Auth: node.Auth, - }, - ) + }) + } } } if len(hc.Nodes) > 0 { diff --git a/handler/http/udp.go b/handler/http/udp.go index 994c640..52db1ca 100644 --- a/handler/http/udp.go +++ b/handler/http/udp.go @@ -71,7 +71,7 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, log logger.L t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) - relay.Run() + relay.Run(ctx) log.WithFields(map[string]any{ "duration": time.Since(t), }).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) diff --git a/handler/relay/bind.go b/handler/relay/bind.go index f6d9aa7..26016e9 100644 --- a/handler/relay/bind.go +++ b/handler/relay/bind.go @@ -176,7 +176,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) - r.Run() + r.Run(ctx) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) diff --git a/handler/socks/v5/udp.go b/handler/socks/v5/udp.go index 1a4506c..d7a34cb 100644 --- a/handler/socks/v5/udp.go +++ b/handler/socks/v5/udp.go @@ -72,7 +72,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger WithLogger(log) r.SetBufferSize(h.md.udpBufferSize) - go r.Run() + go r.Run(ctx) t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) diff --git a/handler/socks/v5/udp_tun.go b/handler/socks/v5/udp_tun.go index a870e2e..6b65b83 100644 --- a/handler/socks/v5/udp_tun.go +++ b/handler/socks/v5/udp_tun.go @@ -63,7 +63,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) - r.Run() + r.Run(ctx) log.WithFields(map[string]any{ "duration": time.Since(t), }).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) diff --git a/handler/tun/client.go b/handler/tun/client.go index af3e4b0..b3ece0c 100644 --- a/handler/tun/client.go +++ b/handler/tun/client.go @@ -130,7 +130,7 @@ func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, log logge ipProtocol(waterutil.IPProtocol(header.NextHeader)), header.PayloadLen, header.TrafficClass) } else { - log.Warn("unknown packet, discarded") + log.Warnf("unknown packet, discarded(%d)", n) return nil } diff --git a/handler/tun/server.go b/handler/tun/server.go index 22a13d6..3c2784c 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -78,7 +78,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con ipProtocol(waterutil.IPProtocol(header.NextHeader)), header.PayloadLen, header.TrafficClass) } else { - log.Warn("unknown packet, discarded") + log.Warnf("unknown packet, discarded(%d)", n) return nil } @@ -199,7 +199,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con ipProtocol(waterutil.IPProtocol(header.NextHeader)), header.PayloadLen, header.TrafficClass) } else { - log.Warn("unknown packet, discarded") + log.Warnf("unknown packet, discarded(%d): % x", n, b[:n]) return nil } diff --git a/internal/matcher/matcher.go b/internal/matcher/matcher.go index 9ab7d17..8a2b3c0 100644 --- a/internal/matcher/matcher.go +++ b/internal/matcher/matcher.go @@ -1,11 +1,11 @@ package matcher import ( - "fmt" "net" "strconv" "strings" + xnet "github.com/go-gost/x/internal/net" "github.com/gobwas/glob" "github.com/yl2chen/cidranger" ) @@ -40,7 +40,7 @@ func (m *ipMatcher) Match(ip string) bool { } type addrMatcher struct { - addrs map[string]*PortRange + addrs map[string]*xnet.PortRange } // AddrMatcher creates a Matcher with a list of HOST:PORT addresses. @@ -50,7 +50,7 @@ type addrMatcher struct { // The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535). func AddrMatcher(addrs []string) Matcher { matcher := &addrMatcher{ - addrs: make(map[string]*PortRange), + addrs: make(map[string]*xnet.PortRange), } for _, addr := range addrs { host, port, _ := net.SplitHostPort(addr) @@ -58,7 +58,10 @@ func AddrMatcher(addrs []string) Matcher { matcher.addrs[addr] = nil continue } - pr, _ := parsePortRange(port) + pr := &xnet.PortRange{} + if err := pr.Parse(port); err != nil { + pr = nil + } matcher.addrs[host] = pr } return matcher @@ -75,13 +78,13 @@ func (m *addrMatcher) Match(addr string) bool { port, _ := strconv.Atoi(sp) if pr, ok := m.addrs[host]; ok { - if pr == nil || pr.contains(port) { + if pr == nil || pr.Contains(port) { return true } } if pr, ok := m.addrs["."+host]; ok { - if pr == nil || pr.contains(port) { + if pr == nil || pr.Contains(port) { return true } } @@ -89,7 +92,7 @@ func (m *addrMatcher) Match(addr string) bool { for { if index := strings.IndexByte(host, '.'); index > 0 { if pr, ok := m.addrs[host[index:]]; ok { - if pr == nil || pr.contains(port) { + if pr == nil || pr.Contains(port) { return true } } @@ -172,7 +175,7 @@ func (m *domainMatcher) Match(domain string) bool { type wildcardMatcherPattern struct { glob glob.Glob - pr *PortRange + pr *xnet.PortRange } type wildcardMatcher struct { patterns []wildcardMatcherPattern @@ -187,7 +190,11 @@ func WildcardMatcher(patterns []string) Matcher { if host == "" { host = pattern } - pr, _ := parsePortRange(port) + pr := &xnet.PortRange{} + if err := pr.Parse(port); err != nil { + pr = nil + } + matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{ glob: glob.MustCompile(host), pr: pr, @@ -208,8 +215,8 @@ func (m *wildcardMatcher) Match(addr string) bool { } port, _ := strconv.Atoi(sp) for _, pattern := range m.patterns { - if pattern.glob.Match(addr) { - if pattern.pr == nil || pattern.pr.contains(port) { + if pattern.glob.Match(host) { + if pattern.pr == nil || pattern.pr.Contains(port) { return true } } @@ -217,44 +224,3 @@ func (m *wildcardMatcher) Match(addr string) bool { return false } - -type PortRange struct { - Min int - Max int -} - -// ParsePortRange parses the s to a PortRange. -// The s can be a single port number and will be converted to port range port-port. -func parsePortRange(s string) (*PortRange, error) { - minmax := strings.Split(s, "-") - switch len(minmax) { - case 1: - port, err := strconv.Atoi(s) - if err != nil { - return nil, err - } - if port < 0 || port > 65535 { - return nil, fmt.Errorf("invalid port: %s", s) - } - return &PortRange{Min: port, Max: port}, nil - - case 2: - min, err := strconv.Atoi(minmax[0]) - if err != nil { - return nil, err - } - max, err := strconv.Atoi(minmax[1]) - if err != nil { - return nil, err - } - - return &PortRange{Min: min, Max: max}, nil - - default: - return nil, fmt.Errorf("invalid range: %s", s) - } -} - -func (pr *PortRange) contains(port int) bool { - return port >= pr.Min && port <= pr.Max -} diff --git a/internal/net/addr.go b/internal/net/addr.go new file mode 100644 index 0000000..b549a00 --- /dev/null +++ b/internal/net/addr.go @@ -0,0 +1,72 @@ +package net + +import ( + "fmt" + "net" + "strconv" + "strings" +) + +// AddrPortRange is the network address with port range supported. +// e.g. 192.168.1.1:0-65535 +type AddrPortRange string + +func (p AddrPortRange) Addrs() (addrs []string) { + h, sp, err := net.SplitHostPort(string(p)) + if err != nil { + return nil + } + + pr := PortRange{} + pr.Parse(sp) + + for i := pr.Min; i <= pr.Max; i++ { + addrs = append(addrs, net.JoinHostPort(h, strconv.Itoa(i))) + } + return addrs +} + +// Port range is a range of port list. +type PortRange struct { + Min int + Max int +} + +// Parse parses the s to PortRange. +// The s can be a single port number and will be converted to port range port-port. +func (pr *PortRange) Parse(s string) error { + minmax := strings.Split(s, "-") + switch len(minmax) { + case 1: + port, err := strconv.Atoi(s) + if err != nil { + return err + } + if port < 0 || port > 65535 { + return fmt.Errorf("invalid port: %s", s) + } + + pr.Min, pr.Max = port, port + return nil + + case 2: + min, err := strconv.Atoi(minmax[0]) + if err != nil { + return err + } + max, err := strconv.Atoi(minmax[1]) + if err != nil { + return err + } + + pr.Min, pr.Max = min, max + return nil + + default: + return fmt.Errorf("invalid range: %s", s) + } +} + +func (pr *PortRange) Contains(port int) bool { + return port >= pr.Min && port <= pr.Max +} diff --git a/internal/net/udp/relay.go b/internal/net/udp/relay.go index 79146f1..1694ba3 100644 --- a/internal/net/udp/relay.go +++ b/internal/net/udp/relay.go @@ -39,7 +39,7 @@ func (r *Relay) SetBufferSize(n int) { r.bufferSize = n } -func (r *Relay) Run() (err error) { +func (r *Relay) Run(ctx context.Context) (err error) { bufSize := r.bufferSize if bufSize <= 0 { bufSize = 4096 @@ -58,7 +58,7 @@ func (r *Relay) Run() (err error) { return err } - if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) { + if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) { if r.logger != nil { r.logger.Warn("bypass: ", raddr) } @@ -96,7 +96,7 @@ func (r *Relay) Run() (err error) { return err } - if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) { + if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) { if r.logger != nil { r.logger.Warn("bypass: ", raddr) } diff --git a/listener/rudp/listener.go b/listener/rudp/listener.go index 38d5efe..7125212 100644 --- a/listener/rudp/listener.go +++ b/listener/rudp/listener.go @@ -50,12 +50,13 @@ func (l *rudpListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "udp4" } - laddr, err := net.ResolveUDPAddr(network, l.options.Addr) - if err != nil { - return + if laddr, _ := net.ResolveUDPAddr(network, l.options.Addr); laddr != nil { + l.laddr = laddr + } + if l.laddr == nil { + l.laddr = &bindAddr{addr: l.options.Addr} } - l.laddr = laddr l.router = chain.NewRouter( chain.ChainRouterOption(l.options.Chain), chain.LoggerRouterOption(l.logger), @@ -116,3 +117,15 @@ func (l *rudpListener) Close() error { return nil } + +type bindAddr struct { + addr string +} + +func (p *bindAddr) Network() string { + return "tcp" +} + +func (p *bindAddr) String() string { + return p.addr +} diff --git a/listener/rudp/metadata.go b/listener/rudp/metadata.go index fd71f7d..b4b7af0 100644 --- a/listener/rudp/metadata.go +++ b/listener/rudp/metadata.go @@ -9,7 +9,7 @@ import ( const ( defaultTTL = 5 * time.Second - defaultReadBufferSize = 4096 + defaultReadBufferSize = 1024 defaultReadQueueSize = 1024 defaultBacklog = 128 )