fix(opentelemetry-instrumentation-celery): attach incoming context on… (#2385)

* fix(opentelemetry-instrumentation-celery): attach incoming context on _trace_prerun

* docs(CHANGELOG): add entry for fix #2385

* fix(opentelemetry-instrumentation-celery): detach context after task is run

* test(opentelemetry-instrumentation-celery): add context utils tests

* fix(opentelemetry-instrumentation-celery): remove duplicated signal registration

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* refactor(opentelemetry-instrumentation-celery): fix types and tests for python 3.8

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* fix(opentelemetry-instrumentation-celery): attach context only if it is not None

* refactor(opentelemetry-instrumentation-celery): fix lint issues
This commit is contained in:
Malcolm Rebughini
2024-08-01 14:33:29 -07:00
committed by GitHub
parent 4ea9e5a99a
commit c3e9f75fb9
6 changed files with 114 additions and 49 deletions

View File

@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2756](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2756)) ([#2756](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2756))
- `opentelemetry-instrumentation-aws-lambda` Fixing w3c baggage support - `opentelemetry-instrumentation-aws-lambda` Fixing w3c baggage support
([#2589](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2589)) ([#2589](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2589))
- `opentelemetry-instrumentation-celery` propagates baggage
([#2385](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2385))
## Version 1.26.0/0.47b0 (2024-07-23) ## Version 1.26.0/0.47b0 (2024-07-23)
@ -119,10 +121,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2610](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2610)) ([#2610](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2610))
- `opentelemetry-instrumentation-asgi` Bugfix: Middleware did not set status code attribute on duration metrics for non-recording spans. - `opentelemetry-instrumentation-asgi` Bugfix: Middleware did not set status code attribute on duration metrics for non-recording spans.
([#2627](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2627)) ([#2627](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2627))
<<<<<<< HEAD - `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9
- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9 ([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751)) ([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751))
=======
>>>>>>> 5a623233 (Changelog update)
## Version 1.25.0/0.46b0 (2024-05-31) ## Version 1.25.0/0.46b0 (2024-05-31)

View File

@ -67,6 +67,7 @@ from billiard import VERSION
from billiard.einfo import ExceptionInfo from billiard.einfo import ExceptionInfo
from celery import signals # pylint: disable=no-name-in-module from celery import signals # pylint: disable=no-name-in-module
from opentelemetry import context as context_api
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.instrumentation.celery import utils from opentelemetry.instrumentation.celery import utils
from opentelemetry.instrumentation.celery.package import _instruments from opentelemetry.instrumentation.celery.package import _instruments
@ -169,6 +170,7 @@ class CeleryInstrumentor(BaseInstrumentor):
self.update_task_duration_time(task_id) self.update_task_duration_time(task_id)
request = task.request request = task.request
tracectx = extract(request, getter=celery_getter) or None tracectx = extract(request, getter=celery_getter) or None
token = context_api.attach(tracectx) if tracectx is not None else None
logger.debug("prerun signal start task_id=%s", task_id) logger.debug("prerun signal start task_id=%s", task_id)
@ -179,7 +181,7 @@ class CeleryInstrumentor(BaseInstrumentor):
activation = trace.use_span(span, end_on_exit=True) activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101 activation.__enter__() # pylint: disable=E1101
utils.attach_span(task, task_id, (span, activation)) utils.attach_context(task, task_id, span, activation, token)
def _trace_postrun(self, *args, **kwargs): def _trace_postrun(self, *args, **kwargs):
task = utils.retrieve_task(kwargs) task = utils.retrieve_task(kwargs)
@ -191,11 +193,14 @@ class CeleryInstrumentor(BaseInstrumentor):
logger.debug("postrun signal task_id=%s", task_id) logger.debug("postrun signal task_id=%s", task_id)
# retrieve and finish the Span # retrieve and finish the Span
span, activation = utils.retrieve_span(task, task_id) ctx = utils.retrieve_context(task, task_id)
if span is None:
if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id) logger.warning("no existing span found for task_id=%s", task_id)
return return
span, activation, token = ctx
# request context tags # request context tags
if span.is_recording(): if span.is_recording():
span.set_attribute(_TASK_TAG_KEY, _TASK_RUN) span.set_attribute(_TASK_TAG_KEY, _TASK_RUN)
@ -204,10 +209,11 @@ class CeleryInstrumentor(BaseInstrumentor):
span.set_attribute(_TASK_NAME_KEY, task.name) span.set_attribute(_TASK_NAME_KEY, task.name)
activation.__exit__(None, None, None) activation.__exit__(None, None, None)
utils.detach_span(task, task_id) utils.detach_context(task, task_id)
self.update_task_duration_time(task_id) self.update_task_duration_time(task_id)
labels = {"task": task.name, "worker": task.request.hostname} labels = {"task": task.name, "worker": task.request.hostname}
self._record_histograms(task_id, labels) self._record_histograms(task_id, labels)
context_api.detach(token)
def _trace_before_publish(self, *args, **kwargs): def _trace_before_publish(self, *args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs) task = utils.retrieve_task_from_sender(kwargs)
@ -238,7 +244,9 @@ class CeleryInstrumentor(BaseInstrumentor):
activation = trace.use_span(span, end_on_exit=True) activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101 activation.__enter__() # pylint: disable=E1101
utils.attach_span(task, task_id, (span, activation), is_publish=True) utils.attach_context(
task, task_id, span, activation, None, is_publish=True
)
headers = kwargs.get("headers") headers = kwargs.get("headers")
if headers: if headers:
@ -253,13 +261,16 @@ class CeleryInstrumentor(BaseInstrumentor):
return return
# retrieve and finish the Span # retrieve and finish the Span
_, activation = utils.retrieve_span(task, task_id, is_publish=True) ctx = utils.retrieve_context(task, task_id, is_publish=True)
if activation is None:
if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id) logger.warning("no existing span found for task_id=%s", task_id)
return return
_, activation, _ = ctx
activation.__exit__(None, None, None) # pylint: disable=E1101 activation.__exit__(None, None, None) # pylint: disable=E1101
utils.detach_span(task, task_id, is_publish=True) utils.detach_context(task, task_id, is_publish=True)
@staticmethod @staticmethod
def _trace_failure(*args, **kwargs): def _trace_failure(*args, **kwargs):
@ -269,9 +280,14 @@ class CeleryInstrumentor(BaseInstrumentor):
if task is None or task_id is None: if task is None or task_id is None:
return return
# retrieve and pass exception info to activation ctx = utils.retrieve_context(task, task_id)
span, _ = utils.retrieve_span(task, task_id)
if span is None or not span.is_recording(): if ctx is None:
return
span, _, _ = ctx
if not span.is_recording():
return return
status_kwargs = {"status_code": StatusCode.ERROR} status_kwargs = {"status_code": StatusCode.ERROR}
@ -311,8 +327,14 @@ class CeleryInstrumentor(BaseInstrumentor):
if task is None or task_id is None or reason is None: if task is None or task_id is None or reason is None:
return return
span, _ = utils.retrieve_span(task, task_id) ctx = utils.retrieve_context(task, task_id)
if span is None or not span.is_recording():
if ctx is None:
return
span, _, _ = ctx
if not span.is_recording():
return return
# Add retry reason metadata to span # Add retry reason metadata to span

View File

@ -13,10 +13,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import ContextManager, Optional, Tuple
from celery import registry # pylint: disable=no-name-in-module from celery import registry # pylint: disable=no-name-in-module
from celery.app.task import Task
from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,10 +84,12 @@ def set_attributes_from_context(span, context):
elif key == "delivery_info": elif key == "delivery_info":
# Get also destination from this # Get also destination from this
routing_key = value.get("routing_key") routing_key = value.get("routing_key")
if routing_key is not None: if routing_key is not None:
span.set_attribute( span.set_attribute(
SpanAttributes.MESSAGING_DESTINATION, routing_key SpanAttributes.MESSAGING_DESTINATION, routing_key
) )
value = str(value) value = str(value)
elif key == "id": elif key == "id":
@ -114,11 +119,18 @@ def set_attributes_from_context(span, context):
span.set_attribute(attribute_name, value) span.set_attribute(attribute_name, value)
def attach_span(task, task_id, span, is_publish=False): def attach_context(
"""Helper to propagate a `Span` for the given `Task` instance. This task: Optional[Task],
function uses a `dict` that stores the Span using the task_id: str,
`(task_id, is_publish)` as a key. This is useful when information must be span: Span,
propagated from one Celery signal to another. activation: ContextManager[Span],
token: Optional[object],
is_publish: bool = False,
) -> None:
"""Helper to propagate a `Span`, `ContextManager` and context token
for the given `Task` instance. This function uses a `dict` that stores
the Span using the `(task_id, is_publish)` as a key. This is useful
when information must be propagated from one Celery signal to another.
We use (task_id, is_publish) for the key to ensure that publishing a We use (task_id, is_publish) for the key to ensure that publishing a
task from within another task does not cause any conflicts. task from within another task does not cause any conflicts.
@ -134,36 +146,41 @@ def attach_span(task, task_id, span, is_publish=False):
""" """
if task is None: if task is None:
return return
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
span_dict = {}
setattr(task, CTX_KEY, span_dict)
span_dict[(task_id, is_publish)] = span ctx_dict = getattr(task, CTX_KEY, None)
if ctx_dict is None:
ctx_dict = {}
setattr(task, CTX_KEY, ctx_dict)
ctx_dict[(task_id, is_publish)] = (span, activation, token)
def detach_span(task, task_id, is_publish=False): def detach_context(task, task_id, is_publish=False) -> None:
"""Helper to remove a `Span` in a Celery task when it's propagated. """Helper to remove `Span`, `ContextManager` and context token in a
This function handles tasks where the `Span` is not attached. Celery task when it's propagated.
This function handles tasks where no values are attached to the `Task`.
""" """
span_dict = getattr(task, CTX_KEY, None) span_dict = getattr(task, CTX_KEY, None)
if span_dict is None: if span_dict is None:
return return
# See note in `attach_span` for key info # See note in `attach_context` for key info
span_dict.pop((task_id, is_publish), (None, None)) span_dict.pop((task_id, is_publish), None)
def retrieve_span(task, task_id, is_publish=False): def retrieve_context(
"""Helper to retrieve an active `Span` stored in a `Task` task, task_id, is_publish=False
instance ) -> Optional[Tuple[Span, ContextManager[Span], Optional[object]]]:
"""Helper to retrieve an active `Span`, `ContextManager` and context token
stored in a `Task` instance
""" """
span_dict = getattr(task, CTX_KEY, None) span_dict = getattr(task, CTX_KEY, None)
if span_dict is None: if span_dict is None:
return (None, None) return None
# See note in `attach_span` for key info # See note in `attach_context` for key info
return span_dict.get((task_id, is_publish), (None, None)) return span_dict.get((task_id, is_publish), None)
def retrieve_task(kwargs): def retrieve_task(kwargs):

View File

@ -14,6 +14,8 @@
from celery import Celery from celery import Celery
from opentelemetry import baggage
class Config: class Config:
result_backend = "rpc" result_backend = "rpc"
@ -36,3 +38,8 @@ def task_add(num_a, num_b):
@app.task @app.task
def task_raises(): def task_raises():
raise CustomError("The task failed!") raise CustomError("The task failed!")
@app.task
def task_returns_baggage():
return dict(baggage.get_all())

View File

@ -15,12 +15,13 @@
import threading import threading
import time import time
from opentelemetry import baggage, context
from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind, StatusCode from opentelemetry.trace import SpanKind, StatusCode
from .celery_test_tasks import app, task_add, task_raises from .celery_test_tasks import app, task_add, task_raises, task_returns_baggage
class TestCeleryInstrumentation(TestBase): class TestCeleryInstrumentation(TestBase):
@ -168,6 +169,22 @@ class TestCeleryInstrumentation(TestBase):
spans = self.memory_exporter.get_finished_spans() spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0) self.assertEqual(len(spans), 0)
def test_baggage(self):
CeleryInstrumentor().instrument()
ctx = baggage.set_baggage("key", "value")
context.attach(ctx)
task = task_returns_baggage.delay()
timeout = time.time() + 60 * 1 # 1 minutes from now
while not task.ready():
if time.time() > timeout:
break
time.sleep(0.05)
self.assertEqual(task.result, {"key": "value"})
class TestCelerySignatureTask(TestBase): class TestCelerySignatureTask(TestBase):
def setUp(self): def setUp(self):

View File

@ -167,8 +167,10 @@ class TestUtils(unittest.TestCase):
# propagate and retrieve a Span # propagate and retrieve a Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext)) span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
utils.attach_span(fn_task, task_id, span) utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
span_after = utils.retrieve_span(fn_task, task_id) ctx = utils.retrieve_context(fn_task, task_id)
self.assertIsNotNone(ctx)
span_after, _, _ = ctx
self.assertIs(span, span_after) self.assertIs(span, span_after)
def test_span_delete(self): def test_span_delete(self):
@ -180,17 +182,19 @@ class TestUtils(unittest.TestCase):
# propagate a Span # propagate a Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext)) span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
utils.attach_span(fn_task, task_id, span) utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
# delete the Span # delete the Span
utils.detach_span(fn_task, task_id) utils.detach_context(fn_task, task_id)
self.assertEqual(utils.retrieve_span(fn_task, task_id), (None, None)) self.assertEqual(utils.retrieve_context(fn_task, task_id), None)
def test_optional_task_span_attach(self): def test_optional_task_span_attach(self):
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext)) span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
# assert this is is a no-aop # assert this is is a no-aop
self.assertIsNone(utils.attach_span(None, task_id, span)) self.assertIsNone(
utils.attach_context(None, task_id, span, mock.Mock(), "")
)
def test_span_delete_empty(self): def test_span_delete_empty(self):
# ensure detach_span doesn't raise an exception if span is not present # ensure detach_span doesn't raise an exception if span is not present
@ -201,10 +205,8 @@ class TestUtils(unittest.TestCase):
# delete the Span # delete the Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f" task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
try: try:
utils.detach_span(fn_task, task_id) utils.detach_context(fn_task, task_id)
self.assertEqual( self.assertEqual(utils.retrieve_context(fn_task, task_id), None)
utils.retrieve_span(fn_task, task_id), (None, None)
)
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
self.fail(f"Exception was raised: {ex}") self.fail(f"Exception was raised: {ex}")