Rewrite gRPC server interceptor (#1171)

Co-authored-by: Aaron Abbott <aaronabbott@google.com>
This commit is contained in:
Michael Stella
2020-10-29 16:30:18 -04:00
committed by GitHub
parent 43b88daa81
commit 0c33c1eaac
5 changed files with 243 additions and 449 deletions

View File

@ -77,8 +77,7 @@ Usage Server
import grpc import grpc
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.instrumentation.grpc import GrpcInstrumentorServer, server_interceptor from opentelemetry.instrumentation.grpc import GrpcInstrumentorServer
from opentelemetry.instrumentation.grpc.grpcext import intercept_server
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import ( from opentelemetry.sdk.trace.export import (
ConsoleSpanExporter, ConsoleSpanExporter,
@ -94,10 +93,10 @@ Usage Server
trace.get_tracer_provider().add_span_processor( trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter()) SimpleExportSpanProcessor(ConsoleSpanExporter())
) )
grpc_server_instrumentor = GrpcInstrumentorServer() grpc_server_instrumentor = GrpcInstrumentorServer()
grpc_server_instrumentor.instrument() grpc_server_instrumentor.instrument()
class Greeter(helloworld_pb2_grpc.GreeterServicer): class Greeter(helloworld_pb2_grpc.GreeterServicer):
def SayHello(self, request, context): def SayHello(self, request, context):
return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name)
@ -106,7 +105,6 @@ Usage Server
def serve(): def serve():
server = grpc.server(futures.ThreadPoolExecutor()) server = grpc.server(futures.ThreadPoolExecutor())
server = intercept_server(server, server_interceptor())
helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
server.add_insecure_port("[::]:50051") server.add_insecure_port("[::]:50051")
@ -117,18 +115,25 @@ Usage Server
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig() logging.basicConfig()
serve() serve()
You can also add the instrumentor manually, rather than using
:py:class:`~opentelemetry.instrumentation.grpc.GrpcInstrumentorServer`:
.. code-block:: python
from opentelemetry.instrumentation.grpc import server_interceptor
server = grpc.server(futures.ThreadPoolExecutor(),
interceptors = [server_interceptor()])
""" """
from contextlib import contextmanager
from functools import partial from functools import partial
import grpc import grpc
from wrapt import wrap_function_wrapper as _wrap from wrapt import wrap_function_wrapper as _wrap
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.instrumentation.grpc.grpcext import ( from opentelemetry.instrumentation.grpc.grpcext import intercept_channel
intercept_channel,
intercept_server,
)
from opentelemetry.instrumentation.grpc.version import __version__ from opentelemetry.instrumentation.grpc.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap from opentelemetry.instrumentation.utils import unwrap
@ -140,15 +145,33 @@ from opentelemetry.instrumentation.utils import unwrap
class GrpcInstrumentorServer(BaseInstrumentor): class GrpcInstrumentorServer(BaseInstrumentor):
"""
Globally instrument the grpc server.
Usage::
grpc_server_instrumentor = GrpcInstrumentorServer()
grpc_server_instrumentor.instrument()
"""
# pylint:disable=attribute-defined-outside-init
def _instrument(self, **kwargs): def _instrument(self, **kwargs):
_wrap("grpc", "server", self.wrapper_fn) self._original_func = grpc.server
def server(*args, **kwargs):
if "interceptors" in kwargs:
# add our interceptor as the first
kwargs["interceptors"].insert(0, server_interceptor())
else:
kwargs["interceptors"] = [server_interceptor()]
return self._original_func(*args, **kwargs)
grpc.server = server
def _uninstrument(self, **kwargs): def _uninstrument(self, **kwargs):
unwrap(grpc, "server") grpc.server = self._original_func
def wrapper_fn(self, original_func, instance, args, kwargs):
server = original_func(*args, **kwargs)
return intercept_server(server, server_interceptor())
class GrpcInstrumentorClient(BaseInstrumentor): class GrpcInstrumentorClient(BaseInstrumentor):

View File

@ -17,12 +17,11 @@
# pylint:disable=no-member # pylint:disable=no-member
# pylint:disable=signature-differs # pylint:disable=signature-differs
"""Implementation of the service-side open-telemetry interceptor. """
Implementation of the service-side open-telemetry interceptor.
This library borrows heavily from the OpenTracing gRPC integration:
https://github.com/opentracing-contrib/python-grpc
""" """
import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import List from typing import List
@ -30,9 +29,37 @@ import grpc
from opentelemetry import propagators, trace from opentelemetry import propagators, trace
from opentelemetry.context import attach, detach from opentelemetry.context import attach, detach
from opentelemetry.trace.status import Status, StatusCode
from . import grpcext logger = logging.getLogger(__name__)
from ._utilities import RpcInfo
# wrap an RPC call
# see https://github.com/grpc/grpc/issues/18191
def _wrap_rpc_behavior(handler, continuation):
if handler is None:
return None
if handler.request_streaming and handler.response_streaming:
behavior_fn = handler.stream_stream
handler_factory = grpc.stream_stream_rpc_method_handler
elif handler.request_streaming and not handler.response_streaming:
behavior_fn = handler.stream_unary
handler_factory = grpc.stream_unary_rpc_method_handler
elif not handler.request_streaming and handler.response_streaming:
behavior_fn = handler.unary_stream
handler_factory = grpc.unary_stream_rpc_method_handler
else:
behavior_fn = handler.unary_unary
handler_factory = grpc.unary_unary_rpc_method_handler
return handler_factory(
continuation(
behavior_fn, handler.request_streaming, handler.response_streaming
),
request_deserializer=handler.request_deserializer,
response_serializer=handler.response_serializer,
)
# pylint:disable=abstract-method # pylint:disable=abstract-method
@ -42,7 +69,7 @@ class _OpenTelemetryServicerContext(grpc.ServicerContext):
self._active_span = active_span self._active_span = active_span
self.code = grpc.StatusCode.OK self.code = grpc.StatusCode.OK
self.details = None self.details = None
super(_OpenTelemetryServicerContext, self).__init__() super().__init__()
def is_active(self, *args, **kwargs): def is_active(self, *args, **kwargs):
return self._servicer_context.is_active(*args, **kwargs) return self._servicer_context.is_active(*args, **kwargs)
@ -56,20 +83,26 @@ class _OpenTelemetryServicerContext(grpc.ServicerContext):
def add_callback(self, *args, **kwargs): def add_callback(self, *args, **kwargs):
return self._servicer_context.add_callback(*args, **kwargs) return self._servicer_context.add_callback(*args, **kwargs)
def disable_next_message_compression(self):
return self._service_context.disable_next_message_compression()
def invocation_metadata(self, *args, **kwargs): def invocation_metadata(self, *args, **kwargs):
return self._servicer_context.invocation_metadata(*args, **kwargs) return self._servicer_context.invocation_metadata(*args, **kwargs)
def peer(self, *args, **kwargs): def peer(self):
return self._servicer_context.peer(*args, **kwargs) return self._servicer_context.peer()
def peer_identities(self, *args, **kwargs): def peer_identities(self):
return self._servicer_context.peer_identities(*args, **kwargs) return self._servicer_context.peer_identities()
def peer_identity_key(self, *args, **kwargs): def peer_identity_key(self):
return self._servicer_context.peer_identity_key(*args, **kwargs) return self._servicer_context.peer_identity_key()
def auth_context(self, *args, **kwargs): def auth_context(self):
return self._servicer_context.auth_context(*args, **kwargs) return self._servicer_context.auth_context()
def set_compression(self, compression):
return self._servicer_context.set_compression(compression)
def send_initial_metadata(self, *args, **kwargs): def send_initial_metadata(self, *args, **kwargs):
return self._servicer_context.send_initial_metadata(*args, **kwargs) return self._servicer_context.send_initial_metadata(*args, **kwargs)
@ -77,47 +110,62 @@ class _OpenTelemetryServicerContext(grpc.ServicerContext):
def set_trailing_metadata(self, *args, **kwargs): def set_trailing_metadata(self, *args, **kwargs):
return self._servicer_context.set_trailing_metadata(*args, **kwargs) return self._servicer_context.set_trailing_metadata(*args, **kwargs)
def abort(self, *args, **kwargs): def abort(self, code, details):
if not hasattr(self._servicer_context, "abort"): self.code = code
raise RuntimeError( self.details = details
"abort() is not supported with the installed version of grpcio" self._active_span.set_status(
) Status(status_code=StatusCode(code.value[0]), description=details)
return self._servicer_context.abort(*args, **kwargs) )
return self._servicer_context.abort(code, details)
def abort_with_status(self, *args, **kwargs): def abort_with_status(self, status):
if not hasattr(self._servicer_context, "abort_with_status"): return self._servicer_context.abort_with_status(status)
raise RuntimeError(
"abort_with_status() is not supported with the installed "
"version of grpcio"
)
return self._servicer_context.abort_with_status(*args, **kwargs)
def set_code(self, code): def set_code(self, code):
self.code = code self.code = code
# use details if we already have it, otherwise the status description
details = self.details or code.value[1]
self._active_span.set_status(
Status(status_code=StatusCode(code.value[0]), description=details)
)
return self._servicer_context.set_code(code) return self._servicer_context.set_code(code)
def set_details(self, details): def set_details(self, details):
self.details = details self.details = details
self._active_span.set_status(
Status(
status_code=StatusCode(self.code.value[0]),
description=details,
)
)
return self._servicer_context.set_details(details) return self._servicer_context.set_details(details)
# On the service-side, errors can be signaled either by exceptions or by # pylint:disable=abstract-method
# calling `set_code` on the `servicer_context`. This function checks for the # pylint:disable=no-self-use
# latter and updates the span accordingly.
# pylint:disable=unused-argument # pylint:disable=unused-argument
def _check_error_code(span, servicer_context, rpc_info): class OpenTelemetryServerInterceptor(grpc.ServerInterceptor):
if servicer_context.code != grpc.StatusCode.OK: """
rpc_info.error = servicer_context.code A gRPC server interceptor, to add OpenTelemetry.
Usage::
tracer = some OpenTelemetry tracer
interceptors = [
OpenTelemetryServerInterceptor(tracer),
]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=concurrency),
interceptors = interceptors)
"""
class OpenTelemetryServerInterceptor(
grpcext.UnaryServerInterceptor, grpcext.StreamServerInterceptor
):
def __init__(self, tracer): def __init__(self, tracer):
self._tracer = tracer self._tracer = tracer
@contextmanager @contextmanager
# pylint:disable=no-self-use
def _set_remote_context(self, servicer_context): def _set_remote_context(self, servicer_context):
metadata = servicer_context.invocation_metadata() metadata = servicer_context.invocation_metadata()
if metadata: if metadata:
@ -136,74 +184,67 @@ class OpenTelemetryServerInterceptor(
else: else:
yield yield
def _start_span(self, method): def _start_span(self, handler_call_details, context):
span = self._tracer.start_as_current_span(
name=method, kind=trace.SpanKind.SERVER
)
return span
def intercept_unary(self, request, servicer_context, server_info, handler): attributes = {
"rpc.method": handler_call_details.method,
"rpc.system": "grpc",
}
with self._set_remote_context(servicer_context): metadata = dict(context.invocation_metadata())
with self._start_span(server_info.full_method) as span: if "user-agent" in metadata:
rpc_info = RpcInfo( attributes["rpc.user_agent"] = metadata["user-agent"]
full_method=server_info.full_method,
metadata=servicer_context.invocation_metadata(),
timeout=servicer_context.time_remaining(),
request=request,
)
servicer_context = _OpenTelemetryServicerContext(
servicer_context, span
)
response = handler(request, servicer_context)
_check_error_code(span, servicer_context, rpc_info) # Split up the peer to keep with how other telemetry sources
# do it. This looks like:
rpc_info.response = response # * ipv6:[::1]:57284
# * ipv4:127.0.0.1:57284
return response # * ipv4:10.2.1.1:57284,127.0.0.1:57284
#
# For RPCs that stream responses, the result can be a generator. To record try:
# the span across the generated responses and detect any errors, we wrap host, port = (
# the result in a new generator that yields the response values. context.peer().split(",")[0].split(":", 1)[1].rsplit(":", 1)
def _intercept_server_stream(
self, request_or_iterator, servicer_context, server_info, handler
):
with self._set_remote_context(servicer_context):
with self._start_span(server_info.full_method) as span:
rpc_info = RpcInfo(
full_method=server_info.full_method,
metadata=servicer_context.invocation_metadata(),
timeout=servicer_context.time_remaining(),
)
if not server_info.is_client_stream:
rpc_info.request = request_or_iterator
servicer_context = _OpenTelemetryServicerContext(
servicer_context, span
)
result = handler(request_or_iterator, servicer_context)
for response in result:
yield response
_check_error_code(span, servicer_context, rpc_info)
def intercept_stream(
self, request_or_iterator, servicer_context, server_info, handler
):
if server_info.is_server_stream:
return self._intercept_server_stream(
request_or_iterator, servicer_context, server_info, handler
) )
with self._set_remote_context(servicer_context):
with self._start_span(server_info.full_method) as span: # other telemetry sources convert this, so we will too
rpc_info = RpcInfo( if host in ("[::1]", "127.0.0.1"):
full_method=server_info.full_method, host = "localhost"
metadata=servicer_context.invocation_metadata(),
timeout=servicer_context.time_remaining(), attributes.update({"net.peer.name": host, "net.peer.port": port})
) except IndexError:
servicer_context = _OpenTelemetryServicerContext( logger.warning("Failed to parse peer address '%s'", context.peer())
servicer_context, span
) return self._tracer.start_as_current_span(
response = handler(request_or_iterator, servicer_context) name=handler_call_details.method,
_check_error_code(span, servicer_context, rpc_info) kind=trace.SpanKind.SERVER,
rpc_info.response = response attributes=attributes,
return response )
def intercept_service(self, continuation, handler_call_details):
def telemetry_wrapper(behavior, request_streaming, response_streaming):
def telemetry_interceptor(request_or_iterator, context):
with self._set_remote_context(context):
with self._start_span(
handler_call_details, context
) as span:
# wrap the context
context = _OpenTelemetryServicerContext(context, span)
# And now we run the actual RPC.
try:
return behavior(request_or_iterator, context)
except Exception as error:
# Bare exceptions are likely to be gRPC aborts, which
# we handle in our context wrapper.
# Here, we're interested in uncaught exceptions.
# pylint:disable=unidiomatic-typecheck
if type(error) != Exception:
span.record_exception(error)
raise error
return telemetry_interceptor
return _wrap_rpc_behavior(
continuation(handler_call_details), telemetry_wrapper
)

View File

@ -21,32 +21,32 @@ import abc
class UnaryClientInfo(abc.ABC): class UnaryClientInfo(abc.ABC):
"""Consists of various information about a unary RPC on the """Consists of various information about a unary RPC on the
invocation-side. invocation-side.
Attributes: Attributes:
full_method: A string of the full RPC method, i.e., full_method: A string of the full RPC method, i.e.,
/package.service/method. /package.service/method.
timeout: The length of time in seconds to wait for the computation to timeout: The length of time in seconds to wait for the computation to
terminate or be cancelled, or None if this method should block until terminate or be cancelled, or None if this method should block until
the computation is terminated or is cancelled no matter how long that the computation is terminated or is cancelled no matter how long that
takes. takes.
""" """
class StreamClientInfo(abc.ABC): class StreamClientInfo(abc.ABC):
"""Consists of various information about a stream RPC on the """Consists of various information about a stream RPC on the
invocation-side. invocation-side.
Attributes: Attributes:
full_method: A string of the full RPC method, i.e., full_method: A string of the full RPC method, i.e.,
/package.service/method. /package.service/method.
is_client_stream: Indicates whether the RPC is client-streaming. is_client_stream: Indicates whether the RPC is client-streaming.
is_server_stream: Indicates whether the RPC is server-streaming. is_server_stream: Indicates whether the RPC is server-streaming.
timeout: The length of time in seconds to wait for the computation to timeout: The length of time in seconds to wait for the computation to
terminate or be cancelled, or None if this method should block until terminate or be cancelled, or None if this method should block until
the computation is terminated or is cancelled no matter how long that the computation is terminated or is cancelled no matter how long that
takes. takes.
""" """
class UnaryClientInterceptor(abc.ABC): class UnaryClientInterceptor(abc.ABC):
@ -56,18 +56,18 @@ class UnaryClientInterceptor(abc.ABC):
def intercept_unary(self, request, metadata, client_info, invoker): def intercept_unary(self, request, metadata, client_info, invoker):
"""Intercepts unary-unary RPCs on the invocation-side. """Intercepts unary-unary RPCs on the invocation-side.
Args: Args:
request: The request value for the RPC. request: The request value for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the metadata: Optional :term:`metadata` to be transmitted to the
service-side of the RPC. service-side of the RPC.
client_info: A UnaryClientInfo containing various information about client_info: A UnaryClientInfo containing various information about
the RPC. the RPC.
invoker: The handler to complete the RPC on the client. It is the invoker: The handler to complete the RPC on the client. It is the
interceptor's responsibility to call it. interceptor's responsibility to call it.
Returns: Returns:
The result from calling invoker(request, metadata). The result from calling invoker(request, metadata).
""" """
raise NotImplementedError() raise NotImplementedError()
@ -80,137 +80,46 @@ class StreamClientInterceptor(abc.ABC):
): ):
"""Intercepts stream RPCs on the invocation-side. """Intercepts stream RPCs on the invocation-side.
Args: Args:
request_or_iterator: The request value for the RPC if request_or_iterator: The request value for the RPC if
`client_info.is_client_stream` is `false`; otherwise, an iterator of `client_info.is_client_stream` is `false`; otherwise, an iterator of
request values. request values.
metadata: Optional :term:`metadata` to be transmitted to the service-side metadata: Optional :term:`metadata` to be transmitted to the service-side
of the RPC. of the RPC.
client_info: A StreamClientInfo containing various information about client_info: A StreamClientInfo containing various information about
the RPC. the RPC.
invoker: The handler to complete the RPC on the client. It is the invoker: The handler to complete the RPC on the client. It is the
interceptor's responsibility to call it. interceptor's responsibility to call it.
Returns: Returns:
The result from calling invoker(metadata). The result from calling invoker(metadata).
""" """
raise NotImplementedError() raise NotImplementedError()
def intercept_channel(channel, *interceptors): def intercept_channel(channel, *interceptors):
"""Creates an intercepted channel. """Creates an intercepted channel.
Args: Args:
channel: A Channel. channel: A Channel.
interceptors: Zero or more UnaryClientInterceptors or interceptors: Zero or more UnaryClientInterceptors or
StreamClientInterceptors StreamClientInterceptors
Returns: Returns:
A Channel. A Channel.
Raises: Raises:
TypeError: If an interceptor derives from neither UnaryClientInterceptor TypeError: If an interceptor derives from neither UnaryClientInterceptor
nor StreamClientInterceptor. nor StreamClientInterceptor.
""" """
from . import _interceptor from . import _interceptor
return _interceptor.intercept_channel(channel, *interceptors) return _interceptor.intercept_channel(channel, *interceptors)
class UnaryServerInfo(abc.ABC):
"""Consists of various information about a unary RPC on the service-side.
Attributes:
full_method: A string of the full RPC method, i.e.,
/package.service/method.
"""
class StreamServerInfo(abc.ABC):
"""Consists of various information about a stream RPC on the service-side.
Attributes:
full_method: A string of the full RPC method, i.e.,
/package.service/method.
is_client_stream: Indicates whether the RPC is client-streaming.
is_server_stream: Indicates whether the RPC is server-streaming.
"""
class UnaryServerInterceptor(abc.ABC):
"""Affords intercepting unary-unary RPCs on the service-side."""
@abc.abstractmethod
def intercept_unary(self, request, servicer_context, server_info, handler):
"""Intercepts unary-unary RPCs on the service-side.
Args:
request: The request value for the RPC.
servicer_context: A ServicerContext.
server_info: A UnaryServerInfo containing various information about
the RPC.
handler: The handler to complete the RPC on the server. It is the
interceptor's responsibility to call it.
Returns:
The result from calling handler(request, servicer_context).
"""
raise NotImplementedError()
class StreamServerInterceptor(abc.ABC):
"""Affords intercepting stream RPCs on the service-side."""
@abc.abstractmethod
def intercept_stream(
self, request_or_iterator, servicer_context, server_info, handler
):
"""Intercepts stream RPCs on the service-side.
Args:
request_or_iterator: The request value for the RPC if
`server_info.is_client_stream` is `False`; otherwise, an iterator of
request values.
servicer_context: A ServicerContext.
server_info: A StreamServerInfo containing various information about
the RPC.
handler: The handler to complete the RPC on the server. It is the
interceptor's responsibility to call it.
Returns:
The result from calling handler(servicer_context).
"""
raise NotImplementedError()
def intercept_server(server, *interceptors):
"""Creates an intercepted server.
Args:
server: A Server.
interceptors: Zero or more UnaryServerInterceptors or
StreamServerInterceptors
Returns:
A Server.
Raises:
TypeError: If an interceptor derives from neither UnaryServerInterceptor
nor StreamServerInterceptor.
"""
from . import _interceptor
return _interceptor.intercept_server(server, *interceptors)
__all__ = ( __all__ = (
"UnaryClientInterceptor", "UnaryClientInterceptor",
"StreamClientInfo", "StreamClientInfo",
"StreamClientInterceptor", "StreamClientInterceptor",
"UnaryServerInfo",
"StreamServerInfo",
"UnaryServerInterceptor",
"StreamServerInterceptor",
"intercept_channel", "intercept_channel",
"intercept_server",
) )

View File

@ -252,180 +252,3 @@ def intercept_channel(channel, *interceptors):
) )
result = _InterceptorChannel(result, interceptor) result = _InterceptorChannel(result, interceptor)
return result return result
class _UnaryServerInfo(
collections.namedtuple("_UnaryServerInfo", ("full_method",))
):
pass
class _StreamServerInfo(
collections.namedtuple(
"_StreamServerInfo",
("full_method", "is_client_stream", "is_server_stream"),
)
):
pass
class _InterceptorRpcMethodHandler(grpc.RpcMethodHandler):
def __init__(self, rpc_method_handler, method, interceptor):
self._rpc_method_handler = rpc_method_handler
self._method = method
self._interceptor = interceptor
@property
def request_streaming(self):
return self._rpc_method_handler.request_streaming
@property
def response_streaming(self):
return self._rpc_method_handler.response_streaming
@property
def request_deserializer(self):
return self._rpc_method_handler.request_deserializer
@property
def response_serializer(self):
return self._rpc_method_handler.response_serializer
@property
def unary_unary(self):
if not isinstance(self._interceptor, grpcext.UnaryServerInterceptor):
return self._rpc_method_handler.unary_unary
def adaptation(request, servicer_context):
def handler(request, servicer_context):
return self._rpc_method_handler.unary_unary(
request, servicer_context
)
return self._interceptor.intercept_unary(
request,
servicer_context,
_UnaryServerInfo(self._method),
handler,
)
return adaptation
@property
def unary_stream(self):
if not isinstance(self._interceptor, grpcext.StreamServerInterceptor):
return self._rpc_method_handler.unary_stream
def adaptation(request, servicer_context):
def handler(request, servicer_context):
return self._rpc_method_handler.unary_stream(
request, servicer_context
)
return self._interceptor.intercept_stream(
request,
servicer_context,
_StreamServerInfo(self._method, False, True),
handler,
)
return adaptation
@property
def stream_unary(self):
if not isinstance(self._interceptor, grpcext.StreamServerInterceptor):
return self._rpc_method_handler.stream_unary
def adaptation(request_iterator, servicer_context):
def handler(request_iterator, servicer_context):
return self._rpc_method_handler.stream_unary(
request_iterator, servicer_context
)
return self._interceptor.intercept_stream(
request_iterator,
servicer_context,
_StreamServerInfo(self._method, True, False),
handler,
)
return adaptation
@property
def stream_stream(self):
if not isinstance(self._interceptor, grpcext.StreamServerInterceptor):
return self._rpc_method_handler.stream_stream
def adaptation(request_iterator, servicer_context):
def handler(request_iterator, servicer_context):
return self._rpc_method_handler.stream_stream(
request_iterator, servicer_context
)
return self._interceptor.intercept_stream(
request_iterator,
servicer_context,
_StreamServerInfo(self._method, True, True),
handler,
)
return adaptation
class _InterceptorGenericRpcHandler(grpc.GenericRpcHandler):
def __init__(self, generic_rpc_handler, interceptor):
self.generic_rpc_handler = generic_rpc_handler
self._interceptor = interceptor
def service(self, handler_call_details):
result = self.generic_rpc_handler.service(handler_call_details)
if result:
result = _InterceptorRpcMethodHandler(
result, handler_call_details.method, self._interceptor
)
return result
class _InterceptorServer(grpc.Server):
def __init__(self, server, interceptor):
self._server = server
self._interceptor = interceptor
def add_generic_rpc_handlers(self, generic_rpc_handlers):
generic_rpc_handlers = [
_InterceptorGenericRpcHandler(
generic_rpc_handler, self._interceptor
)
for generic_rpc_handler in generic_rpc_handlers
]
return self._server.add_generic_rpc_handlers(generic_rpc_handlers)
def add_insecure_port(self, *args, **kwargs):
return self._server.add_insecure_port(*args, **kwargs)
def add_secure_port(self, *args, **kwargs):
return self._server.add_secure_port(*args, **kwargs)
def start(self, *args, **kwargs):
return self._server.start(*args, **kwargs)
def stop(self, *args, **kwargs):
return self._server.stop(*args, **kwargs)
def wait_for_termination(self, *args, **kwargs):
return self._server.wait_for_termination(*args, **kwargs)
def intercept_server(server, *interceptors):
result = server
for interceptor in interceptors:
if not isinstance(
interceptor, grpcext.UnaryServerInterceptor
) and not isinstance(interceptor, grpcext.StreamServerInterceptor):
raise TypeError(
"interceptor must be either a "
"grpcext.UnaryServerInterceptor or a "
"grpcext.StreamServerInterceptor"
)
result = _InterceptorServer(result, interceptor)
return result

View File

@ -26,7 +26,6 @@ from opentelemetry.instrumentation.grpc import (
GrpcInstrumentorServer, GrpcInstrumentorServer,
server_interceptor, server_interceptor,
) )
from opentelemetry.instrumentation.grpc.grpcext import intercept_server
from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.test.test_base import TestBase from opentelemetry.test.test_base import TestBase
@ -123,10 +122,9 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1), futures.ThreadPoolExecutor(max_workers=1),
options=(("grpc.so_reuseport", 0),), options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
) )
# FIXME: grpcext interceptor doesn't apply to handlers passed to server
# init, should use intercept_service API instead.
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),)) server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0") port = server.add_insecure_port("[::]:0")
@ -166,8 +164,8 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1), futures.ThreadPoolExecutor(max_workers=1),
options=(("grpc.so_reuseport", 0),), options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
) )
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),)) server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0") port = server.add_insecure_port("[::]:0")
@ -201,8 +199,8 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1), futures.ThreadPoolExecutor(max_workers=1),
options=(("grpc.so_reuseport", 0),), options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
) )
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),)) server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0") port = server.add_insecure_port("[::]:0")
@ -248,8 +246,8 @@ class TestOpenTelemetryServerInterceptor(TestBase):
server = grpc.server( server = grpc.server(
futures.ThreadPoolExecutor(max_workers=2), futures.ThreadPoolExecutor(max_workers=2),
options=(("grpc.so_reuseport", 0),), options=(("grpc.so_reuseport", 0),),
interceptors=[interceptor],
) )
server = intercept_server(server, interceptor)
server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),)) server.add_generic_rpc_handlers((UnaryUnaryRpcHandler(handler),))
port = server.add_insecure_port("[::]:0") port = server.add_insecure_port("[::]:0")