From 025674fec58a059fa5bdb94d3cd64c2b1c3242e4 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Tue, 12 Apr 2016 21:45:35 -0700 Subject: [PATCH] transport: do not create a Stream on a canceled context Occasionally Invoke() would let a message slip through when the context is already canceled. --- call_test.go | 23 +++++++++++++++++++++++ transport/http2_client.go | 6 +++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/call_test.go b/call_test.go index feeeb7ef..7d01f457 100644 --- a/call_test.go +++ b/call_test.go @@ -54,6 +54,7 @@ var ( expectedResponse = "pong" weirdError = "format verbs: %v%s" sizeLargeErr = 1024 * 1024 + canceled = 0 ) type testCodec struct { @@ -100,6 +101,11 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { h.t.WriteStatus(s, codes.Internal, weirdError) return } + if v == "canceled" { + canceled++ + h.t.WriteStatus(s, codes.Internal, "") + return + } if v != expectedRequest { h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr)) return @@ -244,3 +250,20 @@ func TestInvokeErrorSpecialChars(t *testing.T) { cc.Close() server.stop() } + +// TestInvokeCancel checks that an Invoke with a canceled context is not sent. +func TestInvokeCancel(t *testing.T) { + server, cc := setUp(t, 0, math.MaxUint32) + var reply string + req := "canceled" + for i := 0; i < 100; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + Invoke(ctx, "/foo/bar", &req, &reply, cc) + } + if canceled != 0 { + t.Fatalf("received %d of 100 canceled requests", canceled) + } + cc.Close() + server.stop() +} diff --git a/transport/http2_client.go b/transport/http2_client.go index 76101d7a..77c05443 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -236,9 +236,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea var timeout time.Duration if dl, ok := ctx.Deadline(); ok { timeout = dl.Sub(time.Now()) - if timeout <= 0 { - return nil, ContextErr(context.DeadlineExceeded) - } + } + if err := ctx.Err(); err != nil { + return nil, ContextErr(err) } pr := &peer.Peer{ Addr: t.conn.RemoteAddr(),