From f9bfca76ed1f5a4ce9908ec1b4f9fb226d80d870 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Mon, 24 Jun 2024 21:18:04 +0800 Subject: [PATCH] fix netns for socks5 and relay handler --- config/parsing/service/parse.go | 2 ++ connector/direct/connector.go | 12 +++++-- connector/socks/v5/conn.go | 4 +++ connector/socks/v5/connector.go | 8 +++-- go.mod | 4 +-- go.sum | 8 ++--- handler/relay/bind.go | 14 +++++--- handler/socks/v5/bind.go | 8 +++-- handler/socks/v5/mbind.go | 9 +++-- handler/socks/v5/udp.go | 14 ++++---- handler/socks/v5/udp_tun.go | 42 ++++++++++++++++++----- internal/net/net.go | 59 +++++++++++++++++++++++++++++++++ 12 files changed, 150 insertions(+), 34 deletions(-) diff --git a/config/parsing/service/parse.go b/config/parsing/service/parse.go index 8f4bd1c..79a5c43 100644 --- a/config/parsing/service/parse.go +++ b/config/parsing/service/parse.go @@ -144,6 +144,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { listener.ServiceOption(cfg.Name), listener.ProxyProtocolOption(ppv), listener.StatsOption(pStats), + listener.NetnsOption(netnsIn), } if !ignoreChain { listenOpts = append(listenOpts, @@ -262,6 +263,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { handler.ObserverOption(registry.ObserverRegistry().Get(cfg.Handler.Observer)), handler.LoggerOption(handlerLogger), handler.ServiceOption(cfg.Name), + handler.NetnsOption(netnsIn), ) } else { return nil, fmt.Errorf("unknown handler: %s", cfg.Handler.Type) diff --git a/connector/direct/connector.go b/connector/direct/connector.go index 4dd0b10..e3ffd41 100644 --- a/connector/direct/connector.go +++ b/connector/direct/connector.go @@ -44,9 +44,17 @@ func (c *directConnector) Connect(ctx context.Context, _ net.Conn, network, addr return nil, err } + var localAddr, remoteAddr string + if addr := conn.LocalAddr(); addr != nil { + localAddr = addr.String() + } + if addr := conn.RemoteAddr(); addr != nil { + remoteAddr = addr.String() + } + log := c.options.Logger.WithFields(map[string]any{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), + "remote": remoteAddr, + "local": localAddr, "network": network, "address": address, }) diff --git a/connector/socks/v5/conn.go b/connector/socks/v5/conn.go index da0552f..933a1f2 100644 --- a/connector/socks/v5/conn.go +++ b/connector/socks/v5/conn.go @@ -69,6 +69,10 @@ func (c *udpRelayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { if err = socksAddr.ParseFrom(addr.String()); err != nil { return } + if socksAddr.Host == "" { + socksAddr.Type = gosocks5.AddrIPv4 + socksAddr.Host = "127.0.0.1" + } header := gosocks5.UDPHeader{ Addr: &socksAddr, diff --git a/connector/socks/v5/connector.go b/connector/socks/v5/connector.go index 0f69650..5844708 100644 --- a/connector/socks/v5/connector.go +++ b/connector/socks/v5/connector.go @@ -130,6 +130,10 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a log.Error(err) return nil, err } + if addr.Host == "" { + addr.Type = gosocks5.AddrIPv4 + addr.Host = "127.0.0.1" + } req := gosocks5.NewRequest(gosocks5.CmdConnect, &addr) log.Trace(req) @@ -201,12 +205,12 @@ func (c *socks5Connector) relayUDP(ctx context.Context, conn net.Conn, addr net. } log.Trace(reply) - log.Debugf("bind on: %v", reply.Addr) - if reply.Rep != gosocks5.Succeeded { return nil, errors.New("get socks5 UDP tunnel failure") } + log.Debugf("bind on: %v", reply.Addr) + cc, err := opts.NetDialer.Dial(ctx, "udp", reply.Addr.String()) if err != nil { return nil, err diff --git a/go.mod b/go.mod index 8a9122f..feb4498 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,9 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.5.0 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20240621153412-5aede9a2b32f + github.com/go-gost/core v0.0.0-20240624131323-ca340b1bf1a2 github.com/go-gost/gosocks4 v0.0.1 - github.com/go-gost/gosocks5 v0.4.0 + github.com/go-gost/gosocks5 v0.3.1 github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a github.com/go-gost/relay v0.5.0 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 diff --git a/go.sum b/go.sum index ae4f9b8..af59158 100644 --- a/go.sum +++ b/go.sum @@ -53,12 +53,12 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= -github.com/go-gost/core v0.0.0-20240621153412-5aede9a2b32f h1:deEX5HhpUDB03wAggTWn3/8l20cEGVq25bMsd9wqizo= -github.com/go-gost/core v0.0.0-20240621153412-5aede9a2b32f/go.mod h1:aTPFucvJyqc/o5h5/ZtyHJ0xgFIq5Ip+cMlhazm+TaI= +github.com/go-gost/core v0.0.0-20240624131323-ca340b1bf1a2 h1:+VxqwMcnO/Jqpa88n9D2YoApTFrSRbjlFd9Oy/xvE0s= +github.com/go-gost/core v0.0.0-20240624131323-ca340b1bf1a2/go.mod h1:aTPFucvJyqc/o5h5/ZtyHJ0xgFIq5Ip+cMlhazm+TaI= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= -github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= -github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= +github.com/go-gost/gosocks5 v0.3.1 h1:N6K/gE8oNLJX2nVX/O50FERHjgW4gGksZ7QbOvPF3n8= +github.com/go-gost/gosocks5 v0.3.1/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a h1:ME7P1Brcg4C640DSPqlvQr7JuvvQfJ8QpmS3yCFlK3A= github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= github.com/go-gost/relay v0.5.0 h1:JG1tgy/KWiVXS0ukuVXvbM0kbYuJTWxYpJ5JwzsCf/c= diff --git a/handler/relay/bind.go b/handler/relay/bind.go index 459daa1..807655f 100644 --- a/handler/relay/bind.go +++ b/handler/relay/bind.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" "github.com/go-gost/relay" + xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/udp" "github.com/go-gost/x/internal/util/mux" relay_util "github.com/go-gost/x/internal/util/relay" @@ -50,7 +51,10 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr Status: relay.StatusOK, } - ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error + lc := xnet.ListenConfig{ + Netns: h.options.Netns, + } + ln, err := lc.Listen(ctx, network, address) // strict mode: if the port already in use, it will return error if err != nil { log.Error(err) resp.Status = relay.StatusServiceUnavailable @@ -129,10 +133,10 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr Status: relay.StatusOK, } - var pc net.PacketConn - var err error - bindAddr, _ := net.ResolveUDPAddr(network, address) - pc, err = net.ListenUDP(network, bindAddr) + lc := xnet.ListenConfig{ + Netns: h.options.Netns, + } + pc, err := lc.ListenPacket(ctx, network, address) if err != nil { log.Error(err) return err diff --git a/handler/socks/v5/bind.go b/handler/socks/v5/bind.go index ae37b40..6c7a798 100644 --- a/handler/socks/v5/bind.go +++ b/handler/socks/v5/bind.go @@ -9,6 +9,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" netpkg "github.com/go-gost/x/internal/net" + xnet "github.com/go-gost/x/internal/net" ) func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { @@ -31,7 +32,10 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, } func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { - ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error + lc := xnet.ListenConfig{ + Netns: h.options.Netns, + } + ln, err := lc.Listen(ctx, network, address) // strict mode: if the port already in use, it will return error if err != nil { log.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) @@ -95,7 +99,7 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis defer close(errc) defer pc1.Close() - errc <- netpkg.Transport(conn, pc1) + errc <- xnet.Transport(conn, pc1) }() return errc diff --git a/handler/socks/v5/mbind.go b/handler/socks/v5/mbind.go index 9038780..3c25bd2 100644 --- a/handler/socks/v5/mbind.go +++ b/handler/socks/v5/mbind.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" - netpkg "github.com/go-gost/x/internal/net" + xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/util/mux" ) @@ -31,7 +31,10 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, networ } func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { - ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error + lc := xnet.ListenConfig{ + Netns: h.options.Netns, + } + ln, err := lc.Listen(ctx, network, address) // strict mode: if the port already in use, it will return error if err != nil { log.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) @@ -125,7 +128,7 @@ func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net. t := time.Now() log.Debugf("%s <-> %s", c.LocalAddr(), c.RemoteAddr()) - netpkg.Transport(sc, c) + xnet.Transport(sc, c) log.WithFields(map[string]any{"duration": time.Since(t)}). Debugf("%s >-< %s", c.LocalAddr(), c.RemoteAddr()) }(rc) diff --git a/handler/socks/v5/udp.go b/handler/socks/v5/udp.go index 709a899..091a77b 100644 --- a/handler/socks/v5/udp.go +++ b/handler/socks/v5/udp.go @@ -11,6 +11,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" ctxvalue "github.com/go-gost/x/ctx" + xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/udp" "github.com/go-gost/x/internal/util/socks" "github.com/go-gost/x/stats" @@ -29,7 +30,11 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger return reply.Write(conn) } - cc, err := net.ListenUDP("udp", nil) + lc := xnet.ListenConfig{ + Netns: h.options.Netns, + } + laddr := &net.UDPAddr{IP: conn.LocalAddr().(*net.TCPAddr).IP, Port: 0} // use out-going interface's IP + cc, err := lc.ListenPacket(ctx, "udp", laddr.String()) if err != nil { log.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) @@ -41,8 +46,6 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger saddr := gosocks5.Addr{} saddr.ParseFrom(cc.LocalAddr().String()) - saddr.Type = 0 - saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr) log.Trace(reply) if err := reply.Write(conn); err != nil { @@ -70,17 +73,16 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, log logger return err } - var lc net.PacketConn = cc clientID := ctxvalue.ClientIDFromContext(ctx) if h.options.Observer != nil { pstats := h.stats.Stats(string(clientID)) pstats.Add(stats.KindTotalConns, 1) pstats.Add(stats.KindCurrentConns, 1) defer pstats.Add(stats.KindCurrentConns, -1) - lc = stats_wrapper.WrapPacketConn(lc, pstats) + cc = stats_wrapper.WrapPacketConn(cc, pstats) } - r := udp.NewRelay(socks.UDPConn(lc, h.md.udpBufferSize), pc). + r := udp.NewRelay(socks.UDPConn(cc, h.md.udpBufferSize), pc). WithBypass(h.options.Bypass). WithLogger(log) r.SetBufferSize(h.md.udpBufferSize) diff --git a/handler/socks/v5/udp_tun.go b/handler/socks/v5/udp_tun.go index 33242f4..1f924c2 100644 --- a/handler/socks/v5/udp_tun.go +++ b/handler/socks/v5/udp_tun.go @@ -2,12 +2,14 @@ package v5 import ( "context" + "errors" "net" "time" "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" ctxvalue "github.com/go-gost/x/ctx" + xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/udp" "github.com/go-gost/x/internal/util/socks" "github.com/go-gost/x/stats" @@ -24,28 +26,52 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network bindAddr = &net.UDPAddr{} } + var pc net.PacketConn + // relay mode if bindAddr.Port == 0 { - // relay mode if !h.md.enableUDP { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) log.Trace(reply) log.Error("socks5: UDP relay is disabled") return reply.Write(conn) } - } else { - // BIND mode + + // obtain a udp connection + c, err := h.router.Dial(ctx, "udp", "") // UDP association + if err != nil { + log.Error(err) + return err + } + defer c.Close() + + var ok bool + pc, ok = c.(net.PacketConn) + if !ok { + err := errors.New("socks5: wrong connection type") + log.Error(err) + return err + } + + } else { // BIND mode if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) log.Trace(reply) log.Error("socks5: BIND is disabled") return reply.Write(conn) } - } - pc, err := net.ListenUDP(network, bindAddr) - if err != nil { - log.Error(err) - return err + lc := xnet.ListenConfig{ + Netns: h.options.Netns, + } + var err error + pc, err = lc.ListenPacket(ctx, "udp", bindAddr.String()) + if err != nil { + log.Error(err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + log.Trace(reply) + reply.Write(conn) + return err + } } defer pc.Close() diff --git a/internal/net/net.go b/internal/net/net.go index 364dafa..618552e 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -1,8 +1,13 @@ package net import ( + "context" + "fmt" "net" + "runtime" "syscall" + + "github.com/vishvananda/netns" ) type SetBuffer interface { @@ -26,3 +31,57 @@ type SetDSCP interface { func IsIPv4(address string) bool { return address != "" && address[0] != ':' && address[0] != '[' } + +type ListenConfig struct { + Netns string + net.ListenConfig +} + +func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) { + if lc.Netns != "" { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + originNs, err := netns.Get() + if err != nil { + return nil, fmt.Errorf("netns.Get(): %v", err) + } + defer netns.Set(originNs) + + ns, err := netns.GetFromName(lc.Netns) + if err != nil { + return nil, fmt.Errorf("netns.GetFromName(%s): %v", lc.Netns, err) + } + defer ns.Close() + + if err := netns.Set(ns); err != nil { + return nil, fmt.Errorf("netns.Set(%s): %v", lc.Netns, err) + } + } + + return lc.ListenConfig.Listen(ctx, network, address) +} + +func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if lc.Netns != "" { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + originNs, err := netns.Get() + if err != nil { + return nil, fmt.Errorf("netns.Get(): %v", err) + } + defer netns.Set(originNs) + + ns, err := netns.GetFromName(lc.Netns) + if err != nil { + return nil, fmt.Errorf("netns.GetFromName(%s): %v", lc.Netns, err) + } + defer ns.Close() + + if err := netns.Set(ns); err != nil { + return nil, fmt.Errorf("netns.Set(%s): %v", lc.Netns, err) + } + } + return lc.ListenConfig.ListenPacket(ctx, network, address) +}