diff --git a/server.go b/server.go index fde4311a..53a2a92c 100644 --- a/server.go +++ b/server.go @@ -99,6 +99,7 @@ type options struct { codec Codec cp Compressor dc Decompressor + unaryInt UnaryServerInterceptor maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } @@ -140,6 +141,18 @@ func Creds(c credentials.Credentials) ServerOption { } } +// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the +// server. Only one interceptor can be installed. The construction of multiple interceptors +// (e.g., chaining) can be implemented at the caller. +func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { + return func(o *options) { + if o.unaryInt != nil { + panic("The unary server interceptor has been set.") + } + o.unaryInt = i + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -494,7 +507,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return nil } - reply, appErr := md.Handler(srv.server, stream.Context(), df, nil) + reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { if err, ok := appErr.(rpcError); ok { statusCode = err.code diff --git a/test/end2end_test.go b/test/end2end_test.go index 88f1dce2..09b786d3 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -420,6 +420,7 @@ type test struct { userAgent string clientCompression bool serverCompression bool + unaryInt grpc.UnaryServerInterceptor // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -468,7 +469,9 @@ func (te *test) startServer() { grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), ) } - + if te.unaryInt != nil { + sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt)) + } la := "localhost:0" switch e.network { case "unix": @@ -1685,6 +1688,29 @@ func testCompressOK(t *testing.T, e env) { } } +func TestUnaryServerInterceptor(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testUnaryServerInterceptor(t, e) + } +} + +func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return nil, grpc.Errorf(codes.PermissionDenied, "") +} + +func testUnaryServerInterceptor(t *testing.T, e env) { + te := newTest(t, e) + te.unaryInt = errInjector + te.startServer() + defer te.tearDown() + + tc := testpb.NewTestServiceClient(te.clientConn()) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.PermissionDenied { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code %d", err, codes.PermissionDenied) + } +} + // funcServer implements methods of TestServiceServer using funcs, // similar to an http.HandlerFunc. // Any unimplemented method will crash. Tests implement the method(s)