Merge pull request #755 from menghanl/errorf_pointer

Make grpc.Errorf return struct pointer
This commit is contained in:
Menghan Li
2016-07-11 10:11:51 -07:00
committed by GitHub
7 changed files with 49 additions and 26 deletions

View File

@ -155,7 +155,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(rpcError); ok { if _, ok := err.(*rpcError); ok {
return err return err
} }
if err == errConnClosing { if err == errConnClosing {

View File

@ -234,7 +234,7 @@ func TestInvokeLargeErr(t *testing.T) {
var reply string var reply string
req := "hello" req := "hello"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(rpcError); !ok { if _, ok := err.(*rpcError); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
} }
if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr {
@ -250,7 +250,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
var reply string var reply string
req := "weird error" req := "weird error"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(rpcError); !ok { if _, ok := err.(*rpcError); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
} }
if got, want := ErrorDesc(err), weirdError; got != want { if got, want := ErrorDesc(err), weirdError; got != want {

View File

@ -334,7 +334,7 @@ type rpcError struct {
desc string desc string
} }
func (e rpcError) Error() string { func (e *rpcError) Error() string {
return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc) return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc)
} }
@ -344,7 +344,7 @@ func Code(err error) codes.Code {
if err == nil { if err == nil {
return codes.OK return codes.OK
} }
if e, ok := err.(rpcError); ok { if e, ok := err.(*rpcError); ok {
return e.code return e.code
} }
return codes.Unknown return codes.Unknown
@ -356,7 +356,7 @@ func ErrorDesc(err error) string {
if err == nil { if err == nil {
return "" return ""
} }
if e, ok := err.(rpcError); ok { if e, ok := err.(*rpcError); ok {
return e.desc return e.desc
} }
return err.Error() return err.Error()
@ -368,7 +368,7 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
if c == codes.OK { if c == codes.OK {
return nil return nil
} }
return rpcError{ return &rpcError{
code: c, code: c,
desc: fmt.Sprintf(format, a...), desc: fmt.Sprintf(format, a...),
} }
@ -377,32 +377,32 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
// toRPCErr converts an error into a rpcError. // toRPCErr converts an error into a rpcError.
func toRPCErr(err error) error { func toRPCErr(err error) error {
switch e := err.(type) { switch e := err.(type) {
case rpcError: case *rpcError:
return err return err
case transport.StreamError: case transport.StreamError:
return rpcError{ return &rpcError{
code: e.Code, code: e.Code,
desc: e.Desc, desc: e.Desc,
} }
case transport.ConnectionError: case transport.ConnectionError:
return rpcError{ return &rpcError{
code: codes.Internal, code: codes.Internal,
desc: e.Desc, desc: e.Desc,
} }
default: default:
switch err { switch err {
case context.DeadlineExceeded: case context.DeadlineExceeded:
return rpcError{ return &rpcError{
code: codes.DeadlineExceeded, code: codes.DeadlineExceeded,
desc: err.Error(), desc: err.Error(),
} }
case context.Canceled: case context.Canceled:
return rpcError{ return &rpcError{
code: codes.Canceled, code: codes.Canceled,
desc: err.Error(), desc: err.Error(),
} }
case ErrClientConnClosing: case ErrClientConnClosing:
return rpcError{ return &rpcError{
code: codes.FailedPrecondition, code: codes.FailedPrecondition,
desc: err.Error(), desc: err.Error(),
} }

View File

@ -149,13 +149,17 @@ func TestToRPCErr(t *testing.T) {
// input // input
errIn error errIn error
// outputs // outputs
errOut error errOut *rpcError
}{ }{
{transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "")}, {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "").(*rpcError)},
{transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc)}, {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)},
} { } {
err := toRPCErr(test.errIn) err := toRPCErr(test.errIn)
if err != test.errOut { rpcErr, ok := err.(*rpcError)
if !ok {
t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, rpcError{})
}
if *rpcErr != *test.errOut {
t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
} }
} }
@ -178,6 +182,18 @@ func TestContextErr(t *testing.T) {
} }
} }
func TestErrorsWithSameParameters(t *testing.T) {
const description = "some description"
e1 := Errorf(codes.AlreadyExists, description)
e2 := Errorf(codes.AlreadyExists, description)
if e1 == e2 {
t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
}
if Code(e1) != Code(e2) || ErrorDesc(e1) != ErrorDesc(e2) {
t.Fatalf("Expected errors to have same code and description - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
}
}
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize // bmEncode benchmarks encoding a Protocol Buffer message containing mSize
// bytes. // bytes.
func bmEncode(b *testing.B, mSize int) { func bmEncode(b *testing.B, mSize int) {

View File

@ -560,7 +560,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil { if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(*rpcError); ok {
statusCode = err.code statusCode = err.code
statusDesc = err.desc statusDesc = err.desc
} else { } else {
@ -645,7 +645,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
} }
if appErr != nil { if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(*rpcError); ok {
ss.statusCode = err.code ss.statusCode = err.code
ss.statusDesc = err.desc ss.statusDesc = err.desc
} else if err, ok := appErr.(transport.StreamError); ok { } else if err, ok := appErr.(transport.StreamError); ok {

View File

@ -149,7 +149,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(rpcError); ok { if _, ok := err.(*rpcError); ok {
return nil, err return nil, err
} }
if err == errConnClosing { if err == errConnClosing {

View File

@ -659,7 +659,8 @@ func testHealthCheckOnFailure(t *testing.T, e env) {
defer te.tearDown() defer te.tearDown()
cc := te.clientConn() cc := te.clientConn()
if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") { wantErr := grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded")
if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded) t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded)
} }
awaitNewConnLogOutput() awaitNewConnLogOutput()
@ -681,7 +682,7 @@ func testHealthCheckOff(t *testing.T, e env) {
te.startServer() te.startServer()
defer te.tearDown() defer te.tearDown()
want := grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1.Health") want := grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1.Health")
if _, err := healthCheck(1*time.Second, te.clientConn(), ""); err != want { if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want)
} }
} }
@ -708,7 +709,8 @@ func testHealthCheckServingStatus(t *testing.T, e env) {
if out.Status != healthpb.HealthCheckResponse_SERVING { if out.Status != healthpb.HealthCheckResponse_SERVING {
t.Fatalf("Got the serving status %v, want SERVING", out.Status) t.Fatalf("Got the serving status %v, want SERVING", out.Status)
} }
if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); err != grpc.Errorf(codes.NotFound, "unknown service") { wantErr := grpc.Errorf(codes.NotFound, "unknown service")
if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.NotFound) t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.NotFound)
} }
hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING) hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING)
@ -790,7 +792,7 @@ func testFailedEmptyUnary(t *testing.T, e env) {
ctx := metadata.NewContext(context.Background(), testMetadata) ctx := metadata.NewContext(context.Background(), testMetadata)
wantErr := grpc.Errorf(codes.DataLoss, "missing expected user-agent") wantErr := grpc.Errorf(codes.DataLoss, "missing expected user-agent")
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != wantErr { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !equalErrors(err, wantErr) {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
} }
} }
@ -1373,8 +1375,9 @@ func testFailedServerStreaming(t *testing.T, e env) {
if err != nil { if err != nil {
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err) t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err)
} }
if _, err := stream.Recv(); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { wantErr := grpc.Errorf(codes.DataLoss, "got extra metadata")
t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, grpc.Errorf(codes.DataLoss, "got extra metadata")) if _, err := stream.Recv(); !equalErrors(err, wantErr) {
t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, wantErr)
} }
} }
@ -2124,3 +2127,7 @@ func (fw *filterWriter) Write(p []byte) (n int, err error) {
} }
return fw.dst.Write(p) return fw.dst.Write(p)
} }
func equalErrors(l, r error) bool {
return grpc.Code(l) == grpc.Code(r) && grpc.ErrorDesc(l) == grpc.ErrorDesc(r)
}