fix(opentelemetry-instrumentation-celery): attach incoming context on… (#2385)

* fix(opentelemetry-instrumentation-celery): attach incoming context on _trace_prerun

* docs(CHANGELOG): add entry for fix #2385

* fix(opentelemetry-instrumentation-celery): detach context after task is run

* test(opentelemetry-instrumentation-celery): add context utils tests

* fix(opentelemetry-instrumentation-celery): remove duplicated signal registration

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* refactor(opentelemetry-instrumentation-celery): fix types and tests for python 3.8

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* fix(opentelemetry-instrumentation-celery): attach context only if it is not None

* refactor(opentelemetry-instrumentation-celery): fix lint issues
This commit is contained in:
Malcolm Rebughini
2024-08-01 14:33:29 -07:00
committed by GitHub
parent 4ea9e5a99a
commit c3e9f75fb9
6 changed files with 114 additions and 49 deletions

View File

@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2756](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2756))
- `opentelemetry-instrumentation-aws-lambda` Fixing w3c baggage support
([#2589](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2589))
- `opentelemetry-instrumentation-celery` propagates baggage
([#2385](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2385))
## Version 1.26.0/0.47b0 (2024-07-23)
@ -119,10 +121,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2610](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2610))
- `opentelemetry-instrumentation-asgi` Bugfix: Middleware did not set status code attribute on duration metrics for non-recording spans.
([#2627](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2627))
<<<<<<< HEAD
- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9 ([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751))
=======
>>>>>>> 5a623233 (Changelog update)
- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9
([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751))
## Version 1.25.0/0.46b0 (2024-05-31)

View File

@ -67,6 +67,7 @@ from billiard import VERSION
from billiard.einfo import ExceptionInfo
from celery import signals # pylint: disable=no-name-in-module
from opentelemetry import context as context_api
from opentelemetry import trace
from opentelemetry.instrumentation.celery import utils
from opentelemetry.instrumentation.celery.package import _instruments
@ -169,6 +170,7 @@ class CeleryInstrumentor(BaseInstrumentor):
self.update_task_duration_time(task_id)
request = task.request
tracectx = extract(request, getter=celery_getter) or None
token = context_api.attach(tracectx) if tracectx is not None else None
logger.debug("prerun signal start task_id=%s", task_id)
@ -179,7 +181,7 @@ class CeleryInstrumentor(BaseInstrumentor):
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101
utils.attach_span(task, task_id, (span, activation))
utils.attach_context(task, task_id, span, activation, token)
def _trace_postrun(self, *args, **kwargs):
task = utils.retrieve_task(kwargs)
@ -191,11 +193,14 @@ class CeleryInstrumentor(BaseInstrumentor):
logger.debug("postrun signal task_id=%s", task_id)
# retrieve and finish the Span
span, activation = utils.retrieve_span(task, task_id)
if span is None:
ctx = utils.retrieve_context(task, task_id)
if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return
span, activation, token = ctx
# request context tags
if span.is_recording():
span.set_attribute(_TASK_TAG_KEY, _TASK_RUN)
@ -204,10 +209,11 @@ class CeleryInstrumentor(BaseInstrumentor):
span.set_attribute(_TASK_NAME_KEY, task.name)
activation.__exit__(None, None, None)
utils.detach_span(task, task_id)
utils.detach_context(task, task_id)
self.update_task_duration_time(task_id)
labels = {"task": task.name, "worker": task.request.hostname}
self._record_histograms(task_id, labels)
context_api.detach(token)
def _trace_before_publish(self, *args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
@ -238,7 +244,9 @@ class CeleryInstrumentor(BaseInstrumentor):
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101
utils.attach_span(task, task_id, (span, activation), is_publish=True)
utils.attach_context(
task, task_id, span, activation, None, is_publish=True
)
headers = kwargs.get("headers")
if headers:
@ -253,13 +261,16 @@ class CeleryInstrumentor(BaseInstrumentor):
return
# retrieve and finish the Span
_, activation = utils.retrieve_span(task, task_id, is_publish=True)
if activation is None:
ctx = utils.retrieve_context(task, task_id, is_publish=True)
if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return
_, activation, _ = ctx
activation.__exit__(None, None, None) # pylint: disable=E1101
utils.detach_span(task, task_id, is_publish=True)
utils.detach_context(task, task_id, is_publish=True)
@staticmethod
def _trace_failure(*args, **kwargs):
@ -269,9 +280,14 @@ class CeleryInstrumentor(BaseInstrumentor):
if task is None or task_id is None:
return
# retrieve and pass exception info to activation
span, _ = utils.retrieve_span(task, task_id)
if span is None or not span.is_recording():
ctx = utils.retrieve_context(task, task_id)
if ctx is None:
return
span, _, _ = ctx
if not span.is_recording():
return
status_kwargs = {"status_code": StatusCode.ERROR}
@ -311,8 +327,14 @@ class CeleryInstrumentor(BaseInstrumentor):
if task is None or task_id is None or reason is None:
return
span, _ = utils.retrieve_span(task, task_id)
if span is None or not span.is_recording():
ctx = utils.retrieve_context(task, task_id)
if ctx is None:
return
span, _, _ = ctx
if not span.is_recording():
return
# Add retry reason metadata to span

View File

@ -13,10 +13,13 @@
# limitations under the License.
import logging
from typing import ContextManager, Optional, Tuple
from celery import registry # pylint: disable=no-name-in-module
from celery.app.task import Task
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span
logger = logging.getLogger(__name__)
@ -81,10 +84,12 @@ def set_attributes_from_context(span, context):
elif key == "delivery_info":
# Get also destination from this
routing_key = value.get("routing_key")
if routing_key is not None:
span.set_attribute(
SpanAttributes.MESSAGING_DESTINATION, routing_key
)
value = str(value)
elif key == "id":
@ -114,11 +119,18 @@ def set_attributes_from_context(span, context):
span.set_attribute(attribute_name, value)
def attach_span(task, task_id, span, is_publish=False):
"""Helper to propagate a `Span` for the given `Task` instance. This
function uses a `dict` that stores the Span using the
`(task_id, is_publish)` as a key. This is useful when information must be
propagated from one Celery signal to another.
def attach_context(
task: Optional[Task],
task_id: str,
span: Span,
activation: ContextManager[Span],
token: Optional[object],
is_publish: bool = False,
) -> None:
"""Helper to propagate a `Span`, `ContextManager` and context token
for the given `Task` instance. This function uses a `dict` that stores
the Span using the `(task_id, is_publish)` as a key. This is useful
when information must be propagated from one Celery signal to another.
We use (task_id, is_publish) for the key to ensure that publishing a
task from within another task does not cause any conflicts.
@ -134,36 +146,41 @@ def attach_span(task, task_id, span, is_publish=False):
"""
if task is None:
return
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
span_dict = {}
setattr(task, CTX_KEY, span_dict)
span_dict[(task_id, is_publish)] = span
ctx_dict = getattr(task, CTX_KEY, None)
if ctx_dict is None:
ctx_dict = {}
setattr(task, CTX_KEY, ctx_dict)
ctx_dict[(task_id, is_publish)] = (span, activation, token)
def detach_span(task, task_id, is_publish=False):
"""Helper to remove a `Span` in a Celery task when it's propagated.
This function handles tasks where the `Span` is not attached.
def detach_context(task, task_id, is_publish=False) -> None:
"""Helper to remove `Span`, `ContextManager` and context token in a
Celery task when it's propagated.
This function handles tasks where no values are attached to the `Task`.
"""
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
return
# See note in `attach_span` for key info
span_dict.pop((task_id, is_publish), (None, None))
# See note in `attach_context` for key info
span_dict.pop((task_id, is_publish), None)
def retrieve_span(task, task_id, is_publish=False):
"""Helper to retrieve an active `Span` stored in a `Task`
instance
def retrieve_context(
task, task_id, is_publish=False
) -> Optional[Tuple[Span, ContextManager[Span], Optional[object]]]:
"""Helper to retrieve an active `Span`, `ContextManager` and context token
stored in a `Task` instance
"""
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
return (None, None)
return None
# See note in `attach_span` for key info
return span_dict.get((task_id, is_publish), (None, None))
# See note in `attach_context` for key info
return span_dict.get((task_id, is_publish), None)
def retrieve_task(kwargs):

View File

@ -14,6 +14,8 @@
from celery import Celery
from opentelemetry import baggage
class Config:
result_backend = "rpc"
@ -36,3 +38,8 @@ def task_add(num_a, num_b):
@app.task
def task_raises():
raise CustomError("The task failed!")
@app.task
def task_returns_baggage():
return dict(baggage.get_all())

View File

@ -15,12 +15,13 @@
import threading
import time
from opentelemetry import baggage, context
from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind, StatusCode
from .celery_test_tasks import app, task_add, task_raises
from .celery_test_tasks import app, task_add, task_raises, task_returns_baggage
class TestCeleryInstrumentation(TestBase):
@ -168,6 +169,22 @@ class TestCeleryInstrumentation(TestBase):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)
def test_baggage(self):
CeleryInstrumentor().instrument()
ctx = baggage.set_baggage("key", "value")
context.attach(ctx)
task = task_returns_baggage.delay()
timeout = time.time() + 60 * 1 # 1 minutes from now
while not task.ready():
if time.time() > timeout:
break
time.sleep(0.05)
self.assertEqual(task.result, {"key": "value"})
class TestCelerySignatureTask(TestBase):
def setUp(self):

View File

@ -167,8 +167,10 @@ class TestUtils(unittest.TestCase):
# propagate and retrieve a Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
utils.attach_span(fn_task, task_id, span)
span_after = utils.retrieve_span(fn_task, task_id)
utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
ctx = utils.retrieve_context(fn_task, task_id)
self.assertIsNotNone(ctx)
span_after, _, _ = ctx
self.assertIs(span, span_after)
def test_span_delete(self):
@ -180,17 +182,19 @@ class TestUtils(unittest.TestCase):
# propagate a Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
utils.attach_span(fn_task, task_id, span)
utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
# delete the Span
utils.detach_span(fn_task, task_id)
self.assertEqual(utils.retrieve_span(fn_task, task_id), (None, None))
utils.detach_context(fn_task, task_id)
self.assertEqual(utils.retrieve_context(fn_task, task_id), None)
def test_optional_task_span_attach(self):
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
# assert this is is a no-aop
self.assertIsNone(utils.attach_span(None, task_id, span))
self.assertIsNone(
utils.attach_context(None, task_id, span, mock.Mock(), "")
)
def test_span_delete_empty(self):
# ensure detach_span doesn't raise an exception if span is not present
@ -201,10 +205,8 @@ class TestUtils(unittest.TestCase):
# delete the Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
try:
utils.detach_span(fn_task, task_id)
self.assertEqual(
utils.retrieve_span(fn_task, task_id), (None, None)
)
utils.detach_context(fn_task, task_id)
self.assertEqual(utils.retrieve_context(fn_task, task_id), None)
except Exception as ex: # pylint: disable=broad-except
self.fail(f"Exception was raised: {ex}")