server: handle context errors returned by service handler (#5156)

This commit is contained in:
Menghan Li
2022-01-26 11:02:23 -08:00
committed by GitHub
parent e27717498d
commit 61a6a06b88
4 changed files with 54 additions and 11 deletions

View File

@ -72,9 +72,12 @@ type UnaryServerInfo struct {
}
// UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal
// execution of a unary RPC. If a UnaryHandler returns an error, it should be produced by the
// status package, or else gRPC will use codes.Unknown as the status code and err.Error() as
// the status message of the RPC.
// execution of a unary RPC.
//
// If a UnaryHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use
// codes.Unknown as the status code and err.Error() as the status message of the
// RPC.
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)
// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info

View File

@ -1283,9 +1283,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
// Convert appErr if it is not a grpc status error.
appErr = status.Error(codes.Unknown, appErr.Error())
appStatus, _ = status.FromError(appErr)
// Convert non-status application error to a status error with code
// Unknown, but handle context errors specifically.
appStatus = status.FromContextError(appErr)
appErr = appStatus.Err()
}
if trInfo != nil {
trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
@ -1549,7 +1550,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
if appErr != nil {
appStatus, ok := status.FromError(appErr)
if !ok {
appStatus = status.New(codes.Unknown, appErr.Error())
// Convert non-status application error to a status error with code
// Unknown, but handle context errors specifically.
appStatus = status.FromContextError(appErr)
appErr = appStatus.Err()
}
if trInfo != nil {

View File

@ -46,10 +46,12 @@ import (
)
// StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC. If a StreamHandler returns an error, it
// should be produced by the status package, or else gRPC will use
// codes.Unknown as the status code and err.Error() as the status message
// of the RPC.
// execution of a streaming RPC.
//
// If a StreamHandler returns an error, it should either be produced by the
// status package, or be one of the context errors. Otherwise, gRPC will use
// codes.Unknown as the status code and err.Error() as the status message of the
// RPC.
type StreamHandler func(srv interface{}, stream ServerStream) error
// StreamDesc represents a streaming RPC service's method specification. Used

View File

@ -32,6 +32,41 @@ import (
type ctxKey string
// TestServerReturningContextError verifies that if a context error is returned
// by the service handler, the status will have the correct status code, not
// Unknown.
func (s) TestServerReturningContextError(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return nil, context.DeadlineExceeded
},
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
return context.DeadlineExceeded
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
t.Fatalf("ss.Client.EmptyCall() got error %v; want <status with Code()=DeadlineExceeded>", err)
}
stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("unexpected error starting the stream: %v", err)
}
_, err = stream.Recv()
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want <status with Code()=DeadlineExceeded>", err)
}
}
func (s) TestChainUnaryServerInterceptor(t *testing.T) {
var (
firstIntKey = ctxKey("firstIntKey")