core/common/net/dialer/dialer.go
2024-06-25 20:37:08 +08:00

172 lines
3.6 KiB
Go

package dialer
import (
"context"
"fmt"
"net"
"runtime"
"strings"
"syscall"
"time"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
"github.com/vishvananda/netns"
)
const (
DefaultTimeout = 10 * time.Second
)
var (
DefaultNetDialer = &NetDialer{}
)
type NetDialer struct {
Interface string
Netns string
Mark int
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
Logger logger.Logger
}
func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if d == nil {
d = DefaultNetDialer
}
log := d.Logger
if log == nil {
log = logger.Default()
}
if d.Netns != "" {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
originNs, err := netns.Get()
if err != nil {
return nil, fmt.Errorf("netns.Get(): %v", err)
}
defer netns.Set(originNs)
ns, err := netns.GetFromName(d.Netns)
if err != nil {
return nil, fmt.Errorf("netns.GetFromName(%s): %v", d.Netns, err)
}
defer ns.Close()
if err := netns.Set(ns); err != nil {
return nil, fmt.Errorf("netns.Set(%s): %v", d.Netns, err)
}
}
if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr)
}
switch network {
case "unix":
netd := net.Dialer{}
return netd.DialContext(ctx, network, addr)
default:
}
ifces := strings.Split(d.Interface, ",")
for _, ifce := range ifces {
strict := strings.HasSuffix(ifce, "!")
ifce = strings.TrimSuffix(ifce, "!")
var ifceName string
var ifAddrs []net.Addr
ifceName, ifAddrs, err = xnet.ParseInterfaceAddr(ifce, network)
if err != nil && strict {
return
}
for _, ifAddr := range ifAddrs {
conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log)
if err == nil {
return
}
log.Debugf("dial %s %v@%s failed: %s", network, ifAddr, ifceName, err)
if strict &&
!strings.Contains(err.Error(), "no suitable address found") &&
!strings.Contains(err.Error(), "mismatched local address type") {
return
}
}
}
return
}
func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, log logger.Logger) (net.Conn, error) {
if ifceName != "" {
log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network)
}
switch network {
case "udp", "udp4", "udp6":
if addr == "" {
var laddr *net.UDPAddr
if ifAddr != nil {
laddr, _ = ifAddr.(*net.UDPAddr)
}
c, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
sc, err := c.SyscallConn()
if err != nil {
log.Error(err)
return nil, err
}
err = sc.Control(func(fd uintptr) {
if ifceName != "" {
if err := bindDevice(fd, ifceName); err != nil {
log.Warnf("bind device: %v", err)
}
}
if d.Mark != 0 {
if err := setMark(fd, d.Mark); err != nil {
log.Warnf("set mark: %v", err)
}
}
})
if err != nil {
log.Error(err)
}
return c, nil
}
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("dial: unsupported network %s", network)
}
netd := net.Dialer{
LocalAddr: ifAddr,
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
if ifceName != "" {
if err := bindDevice(fd, ifceName); err != nil {
log.Warnf("bind device: %v", err)
}
}
if d.Mark != 0 {
if err := setMark(fd, d.Mark); err != nil {
log.Warnf("set mark: %v", err)
}
}
})
},
}
if d.Netns != "" {
// https://github.com/golang/go/issues/44922#issuecomment-796645858
netd.FallbackDelay = -1
}
return netd.DialContext(ctx, network, addr)
}