diff --git a/interop/test_utils.go b/interop/test_utils.go index c9813ba5..075e4813 100644 --- a/interop/test_utils.go +++ b/interop/test_utils.go @@ -456,9 +456,9 @@ func DoCancelAfterFirstResponse(tc testpb.TestServiceClient) { // DoStatusCodeAndMessage checks that the status code is propagated back to the client. func DoStatusCodeAndMessage(tc testpb.TestServiceClient) { + var code int32 = 2 + msg := "test status message" // Test UnaryCall. - var code int32 = 2 - msg := "test status message" respStatus := &testpb.EchoStatus{ Code: &code, Message: &msg, @@ -467,7 +467,7 @@ func DoStatusCodeAndMessage(tc testpb.TestServiceClient) { ResponseStatus: respStatus, } _, err := tc.UnaryCall(context.Background(), req) - if grpc.Code(err) != 2 { + if grpc.Code(err) != codes.Code(code) { grpclog.Fatalf("/TestService/UnaryCall RPC compleled with error code %d, want %d", grpc.Code(err), code) } if err.Error() != msg { @@ -485,7 +485,7 @@ func DoStatusCodeAndMessage(tc testpb.TestServiceClient) { grpclog.Fatalf("%v.Send(%v) = %v", stream, stream_req, err) } err = stream.CloseSend() - if grpc.Code(err) != 2 { + if grpc.Code(err) != codes.Code(code) { grpclog.Fatalf("%v compleled with error code %d, want %d", stream, grpc.Code(err), code) } if err.Error() != msg { @@ -524,6 +524,10 @@ func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) } func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + resp := in.GetResponseStatus() + if *resp.Code != 0 { + return nil, grpc.Errorf(codes.Code(*resp.Code), *resp.Message) + } pl, err := serverNewPayload(in.GetResponseType(), in.GetResponseSize()) if err != nil { return nil, err @@ -534,6 +538,10 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* } func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { + resp := args.GetResponseStatus() + if *resp.Code != 0 { + return grpc.Errorf(codes.Code(*resp.Code), *resp.Message) + } cs := args.GetResponseParameters() for _, c := range cs { if us := c.GetIntervalUs(); us > 0 {