diff --git a/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py b/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py index 7ad144057..699085af6 100644 --- a/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py +++ b/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py @@ -41,15 +41,20 @@ def _decorate_callback( ) -> Any: if not properties: properties = BasicProperties() + if properties.headers is None: + properties.headers = {} + ctx = propagate.extract(properties.headers, getter=_pika_getter) + if not ctx: + ctx = context.get_current() span = _get_span( tracer, channel, properties, task_name=task_name, + ctx=ctx, operation=MessagingOperationValues.RECEIVE, ) with trace.use_span(span, end_on_exit=True): - propagate.inject(properties.headers) retval = callback(channel, method, properties, body) return retval @@ -70,11 +75,13 @@ def _decorate_basic_publish( ) -> Any: if not properties: properties = BasicProperties() + ctx = context.get_current() span = _get_span( tracer, channel, properties, task_name="(temporary)", + ctx=ctx, operation=None, ) if not span: @@ -97,11 +104,9 @@ def _get_span( channel: Channel, properties: BasicProperties, task_name: str, + ctx: context.Context, operation: Optional[MessagingOperationValues] = None, ) -> Optional[Span]: - if properties.headers is None: - properties.headers = {} - ctx = propagate.extract(properties.headers, getter=_pika_getter) if context.get_value("suppress_instrumentation") or context.get_value( _SUPPRESS_INSTRUMENTATION_KEY ): diff --git a/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py b/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py index d8ce6d536..fce5d49d1 100644 --- a/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py +++ b/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py @@ -23,9 +23,7 @@ class TestUtils(TestCase): @mock.patch("opentelemetry.context.get_value") @mock.patch("opentelemetry.instrumentation.pika.utils._generate_span_name") @mock.patch("opentelemetry.instrumentation.pika.utils._enrich_span") - @mock.patch("opentelemetry.propagate.extract") def test_get_span( - extract: mock.MagicMock, enrich_span: mock.MagicMock, generate_span_name: mock.MagicMock, get_value: mock.MagicMock, @@ -35,21 +33,19 @@ class TestUtils(TestCase): properties = mock.MagicMock() task_name = "test.test" get_value.return_value = None - _ = utils._get_span(tracer, channel, properties, task_name) - extract.assert_called_once() + ctx = mock.MagicMock() + _ = utils._get_span(tracer, channel, properties, task_name, ctx) generate_span_name.assert_called_once() tracer.start_span.assert_called_once_with( - context=extract.return_value, name=generate_span_name.return_value + context=ctx, name=generate_span_name.return_value ) enrich_span.assert_called_once() @mock.patch("opentelemetry.context.get_value") @mock.patch("opentelemetry.instrumentation.pika.utils._generate_span_name") @mock.patch("opentelemetry.instrumentation.pika.utils._enrich_span") - @mock.patch("opentelemetry.propagate.extract") def test_get_span_suppressed( self, - extract: mock.MagicMock, enrich_span: mock.MagicMock, generate_span_name: mock.MagicMock, get_value: mock.MagicMock, @@ -59,10 +55,11 @@ class TestUtils(TestCase): properties = mock.MagicMock() task_name = "test.test" get_value.return_value = True - span = utils._get_span(tracer, channel, properties, task_name) + ctx = mock.MagicMock() + span = utils._get_span(tracer, channel, properties, task_name, ctx) self.assertEqual(span, None) - extract.assert_called_once() generate_span_name.assert_not_called() + enrich_span.assert_not_called() def test_generate_span_name_no_operation(self) -> None: task_name = "test.test"