From f6b68d0c024cf40d15c08062a18bf70ea73847e6 Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Tue, 29 Oct 2024 21:33:35 +0100 Subject: [PATCH] httpx: rewrite patching to use wrapt instead of subclassing client (#2909) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit httpx: rewrote patching to use wrapt instead of subclassing client Porting of httpx instrumentation to patch async transport methods instead of substituting the client. That is because the current approach will instrument httpx by instantianting another client with a custom transport class and this will race with code already subclassing. This one uses wrapt to patch the default httpx transport classes. --------- Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com> --- CHANGELOG.md | 2 + .../pyproject.toml | 1 + .../instrumentation/httpx/__init__.py | 382 +++++++++++++----- .../tests/test_httpx_integration.py | 99 +++-- 4 files changed, 352 insertions(+), 132 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed4671d55..7597e6064 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2871](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2871)) - `opentelemetry-instrumentation` Don't fail distro loading if instrumentor raises ImportError, instead skip them ([#2923](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2923)) +- `opentelemetry-instrumentation-httpx` Rewrote instrumentation to use wrapt instead of subclassing + ([#2909](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2909)) ## Version 1.27.0/0.48b0 (2024-08-28) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml b/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml index 599091716..c986fac4a 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml +++ b/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "opentelemetry-instrumentation == 0.49b0.dev", "opentelemetry-semantic-conventions == 0.49b0.dev", "opentelemetry-util-http == 0.49b0.dev", + "wrapt >= 1.0.0, < 2.0.0", ] [project.optional-dependencies] diff --git a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py index b9b9a31d3..d3a2cecfe 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-lines """ Usage ----- @@ -194,9 +195,11 @@ API import logging import typing from asyncio import iscoroutinefunction +from functools import partial from types import TracebackType import httpx +from wrapt import wrap_function_wrapper from opentelemetry.instrumentation._semconv import ( _get_schema_url, @@ -217,6 +220,7 @@ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import ( http_status_to_status_code, is_http_instrumentation_enabled, + unwrap, ) from opentelemetry.propagate import inject from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE @@ -731,44 +735,211 @@ class HTTPXClientInstrumentor(BaseInstrumentor): ``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient`` ``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient`` """ - self._original_client = httpx.Client - self._original_async_client = httpx.AsyncClient + tracer_provider = kwargs.get("tracer_provider") request_hook = kwargs.get("request_hook") response_hook = kwargs.get("response_hook") async_request_hook = kwargs.get("async_request_hook") - async_response_hook = kwargs.get("async_response_hook") - if callable(request_hook): - _InstrumentedClient._request_hook = request_hook - if callable(async_request_hook) and iscoroutinefunction( + async_request_hook = ( async_request_hook - ): - _InstrumentedAsyncClient._request_hook = async_request_hook - if callable(response_hook): - _InstrumentedClient._response_hook = response_hook - if callable(async_response_hook) and iscoroutinefunction( + if iscoroutinefunction(async_request_hook) + else None + ) + async_response_hook = kwargs.get("async_response_hook") + async_response_hook = ( async_response_hook - ): - _InstrumentedAsyncClient._response_hook = async_response_hook - tracer_provider = kwargs.get("tracer_provider") - _InstrumentedClient._tracer_provider = tracer_provider - _InstrumentedAsyncClient._tracer_provider = tracer_provider - # Intentionally using a private attribute here, see: - # https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719 - httpx.Client = httpx._api.Client = _InstrumentedClient - httpx.AsyncClient = _InstrumentedAsyncClient + if iscoroutinefunction(async_response_hook) + else None + ) + + _OpenTelemetrySemanticConventionStability._initialize() + sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(sem_conv_opt_in_mode), + ) + + wrap_function_wrapper( + "httpx", + "HTTPTransport.handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + request_hook=request_hook, + response_hook=response_hook, + ), + ) + wrap_function_wrapper( + "httpx", + "AsyncHTTPTransport.handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), + ) def _uninstrument(self, **kwargs): - httpx.Client = httpx._api.Client = self._original_client - httpx.AsyncClient = self._original_async_client - _InstrumentedClient._tracer_provider = None - _InstrumentedClient._request_hook = None - _InstrumentedClient._response_hook = None - _InstrumentedAsyncClient._tracer_provider = None - _InstrumentedAsyncClient._request_hook = None - _InstrumentedAsyncClient._response_hook = None + unwrap(httpx.HTTPTransport, "handle_request") + unwrap(httpx.AsyncHTTPTransport, "handle_async_request") @staticmethod + def _handle_request_wrapper( # pylint: disable=too-many-locals + wrapped, + instance, + args, + kwargs, + tracer, + sem_conv_opt_in_mode, + request_hook, + response_hook, + ): + if not is_http_instrumentation_enabled(): + return wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(request_hook): + request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + sem_conv_opt_in_mode, + ) + if callable(response_hook): + response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new(sem_conv_opt_in_mode): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response + + @staticmethod + async def _handle_async_request_wrapper( # pylint: disable=too-many-locals + wrapped, + instance, + args, + kwargs, + tracer, + sem_conv_opt_in_mode, + async_request_hook, + async_response_hook, + ): + if not is_http_instrumentation_enabled(): + return await wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(async_request_hook): + await async_request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = await wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + sem_conv_opt_in_mode, + ) + + if callable(async_response_hook): + await async_response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new(sem_conv_opt_in_mode): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response + def instrument_client( + self, client: typing.Union[httpx.Client, httpx.AsyncClient], tracer_provider: TracerProvider = None, request_hook: typing.Union[ @@ -788,67 +959,88 @@ class HTTPXClientInstrumentor(BaseInstrumentor): response_hook: A hook that receives the span, request, and response that is called right before the span ends """ - # pylint: disable=protected-access - if not hasattr(client, "_is_instrumented_by_opentelemetry"): - client._is_instrumented_by_opentelemetry = False - if not client._is_instrumented_by_opentelemetry: - if isinstance(client, httpx.Client): - client._original_transport = client._transport - client._original_mounts = client._mounts.copy() - transport = client._transport or httpx.HTTPTransport() - client._transport = SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) - - if isinstance(client, httpx.AsyncClient): - transport = client._transport or httpx.AsyncHTTPTransport() - client._original_mounts = client._mounts.copy() - client._transport = AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) - else: + if getattr(client, "_is_instrumented_by_opentelemetry", False): _logger.warning( "Attempting to instrument Httpx client while already instrumented" ) + return + + _OpenTelemetrySemanticConventionStability._initialize() + sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(sem_conv_opt_in_mode), + ) + + if iscoroutinefunction(request_hook): + async_request_hook = request_hook + request_hook = None + else: + # request_hook already set + async_request_hook = None + + if iscoroutinefunction(response_hook): + async_response_hook = response_hook + response_hook = None + else: + # response_hook already set + async_response_hook = None + + if hasattr(client._transport, "handle_request"): + wrap_function_wrapper( + client._transport, + "handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + request_hook=request_hook, + response_hook=response_hook, + ), + ) + for transport in client._mounts.values(): + wrap_function_wrapper( + transport, + "handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + request_hook=request_hook, + response_hook=response_hook, + ), + ) + client._is_instrumented_by_opentelemetry = True + if hasattr(client._transport, "handle_async_request"): + wrap_function_wrapper( + client._transport, + "handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), + ) + for transport in client._mounts.values(): + wrap_function_wrapper( + transport, + "handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), + ) + client._is_instrumented_by_opentelemetry = True @staticmethod def uninstrument_client( @@ -859,15 +1051,13 @@ class HTTPXClientInstrumentor(BaseInstrumentor): Args: client: The httpx Client or AsyncClient instance """ - if hasattr(client, "_original_transport"): - client._transport = client._original_transport - del client._original_transport + if hasattr(client._transport, "handle_request"): + unwrap(client._transport, "handle_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_request") + client._is_instrumented_by_opentelemetry = False + elif hasattr(client._transport, "handle_async_request"): + unwrap(client._transport, "handle_async_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_async_request") client._is_instrumented_by_opentelemetry = False - if hasattr(client, "_original_mounts"): - client._mounts = client._original_mounts.copy() - del client._original_mounts - else: - _logger.warning( - "Attempting to uninstrument Httpx " - "client while already uninstrumented" - ) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py index 0d055515e..07699700c 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py @@ -21,6 +21,7 @@ from unittest import mock import httpx import respx +from wrapt import ObjectProxy import opentelemetry.instrumentation.httpx from opentelemetry import trace @@ -171,6 +172,7 @@ class BaseTestCases: super().tearDown() self.env_patch.stop() respx.stop() + HTTPXClientInstrumentor().uninstrument() def assert_span( self, exporter: "SpanExporter" = None, num_spans: int = 1 @@ -204,7 +206,7 @@ class BaseTestCases: self.assertEqual(span.name, "GET") self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -228,7 +230,7 @@ class BaseTestCases: self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertEqual(span.name, "HTTP") self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "_OTHER", SpanAttributes.HTTP_URL: self.URL, @@ -252,7 +254,7 @@ class BaseTestCases: self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertEqual(span.name, "HTTP") self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "_OTHER", URL_FULL: self.URL, @@ -292,7 +294,7 @@ class BaseTestCases: SpanAttributes.SCHEMA_URL, ) self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "GET", URL_FULL: url, @@ -327,7 +329,7 @@ class BaseTestCases: ) self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", HTTP_REQUEST_METHOD: "GET", @@ -454,7 +456,7 @@ class BaseTestCases: span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -510,7 +512,7 @@ class BaseTestCases: span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "GET", URL_FULL: url, @@ -531,7 +533,7 @@ class BaseTestCases: span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", HTTP_REQUEST_METHOD: "GET", @@ -632,7 +634,7 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -741,8 +743,10 @@ class BaseTestCases: def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() self.client = self.create_client() + HTTPXClientInstrumentor().instrument_client(self.client) + + def tearDown(self): HTTPXClientInstrumentor().uninstrument() def create_proxy_mounts(self): @@ -755,14 +759,25 @@ class BaseTestCases: ), } - def assert_proxy_mounts(self, mounts, num_mounts, transport_type): + def assert_proxy_mounts(self, mounts, num_mounts, transport_type=None): self.assertEqual(len(mounts), num_mounts) for transport in mounts: with self.subTest(transport): - self.assertIsInstance( - transport, - transport_type, - ) + if transport_type: + self.assertIsInstance( + transport, + transport_type, + ) + else: + handler = getattr(transport, "handle_request", None) + if not handler: + handler = getattr( + transport, "handle_async_request" + ) + self.assertTrue( + isinstance(handler, ObjectProxy) + and getattr(handler, "__wrapped__") + ) def test_custom_tracer_provider(self): resource = resources.Resource.create({}) @@ -778,7 +793,6 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") span = self.assert_span(exporter=exporter) self.assertIs(span.resource, resource) - HTTPXClientInstrumentor().uninstrument() def test_response_hook(self): response_hook_key = ( @@ -797,7 +811,7 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -805,7 +819,6 @@ class BaseTestCases: HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_response_hook_sync_async_kwargs(self): HTTPXClientInstrumentor().instrument( @@ -819,7 +832,7 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -827,7 +840,6 @@ class BaseTestCases: HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_request_hook(self): request_hook_key = ( @@ -846,7 +858,6 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_sync_async_kwargs(self): HTTPXClientInstrumentor().instrument( @@ -860,7 +871,6 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_no_span_update(self): HTTPXClientInstrumentor().instrument( @@ -873,7 +883,6 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument() def test_not_recording(self): with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span: @@ -891,7 +900,6 @@ class BaseTestCases: self.assertTrue(mock_span.is_recording.called) self.assertFalse(mock_span.set_attribute.called) self.assertFalse(mock_span.set_status.called) - HTTPXClientInstrumentor().uninstrument() def test_suppress_instrumentation_new_client(self): HTTPXClientInstrumentor().instrument() @@ -901,7 +909,6 @@ class BaseTestCases: self.assertEqual(result.text, "Hello!") self.assert_span(num_spans=0) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client(self): client = self.create_client() @@ -929,8 +936,6 @@ class BaseTestCases: self.URL, ) - HTTPXClientInstrumentor().uninstrument() - def test_uninstrument(self): HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().uninstrument() @@ -980,9 +985,7 @@ class BaseTestCases: self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client_with_proxy(self): proxy_mounts = self.create_proxy_mounts() @@ -999,7 +1002,6 @@ class BaseTestCases: self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) HTTPXClientInstrumentor().uninstrument_client(client) @@ -1010,7 +1012,6 @@ class BaseTestCases: self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) HTTPXClientInstrumentor().uninstrument_client(client) @@ -1180,6 +1181,21 @@ class TestSyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): def create_proxy_transport(self, url): return httpx.HTTPTransport(proxy=httpx.Proxy(url)) + def test_can_instrument_subclassed_client(self): + class CustomClient(httpx.Client): + pass + + client = CustomClient() + self.assertFalse( + isinstance(client._transport.handle_request, ObjectProxy) + ) + + HTTPXClientInstrumentor().instrument() + + self.assertTrue( + isinstance(client._transport.handle_request, ObjectProxy) + ) + class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): response_hook = staticmethod(_async_response_hook) @@ -1188,10 +1204,8 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() - self.client = self.create_client() self.client2 = self.create_client() - HTTPXClientInstrumentor().uninstrument() + HTTPXClientInstrumentor().instrument_client(self.client2) def create_client( self, @@ -1245,7 +1259,6 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): SpanAttributes.HTTP_STATUS_CODE: 200, }, ) - HTTPXClientInstrumentor().uninstrument() def test_async_request_hook_does_nothing_if_not_coroutine(self): HTTPXClientInstrumentor().instrument( @@ -1258,4 +1271,18 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument() + + def test_can_instrument_subclassed_async_client(self): + class CustomAsyncClient(httpx.AsyncClient): + pass + + client = CustomAsyncClient() + self.assertFalse( + isinstance(client._transport.handle_async_request, ObjectProxy) + ) + + HTTPXClientInstrumentor().instrument() + + self.assertTrue( + isinstance(client._transport.handle_async_request, ObjectProxy) + )