diff --git a/chain/route.go b/chain/route.go index f19d9a9..0e09203 100644 --- a/chain/route.go +++ b/chain/route.go @@ -38,6 +38,7 @@ func (*route) Dial(ctx context.Context, network, address string, opts ...DialOpt netd := dialer.NetDialer{ Timeout: options.Timeout, Interface: options.Interface, + Netns: options.Netns, Logger: options.Logger, } if options.SockOpts != nil { @@ -95,6 +96,7 @@ func (r *route) Nodes() []*Node { type DialOptions struct { Timeout time.Duration Interface string + Netns string SockOpts *SockOpts Logger logger.Logger } @@ -113,6 +115,12 @@ func InterfaceDialOption(ifName string) DialOption { } } +func NetnsDialOption(netns string) DialOption { + return func(opts *DialOptions) { + opts.Netns = netns + } +} + func SockOptsDialOption(so *SockOpts) DialOption { return func(opts *DialOptions) { opts.SockOpts = so diff --git a/chain/router.go b/chain/router.go index 62bc9b3..5552525 100644 --- a/chain/router.go +++ b/chain/router.go @@ -21,6 +21,7 @@ type RouterOptions struct { Retries int Timeout time.Duration IfceName string + Netns string SockOpts *SockOpts Chain Chainer Resolver resolver.Resolver @@ -37,6 +38,12 @@ func InterfaceRouterOption(ifceName string) RouterOption { } } +func NetnsRouterOption(netns string) RouterOption { + return func(o *RouterOptions) { + o.Netns = netns + } +} + func SockOptsRouterOption(so *SockOpts) RouterOption { return func(o *RouterOptions) { o.SockOpts = so @@ -181,6 +188,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co } conn, err = route.Dial(ctx, network, ipAddr, InterfaceDialOption(r.options.IfceName), + NetnsDialOption(r.options.Netns), SockOptsDialOption(r.options.SockOpts), LoggerDialOption(r.options.Logger), TimeoutDialOption(r.options.Timeout), diff --git a/chain/transport.go b/chain/transport.go index 0b3c232..d68e94f 100644 --- a/chain/transport.go +++ b/chain/transport.go @@ -13,6 +13,7 @@ import ( type TransportOptions struct { Addr string IfceName string + Netns string SockOpts *SockOpts Route Route Timeout time.Duration @@ -32,6 +33,12 @@ func InterfaceTransportOption(ifceName string) TransportOption { } } +func NetnsTransportOption(netns string) TransportOption { + return func(o *TransportOptions) { + o.Netns = netns + } +} + func SockOptsTransportOption(so *SockOpts) TransportOption { return func(o *TransportOptions) { o.SockOpts = so @@ -73,6 +80,7 @@ func NewTransport(d dialer.Dialer, c connector.Connector, opts ...TransportOptio func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { netd := &net_dialer.NetDialer{ Interface: tr.options.IfceName, + Netns: tr.options.Netns, Timeout: tr.options.Timeout, } if tr.options.SockOpts != nil { @@ -108,6 +116,7 @@ func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, er func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { netd := &net_dialer.NetDialer{ Interface: tr.options.IfceName, + Netns: tr.options.Netns, Timeout: tr.options.Timeout, } if tr.options.SockOpts != nil { diff --git a/common/net/dialer/dialer.go b/common/net/dialer/dialer.go index 6112e4c..0b63ff2 100644 --- a/common/net/dialer/dialer.go +++ b/common/net/dialer/dialer.go @@ -4,12 +4,14 @@ import ( "context" "fmt" "net" + "runtime" "strings" "syscall" "time" xnet "github.com/go-gost/core/common/net" "github.com/go-gost/core/logger" + "github.com/vishvananda/netns" ) const ( @@ -22,6 +24,7 @@ var ( type NetDialer struct { Interface string + Netns string Mark int Timeout time.Duration DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) @@ -33,6 +36,32 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co d = DefaultNetDialer } + log := d.Logger + if log == nil { + log = logger.Default() + } + + if d.Netns != "" { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + originNs, err := netns.Get() + if err != nil { + return nil, fmt.Errorf("netns.Get(): %v", err) + } + defer netns.Set(originNs) + + ns, err := netns.GetFromName(d.Netns) + if err != nil { + return nil, fmt.Errorf("netns.GetFromName(%s): %v", d.Netns, err) + } + defer ns.Close() + + if err := netns.Set(ns); err != nil { + return nil, fmt.Errorf("netns.Set(%s): %v", d.Netns, err) + } + } + timeout := d.Timeout if timeout <= 0 { timeout = DefaultTimeout @@ -42,11 +71,6 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co return d.DialFunc(ctx, network, addr) } - log := d.Logger - if log == nil { - log = logger.Default() - } - switch network { case "unix": netd := net.Dialer{} @@ -150,5 +174,10 @@ func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string }) }, } + if d.Netns != "" { + // https://github.com/golang/go/issues/44922#issuecomment-796645858 + netd.FallbackDelay = -1 + } + return netd.DialContext(ctx, network, addr) } diff --git a/go.mod b/go.mod index 66f5e02..b9bffba 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ toolchain go1.22.2 require ( github.com/go-gost/x v0.0.0-20240131151842-25dcf536c6f5 - golang.org/x/sys v0.18.0 + github.com/vishvananda/netns v0.0.4 + golang.org/x/sys v0.21.0 ) diff --git a/go.sum b/go.sum index f2775ca..87ed461 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/go-gost/x v0.0.0-20240131151842-25dcf536c6f5 h1:IiZLdqGMx0lGVbDBy/N9LPu10qSlxm939EBvZ77qJNI= github.com/go-gost/x v0.0.0-20240131151842-25dcf536c6f5/go.mod h1:FDqjiiPbCqJLU/wY+q2IZCBVcYnfTJTw+SJLrspLQms= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=