From c3e9f75fb9611d1e2842b4f3fe11b910d7d2ae0c Mon Sep 17 00:00:00 2001 From: Malcolm Rebughini <9681621+malcolmrebughini@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:33:29 -0700 Subject: [PATCH] =?UTF-8?q?fix(opentelemetry-instrumentation-celery):=20at?= =?UTF-8?q?tach=20incoming=20context=20on=E2=80=A6=20(#2385)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(opentelemetry-instrumentation-celery): attach incoming context on _trace_prerun * docs(CHANGELOG): add entry for fix #2385 * fix(opentelemetry-instrumentation-celery): detach context after task is run * test(opentelemetry-instrumentation-celery): add context utils tests * fix(opentelemetry-instrumentation-celery): remove duplicated signal registration * refactor(opentelemetry-instrumentation-celery): fix lint issues * refactor(opentelemetry-instrumentation-celery): fix types and tests for python 3.8 * refactor(opentelemetry-instrumentation-celery): fix lint issues * refactor(opentelemetry-instrumentation-celery): fix lint issues * fix(opentelemetry-instrumentation-celery): attach context only if it is not None * refactor(opentelemetry-instrumentation-celery): fix lint issues --- CHANGELOG.md | 8 +-- .../instrumentation/celery/__init__.py | 48 +++++++++++---- .../instrumentation/celery/utils.py | 59 ++++++++++++------- .../tests/celery_test_tasks.py | 7 +++ .../tests/test_tasks.py | 19 +++++- .../tests/test_utils.py | 22 +++---- 6 files changed, 114 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 237a73786..0984efa3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2756](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2756)) - `opentelemetry-instrumentation-aws-lambda` Fixing w3c baggage support ([#2589](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2589)) +- `opentelemetry-instrumentation-celery` propagates baggage + ([#2385](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2385)) ## Version 1.26.0/0.47b0 (2024-07-23) @@ -119,10 +121,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2610](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2610)) - `opentelemetry-instrumentation-asgi` Bugfix: Middleware did not set status code attribute on duration metrics for non-recording spans. ([#2627](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2627)) -<<<<<<< HEAD -- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9 ([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751)) -======= ->>>>>>> 5a623233 (Changelog update) +- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9 + ([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751)) ## Version 1.25.0/0.46b0 (2024-05-31) diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py index 10ccca127..39b3bffe6 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py @@ -67,6 +67,7 @@ from billiard import VERSION from billiard.einfo import ExceptionInfo from celery import signals # pylint: disable=no-name-in-module +from opentelemetry import context as context_api from opentelemetry import trace from opentelemetry.instrumentation.celery import utils from opentelemetry.instrumentation.celery.package import _instruments @@ -169,6 +170,7 @@ class CeleryInstrumentor(BaseInstrumentor): self.update_task_duration_time(task_id) request = task.request tracectx = extract(request, getter=celery_getter) or None + token = context_api.attach(tracectx) if tracectx is not None else None logger.debug("prerun signal start task_id=%s", task_id) @@ -179,7 +181,7 @@ class CeleryInstrumentor(BaseInstrumentor): activation = trace.use_span(span, end_on_exit=True) activation.__enter__() # pylint: disable=E1101 - utils.attach_span(task, task_id, (span, activation)) + utils.attach_context(task, task_id, span, activation, token) def _trace_postrun(self, *args, **kwargs): task = utils.retrieve_task(kwargs) @@ -191,11 +193,14 @@ class CeleryInstrumentor(BaseInstrumentor): logger.debug("postrun signal task_id=%s", task_id) # retrieve and finish the Span - span, activation = utils.retrieve_span(task, task_id) - if span is None: + ctx = utils.retrieve_context(task, task_id) + + if ctx is None: logger.warning("no existing span found for task_id=%s", task_id) return + span, activation, token = ctx + # request context tags if span.is_recording(): span.set_attribute(_TASK_TAG_KEY, _TASK_RUN) @@ -204,10 +209,11 @@ class CeleryInstrumentor(BaseInstrumentor): span.set_attribute(_TASK_NAME_KEY, task.name) activation.__exit__(None, None, None) - utils.detach_span(task, task_id) + utils.detach_context(task, task_id) self.update_task_duration_time(task_id) labels = {"task": task.name, "worker": task.request.hostname} self._record_histograms(task_id, labels) + context_api.detach(token) def _trace_before_publish(self, *args, **kwargs): task = utils.retrieve_task_from_sender(kwargs) @@ -238,7 +244,9 @@ class CeleryInstrumentor(BaseInstrumentor): activation = trace.use_span(span, end_on_exit=True) activation.__enter__() # pylint: disable=E1101 - utils.attach_span(task, task_id, (span, activation), is_publish=True) + utils.attach_context( + task, task_id, span, activation, None, is_publish=True + ) headers = kwargs.get("headers") if headers: @@ -253,13 +261,16 @@ class CeleryInstrumentor(BaseInstrumentor): return # retrieve and finish the Span - _, activation = utils.retrieve_span(task, task_id, is_publish=True) - if activation is None: + ctx = utils.retrieve_context(task, task_id, is_publish=True) + + if ctx is None: logger.warning("no existing span found for task_id=%s", task_id) return + _, activation, _ = ctx + activation.__exit__(None, None, None) # pylint: disable=E1101 - utils.detach_span(task, task_id, is_publish=True) + utils.detach_context(task, task_id, is_publish=True) @staticmethod def _trace_failure(*args, **kwargs): @@ -269,9 +280,14 @@ class CeleryInstrumentor(BaseInstrumentor): if task is None or task_id is None: return - # retrieve and pass exception info to activation - span, _ = utils.retrieve_span(task, task_id) - if span is None or not span.is_recording(): + ctx = utils.retrieve_context(task, task_id) + + if ctx is None: + return + + span, _, _ = ctx + + if not span.is_recording(): return status_kwargs = {"status_code": StatusCode.ERROR} @@ -311,8 +327,14 @@ class CeleryInstrumentor(BaseInstrumentor): if task is None or task_id is None or reason is None: return - span, _ = utils.retrieve_span(task, task_id) - if span is None or not span.is_recording(): + ctx = utils.retrieve_context(task, task_id) + + if ctx is None: + return + + span, _, _ = ctx + + if not span.is_recording(): return # Add retry reason metadata to span diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py index 6f4f9cbc3..6af310df5 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py @@ -13,10 +13,13 @@ # limitations under the License. import logging +from typing import ContextManager, Optional, Tuple from celery import registry # pylint: disable=no-name-in-module +from celery.app.task import Task from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace import Span logger = logging.getLogger(__name__) @@ -81,10 +84,12 @@ def set_attributes_from_context(span, context): elif key == "delivery_info": # Get also destination from this routing_key = value.get("routing_key") + if routing_key is not None: span.set_attribute( SpanAttributes.MESSAGING_DESTINATION, routing_key ) + value = str(value) elif key == "id": @@ -114,11 +119,18 @@ def set_attributes_from_context(span, context): span.set_attribute(attribute_name, value) -def attach_span(task, task_id, span, is_publish=False): - """Helper to propagate a `Span` for the given `Task` instance. This - function uses a `dict` that stores the Span using the - `(task_id, is_publish)` as a key. This is useful when information must be - propagated from one Celery signal to another. +def attach_context( + task: Optional[Task], + task_id: str, + span: Span, + activation: ContextManager[Span], + token: Optional[object], + is_publish: bool = False, +) -> None: + """Helper to propagate a `Span`, `ContextManager` and context token + for the given `Task` instance. This function uses a `dict` that stores + the Span using the `(task_id, is_publish)` as a key. This is useful + when information must be propagated from one Celery signal to another. We use (task_id, is_publish) for the key to ensure that publishing a task from within another task does not cause any conflicts. @@ -134,36 +146,41 @@ def attach_span(task, task_id, span, is_publish=False): """ if task is None: return - span_dict = getattr(task, CTX_KEY, None) - if span_dict is None: - span_dict = {} - setattr(task, CTX_KEY, span_dict) - span_dict[(task_id, is_publish)] = span + ctx_dict = getattr(task, CTX_KEY, None) + + if ctx_dict is None: + ctx_dict = {} + setattr(task, CTX_KEY, ctx_dict) + + ctx_dict[(task_id, is_publish)] = (span, activation, token) -def detach_span(task, task_id, is_publish=False): - """Helper to remove a `Span` in a Celery task when it's propagated. - This function handles tasks where the `Span` is not attached. +def detach_context(task, task_id, is_publish=False) -> None: + """Helper to remove `Span`, `ContextManager` and context token in a + Celery task when it's propagated. + This function handles tasks where no values are attached to the `Task`. """ span_dict = getattr(task, CTX_KEY, None) if span_dict is None: return - # See note in `attach_span` for key info - span_dict.pop((task_id, is_publish), (None, None)) + # See note in `attach_context` for key info + span_dict.pop((task_id, is_publish), None) -def retrieve_span(task, task_id, is_publish=False): - """Helper to retrieve an active `Span` stored in a `Task` - instance +def retrieve_context( + task, task_id, is_publish=False +) -> Optional[Tuple[Span, ContextManager[Span], Optional[object]]]: + """Helper to retrieve an active `Span`, `ContextManager` and context token + stored in a `Task` instance """ span_dict = getattr(task, CTX_KEY, None) if span_dict is None: - return (None, None) + return None - # See note in `attach_span` for key info - return span_dict.get((task_id, is_publish), (None, None)) + # See note in `attach_context` for key info + return span_dict.get((task_id, is_publish), None) def retrieve_task(kwargs): diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py b/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py index 9ac78f6d8..af88f1d4c 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py @@ -14,6 +14,8 @@ from celery import Celery +from opentelemetry import baggage + class Config: result_backend = "rpc" @@ -36,3 +38,8 @@ def task_add(num_a, num_b): @app.task def task_raises(): raise CustomError("The task failed!") + + +@app.task +def task_returns_baggage(): + return dict(baggage.get_all()) diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_tasks.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_tasks.py index 3ac6a5a70..0dc668b11 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_tasks.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_tasks.py @@ -15,12 +15,13 @@ import threading import time +from opentelemetry import baggage, context from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.test.test_base import TestBase from opentelemetry.trace import SpanKind, StatusCode -from .celery_test_tasks import app, task_add, task_raises +from .celery_test_tasks import app, task_add, task_raises, task_returns_baggage class TestCeleryInstrumentation(TestBase): @@ -168,6 +169,22 @@ class TestCeleryInstrumentation(TestBase): spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 0) + def test_baggage(self): + CeleryInstrumentor().instrument() + + ctx = baggage.set_baggage("key", "value") + context.attach(ctx) + + task = task_returns_baggage.delay() + + timeout = time.time() + 60 * 1 # 1 minutes from now + while not task.ready(): + if time.time() > timeout: + break + time.sleep(0.05) + + self.assertEqual(task.result, {"key": "value"}) + class TestCelerySignatureTask(TestBase): def setUp(self): diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_utils.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_utils.py index 55aa3eec1..a2f6e4338 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_utils.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_utils.py @@ -167,8 +167,10 @@ class TestUtils(unittest.TestCase): # propagate and retrieve a Span task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext)) - utils.attach_span(fn_task, task_id, span) - span_after = utils.retrieve_span(fn_task, task_id) + utils.attach_context(fn_task, task_id, span, mock.Mock(), "") + ctx = utils.retrieve_context(fn_task, task_id) + self.assertIsNotNone(ctx) + span_after, _, _ = ctx self.assertIs(span, span_after) def test_span_delete(self): @@ -180,17 +182,19 @@ class TestUtils(unittest.TestCase): # propagate a Span task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext)) - utils.attach_span(fn_task, task_id, span) + utils.attach_context(fn_task, task_id, span, mock.Mock(), "") # delete the Span - utils.detach_span(fn_task, task_id) - self.assertEqual(utils.retrieve_span(fn_task, task_id), (None, None)) + utils.detach_context(fn_task, task_id) + self.assertEqual(utils.retrieve_context(fn_task, task_id), None) def test_optional_task_span_attach(self): task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext)) # assert this is is a no-aop - self.assertIsNone(utils.attach_span(None, task_id, span)) + self.assertIsNone( + utils.attach_context(None, task_id, span, mock.Mock(), "") + ) def test_span_delete_empty(self): # ensure detach_span doesn't raise an exception if span is not present @@ -201,10 +205,8 @@ class TestUtils(unittest.TestCase): # delete the Span task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" try: - utils.detach_span(fn_task, task_id) - self.assertEqual( - utils.retrieve_span(fn_task, task_id), (None, None) - ) + utils.detach_context(fn_task, task_id) + self.assertEqual(utils.retrieve_context(fn_task, task_id), None) except Exception as ex: # pylint: disable=broad-except self.fail(f"Exception was raised: {ex}")