add range port support for forwarder node
This commit is contained in:
parent
ca1f44d93c
commit
d7b7ac6357
@ -23,6 +23,7 @@ import (
|
|||||||
bypass_parser "github.com/go-gost/x/config/parsing/bypass"
|
bypass_parser "github.com/go-gost/x/config/parsing/bypass"
|
||||||
hop_parser "github.com/go-gost/x/config/parsing/hop"
|
hop_parser "github.com/go-gost/x/config/parsing/hop"
|
||||||
selector_parser "github.com/go-gost/x/config/parsing/selector"
|
selector_parser "github.com/go-gost/x/config/parsing/selector"
|
||||||
|
xnet "github.com/go-gost/x/internal/net"
|
||||||
tls_util "github.com/go-gost/x/internal/util/tls"
|
tls_util "github.com/go-gost/x/internal/util/tls"
|
||||||
"github.com/go-gost/x/metadata"
|
"github.com/go-gost/x/metadata"
|
||||||
"github.com/go-gost/x/registry"
|
"github.com/go-gost/x/registry"
|
||||||
@ -258,10 +259,18 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) {
|
|||||||
}
|
}
|
||||||
for _, node := range cfg.Nodes {
|
for _, node := range cfg.Nodes {
|
||||||
if node != nil {
|
if node != nil {
|
||||||
hc.Nodes = append(hc.Nodes,
|
addrs := xnet.AddrPortRange(node.Addr).Addrs()
|
||||||
&config.NodeConfig{
|
if len(addrs) == 0 {
|
||||||
Name: node.Name,
|
addrs = append(addrs, node.Addr)
|
||||||
Addr: node.Addr,
|
}
|
||||||
|
for i, addr := range addrs {
|
||||||
|
name := node.Name
|
||||||
|
if i > 0 {
|
||||||
|
name = fmt.Sprintf("%s-%d", node.Name, i)
|
||||||
|
}
|
||||||
|
hc.Nodes = append(hc.Nodes, &config.NodeConfig{
|
||||||
|
Name: name,
|
||||||
|
Addr: addr,
|
||||||
Host: node.Host,
|
Host: node.Host,
|
||||||
Network: node.Network,
|
Network: node.Network,
|
||||||
Protocol: node.Protocol,
|
Protocol: node.Protocol,
|
||||||
@ -271,8 +280,8 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) {
|
|||||||
HTTP: node.HTTP,
|
HTTP: node.HTTP,
|
||||||
TLS: node.TLS,
|
TLS: node.TLS,
|
||||||
Auth: node.Auth,
|
Auth: node.Auth,
|
||||||
},
|
})
|
||||||
)
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(hc.Nodes) > 0 {
|
if len(hc.Nodes) > 0 {
|
||||||
|
@ -71,7 +71,7 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, log logger.L
|
|||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
log.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
relay.Run()
|
relay.Run(ctx)
|
||||||
log.WithFields(map[string]any{
|
log.WithFields(map[string]any{
|
||||||
"duration": time.Since(t),
|
"duration": time.Since(t),
|
||||||
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
}).Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
|
@ -176,7 +176,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr
|
|||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
r.Run()
|
r.Run(ctx)
|
||||||
log.WithFields(map[string]any{
|
log.WithFields(map[string]any{
|
||||||
"duration": time.Since(t),
|
"duration": time.Since(t),
|
||||||
}).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
}).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
|
@ -72,7 +72,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger
|
|||||||
WithLogger(log)
|
WithLogger(log)
|
||||||
r.SetBufferSize(h.md.udpBufferSize)
|
r.SetBufferSize(h.md.udpBufferSize)
|
||||||
|
|
||||||
go r.Run()
|
go r.Run(ctx)
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
|
||||||
|
@ -63,7 +63,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network
|
|||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
log.Debugf("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
r.Run()
|
r.Run(ctx)
|
||||||
log.WithFields(map[string]any{
|
log.WithFields(map[string]any{
|
||||||
"duration": time.Since(t),
|
"duration": time.Since(t),
|
||||||
}).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
}).Debugf("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr())
|
||||||
|
@ -130,7 +130,7 @@ func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, log logge
|
|||||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||||
header.PayloadLen, header.TrafficClass)
|
header.PayloadLen, header.TrafficClass)
|
||||||
} else {
|
} else {
|
||||||
log.Warn("unknown packet, discarded")
|
log.Warnf("unknown packet, discarded(%d)", n)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con
|
|||||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||||
header.PayloadLen, header.TrafficClass)
|
header.PayloadLen, header.TrafficClass)
|
||||||
} else {
|
} else {
|
||||||
log.Warn("unknown packet, discarded")
|
log.Warnf("unknown packet, discarded(%d)", n)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con
|
|||||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||||
header.PayloadLen, header.TrafficClass)
|
header.PayloadLen, header.TrafficClass)
|
||||||
} else {
|
} else {
|
||||||
log.Warn("unknown packet, discarded")
|
log.Warnf("unknown packet, discarded(%d): % x", n, b[:n])
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
package matcher
|
package matcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
xnet "github.com/go-gost/x/internal/net"
|
||||||
"github.com/gobwas/glob"
|
"github.com/gobwas/glob"
|
||||||
"github.com/yl2chen/cidranger"
|
"github.com/yl2chen/cidranger"
|
||||||
)
|
)
|
||||||
@ -40,7 +40,7 @@ func (m *ipMatcher) Match(ip string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type addrMatcher struct {
|
type addrMatcher struct {
|
||||||
addrs map[string]*PortRange
|
addrs map[string]*xnet.PortRange
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddrMatcher creates a Matcher with a list of HOST:PORT addresses.
|
// AddrMatcher creates a Matcher with a list of HOST:PORT addresses.
|
||||||
@ -50,7 +50,7 @@ type addrMatcher struct {
|
|||||||
// The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535).
|
// The PORT can be a single port number or port range MIN-MAX(e.g. 0-65535).
|
||||||
func AddrMatcher(addrs []string) Matcher {
|
func AddrMatcher(addrs []string) Matcher {
|
||||||
matcher := &addrMatcher{
|
matcher := &addrMatcher{
|
||||||
addrs: make(map[string]*PortRange),
|
addrs: make(map[string]*xnet.PortRange),
|
||||||
}
|
}
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
host, port, _ := net.SplitHostPort(addr)
|
host, port, _ := net.SplitHostPort(addr)
|
||||||
@ -58,7 +58,10 @@ func AddrMatcher(addrs []string) Matcher {
|
|||||||
matcher.addrs[addr] = nil
|
matcher.addrs[addr] = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
pr, _ := parsePortRange(port)
|
pr := &xnet.PortRange{}
|
||||||
|
if err := pr.Parse(port); err != nil {
|
||||||
|
pr = nil
|
||||||
|
}
|
||||||
matcher.addrs[host] = pr
|
matcher.addrs[host] = pr
|
||||||
}
|
}
|
||||||
return matcher
|
return matcher
|
||||||
@ -75,13 +78,13 @@ func (m *addrMatcher) Match(addr string) bool {
|
|||||||
port, _ := strconv.Atoi(sp)
|
port, _ := strconv.Atoi(sp)
|
||||||
|
|
||||||
if pr, ok := m.addrs[host]; ok {
|
if pr, ok := m.addrs[host]; ok {
|
||||||
if pr == nil || pr.contains(port) {
|
if pr == nil || pr.Contains(port) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if pr, ok := m.addrs["."+host]; ok {
|
if pr, ok := m.addrs["."+host]; ok {
|
||||||
if pr == nil || pr.contains(port) {
|
if pr == nil || pr.Contains(port) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -89,7 +92,7 @@ func (m *addrMatcher) Match(addr string) bool {
|
|||||||
for {
|
for {
|
||||||
if index := strings.IndexByte(host, '.'); index > 0 {
|
if index := strings.IndexByte(host, '.'); index > 0 {
|
||||||
if pr, ok := m.addrs[host[index:]]; ok {
|
if pr, ok := m.addrs[host[index:]]; ok {
|
||||||
if pr == nil || pr.contains(port) {
|
if pr == nil || pr.Contains(port) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -172,7 +175,7 @@ func (m *domainMatcher) Match(domain string) bool {
|
|||||||
|
|
||||||
type wildcardMatcherPattern struct {
|
type wildcardMatcherPattern struct {
|
||||||
glob glob.Glob
|
glob glob.Glob
|
||||||
pr *PortRange
|
pr *xnet.PortRange
|
||||||
}
|
}
|
||||||
type wildcardMatcher struct {
|
type wildcardMatcher struct {
|
||||||
patterns []wildcardMatcherPattern
|
patterns []wildcardMatcherPattern
|
||||||
@ -187,7 +190,11 @@ func WildcardMatcher(patterns []string) Matcher {
|
|||||||
if host == "" {
|
if host == "" {
|
||||||
host = pattern
|
host = pattern
|
||||||
}
|
}
|
||||||
pr, _ := parsePortRange(port)
|
pr := &xnet.PortRange{}
|
||||||
|
if err := pr.Parse(port); err != nil {
|
||||||
|
pr = nil
|
||||||
|
}
|
||||||
|
|
||||||
matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{
|
matcher.patterns = append(matcher.patterns, wildcardMatcherPattern{
|
||||||
glob: glob.MustCompile(host),
|
glob: glob.MustCompile(host),
|
||||||
pr: pr,
|
pr: pr,
|
||||||
@ -208,8 +215,8 @@ func (m *wildcardMatcher) Match(addr string) bool {
|
|||||||
}
|
}
|
||||||
port, _ := strconv.Atoi(sp)
|
port, _ := strconv.Atoi(sp)
|
||||||
for _, pattern := range m.patterns {
|
for _, pattern := range m.patterns {
|
||||||
if pattern.glob.Match(addr) {
|
if pattern.glob.Match(host) {
|
||||||
if pattern.pr == nil || pattern.pr.contains(port) {
|
if pattern.pr == nil || pattern.pr.Contains(port) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -217,44 +224,3 @@ func (m *wildcardMatcher) Match(addr string) bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
type PortRange struct {
|
|
||||||
Min int
|
|
||||||
Max int
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParsePortRange parses the s to a PortRange.
|
|
||||||
// The s can be a single port number and will be converted to port range port-port.
|
|
||||||
func parsePortRange(s string) (*PortRange, error) {
|
|
||||||
minmax := strings.Split(s, "-")
|
|
||||||
switch len(minmax) {
|
|
||||||
case 1:
|
|
||||||
port, err := strconv.Atoi(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if port < 0 || port > 65535 {
|
|
||||||
return nil, fmt.Errorf("invalid port: %s", s)
|
|
||||||
}
|
|
||||||
return &PortRange{Min: port, Max: port}, nil
|
|
||||||
|
|
||||||
case 2:
|
|
||||||
min, err := strconv.Atoi(minmax[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
max, err := strconv.Atoi(minmax[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PortRange{Min: min, Max: max}, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid range: %s", s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pr *PortRange) contains(port int) bool {
|
|
||||||
return port >= pr.Min && port <= pr.Max
|
|
||||||
}
|
|
||||||
|
72
internal/net/addr.go
Normal file
72
internal/net/addr.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddrPortRange is the network address with port range supported.
|
||||||
|
// e.g. 192.168.1.1:0-65535
|
||||||
|
type AddrPortRange string
|
||||||
|
|
||||||
|
func (p AddrPortRange) Addrs() (addrs []string) {
|
||||||
|
h, sp, err := net.SplitHostPort(string(p))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pr := PortRange{}
|
||||||
|
pr.Parse(sp)
|
||||||
|
|
||||||
|
for i := pr.Min; i <= pr.Max; i++ {
|
||||||
|
addrs = append(addrs, net.JoinHostPort(h, strconv.Itoa(i)))
|
||||||
|
}
|
||||||
|
return addrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port range is a range of port list.
|
||||||
|
type PortRange struct {
|
||||||
|
Min int
|
||||||
|
Max int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses the s to PortRange.
|
||||||
|
// The s can be a single port number and will be converted to port range port-port.
|
||||||
|
func (pr *PortRange) Parse(s string) error {
|
||||||
|
minmax := strings.Split(s, "-")
|
||||||
|
switch len(minmax) {
|
||||||
|
case 1:
|
||||||
|
port, err := strconv.Atoi(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if port < 0 || port > 65535 {
|
||||||
|
return fmt.Errorf("invalid port: %s", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
pr.Min, pr.Max = port, port
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case 2:
|
||||||
|
min, err := strconv.Atoi(minmax[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
max, err := strconv.Atoi(minmax[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pr.Min, pr.Max = min, max
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid range: %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pr *PortRange) Contains(port int) bool {
|
||||||
|
return port >= pr.Min && port <= pr.Max
|
||||||
|
}
|
@ -39,7 +39,7 @@ func (r *Relay) SetBufferSize(n int) {
|
|||||||
r.bufferSize = n
|
r.bufferSize = n
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Relay) Run() (err error) {
|
func (r *Relay) Run(ctx context.Context) (err error) {
|
||||||
bufSize := r.bufferSize
|
bufSize := r.bufferSize
|
||||||
if bufSize <= 0 {
|
if bufSize <= 0 {
|
||||||
bufSize = 4096
|
bufSize = 4096
|
||||||
@ -58,7 +58,7 @@ func (r *Relay) Run() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) {
|
if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) {
|
||||||
if r.logger != nil {
|
if r.logger != nil {
|
||||||
r.logger.Warn("bypass: ", raddr)
|
r.logger.Warn("bypass: ", raddr)
|
||||||
}
|
}
|
||||||
@ -96,7 +96,7 @@ func (r *Relay) Run() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.bypass != nil && r.bypass.Contains(context.Background(), "udp", raddr.String()) {
|
if r.bypass != nil && r.bypass.Contains(ctx, "udp", raddr.String()) {
|
||||||
if r.logger != nil {
|
if r.logger != nil {
|
||||||
r.logger.Warn("bypass: ", raddr)
|
r.logger.Warn("bypass: ", raddr)
|
||||||
}
|
}
|
||||||
|
@ -50,12 +50,13 @@ func (l *rudpListener) Init(md md.Metadata) (err error) {
|
|||||||
if xnet.IsIPv4(l.options.Addr) {
|
if xnet.IsIPv4(l.options.Addr) {
|
||||||
network = "udp4"
|
network = "udp4"
|
||||||
}
|
}
|
||||||
laddr, err := net.ResolveUDPAddr(network, l.options.Addr)
|
if laddr, _ := net.ResolveUDPAddr(network, l.options.Addr); laddr != nil {
|
||||||
if err != nil {
|
l.laddr = laddr
|
||||||
return
|
}
|
||||||
|
if l.laddr == nil {
|
||||||
|
l.laddr = &bindAddr{addr: l.options.Addr}
|
||||||
}
|
}
|
||||||
|
|
||||||
l.laddr = laddr
|
|
||||||
l.router = chain.NewRouter(
|
l.router = chain.NewRouter(
|
||||||
chain.ChainRouterOption(l.options.Chain),
|
chain.ChainRouterOption(l.options.Chain),
|
||||||
chain.LoggerRouterOption(l.logger),
|
chain.LoggerRouterOption(l.logger),
|
||||||
@ -116,3 +117,15 @@ func (l *rudpListener) Close() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type bindAddr struct {
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bindAddr) Network() string {
|
||||||
|
return "tcp"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *bindAddr) String() string {
|
||||||
|
return p.addr
|
||||||
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultTTL = 5 * time.Second
|
defaultTTL = 5 * time.Second
|
||||||
defaultReadBufferSize = 4096
|
defaultReadBufferSize = 1024
|
||||||
defaultReadQueueSize = 1024
|
defaultReadQueueSize = 1024
|
||||||
defaultBacklog = 128
|
defaultBacklog = 128
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user