From 669e80b78047d2bba6d7e3d6deac9ea319bd4c09 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 20 Oct 2022 21:08:26 +0800 Subject: [PATCH] add auth for tun --- go.mod | 2 +- go.sum | 4 +-- handler/tun/client.go | 34 +++++++++++++++++------- handler/tun/metadata.go | 4 +++ handler/tun/server.go | 57 ++++++++++++++++++++++++++++++++++------- 5 files changed, 80 insertions(+), 21 deletions(-) diff --git a/go.mod b/go.mod index 2b45c7a..915eb5a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.7.7 - github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903 + github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 diff --git a/go.sum b/go.sum index e754954..4040cb1 100644 --- a/go.sum +++ b/go.sum @@ -98,8 +98,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903 h1:Ye6Ns0+Ms63vC+nbe9sBgBDTr+l+ukPX18SvEJuWXUw= -github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= +github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc h1:pS75VLwTkYLIC3n0QbfwE65N/1Zh8BnXfErNq9DGWd4= +github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= diff --git a/handler/tun/client.go b/handler/tun/client.go index a3b6f40..1d05715 100644 --- a/handler/tun/client.go +++ b/handler/tun/client.go @@ -1,6 +1,7 @@ package tun import ( + "bytes" "context" "io" "net" @@ -15,12 +16,12 @@ import ( ) const ( - // 4-byte magic header followed by 16-byte IP address - keepAliveDataLength = 20 + // 4-byte magic header followed by 16-byte IP address followed by 16-byte key. + keepAliveDataLength = 36 ) var ( - keepAliveHeader = []byte("GOST") + magicHeader = []byte("GOST") ) func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) error { @@ -38,22 +39,26 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.A ctx, cancel := context.WithCancel(ctx) defer cancel() - if h.md.keepAlivePeriod > 0 { - go h.keepAlive(ctx, cc, ip) - } + go h.keepAlive(ctx, cc, ip) return h.transportClient(conn, cc, config, log) } func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { + // handshake var keepAliveData [keepAliveDataLength]byte - copy(keepAliveData[:4], keepAliveHeader) // magic header - copy(keepAliveData[4:], ip.To16()) - + copy(keepAliveData[:4], magicHeader) // magic header + copy(keepAliveData[4:20], ip.To16()) + copy(keepAliveData[20:36], []byte(h.md.passphrase)) if _, err := conn.Write(keepAliveData[:]); err != nil { return } + if h.md.keepAlivePeriod <= 0 { + return + } + conn.SetReadDeadline(time.Now().Add(h.md.keepAlivePeriod * 3)) + ticker := time.NewTicker(h.md.keepAlivePeriod) defer ticker.Stop() @@ -63,6 +68,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { if _, err := conn.Write(keepAliveData[:]); err != nil { return } + h.options.Logger.Debugf("keepalive sended") case <-ctx.Done(): return } @@ -131,6 +137,16 @@ func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_ut return err } + if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) { + ip := net.IP((*b)[4:20]) + log.Debugf("keepalive received at %v", ip) + + if h.md.keepAlivePeriod > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.keepAlivePeriod * 3)) + } + return nil + } + if waterutil.IsIPv4((*b)[:n]) { header, err := ipv4.ParseHeader((*b)[:n]) if err != nil { diff --git a/handler/tun/metadata.go b/handler/tun/metadata.go index b7d3db3..aeb1c49 100644 --- a/handler/tun/metadata.go +++ b/handler/tun/metadata.go @@ -14,6 +14,7 @@ const ( type metadata struct { bufferSize int keepAlivePeriod time.Duration + passphrase string } func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -21,6 +22,7 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { bufferSize = "bufferSize" keepAlive = "keepAlive" keepAlivePeriod = "ttl" + passphrase = "passphrase" ) h.md.bufferSize = mdutil.GetInt(md, bufferSize) @@ -34,5 +36,7 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.keepAlivePeriod = defaultKeepAlivePeriod } } + + h.md.passphrase = mdutil.GetString(md, passphrase) return } diff --git a/handler/tun/server.go b/handler/tun/server.go index 3cd9907..e319bae 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -5,6 +5,7 @@ import ( "context" "io" "net" + "net/netip" "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/logger" @@ -25,6 +26,11 @@ func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tu } func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error { + tunIP, _, err := net.ParseCIDR(config.Net) + if err != nil { + return err + } + errc := make(chan error, 1) go func() { @@ -48,7 +54,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * src, dst = header.Src, header.Dst log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", - header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), + src, dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), header.Len, header.TotalLen, header.ID, header.Flags) } else if waterutil.IsIPv6((*b)[:n]) { header, err := ipv6.ParseHeader((*b)[:n]) @@ -59,7 +65,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * src, dst = header.Src, header.Dst log.Tracef("%s >> %s %s %d %d", - header.Src, header.Dst, + src, dst, ipProtocol(waterutil.IPProtocol(header.NextHeader)), header.PayloadLen, header.TrafficClass) } else { @@ -98,9 +104,42 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * if err != nil { return err } - if n == keepAliveDataLength && bytes.Equal((*b)[:4], keepAliveHeader) { - peerIP := net.IP((*b)[4:keepAliveDataLength]) - log.Debugf("keepalive from %v => %v", peerIP, addr) + if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) { + peerIP := net.IP((*b)[4:20]) + key := bytes.TrimRight((*b)[20:36], "\x00") + + if peerIP.Equal(tunIP.To16()) { + return nil + } + + if auther := h.options.Auther; auther != nil { + ip := peerIP + if v := peerIP.To4(); ip != nil { + ip = v + } + if !auther.Authenticate(ip.String(), string(key)) { + log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIP) + return nil + } + } + + log.Debugf("keepalive from %v => %v", addr, peerIP) + + addrPort, err := netip.ParseAddrPort(addr.String()) + if err != nil { + log.Warnf("keepalive from %v: %v", addr, err) + return nil + } + var keepAliveData [keepAliveDataLength]byte + copy(keepAliveData[:4], magicHeader) // magic header + a16 := addrPort.Addr().As16() + copy(keepAliveData[4:], a16[:]) + + if _, err := conn.WriteTo(keepAliveData[:], addr); err != nil { + log.Warnf("keepalive to %v: %v", addr, err) + return nil + } + h.updateRoute(peerIP, addr, log) return nil } @@ -115,7 +154,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * src, dst = header.Src, header.Dst log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", - header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), + src, dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), header.Len, header.TotalLen, header.ID, header.Flags) } else if waterutil.IsIPv6((*b)[:n]) { header, err := ipv6.ParseHeader((*b)[:n]) @@ -126,7 +165,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * src, dst = header.Src, header.Dst log.Tracef("%s > %s %s %d %d", - header.Src, header.Dst, + src, dst, ipProtocol(waterutil.IPProtocol(header.NextHeader)), header.PayloadLen, header.TrafficClass) } else { @@ -134,7 +173,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * return nil } - h.updateRoute(src, addr, log) + // h.updateRoute(src, addr, log) if addr := h.findRouteFor(dst, config.Routes...); addr != nil { log.Debugf("find route: %s -> %s", dst, addr) @@ -156,7 +195,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * } }() - err := <-errc + err = <-errc if err != nil && err == io.EOF { err = nil }