diff --git a/CHANGELOG.md b/CHANGELOG.md index 028b4ee63..007954745 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2753](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2753)) - `opentelemetry-instrumentation-grpc` Fix grpc supported version ([#2845](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2845)) +- `opentelemetry-instrumentation-asyncio` fix `AttributeError` in + `AsyncioInstrumentor.trace_to_thread` when `func` is a `functools.partial` instance + ([#2911](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2911)) ## Version 1.26.0/0.47b0 (2024-07-23) diff --git a/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py b/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py index a6cc6b044..2d1b063df 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py @@ -78,6 +78,7 @@ API """ import asyncio +import functools import sys from asyncio import futures from timeit import default_timer @@ -231,14 +232,15 @@ class AsyncioInstrumentor(BaseInstrumentor): def trace_to_thread(self, func: callable): """Trace a function.""" start = default_timer() + func_name = getattr(func, "__name__", None) + if func_name is None and isinstance(func, functools.partial): + func_name = func.func.__name__ span = ( - self._tracer.start_span( - f"{ASYNCIO_PREFIX} to_thread-" + func.__name__ - ) - if func.__name__ in self._to_thread_name_to_trace + self._tracer.start_span(f"{ASYNCIO_PREFIX} to_thread-" + func_name) + if func_name in self._to_thread_name_to_trace else None ) - attr = {"type": "to_thread", "name": func.__name__} + attr = {"type": "to_thread", "name": func_name} exception = None try: attr["state"] = "finished" diff --git a/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py b/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py index 3d795d8ae..35191d3d0 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py +++ b/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import functools import sys from unittest import skipIf from unittest.mock import patch @@ -72,3 +73,37 @@ class TestAsyncioToThread(TestBase): for point in metric.data.data_points: self.assertEqual(point.attributes["type"], "to_thread") self.assertEqual(point.attributes["name"], "multiply") + + @skipIf( + sys.version_info < (3, 9), "to_thread is only available in Python 3.9+" + ) + def test_to_thread_partial_func(self): + def multiply(x, y): + return x * y + + double = functools.partial(multiply, 2) + + async def to_thread(): + result = await asyncio.to_thread(double, 3) + assert result == 6 + + with self._tracer.start_as_current_span("root"): + asyncio.run(to_thread()) + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 2) + assert spans[0].name == "asyncio to_thread-multiply" + for metric in ( + self.memory_metrics_reader.get_metrics_data() + .resource_metrics[0] + .scope_metrics[0] + .metrics + ): + if metric.name == "asyncio.process.duration": + for point in metric.data.data_points: + self.assertEqual(point.attributes["type"], "to_thread") + self.assertEqual(point.attributes["name"], "multiply") + if metric.name == "asyncio.process.created": + for point in metric.data.data_points: + self.assertEqual(point.attributes["type"], "to_thread") + self.assertEqual(point.attributes["name"], "multiply")