diff --git a/handler/tun/client.go b/handler/tun/client.go index 1d05715..c5c21c6 100644 --- a/handler/tun/client.go +++ b/handler/tun/client.go @@ -30,18 +30,28 @@ func (h *tunHandler) handleClient(ctx context.Context, conn net.Conn, addr net.A return err } - cc, err := h.router.Dial(ctx, addr.Network(), addr.String()) - if err != nil { - return err + for { + err := func() error { + cc, err := h.router.Dial(ctx, addr.Network(), addr.String()) + if err != nil { + return err + } + defer cc.Close() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go h.keepAlive(ctx, cc, ip) + + return h.transportClient(conn, cc, config, log) + }() + if err == ErrTun { + return err + } + + log.Error(err) + time.Sleep(time.Second) } - defer cc.Close() - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go h.keepAlive(ctx, cc, ip) - - return h.transportClient(conn, cc, config, log) } func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { @@ -75,7 +85,7 @@ func (h *tunHandler) keepAlive(ctx context.Context, conn net.Conn, ip net.IP) { } } -func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_util.Config, log logger.Logger) error { +func (h *tunHandler) transportClient(tun io.ReadWriter, conn net.Conn, config *tun_util.Config, log logger.Logger) error { errc := make(chan error, 1) go func() { @@ -86,7 +96,7 @@ func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_ut n, err := tun.Read(*b) if err != nil { - return err + return ErrTun } if waterutil.IsIPv4((*b)[:n]) { @@ -173,8 +183,10 @@ func (h *tunHandler) transportClient(tun net.Conn, conn net.Conn, config *tun_ut return nil } - _, err = tun.Write((*b)[:n]) - return err + if _, err = tun.Write((*b)[:n]); err != nil { + return ErrTun + } + return nil }() if err != nil { diff --git a/handler/tun/handler.go b/handler/tun/handler.go index 9ff2110..d4a42d3 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -16,6 +16,10 @@ import ( "github.com/songgao/water/waterutil" ) +var ( + ErrTun = errors.New("tun device error") +) + func init() { registry.HandlerRegistry().Register("tun", NewHandler) } diff --git a/handler/tun/server.go b/handler/tun/server.go index 5f5dd07..a7bca79 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/netip" + "time" "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/logger" @@ -16,16 +17,26 @@ import ( ) func (h *tunHandler) handleServer(ctx context.Context, conn net.Conn, config *tun_util.Config, log logger.Logger) error { - pc, err := net.ListenPacket(conn.LocalAddr().Network(), conn.LocalAddr().String()) - if err != nil { - return err - } - defer pc.Close() + for { + err := func() error { + pc, err := net.ListenPacket(conn.LocalAddr().Network(), conn.LocalAddr().String()) + if err != nil { + return err + } + defer pc.Close() - return h.transportServer(conn, pc, config, log) + return h.transportServer(conn, pc, config, log) + }() + if err == ErrTun { + return err + } + + log.Error(err) + time.Sleep(time.Second) + } } -func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error { +func (h *tunHandler) transportServer(tun io.ReadWriter, conn net.PacketConn, config *tun_util.Config, log logger.Logger) error { tunIP, _, err := net.ParseCIDR(config.Net) if err != nil { return err @@ -41,7 +52,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * n, err := tun.Read(*b) if err != nil { - return err + return ErrTun } var src, dst net.IP @@ -181,7 +192,7 @@ func (h *tunHandler) transportServer(tun net.Conn, conn net.PacketConn, config * } if _, err := tun.Write((*b)[:n]); err != nil { - return err + return ErrTun } return nil }()