diff --git a/test/end2end_test.go b/test/end2end_test.go index 7feea666..140673ef 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3807,23 +3807,6 @@ func testCompressServerHasNoSupport(t *testing.T, e env) { if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } - respParam := []*testpb.ResponseParameters{ - { - Size: 31415, - }, - } - payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415)) - if err != nil { - t.Fatal(err) - } - sreq := &testpb.StreamingOutputCallRequest{ - ResponseType: testpb.PayloadType_COMPRESSABLE, - ResponseParameters: respParam, - Payload: payload, - } - if err := stream.Send(sreq); err != nil { - t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) - } if _, err := stream.Recv(); err == nil || status.Code(err) != codes.Unimplemented { t.Fatalf("%v.Recv() = %v, want error code %s", stream, err, codes.Unimplemented) } @@ -4924,6 +4907,36 @@ func TestTapTimeout(t *testing.T) { t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, ", res, err) } } + +} + +func TestClientWriteFailsAfterServerClosesStream(t *testing.T) { + ss := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + return status.Errorf(codes.Internal, "") + }, + } + sopts := []grpc.ServerOption{} + if err := ss.Start(sopts); err != nil { + t.Fatalf("Error starting endpoing server: %v", err) + } + defer ss.Stop() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + stream, err := ss.client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Error while creating stream: %v", err) + } + for { + if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err == nil { + time.Sleep(5 * time.Millisecond) + } else if err == io.EOF { + break // Success. + } else { + t.Fatalf("stream.Send(_) = %v, want io.EOF", err) + } + } + } type windowSizeConfig struct { diff --git a/transport/http2_client.go b/transport/http2_client.go index 18fc41f0..b4495aef 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -649,6 +649,8 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e select { case <-s.ctx.Done(): return ContextErr(s.ctx.Err()) + case <-s.done: + return io.EOF case <-t.ctx.Done(): return ErrConnClosing default: