transport: fix race between header and RPC cancellation (#2947)
This commit is contained in:
@ -184,6 +184,19 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
|
|||||||
// r.readAdditional acts on that message and returns the necessary error.
|
// r.readAdditional acts on that message and returns the necessary error.
|
||||||
select {
|
select {
|
||||||
case <-r.ctxDone:
|
case <-r.ctxDone:
|
||||||
|
// Note that this adds the ctx error to the end of recv buffer, and
|
||||||
|
// reads from the head. This will delay the error until recv buffer is
|
||||||
|
// empty, thus will delay ctx cancellation in Recv().
|
||||||
|
//
|
||||||
|
// It's done this way to fix a race between ctx cancel and trailer. The
|
||||||
|
// race was, stream.Recv() may return ctx error if ctxDone wins the
|
||||||
|
// race, but stream.Trailer() may return a non-nil md because the stream
|
||||||
|
// was not marked as done when trailer is received. This closeStream
|
||||||
|
// call will mark stream as done, thus fix the race.
|
||||||
|
//
|
||||||
|
// TODO: delaying ctx error seems like a unnecessary side effect. What
|
||||||
|
// we really want is to mark the stream as done, and return ctx error
|
||||||
|
// faster.
|
||||||
r.closeStream(ContextErr(r.ctx.Err()))
|
r.closeStream(ContextErr(r.ctx.Err()))
|
||||||
m := <-r.recv.get()
|
m := <-r.recv.get()
|
||||||
return r.readAdditional(m, p)
|
return r.readAdditional(m, p)
|
||||||
@ -298,6 +311,14 @@ func (s *Stream) waitOnHeader() error {
|
|||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-s.ctx.Done():
|
case <-s.ctx.Done():
|
||||||
|
// We prefer success over failure when reading messages because we delay
|
||||||
|
// context error in stream.Read(). To keep behavior consistent, we also
|
||||||
|
// prefer success here.
|
||||||
|
select {
|
||||||
|
case <-s.headerChan:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
return ContextErr(s.ctx.Err())
|
return ContextErr(s.ctx.Err())
|
||||||
case <-s.headerChan:
|
case <-s.headerChan:
|
||||||
return nil
|
return nil
|
||||||
|
@ -23,7 +23,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/encoding/gzip"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
testpb "google.golang.org/grpc/test/grpc_testing"
|
testpb "google.golang.org/grpc/test/grpc_testing"
|
||||||
@ -109,3 +111,49 @@ func (s) TestContextCanceled(t *testing.T) {
|
|||||||
t.Fatalf(`couldn't find the delay that causes canceled/perm denied race.`)
|
t.Fatalf(`couldn't find the delay that causes canceled/perm denied race.`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// To make sure that canceling a stream with compression enabled won't result in
|
||||||
|
// internal error, compressed flag set with identity or empty encoding.
|
||||||
|
//
|
||||||
|
// The root cause is a select race on stream headerChan and ctx. Stream gets
|
||||||
|
// whether compression is enabled and the compression type from two separate
|
||||||
|
// functions, both include select with context. If the `case non-ctx:` wins the
|
||||||
|
// first one, but `case ctx.Done()` wins the second one, the compression info
|
||||||
|
// will be inconsistent, and it causes internal error.
|
||||||
|
func (s) TestCancelWhileRecvingWithCompression(t *testing.T) {
|
||||||
|
ss := &stubServer{
|
||||||
|
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||||
|
for {
|
||||||
|
if err := stream.Send(&testpb.StreamingOutputCallResponse{
|
||||||
|
Payload: nil,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := ss.Start(nil); err != nil {
|
||||||
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
s, err := ss.client.FullDuplexCall(ctx, grpc.UseCompressor(gzip.Name))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start bidi streaming RPC: %v", err)
|
||||||
|
}
|
||||||
|
// Cancel the stream while receiving to trigger the internal error.
|
||||||
|
time.AfterFunc(time.Millisecond*10, cancel)
|
||||||
|
for {
|
||||||
|
_, err := s.Recv()
|
||||||
|
if err != nil {
|
||||||
|
if status.Code(err) != codes.Canceled {
|
||||||
|
t.Fatalf("recv failed with %v, want Canceled", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user