diff --git a/call.go b/call.go index 27cf6411..e4e7771f 100644 --- a/call.go +++ b/call.go @@ -51,6 +51,13 @@ import ( func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { // Try to acquire header metadata from the server if there is any. var err error + defer func() { + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + t.CloseStream(stream, err) + } + } + }() c.headerMD, err = stream.Header() if err != nil { return err @@ -190,20 +197,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - // Receive the response err = recvResponse(cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } - t.CloseStream(stream, err) return toRPCErr(err) } if c.traceInfo.tr != nil { diff --git a/stream.go b/stream.go index fb7e50f9..008ad1e2 100644 --- a/stream.go +++ b/stream.go @@ -165,7 +165,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { cs.finish(err) return nil, toRPCErr(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index 11e215d7..0bbc3d90 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -566,12 +566,6 @@ func TestServerGoAway(t *testing.T) { func testServerGoAway(t *testing.T, e env) { te := newTest(t, e) te.userAgent = testAppUA - te.declareLogNoise( - "transport: http2Client.notifyError got notified that the client transport was broken EOF", - "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", - "grpc: Conn.resetTransport failed to create client transport: connection error", - "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", - ) te.startServer(&testServer{security: e.security}) defer te.tearDown() diff --git a/transport/http2_client.go b/transport/http2_client.go index 62a6fa28..6a34c7e5 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -284,6 +284,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.mu.Unlock() return nil, ErrConnClosing } + if t.state == draining { + t.mu.Unlock() + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -308,6 +312,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, err } t.mu.Lock() + if t.state == draining { + t.mu.Unlock() + if checkStreamsQuota { + t.streamsQuota.add(1) + } + // Need to make t writable again so that the rpc in flight can still proceed. + t.writableChan <- 0 + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -419,7 +432,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) { delete(t.activeStreams, s.id) if t.state == draining && len(t.activeStreams) == 0 { // The transport is draining and s is the last live stream on t. - defer t.Close() + t.mu.Unlock() + t.Close() + return } t.mu.Unlock() if updateStreams { @@ -481,7 +496,7 @@ func (t *http2Client) Close() (err error) { func (t *http2Client) GracefulClose() error { t.mu.Lock() - if t.state == closing { + if t.state == closing || t.state == unreachable { t.mu.Unlock() return nil } diff --git a/transport/transport.go b/transport/transport.go index 86c8fcd6..10b4a2e2 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -240,6 +240,8 @@ func (s *Stream) Header() (metadata.MD, error) { select { case <-s.ctx.Done(): return nil, ContextErr(s.ctx.Err()) + case <-s.goAway: + return nil, ErrStreamDrain case <-s.headerChan: return s.header.Copy(), nil } diff --git a/transport/transport_test.go b/transport/transport_test.go index 6f9cc50c..f4af68b7 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -368,8 +368,8 @@ func TestGracefulClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrConnClosing) + if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain { + t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain) } }() }