From 96f4d7bf5ce466f2ebebaae3c8a63912cd75b573 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Mon, 8 Jul 2024 10:59:32 +0800 Subject: [PATCH] netns: fix network namespaces for listeners --- connector/socks/v5/connector.go | 6 + connector/socks/v5/metadata.go | 17 +-- listener/dns/listener.go | 156 +++++++++++++++++++----- listener/dns/metadata.go | 2 + listener/dns/server.go | 38 ++++-- listener/http3/listener.go | 17 ++- listener/http3/wt/listener.go | 49 +++++--- listener/quic/listener.go | 5 +- listener/redirect/udp/listener_linux.go | 30 +++++ 9 files changed, 246 insertions(+), 74 deletions(-) diff --git a/connector/socks/v5/connector.go b/connector/socks/v5/connector.go index 5844708..d52e795 100644 --- a/connector/socks/v5/connector.go +++ b/connector/socks/v5/connector.go @@ -213,8 +213,14 @@ func (c *socks5Connector) relayUDP(ctx context.Context, conn net.Conn, addr net. cc, err := opts.NetDialer.Dial(ctx, "udp", reply.Addr.String()) if err != nil { + c.options.Logger.Error(err) return nil, err } + log.Debugf("%s <- %s -> %s", cc.LocalAddr(), cc.RemoteAddr(), addr) + + if c.md.udpTimeout > 0 { + cc.SetReadDeadline(time.Now().Add(c.md.udpTimeout)) + } return &udpRelayConn{ udpConn: cc.(*net.UDPConn), diff --git a/connector/socks/v5/metadata.go b/connector/socks/v5/metadata.go index 0ee0616..e6dba39 100644 --- a/connector/socks/v5/metadata.go +++ b/connector/socks/v5/metadata.go @@ -17,24 +17,19 @@ type metadata struct { noTLS bool relay string udpBufferSize int + udpTimeout time.Duration muxCfg *mux.Config } func (c *socks5Connector) parseMetadata(md mdata.Metadata) (err error) { - const ( - connectTimeout = "timeout" - noTLS = "notls" - relay = "relay" - udpBufferSize = "udpBufferSize" - ) - - c.md.connectTimeout = mdutil.GetDuration(md, connectTimeout) - c.md.noTLS = mdutil.GetBool(md, noTLS) - c.md.relay = mdutil.GetString(md, relay) - c.md.udpBufferSize = mdutil.GetInt(md, udpBufferSize) + c.md.connectTimeout = mdutil.GetDuration(md, "timeout") + c.md.noTLS = mdutil.GetBool(md, "notls") + c.md.relay = mdutil.GetString(md, "relay") + c.md.udpBufferSize = mdutil.GetInt(md, "udp.bufferSize", "udpBufferSize") if c.md.udpBufferSize <= 0 { c.md.udpBufferSize = defaultUDPBufferSize } + c.md.udpTimeout = mdutil.GetDuration(md, "udp.timeout") c.md.muxCfg = &mux.Config{ Version: mdutil.GetInt(md, "mux.version"), diff --git a/listener/dns/listener.go b/listener/dns/listener.go index 8f8ef17..a627679 100644 --- a/listener/dns/listener.go +++ b/listener/dns/listener.go @@ -2,6 +2,8 @@ package dns import ( "bytes" + "context" + "crypto/tls" "encoding/base64" "errors" "io" @@ -9,12 +11,12 @@ import ( "net/http" "strings" - admission "github.com/go-gost/x/admission/wrapper" - limiter "github.com/go-gost/x/limiter/traffic/wrapper" - "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + admission "github.com/go-gost/x/admission/wrapper" + xnet "github.com/go-gost/x/internal/net" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" metrics "github.com/go-gost/x/metrics/wrapper" stats "github.com/go-gost/x/observer/stats/wrapper" "github.com/go-gost/x/registry" @@ -51,48 +53,144 @@ func (l *dnsListener) Init(md md.Metadata) (err error) { return } - l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr) - if err != nil { - return err - } - switch strings.ToLower(l.md.mode) { case "tcp": - l.server = &dns.Server{ - Net: "tcp", - Addr: l.options.Addr, - Handler: l, - ReadTimeout: l.md.readTimeout, - WriteTimeout: l.md.writeTimeout, + l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr) + if err != nil { + return } + + network := "tcp" + if xnet.IsIPv4(l.options.Addr) { + network = "tcp4" + } + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + + var ln net.Listener + ln, err = lc.Listen(context.Background(), network, l.options.Addr) + if err != nil { + return + } + l.server = &dnsServer{ + server: &dns.Server{ + Net: "tcp", + Addr: l.options.Addr, + Listener: ln, + Handler: l, + ReadTimeout: l.md.readTimeout, + WriteTimeout: l.md.writeTimeout, + }, + } + case "tls": - l.server = &dns.Server{ - Net: "tcp-tls", - Addr: l.options.Addr, - Handler: l, - TLSConfig: l.options.TLSConfig, - ReadTimeout: l.md.readTimeout, - WriteTimeout: l.md.writeTimeout, + l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr) + if err != nil { + return } + + network := "tcp" + if xnet.IsIPv4(l.options.Addr) { + network = "tcp4" + } + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + + var ln net.Listener + ln, err = lc.Listen(context.Background(), network, l.options.Addr) + if err != nil { + return + } + ln = tls.NewListener(ln, l.options.TLSConfig) + + l.server = &dnsServer{ + server: &dns.Server{ + Net: "tcp-tls", + Addr: l.options.Addr, + Listener: ln, + Handler: l, + TLSConfig: l.options.TLSConfig, + ReadTimeout: l.md.readTimeout, + WriteTimeout: l.md.writeTimeout, + }, + } + case "https": + l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr) + if err != nil { + return + } + + network := "tcp" + if xnet.IsIPv4(l.options.Addr) { + network = "tcp4" + } + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + + var ln net.Listener + ln, err = lc.Listen(context.Background(), network, l.options.Addr) + if err != nil { + return + } + ln = tls.NewListener(ln, l.options.TLSConfig) + l.server = &dohServer{ addr: l.options.Addr, tlsConfig: l.options.TLSConfig, + listener: ln, server: &http.Server{ Handler: l, ReadTimeout: l.md.readTimeout, WriteTimeout: l.md.writeTimeout, }, } + default: l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr) - l.server = &dns.Server{ - Net: "udp", - Addr: l.options.Addr, - Handler: l, - UDPSize: l.md.readBufferSize, - ReadTimeout: l.md.readTimeout, - WriteTimeout: l.md.writeTimeout, + if err != nil { + return + } + + network := "udp" + if xnet.IsIPv4(l.options.Addr) { + network = "udp4" + } + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + + var pc net.PacketConn + pc, err = lc.ListenPacket(context.Background(), network, l.options.Addr) + if err != nil { + return + } + + l.server = &dnsServer{ + server: &dns.Server{ + Net: "udp", + Addr: l.options.Addr, + PacketConn: pc, + Handler: l, + UDPSize: l.md.readBufferSize, + ReadTimeout: l.md.readTimeout, + WriteTimeout: l.md.writeTimeout, + }, } } @@ -104,7 +202,7 @@ func (l *dnsListener) Init(md md.Metadata) (err error) { l.errChan = make(chan error, 1) go func() { - err := l.server.ListenAndServe() + err := l.server.Serve() if err != nil { l.errChan <- err } diff --git a/listener/dns/metadata.go b/listener/dns/metadata.go index 9097f4c..187e670 100644 --- a/listener/dns/metadata.go +++ b/listener/dns/metadata.go @@ -17,6 +17,7 @@ type metadata struct { readTimeout time.Duration writeTimeout time.Duration backlog int + mptcp bool } func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) { @@ -37,6 +38,7 @@ func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) { if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/dns/server.go b/listener/dns/server.go index 4e60f6d..2c0c716 100644 --- a/listener/dns/server.go +++ b/listener/dns/server.go @@ -10,29 +10,47 @@ import ( "time" xnet "github.com/go-gost/x/internal/net" + "github.com/miekg/dns" ) type Server interface { - ListenAndServe() error + Serve() error Shutdown() error } +type dnsServer struct { + server *dns.Server +} + +func (s *dnsServer) Serve() error { + return s.server.ActivateAndServe() +} + +func (s *dnsServer) Shutdown() error { + return s.server.Shutdown() +} + type dohServer struct { addr string tlsConfig *tls.Config + listener net.Listener server *http.Server } -func (s *dohServer) ListenAndServe() error { - network := "tcp" - if xnet.IsIPv4(s.addr) { - network = "tcp4" +func (s *dohServer) Serve() error { + var err error + ln := s.listener + if ln == nil { + network := "tcp" + if xnet.IsIPv4(s.addr) { + network = "tcp4" + } + ln, err = net.Listen(network, s.addr) + if err != nil { + return err + } + ln = tls.NewListener(ln, s.tlsConfig) } - ln, err := net.Listen(network, s.addr) - if err != nil { - return err - } - ln = tls.NewListener(ln, s.tlsConfig) return s.server.Serve(ln) } diff --git a/listener/http3/listener.go b/listener/http3/listener.go index e263c53..4b193a8 100644 --- a/listener/http3/listener.go +++ b/listener/http3/listener.go @@ -46,11 +46,16 @@ func (l *http3Listener) Init(md md.Metadata) (err error) { return } + addr := l.options.Addr + if addr == "" { + addr = ":https" + } + network := "udp" - if xnet.IsIPv4(l.options.Addr) { + if xnet.IsIPv4(addr) { network = "udp4" } - l.addr, err = net.ResolveUDPAddr(network, l.options.Addr) + l.addr, err = net.ResolveUDPAddr(network, addr) if err != nil { return } @@ -66,15 +71,21 @@ func (l *http3Listener) Init(md md.Metadata) (err error) { quic.Version1, }, MaxIncomingStreams: int64(l.md.maxStreams), + Allow0RTT: true, }, Handler: http.HandlerFunc(l.handleFunc), } + ln, err := quic.ListenAddrEarly(addr, http3.ConfigureTLSConfig(l.server.TLSConfig), l.server.QUICConfig.Clone()) + if err != nil { + return + } + l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) go func() { - if err := l.server.ListenAndServe(); err != nil { + if err := l.server.ServeListener(ln); err != nil { l.logger.Error(err) } }() diff --git a/listener/http3/wt/listener.go b/listener/http3/wt/listener.go index 118b10c..8a0d0a9 100644 --- a/listener/http3/wt/listener.go +++ b/listener/http3/wt/listener.go @@ -50,11 +50,22 @@ func (l *wtListener) Init(md md.Metadata) (err error) { return } + addr := l.options.Addr + if addr == "" { + addr = ":https" + } + network := "udp" - if xnet.IsIPv4(l.options.Addr) { + if xnet.IsIPv4(addr) { network = "udp4" } - l.addr, err = net.ResolveUDPAddr(network, l.options.Addr) + laddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return + } + l.addr = laddr + + pc, err := net.ListenUDP(network, laddr) if err != nil { return } @@ -62,23 +73,25 @@ func (l *wtListener) Init(md md.Metadata) (err error) { mux := http.NewServeMux() mux.Handle(l.md.path, http.HandlerFunc(l.upgrade)) + quicCfg := &quic.Config{ + KeepAlivePeriod: l.md.keepAlivePeriod, + HandshakeIdleTimeout: l.md.handshakeTimeout, + MaxIdleTimeout: l.md.maxIdleTimeout, + /* + Versions: []quic.VersionNumber{ + quic.Version1, + quic.Version2, + }, + */ + MaxIncomingStreams: int64(l.md.maxStreams), + Allow0RTT: true, + } l.srv = &wt.Server{ H3: http3.Server{ - Addr: l.options.Addr, - TLSConfig: l.options.TLSConfig, - QUICConfig: &quic.Config{ - KeepAlivePeriod: l.md.keepAlivePeriod, - HandshakeIdleTimeout: l.md.handshakeTimeout, - MaxIdleTimeout: l.md.maxIdleTimeout, - /* - Versions: []quic.VersionNumber{ - quic.Version1, - quic.Version2, - }, - */ - MaxIncomingStreams: int64(l.md.maxStreams), - }, - Handler: mux, + Addr: l.options.Addr, + TLSConfig: l.options.TLSConfig, + QUICConfig: quicCfg, + Handler: mux, }, CheckOrigin: func(r *http.Request) bool { return true }, } @@ -87,7 +100,7 @@ func (l *wtListener) Init(md md.Metadata) (err error) { l.errChan = make(chan error, 1) go func() { - if err := l.srv.ListenAndServe(); err != nil { + if err := l.srv.Serve(pc); err != nil { l.logger.Error(err) } }() diff --git a/listener/quic/listener.go b/listener/quic/listener.go index 3604fef..1c8a070 100644 --- a/listener/quic/listener.go +++ b/listener/quic/listener.go @@ -52,11 +52,10 @@ func (l *quicListener) Init(md md.Metadata) (err error) { } network := "udp" - if xnet.IsIPv4(l.options.Addr) { + if xnet.IsIPv4(addr) { network = "udp4" } - var laddr *net.UDPAddr - laddr, err = net.ResolveUDPAddr(network, addr) + laddr, err := net.ResolveUDPAddr(network, addr) if err != nil { return } diff --git a/listener/redirect/udp/listener_linux.go b/listener/redirect/udp/listener_linux.go index b4c6051..ba2b779 100644 --- a/listener/redirect/udp/listener_linux.go +++ b/listener/redirect/udp/listener_linux.go @@ -7,12 +7,15 @@ import ( "fmt" "net" "os" + "runtime" "strconv" + "strings" "syscall" "unsafe" "github.com/go-gost/core/common/bufpool" xnet "github.com/go-gost/x/internal/net" + "github.com/vishvananda/netns" "golang.org/x/sys/unix" ) @@ -53,6 +56,33 @@ func (l *redirectListener) accept() (conn net.Conn, err error) { l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String()) + if l.options.Netns != "" { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + originNs, err := netns.Get() + if err != nil { + return nil, fmt.Errorf("netns.Get(): %v", err) + } + defer netns.Set(originNs) + + var ns netns.NsHandle + + if strings.HasPrefix(l.options.Netns, "/") { + ns, err = netns.GetFromPath(l.options.Netns) + } else { + ns, err = netns.GetFromName(l.options.Netns) + } + if err != nil { + return nil, fmt.Errorf("netns.Get(%s): %v", l.options.Netns, err) + } + defer ns.Close() + + if err := netns.Set(ns); err != nil { + return nil, fmt.Errorf("netns.Set(%s): %v", l.options.Netns, err) + } + } + network := "udp" if xnet.IsIPv4(l.options.Addr) { network = "udp4"