Rewrite gRPC server interceptor (#1171)

Co-authored-by: Aaron Abbott <aaronabbott@google.com>
This commit is contained in:
Michael Stella
2020-10-29 16:30:18 -04:00
committed by GitHub
parent 43b88daa81
commit 0c33c1eaac
5 changed files with 243 additions and 449 deletions

View File

@ -77,8 +77,7 @@ Usage Server
import grpc
from opentelemetry import trace
from opentelemetry.instrumentation.grpc import GrpcInstrumentorServer, server_interceptor
from opentelemetry.instrumentation.grpc.grpcext import intercept_server
from opentelemetry.instrumentation.grpc import GrpcInstrumentorServer
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
ConsoleSpanExporter,
@ -94,10 +93,10 @@ Usage Server
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
grpc_server_instrumentor = GrpcInstrumentorServer()
grpc_server_instrumentor.instrument()
class Greeter(helloworld_pb2_grpc.GreeterServicer):
def SayHello(self, request, context):
return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name)
@ -106,7 +105,6 @@ Usage Server
def serve():
server = grpc.server(futures.ThreadPoolExecutor())
server = intercept_server(server, server_interceptor())
helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
server.add_insecure_port("[::]:50051")
@ -117,18 +115,25 @@ Usage Server
if __name__ == "__main__":
logging.basicConfig()
serve()
You can also add the instrumentor manually, rather than using
:py:class:`~opentelemetry.instrumentation.grpc.GrpcInstrumentorServer`:
.. code-block:: python
from opentelemetry.instrumentation.grpc import server_interceptor
server = grpc.server(futures.ThreadPoolExecutor(),
interceptors = [server_interceptor()])
"""
from contextlib import contextmanager
from functools import partial
import grpc
from wrapt import wrap_function_wrapper as _wrap
from opentelemetry import trace
from opentelemetry.instrumentation.grpc.grpcext import (
intercept_channel,
intercept_server,
)
from opentelemetry.instrumentation.grpc.grpcext import intercept_channel
from opentelemetry.instrumentation.grpc.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
@ -140,15 +145,33 @@ from opentelemetry.instrumentation.utils import unwrap
class GrpcInstrumentorServer(BaseInstrumentor):
"""
Globally instrument the grpc server.
Usage::
grpc_server_instrumentor = GrpcInstrumentorServer()
grpc_server_instrumentor.instrument()
"""
# pylint:disable=attribute-defined-outside-init
def _instrument(self, **kwargs):
_wrap("grpc", "server", self.wrapper_fn)
self._original_func = grpc.server
def server(*args, **kwargs):
if "interceptors" in kwargs:
# add our interceptor as the first
kwargs["interceptors"].insert(0, server_interceptor())
else:
kwargs["interceptors"] = [server_interceptor()]
return self._original_func(*args, **kwargs)
grpc.server = server
def _uninstrument(self, **kwargs):
unwrap(grpc, "server")
def wrapper_fn(self, original_func, instance, args, kwargs):
server = original_func(*args, **kwargs)
return intercept_server(server, server_interceptor())
grpc.server = self._original_func
class GrpcInstrumentorClient(BaseInstrumentor):

View File

@ -17,12 +17,11 @@
# pylint:disable=no-member
# pylint:disable=signature-differs
"""Implementation of the service-side open-telemetry interceptor.
This library borrows heavily from the OpenTracing gRPC integration:
https://github.com/opentracing-contrib/python-grpc
"""
Implementation of the service-side open-telemetry interceptor.
"""
import logging
from contextlib import contextmanager
from typing import List
@ -30,9 +29,37 @@ import grpc
from opentelemetry import propagators, trace
from opentelemetry.context import attach, detach
from opentelemetry.trace.status import Status, StatusCode
from . import grpcext
from ._utilities import RpcInfo
logger = logging.getLogger(__name__)
# wrap an RPC call
# see https://github.com/grpc/grpc/issues/18191
def _wrap_rpc_behavior(handler, continuation):
if handler is None:
return None
if handler.request_streaming and handler.response_streaming:
behavior_fn = handler.stream_stream
handler_factory = grpc.stream_stream_rpc_method_handler
elif handler.request_streaming and not handler.response_streaming:
behavior_fn = handler.stream_unary
handler_factory = grpc.stream_unary_rpc_method_handler
elif not handler.request_streaming and handler.response_streaming:
behavior_fn = handler.unary_stream
handler_factory = grpc.unary_stream_rpc_method_handler
else:
behavior_fn = handler.unary_unary
handler_factory = grpc.unary_unary_rpc_method_handler
return handler_factory(
continuation(
behavior_fn, handler.request_streaming, handler.response_streaming
),
request_deserializer=handler.request_deserializer,
response_serializer=handler.response_serializer,
)
# pylint:disable=abstract-method
@ -42,7 +69,7 @@ class _OpenTelemetryServicerContext(grpc.ServicerContext):
self._active_span = active_span
self.code = grpc.StatusCode.OK
self.details = None
super(_OpenTelemetryServicerContext, self).__init__()
super().__init__()
def is_active(self, *args, **kwargs):
return self._servicer_context.is_active(*args, **kwargs)
@ -56,20 +83,26 @@ class _OpenTelemetryServicerContext(grpc.ServicerContext):
def add_callback(self, *args, **kwargs):
return self._servicer_context.add_callback(*args, **kwargs)
def disable_next_message_compression(self):
return self._service_context.disable_next_message_compression()
def invocation_metadata(self, *args, **kwargs):
return self._servicer_context.invocation_metadata(*args, **kwargs)
def peer(self, *args, **kwargs):
return self._servicer_context.peer(*args, **kwargs)
def peer(self):
return self._servicer_context.peer()
def peer_identities(self, *args, **kwargs):
return self._servicer_context.peer_identities(*args, **kwargs)
def peer_identities(self):
return self._servicer_context.peer_identities()
def peer_identity_key(self, *args, **kwargs):
return self._servicer_context.peer_identity_key(*args, **kwargs)
def peer_identity_key(self):
return self._servicer_context.peer_identity_key()
def auth_context(self, *args, **kwargs):
return self._servicer_context.auth_context(*args, **kwargs)
def auth_context(self):
return self._servicer_context.auth_context()
def set_compression(self, compression):
return self._servicer_context.set_compression(compression)
def send_initial_metadata(self, *args, **kwargs):
return self._servicer_context.send_initial_metadata(*args, **kwargs)
@ -77,47 +110,62 @@ class _OpenTelemetryServicerContext(grpc.ServicerContext):
def set_trailing_metadata(self, *args, **kwargs):
return self._servicer_context.set_trailing_metadata(*args, **kwargs)
def abort(self, *args, **kwargs):
if not hasattr(self._servicer_context, "abort"):
raise RuntimeError(
"abort() is not supported with the installed version of grpcio"
def abort(self, code, details):
self.code = code
self.details = details
self._active_span.set_status(
Status(status_code=StatusCode(code.value[0]), description=details)
)
return self._servicer_context.abort(*args, **kwargs)
return self._servicer_context.abort(code, details)
def abort_with_status(self, *args, **kwargs):
if not hasattr(self._servicer_context, "abort_with_status"):
raise RuntimeError(
"abort_with_status() is not supported with the installed "
"version of grpcio"
)
return self._servicer_context.abort_with_status(*args, **kwargs)
def abort_with_status(self, status):
return self._servicer_context.abort_with_status(status)
def set_code(self, code):
self.code = code
# use details if we already have it, otherwise the status description
details = self.details or code.value[1]
self._active_span.set_status(
Status(status_code=StatusCode(code.value[0]), description=details)
)
return self._servicer_context.set_code(code)
def set_details(self, details):
self.details = details
self._active_span.set_status(
Status(
status_code=StatusCode(self.code.value[0]),
description=details,
)
)
return self._servicer_context.set_details(details)
# On the service-side, errors can be signaled either by exceptions or by
# calling `set_code` on the `servicer_context`. This function checks for the
# latter and updates the span accordingly.
# pylint:disable=abstract-method
# pylint:disable=no-self-use
# pylint:disable=unused-argument
def _check_error_code(span, servicer_context, rpc_info):
if servicer_context.code != grpc.StatusCode.OK:
rpc_info.error = servicer_context.code
class OpenTelemetryServerInterceptor(grpc.ServerInterceptor):
"""
A gRPC server interceptor, to add OpenTelemetry.
Usage::
tracer = some OpenTelemetry tracer
interceptors = [
OpenTelemetryServerInterceptor(tracer),
]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=concurrency),
interceptors = interceptors)
"""
class OpenTelemetryServerInterceptor(
grpcext.UnaryServerInterceptor, grpcext.StreamServerInterceptor
):
def __init__(self, tracer):
self._tracer = tracer
@contextmanager
# pylint:disable=no-self-use
def _set_remote_context(self, servicer_context):
metadata = servicer_context.invocation_metadata()
if metadata:
@ -136,74 +184,67 @@ class OpenTelemetryServerInterceptor(
else:
yield
def _start_span(self, method):
span = self._tracer.start_as_current_span(
name=method, kind=trace.SpanKind.SERVER
)
return span
def _start_span(self, handler_call_details, context):
def intercept_unary(self, request, servicer_context, server_info, handler):
attributes = {
"rpc.method": handler_call_details.method,
"rpc.system": "grpc",
}
with self._set_remote_context(servicer_context):
with self._start_span(server_info.full_method) as span:
rpc_info = RpcInfo(
full_method=server_info.full_method,
metadata=servicer_context.invocation_metadata(),
timeout=servicer_context.time_remaining(),
request=request,
)
servicer_context = _OpenTelemetryServicerContext(
servicer_context, span
)
response = handler(request, servicer_context)
metadata = dict(context.invocation_metadata())
if "user-agent" in metadata:
attributes["rpc.user_agent"] = metadata["user-agent"]
_check_error_code(span, servicer_context, rpc_info)
# Split up the peer to keep with how other telemetry sources
# do it. This looks like:
# * ipv6:[::1]:57284
# * ipv4:127.0.0.1:57284
# * ipv4:10.2.1.1:57284,127.0.0.1:57284
#
try:
host, port = (
context.peer().split(",")[0].split(":", 1)[1].rsplit(":", 1)
)
rpc_info.response = response
# other telemetry sources convert this, so we will too
if host in ("[::1]", "127.0.0.1"):
host = "localhost"
return response
attributes.update({"net.peer.name": host, "net.peer.port": port})
except IndexError:
logger.warning("Failed to parse peer address '%s'", context.peer())
# For RPCs that stream responses, the result can be a generator. To record
# the span across the generated responses and detect any errors, we wrap
# the result in a new generator that yields the response values.
def _intercept_server_stream(
self, request_or_iterator, servicer_context, server_info, handler
):
with self._set_remote_context(servicer_context):
with self._start_span(server_info.full_method) as span:
rpc_info = RpcInfo(
full_method=server_info.full_method,
metadata=servicer_context.invocation_metadata(),
timeout=servicer_context.time_remaining(),
return self._tracer.start_as_current_span(
name=handler_call_details.method,
kind=trace.SpanKind.SERVER,
attributes=attributes,
)
if not server_info.is_client_stream:
rpc_info.request = request_or_iterator
servicer_context = _OpenTelemetryServicerContext(
servicer_context, span
)
result = handler(request_or_iterator, servicer_context)
for response in result:
yield response
_check_error_code(span, servicer_context, rpc_info)
def intercept_stream(
self, request_or_iterator, servicer_context, server_info, handler
):
if server_info.is_server_stream:
return self._intercept_server_stream(
request_or_iterator, servicer_context, server_info, handler
def intercept_service(self, continuation, handler_call_details):
def telemetry_wrapper(behavior, request_streaming, response_streaming):
def telemetry_interceptor(request_or_iterator, context):
with self._set_remote_context(context):
with self._start_span(
handler_call_details, context
) as span:
# wrap the context
context = _OpenTelemetryServicerContext(context, span)
# And now we run the actual RPC.
try:
return behavior(request_or_iterator, context)
except Exception as error:
# Bare exceptions are likely to be gRPC aborts, which
# we handle in our context wrapper.
# Here, we're interested in uncaught exceptions.
# pylint:disable=unidiomatic-typecheck
if type(error) != Exception:
span.record_exception(error)
raise error
return telemetry_interceptor
return _wrap_rpc_behavior(
continuation(handler_call_details), telemetry_wrapper
)
with self._set_remote_context(servicer_context):
with self._start_span(server_info.full_method) as span:
rpc_info = RpcInfo(
full_method=server_info.full_method,
metadata=servicer_context.invocation_metadata(),
timeout=servicer_context.time_remaining(),
)
servicer_context = _OpenTelemetryServicerContext(
servicer_context, span
)
response = handler(request_or_iterator, servicer_context)
_check_error_code(span, servicer_context, rpc_info)
rpc_info.response = response
return response

View File

@ -117,100 +117,9 @@ def intercept_channel(channel, *interceptors):
return _interceptor.intercept_channel(channel, *interceptors)
class UnaryServerInfo(abc.ABC):
"""Consists of various information about a unary RPC on the service-side.
Attributes:
full_method: A string of the full RPC method, i.e.,
/package.service/method.
"""
class StreamServerInfo(abc.ABC):
"""Consists of various information about a stream RPC on the service-side.
Attributes:
full_method: A string of the full RPC method, i.e.,
/package.service/method.
is_client_stream: Indicates whether the RPC is client-streaming.
is_server_stream: Indicates whether the RPC is server-streaming.
"""
class UnaryServerInterceptor(abc.ABC):
"""Affords intercepting unary-unary RPCs on the service-side."""
@abc.abstractmethod
def intercept_unary(self, request, servicer_context, server_info, handler):
"""Intercepts unary-unary RPCs on the service-side.
Args:
request: The request value for the RPC.
servicer_context: A ServicerContext.
server_info: A UnaryServerInfo containing various information about
the RPC.
handler: The handler to complete the RPC on the server. It is the
interceptor's responsibility to call it.
Returns:
The result from calling handler(request, servicer_context).
"""
raise NotImplementedError()
class StreamServerInterceptor(abc.ABC):
"""Affords intercepting stream RPCs on the service-side."""
@abc.abstractmethod
def intercept_stream(
self, request_or_iterator, servicer_context, server_info, handler
):
"""Intercepts stream RPCs on the service-side.
Args:
request_or_iterator: The request value for the RPC if
`server_info.is_client_stream` is `False`; otherwise, an iterator of
request values.
servicer_context: A ServicerContext.
server_info: A StreamServerInfo containing various information about
the RPC.
handler: The handler to complete the RPC on the server. It is the
interceptor's responsibility to call it.
Returns:
The result from calling handler(servicer_context).
"""
raise NotImplementedError()
def intercept_server(server, *interceptors):
"""Creates an intercepted server.
Args:
server: A Server.
interceptors: Zero or more UnaryServerInterceptors or
StreamServerInterceptors
Returns:
A Server.
Raises:
TypeError: If an interceptor derives from neither UnaryServerInterceptor
nor StreamServerInterceptor.
"""
from . import _interceptor
return _interceptor.intercept_server(server, *interceptors)
__all__ = (
"UnaryClientInterceptor",
"StreamClientInfo",
"StreamClientInterceptor",
"UnaryServerInfo",
"StreamServerInfo",
"UnaryServerInterceptor",
"StreamServerInterceptor",
"intercept_channel",
"intercept_server",
)

View File

@ -252,180 +252,3 @@ def intercept_channel(channel, *interceptors):
)
result = _InterceptorChannel(result, interceptor)
return result
class _UnaryServerInfo(
collections.namedtuple("_UnaryServerInfo", ("full_method",))
):
pass
class _StreamServerInfo(
collections.namedtuple(
"_StreamServerInfo",
("full_method", "is_client_stream", "is_server_stream"),
)
):
pass
class _InterceptorRpcMethodHandler(grpc.RpcMethodHandler):
def __init__(self, rpc_method_handler, method, interceptor):
self._rpc_method_handler = rpc_method_handler
self._method = method
self._interceptor = interceptor
@property
def request_streaming(self):
return self._rpc_method_handler.request_streaming
@property
def response_streaming(self):
return self._rpc_method_handler.response_streaming
@property
def request_deserializer(self):
return self._rpc_method_handler.request_deserializer
@property
def response_serializer(self):
return self._rpc_method_handler.response_serializer
@property
def unary_unary(self):
if not isinstance(self._interceptor, grpcext.UnaryServerInterceptor):
return self._rpc_method_handler.unary_unary
def adaptation(request, servicer_context):
def handler(request, servicer_context):
return self._rpc_method_handler.unary_unary(
request, servicer_context
)
return self._interceptor.intercept_unary(
request,
servicer_context,
_UnaryServerInfo(self._method),
handler,
)
return adaptation
@property
def unary_stream(self):
if not isinstance(self._interceptor, grpcext.StreamServerInterceptor):
return self._rpc_method_handler.unary_stream
def adaptation(request, servicer_context):
def handler(request, servicer_context):
return self._rpc_method_handler.unary_stream(
request, servicer_context
)
return self._interceptor.intercept_stream(
request,
servicer_context,
_StreamServerInfo(self._method, False, True),
handler,
)
return adaptation
@property
def stream_unary(self):
if not isinstance(self._interceptor, grpcext.StreamServerInterceptor):
return self._rpc_method_handler.stream_unary
def adaptation(request_iterator, servicer_context):
def handler(request_iterator, servicer_context):
return self._rpc_method_handler.stream_unary(
request_iterator, servicer_context
)
return self._interceptor.intercept_stream(
request_iterator,
servicer_context,
_StreamServerInfo(self._method, True, False),
handler,
)
return adaptation
@property
def stream_stream(self):
if not isinstance(self._interceptor, grpcext.StreamServerInterceptor):
return self._rpc_method_handler.stream_stream
def adaptation(request_iterator, servicer_context):
def handler(request_iterator, servicer_context):
return self._rpc_method_handler.stream_stream(
request_iterator, servicer_context
)
return self._interceptor.intercept_stream(
request_iterator,
servicer_context,
_StreamServerInfo(self._method, True, True),
handler,
)
return adaptation
class _InterceptorGenericRpcHandler(grpc.GenericRpcHandler):
def __init__(self, generic_rpc_handler, interceptor):
self.generic_rpc_handler = generic_rpc_handler
self._interceptor = interceptor
def service(self, handler_call_details):
result = self.generic_rpc_handler.service(handler_call_details)
if result:
result = _InterceptorRpcMethodHandler(
result, handler_call_details.method, self._interceptor
)
return result
class _InterceptorServer(grpc.Server):
def __init__(self, server, interceptor):
self._server = server
self._interceptor = interceptor
def add_generic_rpc_handlers(self, generic_rpc_handlers):
generic_rpc_handlers = [
_InterceptorGenericRpcHandler(
generic_rpc_handler, self._interceptor
)
for generic_rpc_handler in generic_rpc_handlers
]
return self._server.add_generic_rpc_handlers(generic_rpc_handlers)
def add_insecure_port(self, *args, **kwargs):
return self._server.add_insecure_port(*args, **kwargs)
def add_secure_port(self, *args, **kwargs):
return self._server.add_secure_port(*args, **kwargs)
def start(self, *args, **kwargs):
return self._server.start(*args, **kwargs)
def stop(self, *args, **kwargs):
return self._server.stop(*args, **kwargs)
def wait_for_termination(self, *args, **kwargs):
return self._server.wait_for_termination(*args, **kwargs)
def intercept_server(server, *interceptors):
result = server
for interceptor in interceptors:
if not isinstance(
interceptor, grpcext.UnaryServerInterceptor
) and not isinstance(interceptor, grpcext.StreamServerInterceptor):
raise TypeError(
"interceptor must be either a "
"grpcext.UnaryServerInterceptor or a "
"grpcext.StreamServerInterceptor"
)
result = _InterceptorServer(result, interceptor)
return result

View File

@ -26,7 +26,6 @@ from opentelemetry.instrumentation.grpc import (
GrpcInstrumentorServer,
server_interceptor,
)
from opentelemetry.instrumentation.grpc.grpcext import intercept_server
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.test.test_base import TestBase
@ -123,10 +122,9 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
)
# FIXME: grpcext interceptor doesn't apply to handlers passed to server
# init, should use intercept_service API instead.
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0")
@ -166,8 +164,8 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
)
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0")
@ -201,8 +199,8 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
)
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0")
@ -248,8 +246,8 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=2),
options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
)
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0")