diff --git a/handler/tun/client.go b/handler/tun/client.go index 782d614..1c2db70 100644 --- a/handler/tun/client.go +++ b/handler/tun/client.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "time" "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/logger" @@ -13,16 +14,57 @@ import ( "golang.org/x/net/ipv6" ) +const ( + // 4-byte magic header followed by 16-byte IP address + keepAliveDataLength = 20 +) + +var ( + keepAliveHeader = []byte("GOST") +) + func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.Addr, config *tun_util.Config, log logger.Logger) error { + ip, _, err := net.ParseCIDR(config.Net) + if err != nil { + return err + } + cc, err := h.router.Dial(ctx, addr.Network(), addr.String()) if err != nil { return err } defer cc.Close() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if h.md.keepAlivePeriod > 0 { + go h.keepAlive(ctx, cc, ip) + } + return h.transportClient(conn, cc, config, log) } +func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { + var keepAliveData [keepAliveDataLength]byte + copy(keepAliveData[:4], keepAliveHeader) // magic header + copy(keepAliveData[4:], ip.To16()) + + ticker := time.NewTicker(h.md.keepAlivePeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if _, err := conn.Write(keepAliveData[:]); err != nil { + return + } + case <-ctx.Done(): + return + } + } +} + func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_util.Config, log logger.Logger) error { errc := make(chan error, 1) diff --git a/handler/tun/handler.go b/handler/tun/handler.go index b3b344a..f2668b1 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -23,7 +23,6 @@ func init() { type tunHandler struct { group *chain.NodeGroup routes sync.Map - exit chan struct{} router *chain.Router md metadata options handler.Options @@ -36,7 +35,6 @@ func NewHandler(opts ...handler.Option) handler.Handler { } return &tunHandler{ - exit: make(chan struct{}, 1), options: options, } } diff --git a/handler/tun/metadata.go b/handler/tun/metadata.go index 72eeff2..8ed647b 100644 --- a/handler/tun/metadata.go +++ b/handler/tun/metadata.go @@ -1,22 +1,38 @@ package tun import ( + "time" + mdata "github.com/go-gost/core/metadata" mdx "github.com/go-gost/x/metadata" ) +const ( + defaultKeepAlivePeriod = 10 * time.Second +) + type metadata struct { - bufferSize int + bufferSize int + keepAlivePeriod time.Duration } func (h *tunHandler) parseMetadata(md mdata.Metadata) (err error) { const ( - bufferSize = "bufferSize" + bufferSize = "bufferSize" + keepAlive = "keepAlive" + keepAlivePeriod = "ttl" ) h.md.bufferSize = mdx.GetInt(md, bufferSize) if h.md.bufferSize <= 0 { h.md.bufferSize = 1500 } + + if mdx.GetBool(md, keepAlive) { + h.md.keepAlivePeriod = mdx.GetDuration(md, keepAlivePeriod) + if h.md.keepAlivePeriod <= 0 { + h.md.keepAlivePeriod = defaultKeepAlivePeriod + } + } return } diff --git a/handler/tun/server.go b/handler/tun/server.go index ee21ab4..3cd9907 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -1,6 +1,7 @@ package tun import ( + "bytes" "context" "io" "net" @@ -41,7 +42,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * if waterutil.IsIPv4((*b)[:n]) { header, err := ipv4.ParseHeader((*b)[:n]) if err != nil { - log.Warn(err) + log.Warnf("parse ipv4 packet header: %v", err) return nil } src, dst = header.Src, header.Dst @@ -52,7 +53,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * } else if waterutil.IsIPv6((*b)[:n]) { header, err := ipv6.ParseHeader((*b)[:n]) if err != nil { - log.Warn(err) + log.Warnf("parse ipv6 packet header: %v", err) return nil } src, dst = header.Src, header.Dst @@ -97,12 +98,18 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * if err != nil { return err } + if n == keepAliveDataLength && bytes.Equal((*b)[:4], keepAliveHeader) { + peerIP := net.IP((*b)[4:keepAliveDataLength]) + log.Debugf("keepalive from %v => %v", peerIP, addr) + h.updateRoute(peerIP, addr, log) + return nil + } var src, dst net.IP if waterutil.IsIPv4((*b)[:n]) { header, err := ipv4.ParseHeader((*b)[:n]) if err != nil { - log.Warn(err) + log.Warnf("parse ipv4 packet header: %v", err) return nil } src, dst = header.Src, header.Dst @@ -113,7 +120,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * } else if waterutil.IsIPv6((*b)[:n]) { header, err := ipv6.ParseHeader((*b)[:n]) if err != nil { - log.Warn(err) + log.Warnf("parse ipv6 packet header: %v", err) return nil } src, dst = header.Src, header.Dst @@ -127,16 +134,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * return nil } - rkey := ipToTunRouteKey(src) - if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { - if actual.(net.Addr).String() != addr.String() { - h.routes.Store(rkey, addr) - log.Debugf("update route: %s -> %s (old %s)", - src, addr, actual.(net.Addr)) - } - } else { - log.Debugf("new route: %s -> %s", src, addr) - } + h.updateRoute(src, addr, log) if addr := h.findRouteFor(dst, config.Routes...); addr != nil { log.Debugf("find route: %s -> %s", dst, addr) @@ -164,3 +162,16 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * } return err } + +func (h *tunHandler) updateRoute(ip net.IP, addr net.Addr, log logger.Logger) { + rkey := ipToTunRouteKey(ip) + if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { + if actual.(net.Addr).String() != addr.String() { + h.routes.Store(rkey, addr) + log.Debugf("update route: %s -> %s (old %s)", + ip, addr, actual.(net.Addr)) + } + } else { + log.Debugf("new route: %s -> %s", ip, addr) + } +} diff --git a/listener/tun/listener.go b/listener/tun/listener.go index 3f9cbeb..d47343a 100644 --- a/listener/tun/listener.go +++ b/listener/tun/listener.go @@ -99,7 +99,6 @@ func (l *tunListener) listenLoop() { if err != nil { l.logger.Error(err) cancel() - continue } select {