add tproxy

This commit is contained in:
ginuerzh
2022-03-31 11:14:44 +08:00
parent 8564d711b8
commit 37fed3f372
9 changed files with 210 additions and 27 deletions

View File

@ -62,7 +62,7 @@ type APIConfig struct {
type MetricsConfig struct { type MetricsConfig struct {
Addr string `json:"addr"` Addr string `json:"addr"`
Path string `json:"path"` Path string `yaml:",omitempty" json:"path,omitempty"`
} }
type TLSConfig struct { type TLSConfig struct {

View File

@ -58,7 +58,7 @@ func (h *redirectHandler) Init(md md.Metadata) (err error) {
return return
} }
func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) (err error) {
defer conn.Close() defer conn.Close()
start := time.Now() start := time.Now()
@ -74,16 +74,20 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}() }()
network := "tcp" var dstAddr net.Addr
dstAddr, err := h.getOriginalDstAddr(conn) if h.md.tproxy {
dstAddr = conn.LocalAddr()
} else {
dstAddr, err = h.getOriginalDstAddr(conn)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return
}
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", dstAddr, network), "dst": fmt.Sprintf("%s/%s", dstAddr, dstAddr.Network()),
}) })
var rw io.ReadWriter = conn var rw io.ReadWriter = conn
@ -120,7 +124,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
return nil return nil
} }
cc, err := h.router.Dial(ctx, network, dstAddr.String()) cc, err := h.router.Dial(ctx, dstAddr.Network(), dstAddr.String())
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return err

View File

@ -6,12 +6,15 @@ import (
type metadata struct { type metadata struct {
sniffing bool sniffing bool
tproxy bool
} }
func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) { func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) {
const ( const (
sniffing = "sniffing" sniffing = "sniffing"
tproxy = "tproxy"
) )
h.md.sniffing = mdata.GetBool(md, sniffing) h.md.sniffing = mdata.GetBool(md, sniffing)
h.md.tproxy = mdata.GetBool(md, tproxy)
return return
} }

View File

@ -63,11 +63,10 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}() }()
network := "udp"
dstAddr := conn.LocalAddr() dstAddr := conn.LocalAddr()
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", dstAddr, network), "dst": fmt.Sprintf("%s/%s", dstAddr, dstAddr.Network()),
}) })
log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr) log.Infof("%s >> %s", conn.RemoteAddr(), dstAddr)
@ -77,7 +76,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
return nil return nil
} }
cc, err := h.router.Dial(ctx, network, dstAddr.String()) cc, err := h.router.Dial(ctx, dstAddr.Network(), dstAddr.String())
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return err

View File

@ -9,7 +9,7 @@ import (
func (l *redirectListener) control(network, address string, c syscall.RawConn) error { func (l *redirectListener) control(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil { if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil {
l.logger.Errorf("set sockopt: %v", err) l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err)
} }
}) })
} }

View File

@ -24,12 +24,12 @@ func (c *redirConn) Read(b []byte) (n int, err error) {
c.once.Do(func() { c.once.Do(func() {
n = copy(b, c.buf) n = copy(b, c.buf)
bufpool.Put(&c.buf) bufpool.Put(&c.buf)
c.buf = nil
}) })
if n == 0 { if n == 0 {
n, err = c.Conn.Read(b) n, err = c.Conn.Read(b)
} }
return return
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata" md "github.com/go-gost/core/metadata"
metrics "github.com/go-gost/core/metrics/wrapper"
"github.com/go-gost/core/registry" "github.com/go-gost/core/registry"
) )
@ -36,12 +37,7 @@ func (l *redirectListener) Init(md md.Metadata) (err error) {
return return
} }
laddr, err := net.ResolveUDPAddr("udp", l.options.Addr) ln, err := l.listenUDP(l.options.Addr)
if err != nil {
return
}
ln, err := l.listenUDP(laddr)
if err != nil { if err != nil {
return return
} }
@ -55,7 +51,7 @@ func (l *redirectListener) Accept() (conn net.Conn, err error) {
if err != nil { if err != nil {
return return
} }
// conn = metrics.WrapConn(l.options.Service, conn) conn = metrics.WrapConn(l.options.Service, conn)
return return
} }

View File

@ -1,20 +1,45 @@
package udp package udp
import ( import (
"bytes"
"context"
"encoding/binary"
"fmt"
"net" "net"
"os"
"strconv"
"syscall"
"unsafe"
"github.com/LiamHaworth/go-tproxy"
"github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/common/bufpool"
"golang.org/x/sys/unix"
) )
func (l *redirectListener) listenUDP(addr *net.UDPAddr) (*net.UDPConn, error) { func (l *redirectListener) listenUDP(addr string) (*net.UDPConn, error) {
return tproxy.ListenUDP("udp", addr) lc := net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil {
l.logger.Errorf("SetsockoptInt(SOL_IP, IP_TRANSPARENT, 1): %v", err)
}
if err := unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1); err != nil {
l.logger.Errorf("SetsockoptInt(SOL_IP, IP_RECVORIGDSTADDR, 1): %v", err)
}
})
},
}
pc, err := lc.ListenPacket(context.Background(), "udp", addr)
if err != nil {
return nil, err
}
return pc.(*net.UDPConn), nil
} }
func (l *redirectListener) accept() (conn net.Conn, err error) { func (l *redirectListener) accept() (conn net.Conn, err error) {
b := bufpool.Get(l.md.readBufferSize) b := bufpool.Get(l.md.readBufferSize)
n, raddr, dstAddr, err := tproxy.ReadFromUDP(l.ln, *b) n, raddr, dstAddr, err := readFromUDP(l.ln, *b)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
return return
@ -22,7 +47,7 @@ func (l *redirectListener) accept() (conn net.Conn, err error) {
l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String()) l.logger.Infof("%s >> %s", raddr.String(), dstAddr.String())
c, err := tproxy.DialUDP("udp", dstAddr, raddr) c, err := dialUDP("udp", dstAddr, raddr)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
return return
@ -35,3 +60,159 @@ func (l *redirectListener) accept() (conn net.Conn, err error) {
} }
return return
} }
// ReadFromUDP reads a UDP packet from c, copying the payload into b.
// It returns the number of bytes copied into b and the return address
// that was on the packet.
//
// Out-of-band data is also read in so that the original destination
// address can be identified and parsed.
func readFromUDP(conn *net.UDPConn, b []byte) (n int, remoteAddr *net.UDPAddr, dstAddr *net.UDPAddr, err error) {
oob := bufpool.Get(1024)
defer bufpool.Put(oob)
n, oobn, _, remoteAddr, err := conn.ReadMsgUDP(b, *oob)
if err != nil {
return 0, nil, nil, err
}
msgs, err := unix.ParseSocketControlMessage((*oob)[:oobn])
if err != nil {
return 0, nil, nil, fmt.Errorf("parsing socket control message: %s", err)
}
for _, msg := range msgs {
if msg.Header.Level == unix.SOL_IP && msg.Header.Type == unix.IP_RECVORIGDSTADDR {
originalDstRaw := &unix.RawSockaddrInet4{}
if err = binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originalDstRaw); err != nil {
return 0, nil, nil, fmt.Errorf("reading original destination address: %s", err)
}
switch originalDstRaw.Family {
case unix.AF_INET:
pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw))
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
dstAddr = &net.UDPAddr{
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
Port: int(p[0])<<8 + int(p[1]),
}
case unix.AF_INET6:
pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(originalDstRaw))
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
dstAddr = &net.UDPAddr{
IP: net.IP(pp.Addr[:]),
Port: int(p[0])<<8 + int(p[1]),
Zone: strconv.Itoa(int(pp.Scope_id)),
}
default:
return 0, nil, nil, fmt.Errorf("original destination is an unsupported network family")
}
break
}
}
if dstAddr == nil {
return 0, nil, nil, fmt.Errorf("unable to obtain original destination: %s", err)
}
return
}
// DialUDP connects to the remote address raddr on the network net,
// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is
// used as the local address for the connection.
func dialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (net.Conn, error) {
remoteSocketAddress, err := udpAddrToSocketAddr(raddr)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build destination socket address: %s", err)}
}
localSocketAddress, err := udpAddrToSocketAddr(laddr)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build local socket address: %s", err)}
}
fileDescriptor, err := unix.Socket(udpAddrFamily(network, laddr, raddr), unix.SOCK_DGRAM, 0)
if err != nil {
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %s", err)}
}
if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_IP, unix.IP_TRANSPARENT, 1); err != nil {
unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)}
}
if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEADDR: %s", err)}
}
if err = unix.SetsockoptInt(fileDescriptor, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEPORT: %s", err)}
}
if err = unix.Bind(fileDescriptor, localSocketAddress); err != nil {
unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket bind %v: %s", laddr, err)}
}
if err = unix.Connect(fileDescriptor, remoteSocketAddress); err != nil {
unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket connect: %s", err)}
}
fdFile := os.NewFile(uintptr(fileDescriptor), fmt.Sprintf("net-udp-dial-%s", raddr.String()))
defer fdFile.Close()
remoteConn, err := net.FileConn(fdFile)
if err != nil {
unix.Close(fileDescriptor)
return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("convert file descriptor to connection: %s", err)}
}
return remoteConn, nil
}
// udpAddToSockerAddr will convert a UDPAddr
// into a Sockaddr that may be used when
// connecting and binding sockets
func udpAddrToSocketAddr(addr *net.UDPAddr) (unix.Sockaddr, error) {
switch {
case addr.IP.To4() != nil:
ip := [4]byte{}
copy(ip[:], addr.IP.To4())
return &unix.SockaddrInet4{Addr: ip, Port: addr.Port}, nil
default:
ip := [16]byte{}
copy(ip[:], addr.IP.To16())
zoneID, err := strconv.ParseUint(addr.Zone, 10, 32)
if err != nil {
return nil, err
}
return &unix.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil
}
}
// udpAddrFamily will attempt to work
// out the address family based on the
// network and UDP addresses
func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int {
switch net[len(net)-1] {
case '4':
return unix.AF_INET
case '6':
return unix.AF_INET6
}
if (laddr == nil || laddr.IP.To4() != nil) &&
(raddr == nil || laddr.IP.To4() != nil) {
return unix.AF_INET
}
return unix.AF_INET6
}

View File

@ -7,7 +7,7 @@ import (
) )
const ( const (
defaultTTL = 60 * time.Second defaultTTL = 30 * time.Second
defaultReadBufferSize = 1500 defaultReadBufferSize = 1500
) )