From 2e7aa9a2b106cce5d21e35f6ea2496876ece6b6a Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Tue, 12 Jul 2016 19:24:33 -0700 Subject: [PATCH] Error out the send call for a client streaming rpc if the server has returned an error. --- test/end2end_test.go | 134 +++++++++++++++++++++++++------------- transport/http2_client.go | 6 ++ 2 files changed, 95 insertions(+), 45 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index cc8bae5f..b4165f9a 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -90,7 +90,8 @@ var ( var raceMode bool // set by race_test.go in race mode type testServer struct { - security string // indicate the authentication protocol used by this server. + security string // indicate the authentication protocol used by this server. + streamingInputCallErr bool // whether to error out the StreamingInputCall handler prematurely. } func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { @@ -219,6 +220,9 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput } p := in.GetPayload().GetBody() sum += len(p) + if s.streamingInputCallErr { + return grpc.Errorf(codes.NotFound, "not found") + } } } @@ -404,10 +408,10 @@ func (te *test) tearDown() { // modify it before calling its startServer and clientConn methods. func newTest(t *testing.T, e env) *test { te := &test{ - t: t, - e: e, - testServer: &testServer{security: e.security}, - maxStream: math.MaxUint32, + t: t, + e: e, + //testServer: &testServer{security: e.security}, + maxStream: math.MaxUint32, } te.ctx, te.cancel = context.WithCancel(context.Background()) return te @@ -415,7 +419,8 @@ func newTest(t *testing.T, e env) *test { // startServer starts a gRPC server listening. Callers should defer a // call to te.tearDown to clean up. -func (te *test) startServer() { +func (te *test) startServer(ts testpb.TestServiceServer) { + te.testServer = ts e := te.e te.t.Logf("Running test in %s environment...", e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} @@ -545,7 +550,7 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { "grpc: Conn.resetTransport failed to create client transport: connection error", "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", ) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -580,7 +585,7 @@ func testFailFast(t *testing.T, e env) { "grpc: Conn.resetTransport failed to create client transport: connection error", "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", ) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -629,7 +634,7 @@ func testHealthCheckOnSuccess(t *testing.T, e env) { hs := health.NewHealthServer() hs.SetServingStatus("grpc.health.v1.Health", 1) te.healthServer = hs - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -655,7 +660,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) { hs := health.NewHealthServer() hs.SetServingStatus("grpc.health.v1.HealthCheck", 1) te.healthServer = hs - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -679,7 +684,7 @@ func TestHealthCheckOff(t *testing.T) { func testHealthCheckOff(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() want := grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1.Health") if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) { @@ -698,7 +703,7 @@ func testHealthCheckServingStatus(t *testing.T, e env) { te := newTest(t, e) hs := health.NewHealthServer() te.healthServer = hs - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -741,7 +746,7 @@ func TestErrorChanNoIO(t *testing.T) { func testErrorChanNoIO(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -760,7 +765,7 @@ func TestEmptyUnaryWithUserAgent(t *testing.T) { func testEmptyUnaryWithUserAgent(t *testing.T, e env) { te := newTest(t, e) te.userAgent = testAppUA - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -786,7 +791,7 @@ func TestFailedEmptyUnary(t *testing.T) { func testFailedEmptyUnary(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -806,7 +811,7 @@ func TestLargeUnary(t *testing.T) { func testLargeUnary(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -843,7 +848,7 @@ func TestMetadataUnaryRPC(t *testing.T) { func testMetadataUnaryRPC(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -889,7 +894,7 @@ func TestMalformedHTTP2Metadata(t *testing.T) { func testMalformedHTTP2Metadata(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -951,7 +956,7 @@ func TestRetry(t *testing.T) { func testRetry(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("transport: http2Client.notifyError got notified that the client transport was broken") - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1000,7 +1005,7 @@ func TestRPCTimeout(t *testing.T) { // TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. func testRPCTimeout(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1037,7 +1042,7 @@ func TestCancel(t *testing.T) { func testCancel(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("grpc: the client connection is closing; please retry") - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1075,7 +1080,7 @@ func testCancelNoIO(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("http2Client.notifyError got notified that the client transport was broken") te.maxStream = 1 // Only allows 1 live stream per server transport. - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1145,8 +1150,7 @@ func TestNoService(t *testing.T) { func testNoService(t *testing.T, e env) { te := newTest(t, e) - te.testServer = nil // register nothing - te.startServer() + te.startServer(nil) defer te.tearDown() cc := te.clientConn() @@ -1170,7 +1174,7 @@ func TestPingPong(t *testing.T) { func testPingPong(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1230,7 +1234,7 @@ func TestMetadataStreamingRPC(t *testing.T) { func testMetadataStreamingRPC(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1301,7 +1305,7 @@ func TestServerStreaming(t *testing.T) { func testServerStreaming(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1356,7 +1360,7 @@ func TestFailedServerStreaming(t *testing.T) { func testFailedServerStreaming(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1412,8 +1416,7 @@ func TestServerStreaming_Concurrent(t *testing.T) { func testServerStreaming_Concurrent(t *testing.T, e env) { te := newTest(t, e) - te.testServer = concurrentSendServer{} - te.startServer() + te.startServer(concurrentSendServer{}) defer te.tearDown() cc := te.clientConn() @@ -1471,7 +1474,7 @@ func TestClientStreaming(t *testing.T) { func testClientStreaming(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1504,6 +1507,47 @@ func testClientStreaming(t *testing.T, e env) { } } +func TestClientStreamingError(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testClientStreamingError(t, e) + } +} + +func testClientStreamingError(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, streamingInputCallErr: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + stream, err := tc.StreamingInputCall(te.ctx) + if err != nil { + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want ", tc, err) + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 1) + if err != nil { + t.Fatal(err) + } + + req := &testpb.StreamingInputCallRequest{ + Payload: payload, + } + // The 1st request should go through. + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + for { + if err := stream.Send(req); err == nil { + continue + } else { + if grpc.Code(err) != codes.NotFound { + t.Fatalf("%v.Send(_) = %v, want error %d", stream, err, codes.NotFound) + } + break + } + } +} + func TestExceedMaxStreamsLimit(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -1519,7 +1563,7 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { "grpc: the client connection is closing", ) te.maxStream = 1 // Only allows 1 live stream per server transport. - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1560,7 +1604,7 @@ func testStreamsQuotaRecovery(t *testing.T, e env) { "grpc: the client connection is closing", ) te.maxStream = 1 // Allows 1 live stream. - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1611,7 +1655,7 @@ func testCompressServerHasNoSupport(t *testing.T, e env) { te := newTest(t, e) te.serverCompression = false te.clientCompression = true - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1667,7 +1711,7 @@ func testCompressOK(t *testing.T, e env) { te := newTest(t, e) te.serverCompression = true te.clientCompression = true - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1730,7 +1774,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf func testUnaryServerInterceptor(t *testing.T, e env) { te := newTest(t, e) te.unaryInt = errInjector - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1761,7 +1805,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ func testStreamServerInterceptor(t *testing.T, e env) { te := newTest(t, e) te.streamInt = fullDuplexOnly - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1825,12 +1869,12 @@ func TestClientRequestBodyError_UnexpectedEOF(t *testing.T) { func testClientRequestBodyError_UnexpectedEOF(t *testing.T, e env) { te := newTest(t, e) - te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { errUnexpectedCall := errors.New("unexpected call func server method") t.Error(errUnexpectedCall) return nil, errUnexpectedCall }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") @@ -1850,12 +1894,12 @@ func TestClientRequestBodyError_CloseAfterLength(t *testing.T) { func testClientRequestBodyError_CloseAfterLength(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("Server.processUnaryRPC failed to write status") - te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { errUnexpectedCall := errors.New("unexpected call func server method") t.Error(errUnexpectedCall) return nil, errUnexpectedCall }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") @@ -1875,11 +1919,11 @@ func TestClientRequestBodyError_Cancel(t *testing.T) { func testClientRequestBodyError_Cancel(t *testing.T, e env) { te := newTest(t, e) gotCall := make(chan bool, 1) - te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { gotCall <- true return new(testpb.SimpleResponse), nil }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") @@ -1912,12 +1956,12 @@ func TestClientRequestBodyError_Cancel_StreamingInput(t *testing.T) { func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { te := newTest(t, e) recvErr := make(chan error, 1) - te.testServer = &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { + ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { _, err := stream.Recv() recvErr <- err return nil }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall") diff --git a/transport/http2_client.go b/transport/http2_client.go index f66435fd..d7f19b2e 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -497,6 +497,12 @@ func (t *http2Client) GracefulClose() error { // TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later // if it improves the performance. func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { + s.mu.Lock() + if s.state == streamDone { + s.mu.Unlock() + return StreamErrorf(s.statusCode, "%s", s.statusDesc) + } + s.mu.Unlock() r := bytes.NewBuffer(data) for { var p []byte