Change status package to deal with concrete types instead of interfaces (#1171)

This commit is contained in:
dfawley
2017-04-06 11:41:07 -07:00
committed by GitHub
parent b507112439
commit 1d27587e10
12 changed files with 82 additions and 102 deletions

View File

@ -231,7 +231,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(status.Status); ok { if _, ok := status.FromError(err); ok {
return err return err
} }
if err == errConnClosing || err == errConnUnavailable { if err == errConnClosing || err == errConnUnavailable {

View File

@ -240,7 +240,7 @@ func TestInvokeLargeErr(t *testing.T) {
var reply string var reply string
req := "hello" req := "hello"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(status.Status); !ok { if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
} }
if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr {
@ -256,7 +256,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
var reply string var reply string
req := "weird error" req := "weird error"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(status.Status); !ok { if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
} }
if got, want := ErrorDesc(err), weirdError; got != want { if got, want := ErrorDesc(err), weirdError; got != want {

View File

@ -404,9 +404,10 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
// toRPCErr converts an error into an error from the status package. // toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error { func toRPCErr(err error) error {
switch e := err.(type) { if _, ok := status.FromError(err); ok {
case status.Status:
return err return err
}
switch e := err.(type) {
case transport.StreamError: case transport.StreamError:
return status.Error(e.Code, e.Desc) return status.Error(e.Code, e.Desc)
case transport.ConnectionError: case transport.ConnectionError:

View File

@ -156,7 +156,7 @@ func TestToRPCErr(t *testing.T) {
{transport.ErrConnClosing, status.Error(codes.Internal, transport.ErrConnClosing.Desc)}, {transport.ErrConnClosing, status.Error(codes.Internal, transport.ErrConnClosing.Desc)},
} { } {
err := toRPCErr(test.errIn) err := toRPCErr(test.errIn)
if _, ok := err.(status.Status); !ok { if _, ok := status.FromError(err); !ok {
t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, "")) t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, ""))
} }
if !reflect.DeepEqual(err, test.errOut) { if !reflect.DeepEqual(err, test.errOut) {

View File

@ -682,11 +682,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
} }
if err != nil { if err != nil {
switch st := err.(type) { if st, ok := status.FromError(err); ok {
case status.Status:
if e := t.WriteStatus(stream, st); e != nil { if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
} }
} else {
switch st := err.(type) {
case transport.ConnectionError: case transport.ConnectionError:
// Nothing to do here. // Nothing to do here.
case transport.StreamError: case transport.StreamError:
@ -696,11 +697,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
default: default:
panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
} }
}
return err return err
} }
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
if st, ok := err.(status.Status); ok { if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil { if e := t.WriteStatus(stream, st); e != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
} }
@ -852,15 +854,16 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
appErr = s.opts.streamInt(server, ss, info, sd.Handler) appErr = s.opts.streamInt(server, ss, info, sd.Handler)
} }
if appErr != nil { if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
switch err := appErr.(type) { switch err := appErr.(type) {
case status.Status:
// Do nothing
case transport.StreamError: case transport.StreamError:
appErr = status.Error(err.Code, err.Desc) appStatus = status.New(err.Code, err.Desc)
default: default:
appErr = status.Error(convertCode(appErr), appErr.Error()) appStatus = status.New(convertCode(appErr), appErr.Error())
}
appErr = appStatus.Err()
} }
appStatus, _ := status.FromError(appErr)
if trInfo != nil { if trInfo != nil {
ss.mu.Lock() ss.mu.Lock()
ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)

View File

@ -50,78 +50,56 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
) )
// Status provides access to grpc status details and is implemented by all // statusError is an alias of a status proto. It implements error and Status,
// errors returned from this package except nil errors, which are not typed. // and a nil statusError should never be returned by this package.
// Note: gRPC users should not implement their own Statuses. Custom data may type statusError spb.Status
// be attached to the spb.Status proto's Details field.
type Status interface {
// Code returns the status code.
Code() codes.Code
// Message returns the status message.
Message() string
// Proto returns a copy of the status in proto form.
Proto() *spb.Status
// Err returns an error representing the status.
Err() error
}
// okStatus is a Status whose Code method returns codes.OK, but does not
// implement error. To represent an OK code as an error, use an untyped nil.
type okStatus struct{}
func (okStatus) Code() codes.Code {
return codes.OK
}
func (okStatus) Message() string {
return ""
}
func (okStatus) Proto() *spb.Status {
return nil
}
func (okStatus) Err() error {
return nil
}
// statusError contains a status proto. It is embedded and not aliased to
// allow for accessor functions of the same name. It implements error and
// Status, and a nil statusError should never be returned by this package.
type statusError struct {
*spb.Status
}
func (se *statusError) Error() string { func (se *statusError) Error() string {
return fmt.Sprintf("rpc error: code = %s desc = %s", se.Code(), se.Message()) p := (*spb.Status)(se)
return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(p.GetCode()), p.GetMessage())
} }
func (se *statusError) Code() codes.Code { func (se *statusError) status() *Status {
return codes.Code(se.Status.Code) return &Status{s: (*spb.Status)(se)}
} }
func (se *statusError) Message() string { // Status represents an RPC status code, message, and details. It is immutable
return se.Status.Message // and should be created with New, Newf, or FromProto.
type Status struct {
s *spb.Status
} }
func (se *statusError) Proto() *spb.Status { // Code returns the status code contained in s.
return proto.Clone(se.Status).(*spb.Status) func (s *Status) Code() codes.Code {
return codes.Code(s.s.Code)
} }
func (se *statusError) Err() error { // Message returns the message contained in s.
return se func (s *Status) Message() string {
return s.s.Message
}
// Proto returns s's status as an spb.Status proto message.
func (s *Status) Proto() *spb.Status {
return proto.Clone(s.s).(*spb.Status)
}
// Err returns an immutable error representing s; returns nil if s.Code() is
// OK.
func (s *Status) Err() error {
if s.Code() == codes.OK {
return nil
}
return (*statusError)(s.s)
} }
// New returns a Status representing c and msg. // New returns a Status representing c and msg.
func New(c codes.Code, msg string) Status { func New(c codes.Code, msg string) *Status {
if c == codes.OK { return &Status{s: &spb.Status{Code: int32(c), Message: msg}}
return okStatus{}
}
return &statusError{Status: &spb.Status{Code: int32(c), Message: msg}}
} }
// Newf returns New(c, fmt.Sprintf(format, a...)). // Newf returns New(c, fmt.Sprintf(format, a...)).
func Newf(c codes.Code, format string, a ...interface{}) Status { func Newf(c codes.Code, format string, a ...interface{}) *Status {
return New(c, fmt.Sprintf(format, a...)) return New(c, fmt.Sprintf(format, a...))
} }
@ -140,21 +118,19 @@ func ErrorProto(s *spb.Status) error {
return FromProto(s).Err() return FromProto(s).Err()
} }
// FromProto returns a Status representing s. If s.Code is OK, Message and // FromProto returns a Status representing s.
// Details may be lost. func FromProto(s *spb.Status) *Status {
func FromProto(s *spb.Status) Status { return &Status{s: proto.Clone(s).(*spb.Status)}
if s.GetCode() == int32(codes.OK) {
return okStatus{}
}
return &statusError{Status: proto.Clone(s).(*spb.Status)}
} }
// FromError returns a Status representing err if it was produced from this // FromError returns a Status representing err if it was produced from this
// package, otherwise it returns nil, false. // package, otherwise it returns nil, false.
func FromError(err error) (s Status, ok bool) { func FromError(err error) (s *Status, ok bool) {
if err == nil { if err == nil {
return okStatus{}, true return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true
} }
s, ok = err.(Status) if s, ok := err.(*statusError); ok {
return s, ok return s.status(), true
}
return nil, false
} }

View File

@ -69,7 +69,7 @@ func TestError(t *testing.T) {
if got, want := err.Error(), "rpc error: code = Internal desc = test description"; got != want { if got, want := err.Error(), "rpc error: code = Internal desc = test description"; got != want {
t.Fatalf("err.Error() = %q; want %q", got, want) t.Fatalf("err.Error() = %q; want %q", got, want)
} }
s := err.(Status) s, _ := FromError(err)
if got, want := s.Code(), codes.Internal; got != want { if got, want := s.Code(), codes.Internal; got != want {
t.Fatalf("err.Code() = %s; want %s", got, want) t.Fatalf("err.Code() = %s; want %s", got, want)
} }

View File

@ -178,7 +178,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(status.Status); ok { if _, ok := status.FromError(err); ok {
return nil, err return nil, err
} }
if err == errConnClosing || err == errConnUnavailable { if err == errConnClosing || err == errConnUnavailable {

View File

@ -183,7 +183,7 @@ func (ht *serverHandlerTransport) do(fn func()) error {
} }
} }
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st status.Status) error { func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
err := ht.do(func() { err := ht.do(func() {
ht.writeCommonHeaders(s) ht.writeCommonHeaders(s)

View File

@ -671,7 +671,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// There is no further I/O operations being able to perform on this stream. // There is no further I/O operations being able to perform on this stream.
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted. // OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st status.Status) error { func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
var headersSent, hasHeader bool var headersSent, hasHeader bool
s.mu.Lock() s.mu.Lock()
if s.state == streamDone { if s.state == streamDone {

View File

@ -97,7 +97,7 @@ type decodeState struct {
// statusGen caches the stream status received from the trailer the server // statusGen caches the stream status received from the trailer the server
// sent. Client side only. Do not access directly. After all trailers are // sent. Client side only. Do not access directly. After all trailers are
// parsed, use the status method to retrieve the status. // parsed, use the status method to retrieve the status.
statusGen status.Status statusGen *status.Status
// rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not
// intended for direct access outside of parsing. // intended for direct access outside of parsing.
rawStatusCode int32 rawStatusCode int32
@ -156,7 +156,7 @@ func validContentType(t string) bool {
return true return true
} }
func (d *decodeState) status() status.Status { func (d *decodeState) status() *status.Status {
if d.statusGen == nil { if d.statusGen == nil {
// No status-details were provided; generate status using code/msg. // No status-details were provided; generate status using code/msg.
d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg) d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg)

View File

@ -214,7 +214,7 @@ type Stream struct {
// multiple times. // multiple times.
headerDone bool headerDone bool
// the status error received from the server. // the status error received from the server.
status status.Status status *status.Status
// rstStream indicates whether a RST_STREAM frame needs to be sent // rstStream indicates whether a RST_STREAM frame needs to be sent
// to the server to signify that this stream is closing. // to the server to signify that this stream is closing.
rstStream bool rstStream bool
@ -285,7 +285,7 @@ func (s *Stream) Method() string {
} }
// Status returns the status received from the server. // Status returns the status received from the server.
func (s *Stream) Status() status.Status { func (s *Stream) Status() *status.Status {
return s.status return s.status
} }
@ -334,8 +334,8 @@ func (s *Stream) Read(p []byte) (n int, err error) {
} }
// finish sets the stream's state and status, and closes the done channel. // finish sets the stream's state and status, and closes the done channel.
// s.mu must be held by the caller. // s.mu must be held by the caller. st must always be non-nil.
func (s *Stream) finish(st status.Status) { func (s *Stream) finish(st *status.Status) {
s.status = st s.status = st
s.state = streamDone s.state = streamDone
close(s.done) close(s.done)
@ -508,7 +508,7 @@ type ServerTransport interface {
// WriteStatus sends the status of a stream to the client. WriteStatus is // WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs. // the final call made on a stream and always occurs.
WriteStatus(s *Stream, st status.Status) error WriteStatus(s *Stream, st *status.Status) error
// Close tears down the transport. Once it is called, the transport // Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their // should not be accessed any more. All the pending streams and their