diff --git a/call.go b/call.go index 27cf6411..e4e7771f 100644 --- a/call.go +++ b/call.go @@ -51,6 +51,13 @@ import ( func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { // Try to acquire header metadata from the server if there is any. var err error + defer func() { + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + t.CloseStream(stream, err) + } + } + }() c.headerMD, err = stream.Header() if err != nil { return err @@ -190,20 +197,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - // Receive the response err = recvResponse(cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } - t.CloseStream(stream, err) return toRPCErr(err) } if c.traceInfo.tr != nil { diff --git a/transport/http2_client.go b/transport/http2_client.go index 6c72431a..d561ab66 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -430,7 +430,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) { delete(t.activeStreams, s.id) if t.state == draining && len(t.activeStreams) == 0 { // The transport is draining and s is the last live stream on t. - defer t.Close() + t.mu.Unlock() + t.Close() + return } t.mu.Unlock() if updateStreams { diff --git a/transport/transport.go b/transport/transport.go index 86c8fcd6..10b4a2e2 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -240,6 +240,8 @@ func (s *Stream) Header() (metadata.MD, error) { select { case <-s.ctx.Done(): return nil, ContextErr(s.ctx.Err()) + case <-s.goAway: + return nil, ErrStreamDrain case <-s.headerChan: return s.header.Copy(), nil } diff --git a/transport/transport_test.go b/transport/transport_test.go index 6f9cc50c..f4af68b7 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -368,8 +368,8 @@ func TestGracefulClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrConnClosing) + if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain { + t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain) } }() }