diff --git a/stream.go b/stream.go index 7a3bef51..a0373600 100644 --- a/stream.go +++ b/stream.go @@ -251,7 +251,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { cs.finish(err) } - if err == nil || err == io.EOF { + if err == nil || err == io.EOF || err == transport.ErrEarlyDone { return } if _, ok := err.(transport.ConnectionError); !ok { @@ -328,6 +328,11 @@ func (cs *clientStream) CloseSend() (err error) { if err == nil || err == io.EOF { return } + if err == transport.ErrEarlyDone { + // If the RPC is done prematurely, Stream.RecvMsg(...) needs to be + // called to get the final status and clear the footprint. + return nil + } if _, ok := err.(transport.ConnectionError); !ok { cs.closeTransportStream(err) } diff --git a/test/end2end_test.go b/test/end2end_test.go index c2404f24..c109c885 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -1540,7 +1540,7 @@ func testClientStreamingError(t *testing.T, e env) { continue } if _, err := stream.CloseAndRecv(); grpc.Code(err) != codes.NotFound { - t.Fatalf("%v.Send(_) = %v, want error %d", stream, err, codes.NotFound) + t.Fatalf("%v.CloseAndRecv() = %v, want error %d", stream, err, codes.NotFound) } break } diff --git a/transport/http2_client.go b/transport/http2_client.go index 3f1ef2a5..2715e2d0 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -202,6 +202,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ id: t.nextID, + earlyDone: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), @@ -278,7 +279,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea checkStreamsQuota := t.streamsQuota != nil t.mu.Unlock() if checkStreamsQuota { - sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, nil, t.shutdownChan, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -287,7 +288,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.streamsQuota.add(sq - 1) } } - if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(ctx, nil, t.shutdownChan, t.writableChan); err != nil { // Return the quota back now because there is no stream returned to the caller. if _, ok := err.(StreamError); ok && checkStreamsQuota { t.streamsQuota.add(1) @@ -497,13 +498,6 @@ func (t *http2Client) GracefulClose() error { // TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later // if it improves the performance. func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { - s.mu.Lock() - // The stream has been done. Return the status directly. - if s.state == streamDone { - s.mu.Unlock() - return StreamErrorf(s.statusCode, "%v", s.statusDesc) - } - s.mu.Unlock() r := bytes.NewBuffer(data) for { var p []byte @@ -511,13 +505,13 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, s.earlyDone, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, s.earlyDone, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { if _, ok := err.(StreamError); ok { t.sendQuotaPool.cancel() @@ -551,7 +545,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // Indicate there is a writer who is about to write a data frame. t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, s.earlyDone, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(StreamError); ok { // Return the connection quota back. t.sendQuotaPool.add(len(p)) @@ -781,18 +775,14 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } s.statusCode = state.statusCode s.statusDesc = state.statusDesc - var cancel bool if s.state != streamWriteDone { - // s will be canceled. This is required to interrupt any pending - // blocking Write calls when the final RPC status has been arrived. - cancel = true + // This is required to interrupt any pending blocking Write calls + // when the final RPC status has been arrived. + close(s.earlyDone) } s.state = streamDone s.mu.Unlock() s.write(recvMsg{err: io.EOF}) - if cancel { - s.cancel() - } } func handleMalformedHTTP2(s *Stream, err error) { diff --git a/transport/http2_server.go b/transport/http2_server.go index cee15429..9e35fdd8 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -364,11 +364,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { // Received the end of stream from the client. s.mu.Lock() if s.state != streamDone { - if s.state == streamWriteDone { - s.state = streamDone - } else { - s.state = streamReadDone - } + s.state = streamReadDone } s.mu.Unlock() s.write(recvMsg{err: io.EOF}) @@ -455,7 +451,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } s.headerOk = true s.mu.Unlock() - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -495,7 +491,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s headersSent = true } s.mu.Unlock() - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -544,7 +540,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } s.mu.Unlock() if writeHeaderFrame { - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -572,13 +568,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, nil, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, nil, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { if _, ok := err.(StreamError); ok { t.sendQuotaPool.cancel() @@ -604,7 +600,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the // transport. - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(StreamError); ok { // Return the connection quota back. t.sendQuotaPool.add(ps) diff --git a/transport/transport.go b/transport/transport.go index 1c888083..3340e9ad 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -140,18 +140,6 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { } select { case <-r.ctx.Done(): - // ctx might be canceled by gRPC internals to unblocking pending writing operations - // when the client receives the final status prematurely (for client and bi-directional - // streaming RPCs). Used to return the real status to the users instead of the - // cancellation. - select { - case i := <-r.recv.get(): - m := i.(*recvMsg) - if m.err != nil { - return 0, m.err - } - default: - } return 0, ContextErr(r.ctx.Err()) case i := <-r.recv.get(): r.recv.load() @@ -169,6 +157,7 @@ type streamState uint8 const ( streamActive streamState = iota streamWriteDone // EndStream sent + streamReadDone // EndStream received streamDone // the entire stream is finished. ) @@ -178,8 +167,9 @@ type Stream struct { // nil for client side Stream. st ServerTransport // ctx is the associated context of the stream. - ctx context.Context - cancel context.CancelFunc + ctx context.Context + cancel context.CancelFunc + earlyDone chan struct{} // method records the associated RPC method of the stream. method string recvCompress string @@ -469,6 +459,8 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { } } +var ErrEarlyDone = StreamErrorf(codes.Internal, "rpc is done prematurely") + // ConnectionErrorf creates an ConnectionError with the specified error description. func ConnectionErrorf(format string, a ...interface{}) ConnectionError { return ConnectionError{ @@ -512,12 +504,15 @@ func ContextErr(err error) StreamError { // wait blocks until it can receive from ctx.Done, closing, or proceed. // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. +// If it receives from earlyDone, it returns 0, errEarlyDone. // If it receives from closing, it returns 0, ErrConnClosing. // If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) { +func wait(ctx context.Context, earlyDone, closing <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) + case <-earlyDone: + return 0, ErrEarlyDone case <-closing: return 0, ErrConnClosing case i := <-proceed: