server: handle context errors returned by service handler (#5156)
This commit is contained in:
@ -72,9 +72,12 @@ type UnaryServerInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal
|
// UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal
|
||||||
// execution of a unary RPC. If a UnaryHandler returns an error, it should be produced by the
|
// execution of a unary RPC.
|
||||||
// status package, or else gRPC will use codes.Unknown as the status code and err.Error() as
|
//
|
||||||
// the status message of the RPC.
|
// If a UnaryHandler returns an error, it should either be produced by the
|
||||||
|
// status package, or be one of the context errors. Otherwise, gRPC will use
|
||||||
|
// codes.Unknown as the status code and err.Error() as the status message of the
|
||||||
|
// RPC.
|
||||||
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)
|
type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error)
|
||||||
|
|
||||||
// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info
|
// UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info
|
||||||
|
11
server.go
11
server.go
@ -1283,9 +1283,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
if appErr != nil {
|
if appErr != nil {
|
||||||
appStatus, ok := status.FromError(appErr)
|
appStatus, ok := status.FromError(appErr)
|
||||||
if !ok {
|
if !ok {
|
||||||
// Convert appErr if it is not a grpc status error.
|
// Convert non-status application error to a status error with code
|
||||||
appErr = status.Error(codes.Unknown, appErr.Error())
|
// Unknown, but handle context errors specifically.
|
||||||
appStatus, _ = status.FromError(appErr)
|
appStatus = status.FromContextError(appErr)
|
||||||
|
appErr = appStatus.Err()
|
||||||
}
|
}
|
||||||
if trInfo != nil {
|
if trInfo != nil {
|
||||||
trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
|
trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
|
||||||
@ -1549,7 +1550,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
if appErr != nil {
|
if appErr != nil {
|
||||||
appStatus, ok := status.FromError(appErr)
|
appStatus, ok := status.FromError(appErr)
|
||||||
if !ok {
|
if !ok {
|
||||||
appStatus = status.New(codes.Unknown, appErr.Error())
|
// Convert non-status application error to a status error with code
|
||||||
|
// Unknown, but handle context errors specifically.
|
||||||
|
appStatus = status.FromContextError(appErr)
|
||||||
appErr = appStatus.Err()
|
appErr = appStatus.Err()
|
||||||
}
|
}
|
||||||
if trInfo != nil {
|
if trInfo != nil {
|
||||||
|
10
stream.go
10
stream.go
@ -46,10 +46,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// StreamHandler defines the handler called by gRPC server to complete the
|
// StreamHandler defines the handler called by gRPC server to complete the
|
||||||
// execution of a streaming RPC. If a StreamHandler returns an error, it
|
// execution of a streaming RPC.
|
||||||
// should be produced by the status package, or else gRPC will use
|
//
|
||||||
// codes.Unknown as the status code and err.Error() as the status message
|
// If a StreamHandler returns an error, it should either be produced by the
|
||||||
// of the RPC.
|
// status package, or be one of the context errors. Otherwise, gRPC will use
|
||||||
|
// codes.Unknown as the status code and err.Error() as the status message of the
|
||||||
|
// RPC.
|
||||||
type StreamHandler func(srv interface{}, stream ServerStream) error
|
type StreamHandler func(srv interface{}, stream ServerStream) error
|
||||||
|
|
||||||
// StreamDesc represents a streaming RPC service's method specification. Used
|
// StreamDesc represents a streaming RPC service's method specification. Used
|
||||||
|
@ -32,6 +32,41 @@ import (
|
|||||||
|
|
||||||
type ctxKey string
|
type ctxKey string
|
||||||
|
|
||||||
|
// TestServerReturningContextError verifies that if a context error is returned
|
||||||
|
// by the service handler, the status will have the correct status code, not
|
||||||
|
// Unknown.
|
||||||
|
func (s) TestServerReturningContextError(t *testing.T) {
|
||||||
|
ss := &stubserver.StubServer{
|
||||||
|
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
return nil, context.DeadlineExceeded
|
||||||
|
},
|
||||||
|
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := ss.Start(nil); err != nil {
|
||||||
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
|
||||||
|
t.Fatalf("ss.Client.EmptyCall() got error %v; want <status with Code()=DeadlineExceeded>", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := ss.Client.FullDuplexCall(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error starting the stream: %v", err)
|
||||||
|
}
|
||||||
|
_, err = stream.Recv()
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded {
|
||||||
|
t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want <status with Code()=DeadlineExceeded>", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (s) TestChainUnaryServerInterceptor(t *testing.T) {
|
func (s) TestChainUnaryServerInterceptor(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
firstIntKey = ctxKey("firstIntKey")
|
firstIntKey = ctxKey("firstIntKey")
|
||||||
|
Reference in New Issue
Block a user