diff --git a/call.go b/call.go index 9af9c04f..13ca5b78 100644 --- a/call.go +++ b/call.go @@ -231,7 +231,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if _, ok := err.(status.Status); ok { + if _, ok := status.FromError(err); ok { return err } if err == errConnClosing || err == errConnUnavailable { diff --git a/call_test.go b/call_test.go index 63e87c21..437197c8 100644 --- a/call_test.go +++ b/call_test.go @@ -240,7 +240,7 @@ func TestInvokeLargeErr(t *testing.T) { var reply string req := "hello" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(status.Status); !ok { + if _, ok := status.FromError(err); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { @@ -256,7 +256,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) { var reply string req := "weird error" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(status.Status); !ok { + if _, ok := status.FromError(err); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if got, want := ErrorDesc(err), weirdError; got != want { diff --git a/rpc_util.go b/rpc_util.go index 6386660d..db56a88d 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -404,9 +404,10 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { // toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { - switch e := err.(type) { - case status.Status: + if _, ok := status.FromError(err); ok { return err + } + switch e := err.(type) { case transport.StreamError: return status.Error(e.Code, e.Desc) case transport.ConnectionError: diff --git a/rpc_util_test.go b/rpc_util_test.go index f2b43f0f..8c92b963 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -156,7 +156,7 @@ func TestToRPCErr(t *testing.T) { {transport.ErrConnClosing, status.Error(codes.Internal, transport.ErrConnClosing.Desc)}, } { err := toRPCErr(test.errIn) - if _, ok := err.(status.Status); !ok { + if _, ok := status.FromError(err); !ok { t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, "")) } if !reflect.DeepEqual(err, test.errOut) { diff --git a/server.go b/server.go index 74f9788d..7956c22a 100644 --- a/server.go +++ b/server.go @@ -682,25 +682,27 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { - switch st := err.(type) { - case status.Status: + if st, ok := status.FromError(err); ok { if e := t.WriteStatus(stream, st); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } - case transport.ConnectionError: - // Nothing to do here. - case transport.StreamError: - if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } else { + switch st := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + default: + panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) } - default: - panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) } return err } if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { - if st, ok := err.(status.Status); ok { + if st, ok := status.FromError(err); ok { if e := t.WriteStatus(stream, st); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } @@ -852,15 +854,16 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp appErr = s.opts.streamInt(server, ss, info, sd.Handler) } if appErr != nil { - switch err := appErr.(type) { - case status.Status: - // Do nothing - case transport.StreamError: - appErr = status.Error(err.Code, err.Desc) - default: - appErr = status.Error(convertCode(appErr), appErr.Error()) + appStatus, ok := status.FromError(appErr) + if !ok { + switch err := appErr.(type) { + case transport.StreamError: + appStatus = status.New(err.Code, err.Desc) + default: + appStatus = status.New(convertCode(appErr), appErr.Error()) + } + appErr = appStatus.Err() } - appStatus, _ := status.FromError(appErr) if trInfo != nil { ss.mu.Lock() ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) diff --git a/status/status.go b/status/status.go index 0e402081..9637aff3 100644 --- a/status/status.go +++ b/status/status.go @@ -50,78 +50,56 @@ import ( "google.golang.org/grpc/codes" ) -// Status provides access to grpc status details and is implemented by all -// errors returned from this package except nil errors, which are not typed. -// Note: gRPC users should not implement their own Statuses. Custom data may -// be attached to the spb.Status proto's Details field. -type Status interface { - // Code returns the status code. - Code() codes.Code - // Message returns the status message. - Message() string - // Proto returns a copy of the status in proto form. - Proto() *spb.Status - // Err returns an error representing the status. - Err() error -} - -// okStatus is a Status whose Code method returns codes.OK, but does not -// implement error. To represent an OK code as an error, use an untyped nil. -type okStatus struct{} - -func (okStatus) Code() codes.Code { - return codes.OK -} - -func (okStatus) Message() string { - return "" -} - -func (okStatus) Proto() *spb.Status { - return nil -} - -func (okStatus) Err() error { - return nil -} - -// statusError contains a status proto. It is embedded and not aliased to -// allow for accessor functions of the same name. It implements error and -// Status, and a nil statusError should never be returned by this package. -type statusError struct { - *spb.Status -} +// statusError is an alias of a status proto. It implements error and Status, +// and a nil statusError should never be returned by this package. +type statusError spb.Status func (se *statusError) Error() string { - return fmt.Sprintf("rpc error: code = %s desc = %s", se.Code(), se.Message()) + p := (*spb.Status)(se) + return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(p.GetCode()), p.GetMessage()) } -func (se *statusError) Code() codes.Code { - return codes.Code(se.Status.Code) +func (se *statusError) status() *Status { + return &Status{s: (*spb.Status)(se)} } -func (se *statusError) Message() string { - return se.Status.Message +// Status represents an RPC status code, message, and details. It is immutable +// and should be created with New, Newf, or FromProto. +type Status struct { + s *spb.Status } -func (se *statusError) Proto() *spb.Status { - return proto.Clone(se.Status).(*spb.Status) +// Code returns the status code contained in s. +func (s *Status) Code() codes.Code { + return codes.Code(s.s.Code) } -func (se *statusError) Err() error { - return se +// Message returns the message contained in s. +func (s *Status) Message() string { + return s.s.Message +} + +// Proto returns s's status as an spb.Status proto message. +func (s *Status) Proto() *spb.Status { + return proto.Clone(s.s).(*spb.Status) +} + +// Err returns an immutable error representing s; returns nil if s.Code() is +// OK. +func (s *Status) Err() error { + if s.Code() == codes.OK { + return nil + } + return (*statusError)(s.s) } // New returns a Status representing c and msg. -func New(c codes.Code, msg string) Status { - if c == codes.OK { - return okStatus{} - } - return &statusError{Status: &spb.Status{Code: int32(c), Message: msg}} +func New(c codes.Code, msg string) *Status { + return &Status{s: &spb.Status{Code: int32(c), Message: msg}} } // Newf returns New(c, fmt.Sprintf(format, a...)). -func Newf(c codes.Code, format string, a ...interface{}) Status { +func Newf(c codes.Code, format string, a ...interface{}) *Status { return New(c, fmt.Sprintf(format, a...)) } @@ -140,21 +118,19 @@ func ErrorProto(s *spb.Status) error { return FromProto(s).Err() } -// FromProto returns a Status representing s. If s.Code is OK, Message and -// Details may be lost. -func FromProto(s *spb.Status) Status { - if s.GetCode() == int32(codes.OK) { - return okStatus{} - } - return &statusError{Status: proto.Clone(s).(*spb.Status)} +// FromProto returns a Status representing s. +func FromProto(s *spb.Status) *Status { + return &Status{s: proto.Clone(s).(*spb.Status)} } // FromError returns a Status representing err if it was produced from this // package, otherwise it returns nil, false. -func FromError(err error) (s Status, ok bool) { +func FromError(err error) (s *Status, ok bool) { if err == nil { - return okStatus{}, true + return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true } - s, ok = err.(Status) - return s, ok + if s, ok := err.(*statusError); ok { + return s.status(), true + } + return nil, false } diff --git a/status/status_test.go b/status/status_test.go index 34de196c..5705f141 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -69,7 +69,7 @@ func TestError(t *testing.T) { if got, want := err.Error(), "rpc error: code = Internal desc = test description"; got != want { t.Fatalf("err.Error() = %q; want %q", got, want) } - s := err.(Status) + s, _ := FromError(err) if got, want := s.Code(), codes.Internal; got != want { t.Fatalf("err.Code() = %s; want %s", got, want) } diff --git a/stream.go b/stream.go index 399e93f3..ecb1a31f 100644 --- a/stream.go +++ b/stream.go @@ -178,7 +178,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if _, ok := err.(status.Status); ok { + if _, ok := status.FromError(err); ok { return nil, err } if err == errConnClosing || err == errConnUnavailable { diff --git a/transport/handler_server.go b/transport/handler_server.go index 5bf63630..e1c43f68 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -183,7 +183,7 @@ func (ht *serverHandlerTransport) do(fn func()) error { } } -func (ht *serverHandlerTransport) WriteStatus(s *Stream, st status.Status) error { +func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error { err := ht.do(func() { ht.writeCommonHeaders(s) diff --git a/transport/http2_server.go b/transport/http2_server.go index 9972a839..db72e940 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -671,7 +671,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // There is no further I/O operations being able to perform on this stream. // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // OK is adopted. -func (t *http2Server) WriteStatus(s *Stream, st status.Status) error { +func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { var headersSent, hasHeader bool s.mu.Lock() if s.state == streamDone { diff --git a/transport/http_util.go b/transport/http_util.go index 57aad62d..bec3e3a8 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -97,7 +97,7 @@ type decodeState struct { // statusGen caches the stream status received from the trailer the server // sent. Client side only. Do not access directly. After all trailers are // parsed, use the status method to retrieve the status. - statusGen status.Status + statusGen *status.Status // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not // intended for direct access outside of parsing. rawStatusCode int32 @@ -156,7 +156,7 @@ func validContentType(t string) bool { return true } -func (d *decodeState) status() status.Status { +func (d *decodeState) status() *status.Status { if d.statusGen == nil { // No status-details were provided; generate status using code/msg. d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg) diff --git a/transport/transport.go b/transport/transport.go index 3b8bd01c..df67d57e 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -214,7 +214,7 @@ type Stream struct { // multiple times. headerDone bool // the status error received from the server. - status status.Status + status *status.Status // rstStream indicates whether a RST_STREAM frame needs to be sent // to the server to signify that this stream is closing. rstStream bool @@ -285,7 +285,7 @@ func (s *Stream) Method() string { } // Status returns the status received from the server. -func (s *Stream) Status() status.Status { +func (s *Stream) Status() *status.Status { return s.status } @@ -334,8 +334,8 @@ func (s *Stream) Read(p []byte) (n int, err error) { } // finish sets the stream's state and status, and closes the done channel. -// s.mu must be held by the caller. -func (s *Stream) finish(st status.Status) { +// s.mu must be held by the caller. st must always be non-nil. +func (s *Stream) finish(st *status.Status) { s.status = st s.state = streamDone close(s.done) @@ -508,7 +508,7 @@ type ServerTransport interface { // WriteStatus sends the status of a stream to the client. WriteStatus is // the final call made on a stream and always occurs. - WriteStatus(s *Stream, st status.Status) error + WriteStatus(s *Stream, st *status.Status) error // Close tears down the transport. Once it is called, the transport // should not be accessed any more. All the pending streams and their