netns: fix network namespaces for listeners

This commit is contained in:
ginuerzh 2024-07-08 10:59:32 +08:00
parent 949c98adc0
commit 96f4d7bf5c
9 changed files with 246 additions and 74 deletions

View File

@ -213,8 +213,14 @@ func (c *socks5Connector) relayUDP(ctx context.Context, conn net.Conn, addr net.
cc, err := opts.NetDialer.Dial(ctx, "udp", reply.Addr.String()) cc, err := opts.NetDialer.Dial(ctx, "udp", reply.Addr.String())
if err != nil { if err != nil {
c.options.Logger.Error(err)
return nil, err return nil, err
} }
log.Debugf("%s <- %s -> %s", cc.LocalAddr(), cc.RemoteAddr(), addr)
if c.md.udpTimeout > 0 {
cc.SetReadDeadline(time.Now().Add(c.md.udpTimeout))
}
return &udpRelayConn{ return &udpRelayConn{
udpConn: cc.(*net.UDPConn), udpConn: cc.(*net.UDPConn),

View File

@ -17,24 +17,19 @@ type metadata struct {
noTLS bool noTLS bool
relay string relay string
udpBufferSize int udpBufferSize int
udpTimeout time.Duration
muxCfg *mux.Config muxCfg *mux.Config
} }
func (c *socks5Connector) parseMetadata(md mdata.Metadata) (err error) { func (c *socks5Connector) parseMetadata(md mdata.Metadata) (err error) {
const ( c.md.connectTimeout = mdutil.GetDuration(md, "timeout")
connectTimeout = "timeout" c.md.noTLS = mdutil.GetBool(md, "notls")
noTLS = "notls" c.md.relay = mdutil.GetString(md, "relay")
relay = "relay" c.md.udpBufferSize = mdutil.GetInt(md, "udp.bufferSize", "udpBufferSize")
udpBufferSize = "udpBufferSize"
)
c.md.connectTimeout = mdutil.GetDuration(md, connectTimeout)
c.md.noTLS = mdutil.GetBool(md, noTLS)
c.md.relay = mdutil.GetString(md, relay)
c.md.udpBufferSize = mdutil.GetInt(md, udpBufferSize)
if c.md.udpBufferSize <= 0 { if c.md.udpBufferSize <= 0 {
c.md.udpBufferSize = defaultUDPBufferSize c.md.udpBufferSize = defaultUDPBufferSize
} }
c.md.udpTimeout = mdutil.GetDuration(md, "udp.timeout")
c.md.muxCfg = &mux.Config{ c.md.muxCfg = &mux.Config{
Version: mdutil.GetInt(md, "mux.version"), Version: mdutil.GetInt(md, "mux.version"),

View File

@ -2,6 +2,8 @@ package dns
import ( import (
"bytes" "bytes"
"context"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"io" "io"
@ -9,12 +11,12 @@ import (
"net/http" "net/http"
"strings" "strings"
admission "github.com/go-gost/x/admission/wrapper"
limiter "github.com/go-gost/x/limiter/traffic/wrapper"
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata" md "github.com/go-gost/core/metadata"
admission "github.com/go-gost/x/admission/wrapper"
xnet "github.com/go-gost/x/internal/net"
limiter "github.com/go-gost/x/limiter/traffic/wrapper"
metrics "github.com/go-gost/x/metrics/wrapper" metrics "github.com/go-gost/x/metrics/wrapper"
stats "github.com/go-gost/x/observer/stats/wrapper" stats "github.com/go-gost/x/observer/stats/wrapper"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
@ -51,48 +53,144 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
return return
} }
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return err
}
switch strings.ToLower(l.md.mode) { switch strings.ToLower(l.md.mode) {
case "tcp": case "tcp":
l.server = &dns.Server{ l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
Net: "tcp", if err != nil {
Addr: l.options.Addr, return
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
} }
network := "tcp"
if xnet.IsIPv4(l.options.Addr) {
network = "tcp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var ln net.Listener
ln, err = lc.Listen(context.Background(), network, l.options.Addr)
if err != nil {
return
}
l.server = &dnsServer{
server: &dns.Server{
Net: "tcp",
Addr: l.options.Addr,
Listener: ln,
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
case "tls": case "tls":
l.server = &dns.Server{ l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
Net: "tcp-tls", if err != nil {
Addr: l.options.Addr, return
Handler: l,
TLSConfig: l.options.TLSConfig,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
} }
network := "tcp"
if xnet.IsIPv4(l.options.Addr) {
network = "tcp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var ln net.Listener
ln, err = lc.Listen(context.Background(), network, l.options.Addr)
if err != nil {
return
}
ln = tls.NewListener(ln, l.options.TLSConfig)
l.server = &dnsServer{
server: &dns.Server{
Net: "tcp-tls",
Addr: l.options.Addr,
Listener: ln,
Handler: l,
TLSConfig: l.options.TLSConfig,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
case "https": case "https":
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return
}
network := "tcp"
if xnet.IsIPv4(l.options.Addr) {
network = "tcp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var ln net.Listener
ln, err = lc.Listen(context.Background(), network, l.options.Addr)
if err != nil {
return
}
ln = tls.NewListener(ln, l.options.TLSConfig)
l.server = &dohServer{ l.server = &dohServer{
addr: l.options.Addr, addr: l.options.Addr,
tlsConfig: l.options.TLSConfig, tlsConfig: l.options.TLSConfig,
listener: ln,
server: &http.Server{ server: &http.Server{
Handler: l, Handler: l,
ReadTimeout: l.md.readTimeout, ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout, WriteTimeout: l.md.writeTimeout,
}, },
} }
default: default:
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr) l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
l.server = &dns.Server{ if err != nil {
Net: "udp", return
Addr: l.options.Addr, }
Handler: l,
UDPSize: l.md.readBufferSize, network := "udp"
ReadTimeout: l.md.readTimeout, if xnet.IsIPv4(l.options.Addr) {
WriteTimeout: l.md.writeTimeout, network = "udp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var pc net.PacketConn
pc, err = lc.ListenPacket(context.Background(), network, l.options.Addr)
if err != nil {
return
}
l.server = &dnsServer{
server: &dns.Server{
Net: "udp",
Addr: l.options.Addr,
PacketConn: pc,
Handler: l,
UDPSize: l.md.readBufferSize,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
} }
} }
@ -104,7 +202,7 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
l.errChan = make(chan error, 1) l.errChan = make(chan error, 1)
go func() { go func() {
err := l.server.ListenAndServe() err := l.server.Serve()
if err != nil { if err != nil {
l.errChan <- err l.errChan <- err
} }

View File

@ -17,6 +17,7 @@ type metadata struct {
readTimeout time.Duration readTimeout time.Duration
writeTimeout time.Duration writeTimeout time.Duration
backlog int backlog int
mptcp bool
} }
func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) { func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) {
@ -37,6 +38,7 @@ func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) {
if l.md.backlog <= 0 { if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog l.md.backlog = defaultBacklog
} }
l.md.mptcp = mdutil.GetBool(md, "mptcp")
return return
} }

View File

@ -10,29 +10,47 @@ import (
"time" "time"
xnet "github.com/go-gost/x/internal/net" xnet "github.com/go-gost/x/internal/net"
"github.com/miekg/dns"
) )
type Server interface { type Server interface {
ListenAndServe() error Serve() error
Shutdown() error Shutdown() error
} }
type dnsServer struct {
server *dns.Server
}
func (s *dnsServer) Serve() error {
return s.server.ActivateAndServe()
}
func (s *dnsServer) Shutdown() error {
return s.server.Shutdown()
}
type dohServer struct { type dohServer struct {
addr string addr string
tlsConfig *tls.Config tlsConfig *tls.Config
listener net.Listener
server *http.Server server *http.Server
} }
func (s *dohServer) ListenAndServe() error { func (s *dohServer) Serve() error {
network := "tcp" var err error
if xnet.IsIPv4(s.addr) { ln := s.listener
network = "tcp4" if ln == nil {
network := "tcp"
if xnet.IsIPv4(s.addr) {
network = "tcp4"
}
ln, err = net.Listen(network, s.addr)
if err != nil {
return err
}
ln = tls.NewListener(ln, s.tlsConfig)
} }
ln, err := net.Listen(network, s.addr)
if err != nil {
return err
}
ln = tls.NewListener(ln, s.tlsConfig)
return s.server.Serve(ln) return s.server.Serve(ln)
} }

View File

@ -46,11 +46,16 @@ func (l *http3Listener) Init(md md.Metadata) (err error) {
return return
} }
addr := l.options.Addr
if addr == "" {
addr = ":https"
}
network := "udp" network := "udp"
if xnet.IsIPv4(l.options.Addr) { if xnet.IsIPv4(addr) {
network = "udp4" network = "udp4"
} }
l.addr, err = net.ResolveUDPAddr(network, l.options.Addr) l.addr, err = net.ResolveUDPAddr(network, addr)
if err != nil { if err != nil {
return return
} }
@ -66,15 +71,21 @@ func (l *http3Listener) Init(md md.Metadata) (err error) {
quic.Version1, quic.Version1,
}, },
MaxIncomingStreams: int64(l.md.maxStreams), MaxIncomingStreams: int64(l.md.maxStreams),
Allow0RTT: true,
}, },
Handler: http.HandlerFunc(l.handleFunc), Handler: http.HandlerFunc(l.handleFunc),
} }
ln, err := quic.ListenAddrEarly(addr, http3.ConfigureTLSConfig(l.server.TLSConfig), l.server.QUICConfig.Clone())
if err != nil {
return
}
l.cqueue = make(chan net.Conn, l.md.backlog) l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1) l.errChan = make(chan error, 1)
go func() { go func() {
if err := l.server.ListenAndServe(); err != nil { if err := l.server.ServeListener(ln); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
}() }()

View File

@ -50,11 +50,22 @@ func (l *wtListener) Init(md md.Metadata) (err error) {
return return
} }
addr := l.options.Addr
if addr == "" {
addr = ":https"
}
network := "udp" network := "udp"
if xnet.IsIPv4(l.options.Addr) { if xnet.IsIPv4(addr) {
network = "udp4" network = "udp4"
} }
l.addr, err = net.ResolveUDPAddr(network, l.options.Addr) laddr, err := net.ResolveUDPAddr(network, addr)
if err != nil {
return
}
l.addr = laddr
pc, err := net.ListenUDP(network, laddr)
if err != nil { if err != nil {
return return
} }
@ -62,23 +73,25 @@ func (l *wtListener) Init(md md.Metadata) (err error) {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(l.md.path, http.HandlerFunc(l.upgrade)) mux.Handle(l.md.path, http.HandlerFunc(l.upgrade))
quicCfg := &quic.Config{
KeepAlivePeriod: l.md.keepAlivePeriod,
HandshakeIdleTimeout: l.md.handshakeTimeout,
MaxIdleTimeout: l.md.maxIdleTimeout,
/*
Versions: []quic.VersionNumber{
quic.Version1,
quic.Version2,
},
*/
MaxIncomingStreams: int64(l.md.maxStreams),
Allow0RTT: true,
}
l.srv = &wt.Server{ l.srv = &wt.Server{
H3: http3.Server{ H3: http3.Server{
Addr: l.options.Addr, Addr: l.options.Addr,
TLSConfig: l.options.TLSConfig, TLSConfig: l.options.TLSConfig,
QUICConfig: &quic.Config{ QUICConfig: quicCfg,
KeepAlivePeriod: l.md.keepAlivePeriod, Handler: mux,
HandshakeIdleTimeout: l.md.handshakeTimeout,
MaxIdleTimeout: l.md.maxIdleTimeout,
/*
Versions: []quic.VersionNumber{
quic.Version1,
quic.Version2,
},
*/
MaxIncomingStreams: int64(l.md.maxStreams),
},
Handler: mux,
}, },
CheckOrigin: func(r *http.Request) bool { return true }, CheckOrigin: func(r *http.Request) bool { return true },
} }
@ -87,7 +100,7 @@ func (l *wtListener) Init(md md.Metadata) (err error) {
l.errChan = make(chan error, 1) l.errChan = make(chan error, 1)
go func() { go func() {
if err := l.srv.ListenAndServe(); err != nil { if err := l.srv.Serve(pc); err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
}() }()

View File

@ -52,11 +52,10 @@ func (l *quicListener) Init(md md.Metadata) (err error) {
} }
network := "udp" network := "udp"
if xnet.IsIPv4(l.options.Addr) { if xnet.IsIPv4(addr) {
network = "udp4" network = "udp4"
} }
var laddr *net.UDPAddr laddr, err := net.ResolveUDPAddr(network, addr)
laddr, err = net.ResolveUDPAddr(network, addr)
if err != nil { if err != nil {
return return
} }

View File

@ -7,12 +7,15 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"runtime"
"strconv" "strconv"
"strings"
"syscall" "syscall"
"unsafe" "unsafe"
"github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/common/bufpool"
xnet "github.com/go-gost/x/internal/net" xnet "github.com/go-gost/x/internal/net"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -53,6 +56,33 @@ func (l *redirectListener) accept() (conn net.Conn, err error) {
l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String()) l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String())
if l.options.Netns != "" {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
originNs, err := netns.Get()
if err != nil {
return nil, fmt.Errorf("netns.Get(): %v", err)
}
defer netns.Set(originNs)
var ns netns.NsHandle
if strings.HasPrefix(l.options.Netns, "/") {
ns, err = netns.GetFromPath(l.options.Netns)
} else {
ns, err = netns.GetFromName(l.options.Netns)
}
if err != nil {
return nil, fmt.Errorf("netns.Get(%s): %v", l.options.Netns, err)
}
defer ns.Close()
if err := netns.Set(ns); err != nil {
return nil, fmt.Errorf("netns.Set(%s): %v", l.options.Netns, err)
}
}
network := "udp" network := "udp"
if xnet.IsIPv4(l.options.Addr) { if xnet.IsIPv4(l.options.Addr) {
network = "udp4" network = "udp4"