From 37fed3f372c484d11ecc5a82411155f50940b65b Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 31 Mar 2022 11:14:44 +0800 Subject: [PATCH] add tproxy --- config/config.go | 2 +- handler/redirect/tcp/handler.go | 20 ++- handler/redirect/tcp/metadata.go | 3 + handler/redirect/udp/handler.go | 5 +- listener/redirect/tcp/listener_linux.go | 2 +- listener/redirect/udp/conn.go | 2 +- listener/redirect/udp/listener.go | 10 +- listener/redirect/udp/listener_linux.go | 191 +++++++++++++++++++++++- listener/redirect/udp/metadata.go | 2 +- 9 files changed, 210 insertions(+), 27 deletions(-) diff --git a/config/config.go b/config/config.go index a5c5811..2172281 100644 --- a/config/config.go +++ b/config/config.go @@ -62,7 +62,7 @@ type APIConfig struct { type MetricsConfig struct { Addr string `json:"addr"` - Path string `json:"path"` + Path string `yaml:",omitempty" json:"path,omitempty"` } type TLSConfig struct { diff --git a/handler/redirect/tcp/handler.go b/handler/redirect/tcp/handler.go index 62ff2b0..2d4cb25 100644 --- a/handler/redirect/tcp/handler.go +++ b/handler/redirect/tcp/handler.go @@ -58,7 +58,7 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) { return } -func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { +func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) (err error) { defer conn.Close() start := time.Now() @@ -74,16 +74,20 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - network := "tcp" + var dstAddr net.Addr - dstAddr, err := h.getOriginalDstAddr(conn) - if err != nil { - log.Error(err) - return err + if h.md.tproxy { + dstAddr = conn.LocalAddr() + } else { + dstAddr, err = h.getOriginalDstAddr(conn) + if err != nil { + log.Error(err) + return + } } log = log.WithFields(map[string]any{ - "dst": fmt.Sprintf("%s/%s", dstAddr, network), + "dst": fmt.Sprintf("%s/%s", dstAddr, dstAddr.Network()), }) var rw io.ReadWriter = conn @@ -120,7 +124,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han return nil } - cc, err := h.router.Dial(ctx, network, dstAddr.String()) + cc, err := h.router.Dial(ctx, dstAddr.Network(), dstAddr.String()) if err != nil { log.Error(err) return err diff --git a/handler/redirect/tcp/metadata.go b/handler/redirect/tcp/metadata.go index 209bce3..6c707f3 100644 --- a/handler/redirect/tcp/metadata.go +++ b/handler/redirect/tcp/metadata.go @@ -6,12 +6,15 @@ import ( type metadata struct { sniffing bool + tproxy bool } func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) { const ( sniffing = "sniffing" + tproxy = "tproxy" ) h.md.sniffing = mdata.GetBool(md, sniffing) + h.md.tproxy = mdata.GetBool(md, tproxy) return } diff --git a/handler/redirect/udp/handler.go b/handler/redirect/udp/handler.go index 940a4d4..f92b9b6 100644 --- a/handler/redirect/udp/handler.go +++ b/handler/redirect/udp/handler.go @@ -63,11 +63,10 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - network := "udp" dstAddr := conn.LocalAddr() log = log.WithFields(map[string]any{ - "dst": fmt.Sprintf("%s/%s", dstAddr, network), + "dst": fmt.Sprintf("%s/%s", dstAddr, dstAddr.Network()), }) log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr) @@ -77,7 +76,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han return nil } - cc, err := h.router.Dial(ctx, network, dstAddr.String()) + cc, err := h.router.Dial(ctx, dstAddr.Network(), dstAddr.String()) if err != nil { log.Error(err) return err diff --git a/listener/redirect/tcp/listener_linux.go b/listener/redirect/tcp/listener_linux.go index 95bae02..6e33b26 100644 --- a/listener/redirect/tcp/listener_linux.go +++ b/listener/redirect/tcp/listener_linux.go @@ -9,7 +9,7 @@ import ( func (l *redirectListener) control(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { - l.logger.Errorf("set sockopt: %v", err) + l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err) } }) } diff --git a/listener/redirect/udp/conn.go b/listener/redirect/udp/conn.go index 152fa65..b023529 100644 --- a/listener/redirect/udp/conn.go +++ b/listener/redirect/udp/conn.go @@ -24,12 +24,12 @@ func (c *redirConn) Read(b []byte) (n int, err error) { c.once.Do(func() { n = copy(b, c.buf) bufpool.Put(&c.buf) - c.buf = nil }) if n == 0 { n, err = c.Conn.Read(b) } + return } diff --git a/listener/redirect/udp/listener.go b/listener/redirect/udp/listener.go index 57a066c..3a65fd7 100644 --- a/listener/redirect/udp/listener.go +++ b/listener/redirect/udp/listener.go @@ -6,6 +6,7 @@ import ( "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + metrics "github.com/go-gost/core/metrics/wrapper" "github.com/go-gost/core/registry" ) @@ -36,12 +37,7 @@ func (l *redirectListener) Init(md md.Metadata) (err error) { return } - laddr, err := net.ResolveUDPAddr("udp", l.options.Addr) - if err != nil { - return - } - - ln, err := l.listenUDP(laddr) + ln, err := l.listenUDP(l.options.Addr) if err != nil { return } @@ -55,7 +51,7 @@ func (l *redirectListener) Accept() (conn net.Conn, err error) { if err != nil { return } - // conn = metrics.WrapConn(l.options.Service, conn) + conn = metrics.WrapConn(l.options.Service, conn) return } diff --git a/listener/redirect/udp/listener_linux.go b/listener/redirect/udp/listener_linux.go index a434ff9..79cac23 100644 --- a/listener/redirect/udp/listener_linux.go +++ b/listener/redirect/udp/listener_linux.go @@ -1,20 +1,45 @@ package udp import ( + "bytes" + "context" + "encoding/binary" + "fmt" "net" + "os" + "strconv" + "syscall" + "unsafe" - "github.com/LiamHaworth/go-tproxy" "github.com/go-gost/core/common/bufpool" + "golang.org/x/sys/unix" ) -func (l *redirectListener) listenUDP(addr *net.UDPAddr) (*net.UDPConn, error) { - return tproxy.ListenUDP("udp", addr) +func (l *redirectListener) listenUDP(addr string) (*net.UDPConn, error) { + lc := net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { + l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err) + } + if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil { + l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err) + } + }) + }, + } + pc, err := lc.ListenPacket(context.Background(), "udp", addr) + if err != nil { + return nil, err + } + + return pc.(*net.UDPConn), nil } func (l *redirectListener) accept() (conn net.Conn, err error) { b := bufpool.Get(l.md.readBufferSize) - n, raddr, dstAddr, err := tproxy.ReadFromUDP(l.ln, *b) + n, raddr, dstAddr, err := readFromUDP(l.ln, *b) if err != nil { l.logger.Error(err) return @@ -22,7 +47,7 @@ func (l *redirectListener) accept() (conn net.Conn, err error) { l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String()) - c, err := tproxy.DialUDP("udp", dstAddr, raddr) + c, err := dialUDP("udp", dstAddr, raddr) if err != nil { l.logger.Error(err) return @@ -35,3 +60,159 @@ func (l *redirectListener) accept() (conn net.Conn, err error) { } return } + +// ReadFromUDP reads a UDP packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// Out-of-band data is also read in so that the original destination +// address can be identified and parsed. +func readFromUDP(conn *net.UDPConn, b []byte) (n int, remoteAddr *net.UDPAddr, dstAddr *net.UDPAddr, err error) { + oob := bufpool.Get(1024) + defer bufpool.Put(oob) + + n, oobn, _, remoteAddr, err := conn.ReadMsgUDP(b, *oob) + if err != nil { + return 0, nil, nil, err + } + + msgs, err := unix.ParseSocketControlMessage((*oob)[:oobn]) + if err != nil { + return 0, nil, nil, fmt.Errorf("parsing socket control message: %s", err) + } + + for _, msg := range msgs { + if msg.Header.Level == unix.SOL_IP && msg.Header.Type == unix.IP_RECVORIGDSTADDR { + originalDstRaw := &unix.RawSockaddrInet4{} + if err = binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originalDstRaw); err != nil { + return 0, nil, nil, fmt.Errorf("reading original destination address: %s", err) + } + + switch originalDstRaw.Family { + case unix.AF_INET: + pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw)) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + dstAddr = &net.UDPAddr{ + IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), + Port: int(p[0])<<8 + int(p[1]), + } + + case unix.AF_INET6: + pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(originalDstRaw)) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + dstAddr = &net.UDPAddr{ + IP: net.IP(pp.Addr[:]), + Port: int(p[0])<<8 + int(p[1]), + Zone: strconv.Itoa(int(pp.Scope_id)), + } + + default: + return 0, nil, nil, fmt.Errorf("original destination is an unsupported network family") + } + break + } + } + + if dstAddr == nil { + return 0, nil, nil, fmt.Errorf("unable to obtain original destination: %s", err) + } + + return +} + +// DialUDP connects to the remote address raddr on the network net, +// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is +// used as the local address for the connection. +func dialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (net.Conn, error) { + remoteSocketAddress, err := udpAddrToSocketAddr(raddr) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build destination socket address: %s", err)} + } + + localSocketAddress, err := udpAddrToSocketAddr(laddr) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build local socket address: %s", err)} + } + + fileDescriptor, err := unix.Socket(udpAddrFamily(network, laddr, raddr), unix.SOCK_DGRAM, 0) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %s", err)} + } + + if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)} + } + + if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEADDR: %s", err)} + } + if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEPORT: %s", err)} + } + + if err = unix.Bind(fileDescriptor, localSocketAddress); err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket bind %v: %s", laddr, err)} + } + + if err = unix.Connect(fileDescriptor, remoteSocketAddress); err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket connect: %s", err)} + } + + fdFile := os.NewFile(uintptr(fileDescriptor), fmt.Sprintf("net-udp-dial-%s", raddr.String())) + defer fdFile.Close() + + remoteConn, err := net.FileConn(fdFile) + if err != nil { + unix.Close(fileDescriptor) + return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("convert file descriptor to connection: %s", err)} + } + + return remoteConn, nil +} + +// udpAddToSockerAddr will convert a UDPAddr +// into a Sockaddr that may be used when +// connecting and binding sockets +func udpAddrToSocketAddr(addr *net.UDPAddr) (unix.Sockaddr, error) { + switch { + case addr.IP.To4() != nil: + ip := [4]byte{} + copy(ip[:], addr.IP.To4()) + + return &unix.SockaddrInet4{Addr: ip, Port: addr.Port}, nil + + default: + ip := [16]byte{} + copy(ip[:], addr.IP.To16()) + + zoneID, err := strconv.ParseUint(addr.Zone, 10, 32) + if err != nil { + return nil, err + } + + return &unix.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil + } +} + +// udpAddrFamily will attempt to work +// out the address family based on the +// network and UDP addresses +func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int { + switch net[len(net)-1] { + case '4': + return unix.AF_INET + case '6': + return unix.AF_INET6 + } + + if (laddr == nil || laddr.IP.To4() != nil) && + (raddr == nil || laddr.IP.To4() != nil) { + return unix.AF_INET + } + return unix.AF_INET6 +} diff --git a/listener/redirect/udp/metadata.go b/listener/redirect/udp/metadata.go index 34b21e6..4173cf5 100644 --- a/listener/redirect/udp/metadata.go +++ b/listener/redirect/udp/metadata.go @@ -7,7 +7,7 @@ import ( ) const ( - defaultTTL = 60 * time.Second + defaultTTL = 30 * time.Second defaultReadBufferSize = 1500 )