
`transport/Stream.RecvCompress` returns what the header contains, if present, or empty string if a context error occurs. However, it "prefers" the header data even if there is a context error, to prevent a related race. What happens here is: 1. RPC starts. 2. Client cancels RPC. 3. `RecvCompress` tells `ClientStream.Recv` that compression used is "" because of the context error. `as.decomp` is left nil, because there is no compressor to look up in the registry. 4. Server's header and first message hit client. 5. Client sees the header and message and allows grpc's stream to see them. (We only provide context errors if we need to block.) 6. Client performs a successful `Read` on the stream, receiving the gzipped payload, then checks `as.decomp`. 7. We have no decompressor but the payload has a bit set indicating the message is compressed, so this is an error. However, when forming the error string, `RecvCompress` now returns "gzip" because it doesn't need to block to get this from the now-received header. This leads to the confusing message about how "gzip" is not installed even though it is. This change makes `waitOnHeader` close the stream when context cancellation happens. Then `RecvCompress` uses whatever value is present in the stream at that time, which can no longer change because the stream is closed. Also, this will be in sync with the messages on the stream - if there are any messages present, the headers must have been processed first, and `RecvCompress` will contain the proper value.
159 lines
4.5 KiB
Go
159 lines
4.5 KiB
Go
/*
|
|
*
|
|
* Copyright 2019 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package test
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/encoding/gzip"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
testpb "google.golang.org/grpc/test/grpc_testing"
|
|
)
|
|
|
|
func (s) TestContextCanceled(t *testing.T) {
|
|
ss := &stubServer{
|
|
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
|
stream.SetTrailer(metadata.New(map[string]string{"a": "b"}))
|
|
return status.Error(codes.PermissionDenied, "perm denied")
|
|
},
|
|
}
|
|
if err := ss.Start(nil); err != nil {
|
|
t.Fatalf("Error starting endpoint server: %v", err)
|
|
}
|
|
defer ss.Stop()
|
|
|
|
// Runs 10 rounds of tests with the given delay and returns counts of status codes.
|
|
// Fails in case of trailer/status code inconsistency.
|
|
const cntRetry uint = 10
|
|
runTest := func(delay time.Duration) (cntCanceled, cntPermDenied uint) {
|
|
for i := uint(0); i < cntRetry; i++ {
|
|
ctx, cancel := context.WithTimeout(context.Background(), delay)
|
|
defer cancel()
|
|
|
|
str, err := ss.client.FullDuplexCall(ctx)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
_, err = str.Recv()
|
|
if err == nil {
|
|
t.Fatalf("non-nil error expected from Recv()")
|
|
}
|
|
|
|
_, trlOk := str.Trailer()["a"]
|
|
switch status.Code(err) {
|
|
case codes.PermissionDenied:
|
|
if !trlOk {
|
|
t.Fatalf(`status err: %v; wanted key "a" in trailer but didn't get it`, err)
|
|
}
|
|
cntPermDenied++
|
|
case codes.DeadlineExceeded:
|
|
if trlOk {
|
|
t.Fatalf(`status err: %v; didn't want key "a" in trailer but got it`, err)
|
|
}
|
|
cntCanceled++
|
|
default:
|
|
t.Fatalf(`unexpected status err: %v`, err)
|
|
}
|
|
}
|
|
return cntCanceled, cntPermDenied
|
|
}
|
|
|
|
// Tries to find the delay that causes canceled/perm denied race.
|
|
canceledOk, permDeniedOk := false, false
|
|
for lower, upper := time.Duration(0), 2*time.Millisecond; lower <= upper; {
|
|
delay := lower + (upper-lower)/2
|
|
cntCanceled, cntPermDenied := runTest(delay)
|
|
if cntPermDenied > 0 && cntCanceled > 0 {
|
|
// Delay that causes the race is found.
|
|
return
|
|
}
|
|
|
|
// Set OK flags.
|
|
if cntCanceled > 0 {
|
|
canceledOk = true
|
|
}
|
|
if cntPermDenied > 0 {
|
|
permDeniedOk = true
|
|
}
|
|
|
|
if cntPermDenied == 0 {
|
|
// No perm denied, increase the delay.
|
|
lower += (upper-lower)/10 + 1
|
|
} else {
|
|
// All perm denied, decrease the delay.
|
|
upper -= (upper-lower)/10 + 1
|
|
}
|
|
}
|
|
|
|
if !canceledOk || !permDeniedOk {
|
|
t.Fatalf(`couldn't find the delay that causes canceled/perm denied race.`)
|
|
}
|
|
}
|
|
|
|
// To make sure that canceling a stream with compression enabled won't result in
|
|
// internal error, compressed flag set with identity or empty encoding.
|
|
//
|
|
// The root cause is a select race on stream headerChan and ctx. Stream gets
|
|
// whether compression is enabled and the compression type from two separate
|
|
// functions, both include select with context. If the `case non-ctx:` wins the
|
|
// first one, but `case ctx.Done()` wins the second one, the compression info
|
|
// will be inconsistent, and it causes internal error.
|
|
func (s) TestCancelWhileRecvingWithCompression(t *testing.T) {
|
|
ss := &stubServer{
|
|
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
|
for {
|
|
if err := stream.Send(&testpb.StreamingOutputCallResponse{
|
|
Payload: nil,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
},
|
|
}
|
|
if err := ss.Start(nil); err != nil {
|
|
t.Fatalf("Error starting endpoint server: %v", err)
|
|
}
|
|
defer ss.Stop()
|
|
|
|
for i := 0; i < 10; i++ {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
s, err := ss.client.FullDuplexCall(ctx, grpc.UseCompressor(gzip.Name))
|
|
if err != nil {
|
|
t.Fatalf("failed to start bidi streaming RPC: %v", err)
|
|
}
|
|
// Cancel the stream while receiving to trigger the internal error.
|
|
time.AfterFunc(time.Millisecond, cancel)
|
|
for {
|
|
_, err := s.Recv()
|
|
if err != nil {
|
|
if status.Code(err) != codes.Canceled {
|
|
t.Fatalf("recv failed with %v, want Canceled", err)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|