From 3b245ec38190cdaef0b0bed5d248fd4523154b90 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 30 Dec 2022 19:34:48 +0800 Subject: [PATCH] grpc: cancel stream --- dialer/grpc/conn.go | 20 ++++---------------- dialer/grpc/dialer.go | 6 ++++-- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/dialer/grpc/conn.go b/dialer/grpc/conn.go index 5bdc7a3..d47b54c 100644 --- a/dialer/grpc/conn.go +++ b/dialer/grpc/conn.go @@ -1,8 +1,8 @@ package grpc import ( + "context" "errors" - "io" "net" "time" @@ -14,7 +14,7 @@ type conn struct { rb []byte localAddr net.Addr remoteAddr net.Addr - closed chan struct{} + cancelFunc context.CancelFunc } func (c *conn) Read(b []byte) (n int, err error) { @@ -22,9 +22,6 @@ func (c *conn) Read(b []byte) (n int, err error) { case <-c.c.Context().Done(): err = c.c.Context().Err() return - case <-c.closed: - err = io.ErrClosedPipe - return default: } @@ -46,9 +43,6 @@ func (c *conn) Write(b []byte) (n int, err error) { case <-c.c.Context().Done(): err = c.c.Context().Err() return - case <-c.closed: - err = io.ErrClosedPipe - return default: } @@ -62,14 +56,8 @@ func (c *conn) Write(b []byte) (n int, err error) { } func (c *conn) Close() error { - select { - case <-c.closed: - default: - close(c.closed) - return c.c.CloseSend() - } - - return nil + defer c.cancelFunc() + return c.c.CloseSend() } func (c *conn) LocalAddr() net.Addr { diff --git a/dialer/grpc/dialer.go b/dialer/grpc/dialer.go index 1714326..007fe3c 100644 --- a/dialer/grpc/dialer.go +++ b/dialer/grpc/dialer.go @@ -109,8 +109,10 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO d.clients[addr] = client } - cli, err := client.TunnelX(context.Background(), d.md.path) + ctx2, cancel := context.WithCancel(context.Background()) + cli, err := client.TunnelX(ctx2, d.md.path) if err != nil { + cancel() return nil, err } @@ -118,6 +120,6 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO c: cli, localAddr: &net.TCPAddr{}, remoteAddr: &net.TCPAddr{}, - closed: make(chan struct{}), + cancelFunc: cancel, }, nil }