add support for linux network namespace

This commit is contained in:
ginuerzh
2024-06-21 23:34:12 +08:00
parent 8d554ddcf7
commit 5aede9a2b3
6 changed files with 65 additions and 8 deletions

View File

@ -4,12 +4,14 @@ import (
"context"
"fmt"
"net"
"runtime"
"strings"
"syscall"
"time"
xnet "github.com/go-gost/core/common/net"
"github.com/go-gost/core/logger"
"github.com/vishvananda/netns"
)
const (
@ -22,6 +24,7 @@ var (
type NetDialer struct {
Interface string
Netns string
Mark int
Timeout time.Duration
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
@ -33,6 +36,32 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co
d = DefaultNetDialer
}
log := d.Logger
if log == nil {
log = logger.Default()
}
if d.Netns != "" {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
originNs, err := netns.Get()
if err != nil {
return nil, fmt.Errorf("netns.Get(): %v", err)
}
defer netns.Set(originNs)
ns, err := netns.GetFromName(d.Netns)
if err != nil {
return nil, fmt.Errorf("netns.GetFromName(%s): %v", d.Netns, err)
}
defer ns.Close()
if err := netns.Set(ns); err != nil {
return nil, fmt.Errorf("netns.Set(%s): %v", d.Netns, err)
}
}
timeout := d.Timeout
if timeout <= 0 {
timeout = DefaultTimeout
@ -42,11 +71,6 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co
return d.DialFunc(ctx, network, addr)
}
log := d.Logger
if log == nil {
log = logger.Default()
}
switch network {
case "unix":
netd := net.Dialer{}
@ -150,5 +174,10 @@ func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string
})
},
}
if d.Netns != "" {
// https://github.com/golang/go/issues/44922#issuecomment-796645858
netd.FallbackDelay = -1
}
return netd.DialContext(ctx, network, addr)
}