netns: fix network namespaces for listeners

This commit is contained in:
ginuerzh
2024-07-08 10:59:32 +08:00
parent 949c98adc0
commit 96f4d7bf5c
9 changed files with 246 additions and 74 deletions

View File

@ -2,6 +2,8 @@ package dns
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"io"
@ -9,12 +11,12 @@ import (
"net/http"
"strings"
admission "github.com/go-gost/x/admission/wrapper"
limiter "github.com/go-gost/x/limiter/traffic/wrapper"
"github.com/go-gost/core/listener"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
admission "github.com/go-gost/x/admission/wrapper"
xnet "github.com/go-gost/x/internal/net"
limiter "github.com/go-gost/x/limiter/traffic/wrapper"
metrics "github.com/go-gost/x/metrics/wrapper"
stats "github.com/go-gost/x/observer/stats/wrapper"
"github.com/go-gost/x/registry"
@ -51,48 +53,144 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
return
}
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return err
}
switch strings.ToLower(l.md.mode) {
case "tcp":
l.server = &dns.Server{
Net: "tcp",
Addr: l.options.Addr,
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return
}
network := "tcp"
if xnet.IsIPv4(l.options.Addr) {
network = "tcp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var ln net.Listener
ln, err = lc.Listen(context.Background(), network, l.options.Addr)
if err != nil {
return
}
l.server = &dnsServer{
server: &dns.Server{
Net: "tcp",
Addr: l.options.Addr,
Listener: ln,
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
case "tls":
l.server = &dns.Server{
Net: "tcp-tls",
Addr: l.options.Addr,
Handler: l,
TLSConfig: l.options.TLSConfig,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return
}
network := "tcp"
if xnet.IsIPv4(l.options.Addr) {
network = "tcp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var ln net.Listener
ln, err = lc.Listen(context.Background(), network, l.options.Addr)
if err != nil {
return
}
ln = tls.NewListener(ln, l.options.TLSConfig)
l.server = &dnsServer{
server: &dns.Server{
Net: "tcp-tls",
Addr: l.options.Addr,
Listener: ln,
Handler: l,
TLSConfig: l.options.TLSConfig,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
case "https":
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return
}
network := "tcp"
if xnet.IsIPv4(l.options.Addr) {
network = "tcp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var ln net.Listener
ln, err = lc.Listen(context.Background(), network, l.options.Addr)
if err != nil {
return
}
ln = tls.NewListener(ln, l.options.TLSConfig)
l.server = &dohServer{
addr: l.options.Addr,
tlsConfig: l.options.TLSConfig,
listener: ln,
server: &http.Server{
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
default:
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
l.server = &dns.Server{
Net: "udp",
Addr: l.options.Addr,
Handler: l,
UDPSize: l.md.readBufferSize,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
if err != nil {
return
}
network := "udp"
if xnet.IsIPv4(l.options.Addr) {
network = "udp4"
}
lc := net.ListenConfig{}
if l.md.mptcp {
lc.SetMultipathTCP(true)
l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP())
}
var pc net.PacketConn
pc, err = lc.ListenPacket(context.Background(), network, l.options.Addr)
if err != nil {
return
}
l.server = &dnsServer{
server: &dns.Server{
Net: "udp",
Addr: l.options.Addr,
PacketConn: pc,
Handler: l,
UDPSize: l.md.readBufferSize,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
},
}
}
@ -104,7 +202,7 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
l.errChan = make(chan error, 1)
go func() {
err := l.server.ListenAndServe()
err := l.server.Serve()
if err != nil {
l.errChan <- err
}