add auth for tun

This commit is contained in:
ginuerzh 2022-10-20 21:08:26 +08:00
parent 15d0a33716
commit 669e80b780
5 changed files with 80 additions and 21 deletions

2
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/gin-contrib/cors v1.3.1 github.com/gin-contrib/cors v1.3.1
github.com/gin-gonic/gin v1.7.7 github.com/gin-gonic/gin v1.7.7
github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903 github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc
github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7

4
go.sum
View File

@ -98,8 +98,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903 h1:Ye6Ns0+Ms63vC+nbe9sBgBDTr+l+ukPX18SvEJuWXUw= github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc h1:pS75VLwTkYLIC3n0QbfwE65N/1Zh8BnXfErNq9DGWd4=
github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=

View File

@ -1,6 +1,7 @@
package tun package tun
import ( import (
"bytes"
"context" "context"
"io" "io"
"net" "net"
@ -15,12 +16,12 @@ import (
) )
const ( const (
// 4-byte magic header followed by 16-byte IP address // 4-byte magic header followed by 16-byte IP address followed by 16-byte key.
keepAliveDataLength = 20 keepAliveDataLength = 36
) )
var ( var (
keepAliveHeader = []byte("GOST") magicHeader = []byte("GOST")
) )
func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) error { func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) error {
@ -38,22 +39,26 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.A
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
if h.md.keepAlivePeriod > 0 { go h.keepAlive(ctx, cc, ip)
go h.keepAlive(ctx, cc, ip)
}
return h.transportClient(conn, cc, config, log) return h.transportClient(conn, cc, config, log)
} }
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
// handshake
var keepAliveData [keepAliveDataLength]byte var keepAliveData [keepAliveDataLength]byte
copy(keepAliveData[:4], keepAliveHeader) // magic header copy(keepAliveData[:4], magicHeader) // magic header
copy(keepAliveData[4:], ip.To16()) copy(keepAliveData[4:20], ip.To16())
copy(keepAliveData[20:36], []byte(h.md.passphrase))
if _, err := conn.Write(keepAliveData[:]); err != nil { if _, err := conn.Write(keepAliveData[:]); err != nil {
return return
} }
if h.md.keepAlivePeriod <= 0 {
return
}
conn.SetReadDeadline(time.Now().Add(h.md.keepAlivePeriod * 3))
ticker := time.NewTicker(h.md.keepAlivePeriod) ticker := time.NewTicker(h.md.keepAlivePeriod)
defer ticker.Stop() defer ticker.Stop()
@ -63,6 +68,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
if _, err := conn.Write(keepAliveData[:]); err != nil { if _, err := conn.Write(keepAliveData[:]); err != nil {
return return
} }
h.options.Logger.Debugf("keepalive sended")
case <-ctx.Done(): case <-ctx.Done():
return return
} }
@ -131,6 +137,16 @@ func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_ut
return err return err
} }
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
ip := net.IP((*b)[4:20])
log.Debugf("keepalive received at %v", ip)
if h.md.keepAlivePeriod > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.keepAlivePeriod * 3))
}
return nil
}
if waterutil.IsIPv4((*b)[:n]) { if waterutil.IsIPv4((*b)[:n]) {
header, err := ipv4.ParseHeader((*b)[:n]) header, err := ipv4.ParseHeader((*b)[:n])
if err != nil { if err != nil {

View File

@ -14,6 +14,7 @@ const (
type metadata struct { type metadata struct {
bufferSize int bufferSize int
keepAlivePeriod time.Duration keepAlivePeriod time.Duration
passphrase string
} }
func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -21,6 +22,7 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
bufferSize = "bufferSize" bufferSize = "bufferSize"
keepAlive = "keepAlive" keepAlive = "keepAlive"
keepAlivePeriod = "ttl" keepAlivePeriod = "ttl"
passphrase = "passphrase"
) )
h.md.bufferSize = mdutil.GetInt(md, bufferSize) h.md.bufferSize = mdutil.GetInt(md, bufferSize)
@ -34,5 +36,7 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.keepAlivePeriod = defaultKeepAlivePeriod h.md.keepAlivePeriod = defaultKeepAlivePeriod
} }
} }
h.md.passphrase = mdutil.GetString(md, passphrase)
return return
} }

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"io" "io"
"net" "net"
"net/netip"
"github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/common/bufpool"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
@ -25,6 +26,11 @@ func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tu
} }
func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error { func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error {
tunIP, _, err := net.ParseCIDR(config.Net)
if err != nil {
return err
}
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
@ -48,7 +54,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
src, dst = header.Src, header.Dst src, dst = header.Src, header.Dst
log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d",
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), src, dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
header.Len, header.TotalLen, header.ID, header.Flags) header.Len, header.TotalLen, header.ID, header.Flags)
} else if waterutil.IsIPv6((*b)[:n]) { } else if waterutil.IsIPv6((*b)[:n]) {
header, err := ipv6.ParseHeader((*b)[:n]) header, err := ipv6.ParseHeader((*b)[:n])
@ -59,7 +65,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
src, dst = header.Src, header.Dst src, dst = header.Src, header.Dst
log.Tracef("%s >> %s %s %d %d", log.Tracef("%s >> %s %s %d %d",
header.Src, header.Dst, src, dst,
ipProtocol(waterutil.IPProtocol(header.NextHeader)), ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass) header.PayloadLen, header.TrafficClass)
} else { } else {
@ -98,9 +104,42 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
if err != nil { if err != nil {
return err return err
} }
if n == keepAliveDataLength && bytes.Equal((*b)[:4], keepAliveHeader) { if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
peerIP := net.IP((*b)[4:keepAliveDataLength]) peerIP := net.IP((*b)[4:20])
log.Debugf("keepalive from %v => %v", peerIP, addr) key := bytes.TrimRight((*b)[20:36], "\x00")
if peerIP.Equal(tunIP.To16()) {
return nil
}
if auther := h.options.Auther; auther != nil {
ip := peerIP
if v := peerIP.To4(); ip != nil {
ip = v
}
if !auther.Authenticate(ip.String(), string(key)) {
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIP)
return nil
}
}
log.Debugf("keepalive from %v => %v", addr, peerIP)
addrPort, err := netip.ParseAddrPort(addr.String())
if err != nil {
log.Warnf("keepalive from %v: %v", addr, err)
return nil
}
var keepAliveData [keepAliveDataLength]byte
copy(keepAliveData[:4], magicHeader) // magic header
a16 := addrPort.Addr().As16()
copy(keepAliveData[4:], a16[:])
if _, err := conn.WriteTo(keepAliveData[:], addr); err != nil {
log.Warnf("keepalive to %v: %v", addr, err)
return nil
}
h.updateRoute(peerIP, addr, log) h.updateRoute(peerIP, addr, log)
return nil return nil
} }
@ -115,7 +154,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
src, dst = header.Src, header.Dst src, dst = header.Src, header.Dst
log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d", log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d",
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])), src, dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
header.Len, header.TotalLen, header.ID, header.Flags) header.Len, header.TotalLen, header.ID, header.Flags)
} else if waterutil.IsIPv6((*b)[:n]) { } else if waterutil.IsIPv6((*b)[:n]) {
header, err := ipv6.ParseHeader((*b)[:n]) header, err := ipv6.ParseHeader((*b)[:n])
@ -126,7 +165,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
src, dst = header.Src, header.Dst src, dst = header.Src, header.Dst
log.Tracef("%s > %s %s %d %d", log.Tracef("%s > %s %s %d %d",
header.Src, header.Dst, src, dst,
ipProtocol(waterutil.IPProtocol(header.NextHeader)), ipProtocol(waterutil.IPProtocol(header.NextHeader)),
header.PayloadLen, header.TrafficClass) header.PayloadLen, header.TrafficClass)
} else { } else {
@ -134,7 +173,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
return nil return nil
} }
h.updateRoute(src, addr, log) // h.updateRoute(src, addr, log)
if addr := h.findRouteFor(dst, config.Routes...); addr != nil { if addr := h.findRouteFor(dst, config.Routes...); addr != nil {
log.Debugf("find route: %s -> %s", dst, addr) log.Debugf("find route: %s -> %s", dst, addr)
@ -156,7 +195,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
} }
}() }()
err := <-errc err = <-errc
if err != nil && err == io.EOF { if err != nil && err == io.EOF {
err = nil err = nil
} }