update common/net

This commit is contained in:
ginuerzh
2022-12-01 19:22:14 +08:00
parent 8effb0b8b9
commit c9c31fa74c
2 changed files with 137 additions and 130 deletions

84
common/net/addr.go Normal file
View File

@ -0,0 +1,84 @@
package net
import (
"fmt"
"net"
)
func ParseInterfaceAddr(ifceName, network string) (ifce string, addr []net.Addr, err error) {
if ifceName == "" {
addr = append(addr, nil)
return
}
ip := net.ParseIP(ifceName)
if ip == nil {
var ife *net.Interface
ife, err = net.InterfaceByName(ifceName)
if err != nil {
return
}
var addrs []net.Addr
addrs, err = ife.Addrs()
if err != nil {
return
}
if len(addrs) == 0 {
err = fmt.Errorf("addr not found for interface %s", ifceName)
return
}
ifce = ifceName
for _, addr_ := range addrs {
if ipNet, ok := addr_.(*net.IPNet); ok {
addr = append(addr, ipToAddr(ipNet.IP, network))
}
}
} else {
ifce, err = findInterfaceByIP(ip)
if err != nil {
return
}
addr = []net.Addr{ipToAddr(ip, network)}
}
return
}
func ipToAddr(ip net.IP, network string) (addr net.Addr) {
port := 0
switch network {
case "tcp", "tcp4", "tcp6":
addr = &net.TCPAddr{IP: ip, Port: port}
return
case "udp", "udp4", "udp6":
addr = &net.UDPAddr{IP: ip, Port: port}
return
default:
addr = &net.IPAddr{IP: ip}
return
}
}
func findInterfaceByIP(ip net.IP) (string, error) {
ifces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, ifce := range ifces {
addrs, _ := ifce.Addrs()
if len(addrs) == 0 {
continue
}
for _, addr := range addrs {
ipAddr, _ := addr.(*net.IPNet)
if ipAddr == nil {
continue
}
// logger.Default().Infof("%s-%s", ipAddr, ip)
if ipAddr.IP.Equal(ip) {
return ifce.Name, nil
}
}
}
return "", nil
}

View File

@ -8,6 +8,7 @@ import (
"syscall"
"time"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
)
@ -28,6 +29,58 @@ type NetDialer struct {
deadline time.Time
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if d == nil {
d = DefaultNetDialer
}
if d.Timeout <= 0 {
d.Timeout = DefaultTimeout
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
log := d.Logger
if log == nil {
log = logger.Default()
}
ifces := strings.Split(d.Interface, ",")
d.deadline = time.Now().Add(d.Timeout)
for _, ifce := range ifces {
strict := strings.HasSuffix(ifce, "!")
ifce = strings.TrimSuffix(ifce, "!")
var ifceName string
var ifAddrs []net.Addr
ifceName, ifAddrs, err = xnet.ParseInterfaceAddr(ifce, network)
if err != nil && strict {
return
}
for _, ifAddr := range ifAddrs {
conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log)
if err == nil {
return
}
log.Debugf("dial %s %v@%s failed: %s", network, ifAddr, ifceName, err)
if strict &&
!strings.Contains(err.Error(), "no suitable address found") &&
!strings.Contains(err.Error(), "mismatched local address type") {
return
}
if time.Until(d.deadline) < 0 {
return
}
}
}
return
}
func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, log logger.Logger) (net.Conn, error) {
if ifceName != "" {
log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network)
@ -91,133 +144,3 @@ func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string
}
return netd.DialContext(ctx, network, addr)
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if d == nil {
d = DefaultNetDialer
}
if d.Timeout <= 0 {
d.Timeout = DefaultTimeout
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
log := d.Logger
if log == nil {
log = logger.Default()
}
ifces := strings.Split(d.Interface, ",")
d.deadline = time.Now().Add(d.Timeout)
for _, ifce := range ifces {
strict := strings.HasSuffix(ifce, "!")
ifce = strings.TrimSuffix(ifce, "!")
var ifceName string
var ifAddrs []net.Addr
ifceName, ifAddrs, err = parseInterfaceAddr(ifce, network)
if err != nil && strict {
return
}
for _, ifAddr := range ifAddrs {
conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log)
if err == nil {
return
}
log.Debugf("dial %s %v@%s failed: %s", network, ifAddr, ifceName, err)
if strict &&
!strings.Contains(err.Error(), "no suitable address found") &&
!strings.Contains(err.Error(), "mismatched local address type") {
return
}
if time.Until(d.deadline) < 0 {
return
}
}
}
return
}
func ipToAddr(ip net.IP, network string) (addr net.Addr) {
port := 0
switch network {
case "tcp", "tcp4", "tcp6":
addr = &net.TCPAddr{IP: ip, Port: port}
return
case "udp", "udp4", "udp6":
addr = &net.UDPAddr{IP: ip, Port: port}
return
default:
addr = &net.IPAddr{IP: ip}
return
}
}
func parseInterfaceAddr(ifceName, network string) (ifce string, addr []net.Addr, err error) {
if ifceName == "" {
addr = append(addr, nil)
return
}
ip := net.ParseIP(ifceName)
if ip == nil {
var ife *net.Interface
ife, err = net.InterfaceByName(ifceName)
if err != nil {
return
}
var addrs []net.Addr
addrs, err = ife.Addrs()
if err != nil {
return
}
if len(addrs) == 0 {
err = fmt.Errorf("addr not found for interface %s", ifceName)
return
}
ifce = ifceName
for _, addr_ := range addrs {
if ipNet, ok := addr_.(*net.IPNet); ok {
addr = append(addr, ipToAddr(ipNet.IP, network))
}
}
} else {
ifce, err = findInterfaceByIP(ip)
if err != nil {
return
}
addr = []net.Addr{ipToAddr(ip, network)}
}
return
}
func findInterfaceByIP(ip net.IP) (string, error) {
ifces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, ifce := range ifces {
addrs, _ := ifce.Addrs()
if len(addrs) == 0 {
continue
}
for _, addr := range addrs {
ipAddr, _ := addr.(*net.IPNet)
if ipAddr == nil {
continue
}
// logger.Default().Infof("%s-%s", ipAddr, ip)
if ipAddr.IP.Equal(ip) {
return ifce.Name, nil
}
}
}
return "", nil
}