add support for linux network namespace

This commit is contained in:
ginuerzh
2024-06-21 23:38:18 +08:00
parent 15f28c667a
commit 2ae0462822
9 changed files with 127 additions and 8 deletions

View File

@ -2,6 +2,7 @@ package service
import (
"fmt"
"runtime"
"strings"
"time"
@ -32,6 +33,7 @@ import (
"github.com/go-gost/x/registry"
xservice "github.com/go-gost/x/service"
"github.com/go-gost/x/stats"
"github.com/vishvananda/netns"
)
func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
@ -104,6 +106,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
var ignoreChain bool
var pStats *stats.Stats
var observePeriod time.Duration
var netnsIn, netnsOut string
if cfg.Metadata != nil {
md := metadata.NewMetadata(cfg.Metadata)
ppv = mdutil.GetInt(md, parsing.MDKeyProxyProtocol)
@ -125,6 +128,8 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
pStats = &stats.Stats{}
}
observePeriod = mdutil.GetDuration(md, "observePeriod")
netnsIn = mdutil.GetString(md, "netns")
netnsOut = mdutil.GetString(md, "netns.out")
}
listenOpts := []listener.Option{
@ -146,6 +151,27 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
)
}
if netnsIn != "" {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
originNs, err := netns.Get()
if err != nil {
return nil, fmt.Errorf("netns.Get(): %v", err)
}
defer netns.Set(originNs)
ns, err := netns.GetFromName(netnsIn)
if err != nil {
return nil, fmt.Errorf("netns.GetFromName(%s): %v", netnsIn, err)
}
defer ns.Close()
if err := netns.Set(ns); err != nil {
return nil, fmt.Errorf("netns.Set(%s): %v", netnsIn, err)
}
}
var ln listener.Listener
if rf := registry.ListenerRegistry().Get(cfg.Listener.Type); rf != nil {
ln = rf(listenOpts...)
@ -209,6 +235,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
chain.RetriesRouterOption(cfg.Handler.Retries),
// chain.TimeoutRouterOption(10*time.Second),
chain.InterfaceRouterOption(ifce),
chain.NetnsRouterOption(netnsOut),
chain.SockOptsRouterOption(sockOpts),
chain.ResolverRouterOption(registry.ResolverRegistry().Get(cfg.Resolver)),
chain.HostMapperRouterOption(registry.HostsRegistry().Get(cfg.Hosts)),