Merge pull request #634 from heyitsanthony/cancel-nosend

transport: do not create a Stream on a canceled context
This commit is contained in:
Qi Zhao
2016-04-13 10:35:15 -07:00
2 changed files with 26 additions and 3 deletions

View File

@ -54,6 +54,7 @@ var (
expectedResponse = "pong" expectedResponse = "pong"
weirdError = "format verbs: %v%s" weirdError = "format verbs: %v%s"
sizeLargeErr = 1024 * 1024 sizeLargeErr = 1024 * 1024
canceled = 0
) )
type testCodec struct { type testCodec struct {
@ -100,6 +101,11 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
h.t.WriteStatus(s, codes.Internal, weirdError) h.t.WriteStatus(s, codes.Internal, weirdError)
return return
} }
if v == "canceled" {
canceled++
h.t.WriteStatus(s, codes.Internal, "")
return
}
if v != expectedRequest { if v != expectedRequest {
h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr)) h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr))
return return
@ -244,3 +250,20 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
cc.Close() cc.Close()
server.stop() server.stop()
} }
// TestInvokeCancel checks that an Invoke with a canceled context is not sent.
func TestInvokeCancel(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "canceled"
for i := 0; i < 100; i++ {
ctx, cancel := context.WithCancel(context.Background())
cancel()
Invoke(ctx, "/foo/bar", &req, &reply, cc)
}
if canceled != 0 {
t.Fatalf("received %d of 100 canceled requests", canceled)
}
cc.Close()
server.stop()
}

View File

@ -236,9 +236,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
var timeout time.Duration var timeout time.Duration
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
timeout = dl.Sub(time.Now()) timeout = dl.Sub(time.Now())
if timeout <= 0 { }
return nil, ContextErr(context.DeadlineExceeded) if err := ctx.Err(); err != nil {
} return nil, ContextErr(err)
} }
pr := &peer.Peer{ pr := &peer.Peer{
Addr: t.conn.RemoteAddr(), Addr: t.conn.RemoteAddr(),