diff --git a/common/net/dialer/dialer.go b/common/net/dialer/dialer.go index c86a494..0a85e32 100644 --- a/common/net/dialer/dialer.go +++ b/common/net/dialer/dialer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "strings" "syscall" "time" @@ -24,28 +25,10 @@ type NetDialer struct { Timeout time.Duration DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) Logger logger.Logger + deadline time.Time } -func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { - if d == nil { - d = DefaultNetDialer - } - if d.Timeout <= 0 { - d.Timeout = DefaultTimeout - } - - log := d.Logger - if log == nil { - log = logger.Default() - } - - ifceName, ifAddr, err := parseInterfaceAddr(d.Interface, network) - if err != nil { - return nil, err - } - if d.DialFunc != nil { - return d.DialFunc(ctx, network, addr) - } +func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, log logger.Logger) (net.Conn, error) { if ifceName != "" { log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network) } @@ -89,7 +72,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, e return nil, fmt.Errorf("dial: unsupported network %s", network) } netd := net.Dialer{ - Timeout: d.Timeout, + Deadline: d.deadline, LocalAddr: ifAddr, Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { @@ -109,8 +92,76 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, e return netd.DialContext(ctx, network, addr) } -func parseInterfaceAddr(ifceName, network string) (ifce string, addr net.Addr, err error) { +func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Conn, err error) { + if d == nil { + d = DefaultNetDialer + } + if d.Timeout <= 0 { + d.Timeout = DefaultTimeout + } + + if d.DialFunc != nil { + return d.DialFunc(ctx, network, addr) + } + + log := d.Logger + if log == nil { + log = logger.Default() + } + + ifces := strings.Split(d.Interface, ",") + d.deadline = time.Now().Add(d.Timeout) + for _, ifce := range ifces { + strict := strings.HasSuffix(ifce, "!") + ifce = strings.TrimSuffix(ifce, "!") + var ifceName string + var ifAddrs []net.Addr + ifceName, ifAddrs, err = parseInterfaceAddr(ifce, network) + if err != nil && strict { + return + } + + for _, ifAddr := range ifAddrs { + conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log) + if err == nil { + return + } + + log.Debugf("dial %s %v@%s failed: %s", network, ifAddr, ifceName, err) + + if strict && + !strings.Contains(err.Error(), "no suitable address found") && + !strings.Contains(err.Error(), "mismatched local address type") { + return + } + + if time.Until(d.deadline) < 0 { + return + } + } + } + + return +} + +func ipToAddr(ip net.IP, network string) (addr net.Addr) { + port := 0 + switch network { + case "tcp", "tcp4", "tcp6": + addr = &net.TCPAddr{IP: ip, Port: port} + return + case "udp", "udp4", "udp6": + addr = &net.UDPAddr{IP: ip, Port: port} + return + default: + addr = &net.IPAddr{IP: ip} + return + } +} + +func parseInterfaceAddr(ifceName, network string) (ifce string, addr []net.Addr, err error) { if ifceName == "" { + addr = append(addr, nil) return } @@ -130,27 +181,21 @@ func parseInterfaceAddr(ifceName, network string) (ifce string, addr net.Addr, e err = fmt.Errorf("addr not found for interface %s", ifceName) return } - ip = addrs[0].(*net.IPNet).IP ifce = ifceName + for _, addr_ := range addrs { + if ipNet, ok := addr_.(*net.IPNet); ok { + addr = append(addr, ipToAddr(ipNet.IP, network)) + } + } } else { ifce, err = findInterfaceByIP(ip) if err != nil { return } + addr = []net.Addr{ipToAddr(ip, network)} } - port := 0 - switch network { - case "tcp", "tcp4", "tcp6": - addr = &net.TCPAddr{IP: ip, Port: port} - return - case "udp", "udp4", "udp6": - addr = &net.UDPAddr{IP: ip, Port: port} - return - default: - addr = &net.IPAddr{IP: ip} - return - } + return } func findInterfaceByIP(ip net.IP) (string, error) {