Bugfix: Pika basicConsume context propagation (#766)

This commit is contained in:
oxeye-yuval
2021-10-21 20:50:52 +03:00
committed by GitHub
parent ae7a415f47
commit 3ff06da2fb
5 changed files with 103 additions and 79 deletions

View File

@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `opentelemetry-instrumentation-asgi` now explicitly depends on asgiref as it uses the package instead of instrumenting it. - `opentelemetry-instrumentation-asgi` now explicitly depends on asgiref as it uses the package instead of instrumenting it.
([#765](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/765)) ([#765](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/765))
- `opentelemetry-instrumentation-pika` now propagates context to basic_consume callback
([#766](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/766))
## [1.6.2-0.25b2](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.6.2-0.25b2) - 2021-10-19 ## [1.6.2-0.25b2](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.6.2-0.25b2) - 2021-10-19

View File

@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from logging import getLogger from logging import getLogger
from typing import Any, Callable, Collection, Dict, Optional from typing import Any, Collection, Dict, Optional
import wrapt import wrapt
from pika.adapters import BlockingConnection from pika.adapters import BlockingConnection
from pika.channel import Channel from pika.adapters.blocking_connection import BlockingChannel
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
@ -35,18 +35,25 @@ _FUNCTIONS_TO_UNINSTRUMENT = ["basic_publish"]
class PikaInstrumentor(BaseInstrumentor): # type: ignore class PikaInstrumentor(BaseInstrumentor): # type: ignore
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
@staticmethod @staticmethod
def _instrument_consumers( def _instrument_blocking_channel_consumers(
consumers_dict: Dict[str, Callable[..., Any]], tracer: Tracer channel: BlockingChannel, tracer: Tracer
) -> Any: ) -> Any:
for key, callback in consumers_dict.items(): for consumer_tag, consumer_info in channel._consumer_infos.items():
decorated_callback = utils._decorate_callback( decorated_callback = utils._decorate_callback(
callback, tracer, key consumer_info.on_message_callback, tracer, consumer_tag
) )
setattr(decorated_callback, "_original_callback", callback)
consumers_dict[key] = decorated_callback setattr(
decorated_callback,
"_original_callback",
consumer_info.on_message_callback,
)
consumer_info.on_message_callback = decorated_callback
@staticmethod @staticmethod
def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None: def _instrument_basic_publish(
channel: BlockingChannel, tracer: Tracer
) -> None:
original_function = getattr(channel, "basic_publish") original_function = getattr(channel, "basic_publish")
decorated_function = utils._decorate_basic_publish( decorated_function = utils._decorate_basic_publish(
original_function, channel, tracer original_function, channel, tracer
@ -57,13 +64,13 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
@staticmethod @staticmethod
def _instrument_channel_functions( def _instrument_channel_functions(
channel: Channel, tracer: Tracer channel: BlockingChannel, tracer: Tracer
) -> None: ) -> None:
if hasattr(channel, "basic_publish"): if hasattr(channel, "basic_publish"):
PikaInstrumentor._instrument_basic_publish(channel, tracer) PikaInstrumentor._instrument_basic_publish(channel, tracer)
@staticmethod @staticmethod
def _uninstrument_channel_functions(channel: Channel) -> None: def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
for function_name in _FUNCTIONS_TO_UNINSTRUMENT: for function_name in _FUNCTIONS_TO_UNINSTRUMENT:
if not hasattr(channel, function_name): if not hasattr(channel, function_name):
continue continue
@ -73,8 +80,10 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
unwrap(channel, "basic_consume") unwrap(channel, "basic_consume")
@staticmethod @staticmethod
# Make sure that the spans are created inside hash them set as parent and not as brothers
def instrument_channel( def instrument_channel(
channel: Channel, tracer_provider: Optional[TracerProvider] = None, channel: BlockingChannel,
tracer_provider: Optional[TracerProvider] = None,
) -> None: ) -> None:
if not hasattr(channel, "_is_instrumented_by_opentelemetry"): if not hasattr(channel, "_is_instrumented_by_opentelemetry"):
channel._is_instrumented_by_opentelemetry = False channel._is_instrumented_by_opentelemetry = False
@ -84,18 +93,14 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
) )
return return
tracer = trace.get_tracer(__name__, __version__, tracer_provider) tracer = trace.get_tracer(__name__, __version__, tracer_provider)
if not hasattr(channel, "_impl"): PikaInstrumentor._instrument_blocking_channel_consumers(
_LOG.error("Could not find implementation for provided channel!") channel, tracer
return )
if channel._impl._consumers:
PikaInstrumentor._instrument_consumers(
channel._impl._consumers, tracer
)
PikaInstrumentor._decorate_basic_consume(channel, tracer) PikaInstrumentor._decorate_basic_consume(channel, tracer)
PikaInstrumentor._instrument_channel_functions(channel, tracer) PikaInstrumentor._instrument_channel_functions(channel, tracer)
@staticmethod @staticmethod
def uninstrument_channel(channel: Channel) -> None: def uninstrument_channel(channel: BlockingChannel) -> None:
if ( if (
not hasattr(channel, "_is_instrumented_by_opentelemetry") not hasattr(channel, "_is_instrumented_by_opentelemetry")
or not channel._is_instrumented_by_opentelemetry or not channel._is_instrumented_by_opentelemetry
@ -104,12 +109,12 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
"Attempting to uninstrument Pika channel while already uninstrumented!" "Attempting to uninstrument Pika channel while already uninstrumented!"
) )
return return
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!") for consumers_tag, client_info in channel._consumer_infos.items():
return if hasattr(client_info.on_message_callback, "_original_callback"):
for key, callback in channel._impl._consumers.items(): channel._consumer_infos[
if hasattr(callback, "_original_callback"): consumers_tag
channel._impl._consumers[key] = callback._original_callback ] = client_info.on_message_callback._original_callback
PikaInstrumentor._uninstrument_channel_functions(channel) PikaInstrumentor._uninstrument_channel_functions(channel)
def _decorate_channel_function( def _decorate_channel_function(
@ -123,28 +128,15 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore
wrapt.wrap_function_wrapper(BlockingConnection, "channel", wrapper) wrapt.wrap_function_wrapper(BlockingConnection, "channel", wrapper)
@staticmethod @staticmethod
def _decorate_basic_consume(channel, tracer: Optional[Tracer]) -> None: def _decorate_basic_consume(
channel: BlockingChannel, tracer: Optional[Tracer]
) -> None:
def wrapper(wrapped, instance, args, kwargs): def wrapper(wrapped, instance, args, kwargs):
if not hasattr(channel, "_impl"):
_LOG.error(
"Could not find implementation for provided channel!"
)
return wrapped(*args, **kwargs)
current_keys = set(channel._impl._consumers.keys())
return_value = wrapped(*args, **kwargs) return_value = wrapped(*args, **kwargs)
new_key_list = list(
set(channel._impl._consumers.keys()) - current_keys PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
) )
if not new_key_list:
_LOG.error("Could not find added callback")
return return_value
new_key = new_key_list[0]
callback = channel._impl._consumers[new_key]
decorated_callback = utils._decorate_callback(
callback, tracer, new_key
)
setattr(decorated_callback, "_original_callback", callback)
channel._impl._consumers[new_key] = decorated_callback
return return_value return return_value
wrapt.wrap_function_wrapper(channel, "basic_consume", wrapper) wrapt.wrap_function_wrapper(channel, "basic_consume", wrapper)

View File

@ -46,17 +46,23 @@ def _decorate_callback(
ctx = propagate.extract(properties.headers, getter=_pika_getter) ctx = propagate.extract(properties.headers, getter=_pika_getter)
if not ctx: if not ctx:
ctx = context.get_current() ctx = context.get_current()
token = context.attach(ctx)
span = _get_span( span = _get_span(
tracer, tracer,
channel, channel,
properties, properties,
destination=method.exchange
if method.exchange
else method.routing_key,
span_kind=SpanKind.CONSUMER, span_kind=SpanKind.CONSUMER,
task_name=task_name, task_name=task_name,
ctx=ctx,
operation=MessagingOperationValues.RECEIVE, operation=MessagingOperationValues.RECEIVE,
) )
with trace.use_span(span, end_on_exit=True): try:
retval = callback(channel, method, properties, body) with trace.use_span(span, end_on_exit=True):
retval = callback(channel, method, properties, body)
finally:
context.detach(token)
return retval return retval
return decorated_callback return decorated_callback
@ -78,14 +84,13 @@ def _decorate_basic_publish(
properties = BasicProperties(headers={}) properties = BasicProperties(headers={})
if properties.headers is None: if properties.headers is None:
properties.headers = {} properties.headers = {}
ctx = context.get_current()
span = _get_span( span = _get_span(
tracer, tracer,
channel, channel,
properties, properties,
destination=exchange if exchange else routing_key,
span_kind=SpanKind.PRODUCER, span_kind=SpanKind.PRODUCER,
task_name="(temporary)", task_name="(temporary)",
ctx=ctx,
operation=None, operation=None,
) )
if not span: if not span:
@ -108,8 +113,8 @@ def _get_span(
channel: Channel, channel: Channel,
properties: BasicProperties, properties: BasicProperties,
task_name: str, task_name: str,
destination: str,
span_kind: SpanKind, span_kind: SpanKind,
ctx: context.Context,
operation: Optional[MessagingOperationValues] = None, operation: Optional[MessagingOperationValues] = None,
) -> Optional[Span]: ) -> Optional[Span]:
if context.get_value("suppress_instrumentation") or context.get_value( if context.get_value("suppress_instrumentation") or context.get_value(
@ -118,9 +123,7 @@ def _get_span(
return None return None
task_name = properties.type if properties.type else task_name task_name = properties.type if properties.type else task_name
span = tracer.start_span( span = tracer.start_span(
context=ctx, name=_generate_span_name(destination, operation), kind=span_kind,
name=_generate_span_name(task_name, operation),
kind=span_kind,
) )
if span.is_recording(): if span.is_recording():
_enrich_span(span, channel, properties, task_name, operation) _enrich_span(span, channel, properties, task_name, operation)

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from unittest import TestCase, mock from unittest import TestCase, mock
from pika.adapters import BaseConnection, BlockingConnection from pika.adapters import BlockingConnection
from pika.channel import Channel from pika.channel import Channel
from wrapt import BoundFunctionWrapper from wrapt import BoundFunctionWrapper
@ -24,9 +24,10 @@ from opentelemetry.trace import Tracer
class TestPika(TestCase): class TestPika(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.channel = mock.MagicMock(spec=Channel) self.channel = mock.MagicMock(spec=Channel)
self.channel._impl = mock.MagicMock(spec=BaseConnection) consumer_info = mock.MagicMock()
consumer_info.on_message_callback = mock.MagicMock()
self.channel._consumer_infos = {"consumer-tag": consumer_info}
self.mock_callback = mock.MagicMock() self.mock_callback = mock.MagicMock()
self.channel._impl._consumers = {"mock_key": self.mock_callback}
def test_instrument_api(self) -> None: def test_instrument_api(self) -> None:
instrumentation = PikaInstrumentor() instrumentation = PikaInstrumentor()
@ -49,11 +50,11 @@ class TestPika(TestCase):
"opentelemetry.instrumentation.pika.PikaInstrumentor._decorate_basic_consume" "opentelemetry.instrumentation.pika.PikaInstrumentor._decorate_basic_consume"
) )
@mock.patch( @mock.patch(
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_consumers" "opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_blocking_channel_consumers"
) )
def test_instrument( def test_instrument(
self, self,
instrument_consumers: mock.MagicMock, instrument_blocking_channel_consumers: mock.MagicMock,
instrument_basic_consume: mock.MagicMock, instrument_basic_consume: mock.MagicMock,
instrument_channel_functions: mock.MagicMock, instrument_channel_functions: mock.MagicMock,
): ):
@ -61,7 +62,7 @@ class TestPika(TestCase):
assert hasattr( assert hasattr(
self.channel, "_is_instrumented_by_opentelemetry" self.channel, "_is_instrumented_by_opentelemetry"
), "channel is not marked as instrumented!" ), "channel is not marked as instrumented!"
instrument_consumers.assert_called_once() instrument_blocking_channel_consumers.assert_called_once()
instrument_basic_consume.assert_called_once() instrument_basic_consume.assert_called_once()
instrument_channel_functions.assert_called_once() instrument_channel_functions.assert_called_once()
@ -71,18 +72,18 @@ class TestPika(TestCase):
) -> None: ) -> None:
tracer = mock.MagicMock(spec=Tracer) tracer = mock.MagicMock(spec=Tracer)
expected_decoration_calls = [ expected_decoration_calls = [
mock.call(value, tracer, key) mock.call(value.on_message_callback, tracer, key)
for key, value in self.channel._impl._consumers.items() for key, value in self.channel._consumer_infos.items()
] ]
PikaInstrumentor._instrument_consumers( PikaInstrumentor._instrument_blocking_channel_consumers(
self.channel._impl._consumers, tracer self.channel, tracer
) )
decorate_callback.assert_has_calls( decorate_callback.assert_has_calls(
calls=expected_decoration_calls, any_order=True calls=expected_decoration_calls, any_order=True
) )
assert all( assert all(
hasattr(callback, "_original_callback") hasattr(callback, "_original_callback")
for callback in self.channel._impl._consumers.values() for callback in self.channel._consumer_infos.values()
) )
@mock.patch( @mock.patch(

View File

@ -38,15 +38,15 @@ class TestUtils(TestCase):
channel = mock.MagicMock() channel = mock.MagicMock()
properties = mock.MagicMock() properties = mock.MagicMock()
task_name = "test.test" task_name = "test.test"
destination = "myqueue"
span_kind = mock.MagicMock(spec=SpanKind) span_kind = mock.MagicMock(spec=SpanKind)
get_value.return_value = None get_value.return_value = None
ctx = mock.MagicMock()
_ = utils._get_span( _ = utils._get_span(
tracer, channel, properties, task_name, span_kind, ctx tracer, channel, properties, task_name, destination, span_kind
) )
generate_span_name.assert_called_once() generate_span_name.assert_called_once()
tracer.start_span.assert_called_once_with( tracer.start_span.assert_called_once_with(
context=ctx, name=generate_span_name.return_value, kind=span_kind name=generate_span_name.return_value, kind=span_kind
) )
enrich_span.assert_called_once() enrich_span.assert_called_once()
@ -185,6 +185,7 @@ class TestUtils(TestCase):
tracer = mock.MagicMock() tracer = mock.MagicMock()
channel = mock.MagicMock(spec=Channel) channel = mock.MagicMock(spec=Channel)
method = mock.MagicMock(spec=Basic.Deliver) method = mock.MagicMock(spec=Basic.Deliver)
method.exchange = "test_exchange"
properties = mock.MagicMock() properties = mock.MagicMock()
mock_body = b"mock_body" mock_body = b"mock_body"
decorated_callback = utils._decorate_callback( decorated_callback = utils._decorate_callback(
@ -198,9 +199,9 @@ class TestUtils(TestCase):
tracer, tracer,
channel, channel,
properties, properties,
destination=method.exchange,
span_kind=SpanKind.CONSUMER, span_kind=SpanKind.CONSUMER,
task_name=mock_task_name, task_name=mock_task_name,
ctx=extract.return_value,
operation=MessagingOperationValues.RECEIVE, operation=MessagingOperationValues.RECEIVE,
) )
use_span.assert_called_once_with( use_span.assert_called_once_with(
@ -213,35 +214,33 @@ class TestUtils(TestCase):
@mock.patch("opentelemetry.instrumentation.pika.utils._get_span") @mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject") @mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.context.get_current")
@mock.patch("opentelemetry.trace.use_span") @mock.patch("opentelemetry.trace.use_span")
def test_decorate_basic_publish( def test_decorate_basic_publish(
self, self,
use_span: mock.MagicMock, use_span: mock.MagicMock,
get_current: mock.MagicMock,
inject: mock.MagicMock, inject: mock.MagicMock,
get_span: mock.MagicMock, get_span: mock.MagicMock,
) -> None: ) -> None:
callback = mock.MagicMock() callback = mock.MagicMock()
tracer = mock.MagicMock() tracer = mock.MagicMock()
channel = mock.MagicMock(spec=Channel) channel = mock.MagicMock(spec=Channel)
method = mock.MagicMock(spec=Basic.Deliver) exchange_name = "test-exchange"
routing_key = "test-routing-key"
properties = mock.MagicMock() properties = mock.MagicMock()
mock_body = b"mock_body" mock_body = b"mock_body"
decorated_basic_publish = utils._decorate_basic_publish( decorated_basic_publish = utils._decorate_basic_publish(
callback, channel, tracer callback, channel, tracer
) )
retval = decorated_basic_publish( retval = decorated_basic_publish(
channel, method, mock_body, properties exchange_name, routing_key, mock_body, properties
) )
get_current.assert_called_once()
get_span.assert_called_once_with( get_span.assert_called_once_with(
tracer, tracer,
channel, channel,
properties, properties,
destination=exchange_name,
span_kind=SpanKind.PRODUCER, span_kind=SpanKind.PRODUCER,
task_name="(temporary)", task_name="(temporary)",
ctx=get_current.return_value,
operation=None, operation=None,
) )
use_span.assert_called_once_with( use_span.assert_called_once_with(
@ -250,20 +249,18 @@ class TestUtils(TestCase):
get_span.return_value.is_recording.assert_called_once() get_span.return_value.is_recording.assert_called_once()
inject.assert_called_once_with(properties.headers) inject.assert_called_once_with(properties.headers)
callback.assert_called_once_with( callback.assert_called_once_with(
channel, method, mock_body, properties, False exchange_name, routing_key, mock_body, properties, False
) )
self.assertEqual(retval, callback.return_value) self.assertEqual(retval, callback.return_value)
@mock.patch("opentelemetry.instrumentation.pika.utils._get_span") @mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject") @mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.context.get_current")
@mock.patch("opentelemetry.trace.use_span") @mock.patch("opentelemetry.trace.use_span")
@mock.patch("pika.spec.BasicProperties.__new__") @mock.patch("pika.spec.BasicProperties.__new__")
def test_decorate_basic_publish_no_properties( def test_decorate_basic_publish_no_properties(
self, self,
basic_properties: mock.MagicMock, basic_properties: mock.MagicMock,
use_span: mock.MagicMock, use_span: mock.MagicMock,
get_current: mock.MagicMock,
inject: mock.MagicMock, inject: mock.MagicMock,
get_span: mock.MagicMock, get_span: mock.MagicMock,
) -> None: ) -> None:
@ -277,10 +274,39 @@ class TestUtils(TestCase):
) )
retval = decorated_basic_publish(channel, method, body=mock_body) retval = decorated_basic_publish(channel, method, body=mock_body)
basic_properties.assert_called_once_with(BasicProperties, headers={}) basic_properties.assert_called_once_with(BasicProperties, headers={})
get_current.assert_called_once()
use_span.assert_called_once_with( use_span.assert_called_once_with(
get_span.return_value, end_on_exit=True get_span.return_value, end_on_exit=True
) )
get_span.return_value.is_recording.assert_called_once() get_span.return_value.is_recording.assert_called_once()
inject.assert_called_once_with(basic_properties.return_value.headers) inject.assert_called_once_with(basic_properties.return_value.headers)
self.assertEqual(retval, callback.return_value) self.assertEqual(retval, callback.return_value)
@staticmethod
@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
def test_decorate_basic_publish_published_message_to_queue(
get_span: mock.MagicMock,
) -> None:
callback = mock.MagicMock()
tracer = mock.MagicMock()
channel = mock.MagicMock(spec=Channel)
exchange_name = ""
routing_key = "test-routing-key"
properties = mock.MagicMock()
mock_body = b"mock_body"
decorated_basic_publish = utils._decorate_basic_publish(
callback, channel, tracer
)
decorated_basic_publish(
exchange_name, routing_key, mock_body, properties
)
get_span.assert_called_once_with(
tracer,
channel,
properties,
destination=routing_key,
span_kind=SpanKind.PRODUCER,
task_name="(temporary)",
operation=None,
)