diff --git a/call.go b/call.go index 93262360..84ac178c 100644 --- a/call.go +++ b/call.go @@ -155,7 +155,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if _, ok := err.(rpcError); ok { + if _, ok := err.(*rpcError); ok { return err } if err == errConnClosing { diff --git a/call_test.go b/call_test.go index 380bf872..49349858 100644 --- a/call_test.go +++ b/call_test.go @@ -234,7 +234,7 @@ func TestInvokeLargeErr(t *testing.T) { var reply string req := "hello" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(rpcError); !ok { + if _, ok := err.(*rpcError); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { @@ -250,7 +250,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) { var reply string req := "weird error" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(rpcError); !ok { + if _, ok := err.(*rpcError); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if got, want := ErrorDesc(err), weirdError; got != want { diff --git a/rpc_util.go b/rpc_util.go index 91342bd8..d6287175 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -334,7 +334,7 @@ type rpcError struct { desc string } -func (e rpcError) Error() string { +func (e *rpcError) Error() string { return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc) } @@ -344,7 +344,7 @@ func Code(err error) codes.Code { if err == nil { return codes.OK } - if e, ok := err.(rpcError); ok { + if e, ok := err.(*rpcError); ok { return e.code } return codes.Unknown @@ -356,7 +356,7 @@ func ErrorDesc(err error) string { if err == nil { return "" } - if e, ok := err.(rpcError); ok { + if e, ok := err.(*rpcError); ok { return e.desc } return err.Error() @@ -368,7 +368,7 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { if c == codes.OK { return nil } - return rpcError{ + return &rpcError{ code: c, desc: fmt.Sprintf(format, a...), } @@ -377,32 +377,32 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { // toRPCErr converts an error into a rpcError. func toRPCErr(err error) error { switch e := err.(type) { - case rpcError: + case *rpcError: return err case transport.StreamError: - return rpcError{ + return &rpcError{ code: e.Code, desc: e.Desc, } case transport.ConnectionError: - return rpcError{ + return &rpcError{ code: codes.Internal, desc: e.Desc, } default: switch err { case context.DeadlineExceeded: - return rpcError{ + return &rpcError{ code: codes.DeadlineExceeded, desc: err.Error(), } case context.Canceled: - return rpcError{ + return &rpcError{ code: codes.Canceled, desc: err.Error(), } case ErrClientConnClosing: - return rpcError{ + return &rpcError{ code: codes.FailedPrecondition, desc: err.Error(), } diff --git a/rpc_util_test.go b/rpc_util_test.go index f6327f13..5a802d65 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -149,13 +149,17 @@ func TestToRPCErr(t *testing.T) { // input errIn error // outputs - errOut error + errOut *rpcError }{ - {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "")}, - {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc)}, + {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "").(*rpcError)}, + {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)}, } { err := toRPCErr(test.errIn) - if err != test.errOut { + rpcErr, ok := err.(*rpcError) + if !ok { + t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, rpcError{}) + } + if *rpcErr != *test.errOut { t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) } } @@ -178,6 +182,18 @@ func TestContextErr(t *testing.T) { } } +func TestErrorsWithSameParameters(t *testing.T) { + const description = "some description" + e1 := Errorf(codes.AlreadyExists, description) + e2 := Errorf(codes.AlreadyExists, description) + if e1 == e2 { + t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2) + } + if Code(e1) != Code(e2) || ErrorDesc(e1) != ErrorDesc(e2) { + t.Fatalf("Expected errors to have same code and description - e1: %p - %v e2: %p - %v", e1, e1, e2, e2) + } +} + // bmEncode benchmarks encoding a Protocol Buffer message containing mSize // bytes. func bmEncode(b *testing.B, mSize int) { diff --git a/server.go b/server.go index 6782009e..7dabfc64 100644 --- a/server.go +++ b/server.go @@ -560,7 +560,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { - if err, ok := appErr.(rpcError); ok { + if err, ok := appErr.(*rpcError); ok { statusCode = err.code statusDesc = err.desc } else { @@ -645,7 +645,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) } if appErr != nil { - if err, ok := appErr.(rpcError); ok { + if err, ok := appErr.(*rpcError); ok { ss.statusCode = err.code ss.statusDesc = err.desc } else if err, ok := appErr.(transport.StreamError); ok { diff --git a/stream.go b/stream.go index 73d1da23..7a3bef51 100644 --- a/stream.go +++ b/stream.go @@ -149,7 +149,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if _, ok := err.(rpcError); ok { + if _, ok := err.(*rpcError); ok { return nil, err } if err == errConnClosing { diff --git a/test/end2end_test.go b/test/end2end_test.go index 245b8c0f..cc8bae5f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -659,7 +659,8 @@ func testHealthCheckOnFailure(t *testing.T, e env) { defer te.tearDown() cc := te.clientConn() - if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") { + wantErr := grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") + if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded) } awaitNewConnLogOutput() @@ -681,7 +682,7 @@ func testHealthCheckOff(t *testing.T, e env) { te.startServer() defer te.tearDown() want := grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1.Health") - if _, err := healthCheck(1*time.Second, te.clientConn(), ""); err != want { + if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) { t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) } } @@ -708,7 +709,8 @@ func testHealthCheckServingStatus(t *testing.T, e env) { if out.Status != healthpb.HealthCheckResponse_SERVING { t.Fatalf("Got the serving status %v, want SERVING", out.Status) } - if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); err != grpc.Errorf(codes.NotFound, "unknown service") { + wantErr := grpc.Errorf(codes.NotFound, "unknown service") + if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.NotFound) } hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING) @@ -790,7 +792,7 @@ func testFailedEmptyUnary(t *testing.T, e env) { ctx := metadata.NewContext(context.Background(), testMetadata) wantErr := grpc.Errorf(codes.DataLoss, "missing expected user-agent") - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != wantErr { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !equalErrors(err, wantErr) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } } @@ -1373,8 +1375,9 @@ func testFailedServerStreaming(t *testing.T, e env) { if err != nil { t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want ", tc, err) } - if _, err := stream.Recv(); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { - t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, grpc.Errorf(codes.DataLoss, "got extra metadata")) + wantErr := grpc.Errorf(codes.DataLoss, "got extra metadata") + if _, err := stream.Recv(); !equalErrors(err, wantErr) { + t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, wantErr) } } @@ -2124,3 +2127,7 @@ func (fw *filterWriter) Write(p []byte) (n int, err error) { } return fw.dst.Write(p) } + +func equalErrors(l, r error) bool { + return grpc.Code(l) == grpc.Code(r) && grpc.ErrorDesc(l) == grpc.ErrorDesc(r) +}