add auth for tun
This commit is contained in:
parent
15d0a33716
commit
669e80b780
2
go.mod
2
go.mod
@ -7,7 +7,7 @@ require (
|
|||||||
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
|
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
|
||||||
github.com/gin-contrib/cors v1.3.1
|
github.com/gin-contrib/cors v1.3.1
|
||||||
github.com/gin-gonic/gin v1.7.7
|
github.com/gin-gonic/gin v1.7.7
|
||||||
github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903
|
github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc
|
||||||
github.com/go-gost/gosocks4 v0.0.1
|
github.com/go-gost/gosocks4 v0.0.1
|
||||||
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
|
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
|
||||||
github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7
|
github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7
|
||||||
|
4
go.sum
4
go.sum
@ -98,8 +98,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ
|
|||||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||||
github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903 h1:Ye6Ns0+Ms63vC+nbe9sBgBDTr+l+ukPX18SvEJuWXUw=
|
github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc h1:pS75VLwTkYLIC3n0QbfwE65N/1Zh8BnXfErNq9DGWd4=
|
||||||
github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
|
github.com/go-gost/core v0.0.0-20221020130224-eb9d483127cc/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
|
||||||
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
|
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
|
||||||
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
|
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
|
||||||
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=
|
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package tun
|
package tun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -15,12 +16,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// 4-byte magic header followed by 16-byte IP address
|
// 4-byte magic header followed by 16-byte IP address followed by 16-byte key.
|
||||||
keepAliveDataLength = 20
|
keepAliveDataLength = 36
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
keepAliveHeader = []byte("GOST")
|
magicHeader = []byte("GOST")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) error {
|
func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) error {
|
||||||
@ -38,22 +39,26 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.A
|
|||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if h.md.keepAlivePeriod > 0 {
|
|
||||||
go h.keepAlive(ctx, cc, ip)
|
go h.keepAlive(ctx, cc, ip)
|
||||||
}
|
|
||||||
|
|
||||||
return h.transportClient(conn, cc, config, log)
|
return h.transportClient(conn, cc, config, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
|
func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
|
||||||
|
// handshake
|
||||||
var keepAliveData [keepAliveDataLength]byte
|
var keepAliveData [keepAliveDataLength]byte
|
||||||
copy(keepAliveData[:4], keepAliveHeader) // magic header
|
copy(keepAliveData[:4], magicHeader) // magic header
|
||||||
copy(keepAliveData[4:], ip.To16())
|
copy(keepAliveData[4:20], ip.To16())
|
||||||
|
copy(keepAliveData[20:36], []byte(h.md.passphrase))
|
||||||
if _, err := conn.Write(keepAliveData[:]); err != nil {
|
if _, err := conn.Write(keepAliveData[:]); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.md.keepAlivePeriod <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.SetReadDeadline(time.Now().Add(h.md.keepAlivePeriod * 3))
|
||||||
|
|
||||||
ticker := time.NewTicker(h.md.keepAlivePeriod)
|
ticker := time.NewTicker(h.md.keepAlivePeriod)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@ -63,6 +68,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) {
|
|||||||
if _, err := conn.Write(keepAliveData[:]); err != nil {
|
if _, err := conn.Write(keepAliveData[:]); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.options.Logger.Debugf("keepalive sended")
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -131,6 +137,16 @@ func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_ut
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
|
||||||
|
ip := net.IP((*b)[4:20])
|
||||||
|
log.Debugf("keepalive received at %v", ip)
|
||||||
|
|
||||||
|
if h.md.keepAlivePeriod > 0 {
|
||||||
|
conn.SetReadDeadline(time.Now().Add(h.md.keepAlivePeriod * 3))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if waterutil.IsIPv4((*b)[:n]) {
|
if waterutil.IsIPv4((*b)[:n]) {
|
||||||
header, err := ipv4.ParseHeader((*b)[:n])
|
header, err := ipv4.ParseHeader((*b)[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -14,6 +14,7 @@ const (
|
|||||||
type metadata struct {
|
type metadata struct {
|
||||||
bufferSize int
|
bufferSize int
|
||||||
keepAlivePeriod time.Duration
|
keepAlivePeriod time.Duration
|
||||||
|
passphrase string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
|
func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||||
@ -21,6 +22,7 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
bufferSize = "bufferSize"
|
bufferSize = "bufferSize"
|
||||||
keepAlive = "keepAlive"
|
keepAlive = "keepAlive"
|
||||||
keepAlivePeriod = "ttl"
|
keepAlivePeriod = "ttl"
|
||||||
|
passphrase = "passphrase"
|
||||||
)
|
)
|
||||||
|
|
||||||
h.md.bufferSize = mdutil.GetInt(md, bufferSize)
|
h.md.bufferSize = mdutil.GetInt(md, bufferSize)
|
||||||
@ -34,5 +36,7 @@ func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
h.md.keepAlivePeriod = defaultKeepAlivePeriod
|
h.md.keepAlivePeriod = defaultKeepAlivePeriod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.md.passphrase = mdutil.GetString(md, passphrase)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/go-gost/core/common/bufpool"
|
"github.com/go-gost/core/common/bufpool"
|
||||||
"github.com/go-gost/core/logger"
|
"github.com/go-gost/core/logger"
|
||||||
@ -25,6 +26,11 @@ func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error {
|
func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error {
|
||||||
|
tunIP, _, err := net.ParseCIDR(config.Net)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
errc := make(chan error, 1)
|
errc := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -48,7 +54,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
|
|||||||
src, dst = header.Src, header.Dst
|
src, dst = header.Src, header.Dst
|
||||||
|
|
||||||
log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d",
|
log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||||
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
src, dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
||||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||||
} else if waterutil.IsIPv6((*b)[:n]) {
|
} else if waterutil.IsIPv6((*b)[:n]) {
|
||||||
header, err := ipv6.ParseHeader((*b)[:n])
|
header, err := ipv6.ParseHeader((*b)[:n])
|
||||||
@ -59,7 +65,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
|
|||||||
src, dst = header.Src, header.Dst
|
src, dst = header.Src, header.Dst
|
||||||
|
|
||||||
log.Tracef("%s >> %s %s %d %d",
|
log.Tracef("%s >> %s %s %d %d",
|
||||||
header.Src, header.Dst,
|
src, dst,
|
||||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||||
header.PayloadLen, header.TrafficClass)
|
header.PayloadLen, header.TrafficClass)
|
||||||
} else {
|
} else {
|
||||||
@ -98,9 +104,42 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if n == keepAliveDataLength && bytes.Equal((*b)[:4], keepAliveHeader) {
|
if n == keepAliveDataLength && bytes.Equal((*b)[:4], magicHeader) {
|
||||||
peerIP := net.IP((*b)[4:keepAliveDataLength])
|
peerIP := net.IP((*b)[4:20])
|
||||||
log.Debugf("keepalive from %v => %v", peerIP, addr)
|
key := bytes.TrimRight((*b)[20:36], "\x00")
|
||||||
|
|
||||||
|
if peerIP.Equal(tunIP.To16()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if auther := h.options.Auther; auther != nil {
|
||||||
|
ip := peerIP
|
||||||
|
if v := peerIP.To4(); ip != nil {
|
||||||
|
ip = v
|
||||||
|
}
|
||||||
|
if !auther.Authenticate(ip.String(), string(key)) {
|
||||||
|
log.Debugf("keepalive from %v => %v, auth FAILED", addr, peerIP)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("keepalive from %v => %v", addr, peerIP)
|
||||||
|
|
||||||
|
addrPort, err := netip.ParseAddrPort(addr.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("keepalive from %v: %v", addr, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var keepAliveData [keepAliveDataLength]byte
|
||||||
|
copy(keepAliveData[:4], magicHeader) // magic header
|
||||||
|
a16 := addrPort.Addr().As16()
|
||||||
|
copy(keepAliveData[4:], a16[:])
|
||||||
|
|
||||||
|
if _, err := conn.WriteTo(keepAliveData[:], addr); err != nil {
|
||||||
|
log.Warnf("keepalive to %v: %v", addr, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
h.updateRoute(peerIP, addr, log)
|
h.updateRoute(peerIP, addr, log)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -115,7 +154,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
|
|||||||
src, dst = header.Src, header.Dst
|
src, dst = header.Src, header.Dst
|
||||||
|
|
||||||
log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d",
|
log.Tracef("%s >> %s %-4s %d/%-4d %-4x %d",
|
||||||
header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
src, dst, ipProtocol(waterutil.IPv4Protocol((*b)[:n])),
|
||||||
header.Len, header.TotalLen, header.ID, header.Flags)
|
header.Len, header.TotalLen, header.ID, header.Flags)
|
||||||
} else if waterutil.IsIPv6((*b)[:n]) {
|
} else if waterutil.IsIPv6((*b)[:n]) {
|
||||||
header, err := ipv6.ParseHeader((*b)[:n])
|
header, err := ipv6.ParseHeader((*b)[:n])
|
||||||
@ -126,7 +165,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
|
|||||||
src, dst = header.Src, header.Dst
|
src, dst = header.Src, header.Dst
|
||||||
|
|
||||||
log.Tracef("%s > %s %s %d %d",
|
log.Tracef("%s > %s %s %d %d",
|
||||||
header.Src, header.Dst,
|
src, dst,
|
||||||
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
ipProtocol(waterutil.IPProtocol(header.NextHeader)),
|
||||||
header.PayloadLen, header.TrafficClass)
|
header.PayloadLen, header.TrafficClass)
|
||||||
} else {
|
} else {
|
||||||
@ -134,7 +173,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
h.updateRoute(src, addr, log)
|
// h.updateRoute(src, addr, log)
|
||||||
|
|
||||||
if addr := h.findRouteFor(dst, config.Routes...); addr != nil {
|
if addr := h.findRouteFor(dst, config.Routes...); addr != nil {
|
||||||
log.Debugf("find route: %s -> %s", dst, addr)
|
log.Debugf("find route: %s -> %s", dst, addr)
|
||||||
@ -156,7 +195,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := <-errc
|
err = <-errc
|
||||||
if err != nil && err == io.EOF {
|
if err != nil && err == io.EOF {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user