added request and response hooks for grpc client (#1706)

This commit is contained in:
Yaron
2023-04-15 14:18:57 +03:00
committed by GitHub
parent a7a4f71570
commit d01c96fb42
6 changed files with 379 additions and 15 deletions

View File

@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
- `opentelemetry-instrumentation-system-metrics` Add `process.` prefix to `runtime.memory`, `runtime.cpu.time`, and `runtime.gc_count`. Change `runtime.memory` from count to UpDownCounter. ([#1735](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1735))
- Add request and response hooks for GRPC instrumentation (client only)
([#1706](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1706))
### Added

View File

@ -434,6 +434,8 @@ class GrpcInstrumentorClient(BaseInstrumentor):
else:
filter_ = any_of(filter_, excluded_service_filter)
self._filter = filter_
self._request_hook = None
self._response_hook = None
# Figures out which channel type we need to wrap
def _which_channel(self, kwargs):
@ -455,6 +457,8 @@ class GrpcInstrumentorClient(BaseInstrumentor):
return _instruments
def _instrument(self, **kwargs):
self._request_hook = kwargs.get("request_hook")
self._response_hook = kwargs.get("response_hook")
for ctype in self._which_channel(kwargs):
_wrap(
"grpc",
@ -469,11 +473,15 @@ class GrpcInstrumentorClient(BaseInstrumentor):
def wrapper_fn(self, original_func, instance, args, kwargs):
channel = original_func(*args, **kwargs)
tracer_provider = kwargs.get("tracer_provider")
request_hook = self._request_hook
response_hook = self._response_hook
return intercept_channel(
channel,
client_interceptor(
tracer_provider=tracer_provider,
filter_=self._filter,
request_hook=request_hook,
response_hook=response_hook,
),
)
@ -499,6 +507,8 @@ class GrpcAioInstrumentorClient(BaseInstrumentor):
else:
filter_ = any_of(filter_, excluded_service_filter)
self._filter = filter_
self._request_hook = None
self._response_hook = None
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments
@ -507,13 +517,19 @@ class GrpcAioInstrumentorClient(BaseInstrumentor):
if "interceptors" in kwargs and kwargs["interceptors"]:
kwargs["interceptors"] = (
aio_client_interceptors(
tracer_provider=tracer_provider, filter_=self._filter
tracer_provider=tracer_provider,
filter_=self._filter,
request_hook=self._request_hook,
response_hook=self._response_hook,
)
+ kwargs["interceptors"]
)
else:
kwargs["interceptors"] = aio_client_interceptors(
tracer_provider=tracer_provider, filter_=self._filter
tracer_provider=tracer_provider,
filter_=self._filter,
request_hook=self._request_hook,
response_hook=self._response_hook,
)
return kwargs
@ -521,6 +537,8 @@ class GrpcAioInstrumentorClient(BaseInstrumentor):
def _instrument(self, **kwargs):
self._original_insecure = grpc.aio.insecure_channel
self._original_secure = grpc.aio.secure_channel
self._request_hook = kwargs.get("request_hook")
self._response_hook = kwargs.get("response_hook")
tracer_provider = kwargs.get("tracer_provider")
def insecure(*args, **kwargs):
@ -541,7 +559,9 @@ class GrpcAioInstrumentorClient(BaseInstrumentor):
grpc.aio.secure_channel = self._original_secure
def client_interceptor(tracer_provider=None, filter_=None):
def client_interceptor(
tracer_provider=None, filter_=None, request_hook=None, response_hook=None
):
"""Create a gRPC client channel interceptor.
Args:
@ -558,7 +578,12 @@ def client_interceptor(tracer_provider=None, filter_=None):
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
return _client.OpenTelemetryClientInterceptor(tracer, filter_=filter_)
return _client.OpenTelemetryClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
)
def server_interceptor(tracer_provider=None, filter_=None):
@ -581,7 +606,9 @@ def server_interceptor(tracer_provider=None, filter_=None):
return _server.OpenTelemetryServerInterceptor(tracer, filter_=filter_)
def aio_client_interceptors(tracer_provider=None, filter_=None):
def aio_client_interceptors(
tracer_provider=None, filter_=None, request_hook=None, response_hook=None
):
"""Create a gRPC client channel interceptor.
Args:
@ -595,10 +622,30 @@ def aio_client_interceptors(tracer_provider=None, filter_=None):
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
return [
_aio_client.UnaryUnaryAioClientInterceptor(tracer, filter_=filter_),
_aio_client.UnaryStreamAioClientInterceptor(tracer, filter_=filter_),
_aio_client.StreamUnaryAioClientInterceptor(tracer, filter_=filter_),
_aio_client.StreamStreamAioClientInterceptor(tracer, filter_=filter_),
_aio_client.UnaryUnaryAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
_aio_client.UnaryStreamAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
_aio_client.StreamUnaryAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
_aio_client.StreamStreamAioClientInterceptor(
tracer,
filter_=filter_,
request_hook=request_hook,
response_hook=response_hook,
),
]

View File

@ -13,6 +13,7 @@
# limitations under the License.
import functools
import logging
from collections import OrderedDict
import grpc
@ -28,8 +29,10 @@ from opentelemetry.propagate import inject
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode
logger = logging.getLogger(__name__)
def _unary_done_callback(span, code, details):
def _unary_done_callback(span, code, details, response_hook):
def callback(call):
try:
span.set_attribute(
@ -43,6 +46,8 @@ def _unary_done_callback(span, code, details):
description=details,
)
)
response_hook(span, details)
finally:
span.end()
@ -110,7 +115,11 @@ class _BaseAioClientInterceptor(OpenTelemetryClientInterceptor):
code = await call.code()
details = await call.details()
call.add_done_callback(_unary_done_callback(span, code, details))
call.add_done_callback(
_unary_done_callback(
span, code, details, self._call_response_hook
)
)
return call
except grpc.aio.AioRpcError as exc:
@ -120,6 +129,8 @@ class _BaseAioClientInterceptor(OpenTelemetryClientInterceptor):
async def _wrap_stream_response(self, span, call):
try:
async for response in call:
if self._response_hook:
self._call_response_hook(span, response)
yield response
except Exception as exc:
self.add_error_details_to_span(span, exc)
@ -151,6 +162,9 @@ class UnaryUnaryAioClientInterceptor(
) as span:
new_details = self.propagate_trace_in_details(client_call_details)
if self._request_hook:
self._call_request_hook(span, request)
continuation_with_args = functools.partial(
continuation, new_details, request
)
@ -175,7 +189,8 @@ class UnaryStreamAioClientInterceptor(
new_details = self.propagate_trace_in_details(client_call_details)
resp = await continuation(new_details, request)
if self._request_hook:
self._call_request_hook(span, request)
return self._wrap_stream_response(span, resp)

View File

@ -19,8 +19,9 @@
"""Implementation of the invocation-side open-telemetry interceptor."""
import logging
from collections import OrderedDict
from typing import MutableMapping
from typing import Callable, MutableMapping
import grpc
@ -33,6 +34,8 @@ from opentelemetry.propagators.textmap import Setter
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode
logger = logging.getLogger(__name__)
class _CarrierSetter(Setter):
"""We use a custom setter in order to be able to lower case
@ -59,12 +62,27 @@ def _make_future_done_callback(span, rpc_info):
return callback
def _safe_invoke(function: Callable, *args):
function_name = "<unknown>"
try:
function_name = function.__name__
function(*args)
except Exception as ex: # pylint:disable=broad-except
logger.error(
"Error when invoking function '%s'", function_name, exc_info=ex
)
class OpenTelemetryClientInterceptor(
grpcext.UnaryClientInterceptor, grpcext.StreamClientInterceptor
):
def __init__(self, tracer, filter_=None):
def __init__(
self, tracer, filter_=None, request_hook=None, response_hook=None
):
self._tracer = tracer
self._filter = filter_
self._request_hook = request_hook
self._response_hook = response_hook
def _start_span(self, method, **kwargs):
service, meth = method.lstrip("/").split("/", 1)
@ -99,6 +117,8 @@ class OpenTelemetryClientInterceptor(
if isinstance(result, tuple):
response = result[0]
rpc_info.response = response
if self._response_hook:
self._call_response_hook(span, response)
span.end()
return result
@ -127,7 +147,8 @@ class OpenTelemetryClientInterceptor(
timeout=client_info.timeout,
request=request,
)
if self._request_hook:
self._call_request_hook(span, request)
result = invoker(request, metadata)
except Exception as exc:
if isinstance(exc, grpc.RpcError):
@ -148,6 +169,16 @@ class OpenTelemetryClientInterceptor(
span.end()
return self._trace_result(span, rpc_info, result)
def _call_request_hook(self, span, request):
if not callable(self._request_hook):
return
_safe_invoke(self._request_hook, span, request)
def _call_response_hook(self, span, response):
if not callable(self._response_hook):
return
_safe_invoke(self._response_hook, span, response)
def intercept_unary(self, request, metadata, client_info, invoker):
if self._filter is not None and not self._filter(client_info):
return invoker(request, metadata)

View File

@ -0,0 +1,120 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from unittest import IsolatedAsyncioTestCase
except ImportError:
# unittest.IsolatedAsyncioTestCase was introduced in Python 3.8. It's use
# simplifies the following tests. Without it, the amount of test code
# increases significantly, with most of the additional code handling
# the asyncio set up.
from unittest import TestCase
class IsolatedAsyncioTestCase(TestCase):
def run(self, result=None):
self.skipTest(
"This test requires Python 3.8 for unittest.IsolatedAsyncioTestCase"
)
import grpc
import pytest
from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorClient
from opentelemetry.test.test_base import TestBase
from ._aio_client import simple_method
from ._server import create_test_server
from .protobuf import test_server_pb2_grpc # pylint: disable=no-name-in-module
def request_hook(span, request):
span.set_attribute("request_data", request.request_data)
def response_hook(span, response):
span.set_attribute("response_data", response)
def request_hook_with_exception(_span, _request):
raise Exception()
def response_hook_with_exception(_span, _response):
raise Exception()
@pytest.mark.asyncio
class TestAioClientInterceptorWithHooks(TestBase, IsolatedAsyncioTestCase):
def setUp(self):
super().setUp()
self.server = create_test_server(25565)
self.server.start()
def tearDown(self):
super().tearDown()
self.server.stop(None)
async def test_request_and_response_hooks(self):
instrumentor = GrpcAioInstrumentorClient()
try:
instrumentor.instrument(
request_hook=request_hook,
response_hook=response_hook,
)
channel = grpc.aio.insecure_channel(
"localhost:25565",
)
stub = test_server_pb2_grpc.GRPCTestServerStub(channel)
response = await simple_method(stub)
assert response.response_data == "data"
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertIn("request_data", span.attributes)
self.assertEqual(span.attributes["request_data"], "data")
self.assertIn("response_data", span.attributes)
self.assertEqual(span.attributes["response_data"], "")
finally:
instrumentor.uninstrument()
async def test_hooks_with_exception(self):
instrumentor = GrpcAioInstrumentorClient()
try:
instrumentor.instrument(
request_hook=request_hook_with_exception,
response_hook=response_hook_with_exception,
)
channel = grpc.aio.insecure_channel(
"localhost:25565",
)
stub = test_server_pb2_grpc.GRPCTestServerStub(channel)
response = await simple_method(stub)
assert response.response_data == "data"
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod")
finally:
instrumentor.uninstrument()

View File

@ -0,0 +1,149 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import grpc
from tests.protobuf import ( # pylint: disable=no-name-in-module
test_server_pb2_grpc,
)
from opentelemetry import trace
from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient
from opentelemetry.test.test_base import TestBase
from ._client import simple_method
from ._server import create_test_server
# User defined interceptor. Is used in the tests along with the opentelemetry client interceptor.
class Interceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
def __init__(self):
pass
def intercept_unary_unary(
self, continuation, client_call_details, request
):
return self._intercept_call(continuation, client_call_details, request)
def intercept_unary_stream(
self, continuation, client_call_details, request
):
return self._intercept_call(continuation, client_call_details, request)
def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(
continuation, client_call_details, request_iterator
)
def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(
continuation, client_call_details, request_iterator
)
@staticmethod
def _intercept_call(
continuation, client_call_details, request_or_iterator
):
return continuation(client_call_details, request_or_iterator)
def request_hook(span, request):
span.set_attribute("request_data", request.request_data)
def response_hook(span, response):
span.set_attribute("response_data", response.response_data)
def request_hook_with_exception(_span, _request):
raise Exception()
def response_hook_with_exception(_span, _response):
raise Exception()
class TestHooks(TestBase):
def setUp(self):
super().setUp()
self.server = create_test_server(25565)
self.server.start()
# use a user defined interceptor along with the opentelemetry client interceptor
self.interceptors = [Interceptor()]
def tearDown(self):
super().tearDown()
self.server.stop(None)
def test_response_and_request_hooks(self):
instrumentor = GrpcInstrumentorClient()
try:
instrumentor.instrument(
request_hook=request_hook,
response_hook=response_hook,
)
channel = grpc.insecure_channel("localhost:25565")
channel = grpc.intercept_channel(channel, *self.interceptors)
stub = test_server_pb2_grpc.GRPCTestServerStub(channel)
simple_method(stub)
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod")
self.assertIs(span.kind, trace.SpanKind.CLIENT)
self.assertIn("request_data", span.attributes)
self.assertEqual(span.attributes["request_data"], "data")
self.assertIn("response_data", span.attributes)
self.assertEqual(span.attributes["response_data"], "data")
finally:
instrumentor.uninstrument()
def test_hooks_with_exception(self):
instrumentor = GrpcInstrumentorClient()
try:
instrumentor.instrument(
request_hook=request_hook_with_exception,
response_hook=response_hook_with_exception,
)
channel = grpc.insecure_channel("localhost:25565")
channel = grpc.intercept_channel(channel, *self.interceptors)
stub = test_server_pb2_grpc.GRPCTestServerStub(channel)
simple_method(stub)
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod")
self.assertIs(span.kind, trace.SpanKind.CLIENT)
finally:
instrumentor.uninstrument()