httpx: rewrite patching to use wrapt instead of subclassing client (#2909)

httpx: rewrote patching to use wrapt instead of subclassing client

Porting of httpx instrumentation to patch async transport methods instead of substituting the client. That is because the current approach will instrument httpx by instantianting another client with a custom transport class and this will race with code already subclassing. This one uses wrapt to patch the default httpx transport classes.

---------

Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com>
This commit is contained in:
Riccardo Magliocchetti
2024-10-29 21:33:35 +01:00
committed by GitHub
parent 7cbe58691a
commit f6b68d0c02
4 changed files with 352 additions and 132 deletions

View File

@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2871](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2871)) ([#2871](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2871))
- `opentelemetry-instrumentation` Don't fail distro loading if instrumentor raises ImportError, instead skip them - `opentelemetry-instrumentation` Don't fail distro loading if instrumentor raises ImportError, instead skip them
([#2923](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2923)) ([#2923](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2923))
- `opentelemetry-instrumentation-httpx` Rewrote instrumentation to use wrapt instead of subclassing
([#2909](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2909))
## Version 1.27.0/0.48b0 (2024-08-28) ## Version 1.27.0/0.48b0 (2024-08-28)

View File

@ -29,6 +29,7 @@ dependencies = [
"opentelemetry-instrumentation == 0.49b0.dev", "opentelemetry-instrumentation == 0.49b0.dev",
"opentelemetry-semantic-conventions == 0.49b0.dev", "opentelemetry-semantic-conventions == 0.49b0.dev",
"opentelemetry-util-http == 0.49b0.dev", "opentelemetry-util-http == 0.49b0.dev",
"wrapt >= 1.0.0, < 2.0.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=too-many-lines
""" """
Usage Usage
----- -----
@ -194,9 +195,11 @@ API
import logging import logging
import typing import typing
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
from functools import partial
from types import TracebackType from types import TracebackType
import httpx import httpx
from wrapt import wrap_function_wrapper
from opentelemetry.instrumentation._semconv import ( from opentelemetry.instrumentation._semconv import (
_get_schema_url, _get_schema_url,
@ -217,6 +220,7 @@ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import ( from opentelemetry.instrumentation.utils import (
http_status_to_status_code, http_status_to_status_code,
is_http_instrumentation_enabled, is_http_instrumentation_enabled,
unwrap,
) )
from opentelemetry.propagate import inject from opentelemetry.propagate import inject
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
@ -731,44 +735,211 @@ class HTTPXClientInstrumentor(BaseInstrumentor):
``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient`` ``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient``
``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient`` ``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient``
""" """
self._original_client = httpx.Client tracer_provider = kwargs.get("tracer_provider")
self._original_async_client = httpx.AsyncClient
request_hook = kwargs.get("request_hook") request_hook = kwargs.get("request_hook")
response_hook = kwargs.get("response_hook") response_hook = kwargs.get("response_hook")
async_request_hook = kwargs.get("async_request_hook") async_request_hook = kwargs.get("async_request_hook")
async_response_hook = kwargs.get("async_response_hook") async_request_hook = (
if callable(request_hook):
_InstrumentedClient._request_hook = request_hook
if callable(async_request_hook) and iscoroutinefunction(
async_request_hook async_request_hook
): if iscoroutinefunction(async_request_hook)
_InstrumentedAsyncClient._request_hook = async_request_hook else None
if callable(response_hook): )
_InstrumentedClient._response_hook = response_hook async_response_hook = kwargs.get("async_response_hook")
if callable(async_response_hook) and iscoroutinefunction( async_response_hook = (
async_response_hook async_response_hook
): if iscoroutinefunction(async_response_hook)
_InstrumentedAsyncClient._response_hook = async_response_hook else None
tracer_provider = kwargs.get("tracer_provider") )
_InstrumentedClient._tracer_provider = tracer_provider
_InstrumentedAsyncClient._tracer_provider = tracer_provider _OpenTelemetrySemanticConventionStability._initialize()
# Intentionally using a private attribute here, see: sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
# https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719 _OpenTelemetryStabilitySignalType.HTTP,
httpx.Client = httpx._api.Client = _InstrumentedClient )
httpx.AsyncClient = _InstrumentedAsyncClient tracer = get_tracer(
__name__,
instrumenting_library_version=__version__,
tracer_provider=tracer_provider,
schema_url=_get_schema_url(sem_conv_opt_in_mode),
)
wrap_function_wrapper(
"httpx",
"HTTPTransport.handle_request",
partial(
self._handle_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
request_hook=request_hook,
response_hook=response_hook,
),
)
wrap_function_wrapper(
"httpx",
"AsyncHTTPTransport.handle_async_request",
partial(
self._handle_async_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
async_request_hook=async_request_hook,
async_response_hook=async_response_hook,
),
)
def _uninstrument(self, **kwargs): def _uninstrument(self, **kwargs):
httpx.Client = httpx._api.Client = self._original_client unwrap(httpx.HTTPTransport, "handle_request")
httpx.AsyncClient = self._original_async_client unwrap(httpx.AsyncHTTPTransport, "handle_async_request")
_InstrumentedClient._tracer_provider = None
_InstrumentedClient._request_hook = None
_InstrumentedClient._response_hook = None
_InstrumentedAsyncClient._tracer_provider = None
_InstrumentedAsyncClient._request_hook = None
_InstrumentedAsyncClient._response_hook = None
@staticmethod @staticmethod
def _handle_request_wrapper( # pylint: disable=too-many-locals
wrapped,
instance,
args,
kwargs,
tracer,
sem_conv_opt_in_mode,
request_hook,
response_hook,
):
if not is_http_instrumentation_enabled():
return wrapped(*args, **kwargs)
method, url, headers, stream, extensions = _extract_parameters(
args, kwargs
)
method_original = method.decode()
span_name = _get_default_span_name(method_original)
span_attributes = {}
# apply http client response attributes according to semconv
_apply_request_client_attributes_to_span(
span_attributes,
url,
method_original,
sem_conv_opt_in_mode,
)
request_info = RequestInfo(method, url, headers, stream, extensions)
with tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(request_hook):
request_hook(span, request_info)
_inject_propagation_headers(headers, args, kwargs)
try:
response = wrapped(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
response = getattr(exc, "response", None)
if isinstance(response, (httpx.Response, tuple)):
status_code, headers, stream, extensions, http_version = (
_extract_response(response)
)
if span.is_recording():
# apply http client response attributes according to semconv
_apply_response_client_attributes_to_span(
span,
status_code,
http_version,
sem_conv_opt_in_mode,
)
if callable(response_hook):
response_hook(
span,
request_info,
ResponseInfo(status_code, headers, stream, extensions),
)
if exception:
if span.is_recording() and _report_new(sem_conv_opt_in_mode):
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)
return response
@staticmethod
async def _handle_async_request_wrapper( # pylint: disable=too-many-locals
wrapped,
instance,
args,
kwargs,
tracer,
sem_conv_opt_in_mode,
async_request_hook,
async_response_hook,
):
if not is_http_instrumentation_enabled():
return await wrapped(*args, **kwargs)
method, url, headers, stream, extensions = _extract_parameters(
args, kwargs
)
method_original = method.decode()
span_name = _get_default_span_name(method_original)
span_attributes = {}
# apply http client response attributes according to semconv
_apply_request_client_attributes_to_span(
span_attributes,
url,
method_original,
sem_conv_opt_in_mode,
)
request_info = RequestInfo(method, url, headers, stream, extensions)
with tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(async_request_hook):
await async_request_hook(span, request_info)
_inject_propagation_headers(headers, args, kwargs)
try:
response = await wrapped(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
response = getattr(exc, "response", None)
if isinstance(response, (httpx.Response, tuple)):
status_code, headers, stream, extensions, http_version = (
_extract_response(response)
)
if span.is_recording():
# apply http client response attributes according to semconv
_apply_response_client_attributes_to_span(
span,
status_code,
http_version,
sem_conv_opt_in_mode,
)
if callable(async_response_hook):
await async_response_hook(
span,
request_info,
ResponseInfo(status_code, headers, stream, extensions),
)
if exception:
if span.is_recording() and _report_new(sem_conv_opt_in_mode):
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)
return response
def instrument_client( def instrument_client(
self,
client: typing.Union[httpx.Client, httpx.AsyncClient], client: typing.Union[httpx.Client, httpx.AsyncClient],
tracer_provider: TracerProvider = None, tracer_provider: TracerProvider = None,
request_hook: typing.Union[ request_hook: typing.Union[
@ -788,67 +959,88 @@ class HTTPXClientInstrumentor(BaseInstrumentor):
response_hook: A hook that receives the span, request, and response response_hook: A hook that receives the span, request, and response
that is called right before the span ends that is called right before the span ends
""" """
# pylint: disable=protected-access
if not hasattr(client, "_is_instrumented_by_opentelemetry"):
client._is_instrumented_by_opentelemetry = False
if not client._is_instrumented_by_opentelemetry: if getattr(client, "_is_instrumented_by_opentelemetry", False):
if isinstance(client, httpx.Client):
client._original_transport = client._transport
client._original_mounts = client._mounts.copy()
transport = client._transport or httpx.HTTPTransport()
client._transport = SyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
client._is_instrumented_by_opentelemetry = True
client._mounts.update(
{
url_pattern: (
SyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
if transport is not None
else transport
)
for url_pattern, transport in client._original_mounts.items()
}
)
if isinstance(client, httpx.AsyncClient):
transport = client._transport or httpx.AsyncHTTPTransport()
client._original_mounts = client._mounts.copy()
client._transport = AsyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
client._is_instrumented_by_opentelemetry = True
client._mounts.update(
{
url_pattern: (
AsyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
if transport is not None
else transport
)
for url_pattern, transport in client._original_mounts.items()
}
)
else:
_logger.warning( _logger.warning(
"Attempting to instrument Httpx client while already instrumented" "Attempting to instrument Httpx client while already instrumented"
) )
return
_OpenTelemetrySemanticConventionStability._initialize()
sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
_OpenTelemetryStabilitySignalType.HTTP,
)
tracer = get_tracer(
__name__,
instrumenting_library_version=__version__,
tracer_provider=tracer_provider,
schema_url=_get_schema_url(sem_conv_opt_in_mode),
)
if iscoroutinefunction(request_hook):
async_request_hook = request_hook
request_hook = None
else:
# request_hook already set
async_request_hook = None
if iscoroutinefunction(response_hook):
async_response_hook = response_hook
response_hook = None
else:
# response_hook already set
async_response_hook = None
if hasattr(client._transport, "handle_request"):
wrap_function_wrapper(
client._transport,
"handle_request",
partial(
self._handle_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
request_hook=request_hook,
response_hook=response_hook,
),
)
for transport in client._mounts.values():
wrap_function_wrapper(
transport,
"handle_request",
partial(
self._handle_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
request_hook=request_hook,
response_hook=response_hook,
),
)
client._is_instrumented_by_opentelemetry = True
if hasattr(client._transport, "handle_async_request"):
wrap_function_wrapper(
client._transport,
"handle_async_request",
partial(
self._handle_async_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
async_request_hook=async_request_hook,
async_response_hook=async_response_hook,
),
)
for transport in client._mounts.values():
wrap_function_wrapper(
transport,
"handle_async_request",
partial(
self._handle_async_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
async_request_hook=async_request_hook,
async_response_hook=async_response_hook,
),
)
client._is_instrumented_by_opentelemetry = True
@staticmethod @staticmethod
def uninstrument_client( def uninstrument_client(
@ -859,15 +1051,13 @@ class HTTPXClientInstrumentor(BaseInstrumentor):
Args: Args:
client: The httpx Client or AsyncClient instance client: The httpx Client or AsyncClient instance
""" """
if hasattr(client, "_original_transport"): if hasattr(client._transport, "handle_request"):
client._transport = client._original_transport unwrap(client._transport, "handle_request")
del client._original_transport for transport in client._mounts.values():
unwrap(transport, "handle_request")
client._is_instrumented_by_opentelemetry = False
elif hasattr(client._transport, "handle_async_request"):
unwrap(client._transport, "handle_async_request")
for transport in client._mounts.values():
unwrap(transport, "handle_async_request")
client._is_instrumented_by_opentelemetry = False client._is_instrumented_by_opentelemetry = False
if hasattr(client, "_original_mounts"):
client._mounts = client._original_mounts.copy()
del client._original_mounts
else:
_logger.warning(
"Attempting to uninstrument Httpx "
"client while already uninstrumented"
)

View File

@ -21,6 +21,7 @@ from unittest import mock
import httpx import httpx
import respx import respx
from wrapt import ObjectProxy
import opentelemetry.instrumentation.httpx import opentelemetry.instrumentation.httpx
from opentelemetry import trace from opentelemetry import trace
@ -171,6 +172,7 @@ class BaseTestCases:
super().tearDown() super().tearDown()
self.env_patch.stop() self.env_patch.stop()
respx.stop() respx.stop()
HTTPXClientInstrumentor().uninstrument()
def assert_span( def assert_span(
self, exporter: "SpanExporter" = None, num_spans: int = 1 self, exporter: "SpanExporter" = None, num_spans: int = 1
@ -204,7 +206,7 @@ class BaseTestCases:
self.assertEqual(span.name, "GET") self.assertEqual(span.name, "GET")
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL, SpanAttributes.HTTP_URL: self.URL,
@ -228,7 +230,7 @@ class BaseTestCases:
self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertIs(span.kind, trace.SpanKind.CLIENT)
self.assertEqual(span.name, "HTTP") self.assertEqual(span.name, "HTTP")
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "_OTHER", SpanAttributes.HTTP_METHOD: "_OTHER",
SpanAttributes.HTTP_URL: self.URL, SpanAttributes.HTTP_URL: self.URL,
@ -252,7 +254,7 @@ class BaseTestCases:
self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertIs(span.kind, trace.SpanKind.CLIENT)
self.assertEqual(span.name, "HTTP") self.assertEqual(span.name, "HTTP")
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
HTTP_REQUEST_METHOD: "_OTHER", HTTP_REQUEST_METHOD: "_OTHER",
URL_FULL: self.URL, URL_FULL: self.URL,
@ -292,7 +294,7 @@ class BaseTestCases:
SpanAttributes.SCHEMA_URL, SpanAttributes.SCHEMA_URL,
) )
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
HTTP_REQUEST_METHOD: "GET", HTTP_REQUEST_METHOD: "GET",
URL_FULL: url, URL_FULL: url,
@ -327,7 +329,7 @@ class BaseTestCases:
) )
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_METHOD: "GET",
HTTP_REQUEST_METHOD: "GET", HTTP_REQUEST_METHOD: "GET",
@ -454,7 +456,7 @@ class BaseTestCases:
span = self.assert_span() span = self.assert_span()
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL, SpanAttributes.HTTP_URL: self.URL,
@ -510,7 +512,7 @@ class BaseTestCases:
span = self.assert_span() span = self.assert_span()
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
HTTP_REQUEST_METHOD: "GET", HTTP_REQUEST_METHOD: "GET",
URL_FULL: url, URL_FULL: url,
@ -531,7 +533,7 @@ class BaseTestCases:
span = self.assert_span() span = self.assert_span()
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_METHOD: "GET",
HTTP_REQUEST_METHOD: "GET", HTTP_REQUEST_METHOD: "GET",
@ -632,7 +634,7 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span() span = self.assert_span()
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL, SpanAttributes.HTTP_URL: self.URL,
@ -741,8 +743,10 @@ class BaseTestCases:
def setUp(self): def setUp(self):
super().setUp() super().setUp()
HTTPXClientInstrumentor().instrument()
self.client = self.create_client() self.client = self.create_client()
HTTPXClientInstrumentor().instrument_client(self.client)
def tearDown(self):
HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().uninstrument()
def create_proxy_mounts(self): def create_proxy_mounts(self):
@ -755,14 +759,25 @@ class BaseTestCases:
), ),
} }
def assert_proxy_mounts(self, mounts, num_mounts, transport_type): def assert_proxy_mounts(self, mounts, num_mounts, transport_type=None):
self.assertEqual(len(mounts), num_mounts) self.assertEqual(len(mounts), num_mounts)
for transport in mounts: for transport in mounts:
with self.subTest(transport): with self.subTest(transport):
if transport_type:
self.assertIsInstance( self.assertIsInstance(
transport, transport,
transport_type, transport_type,
) )
else:
handler = getattr(transport, "handle_request", None)
if not handler:
handler = getattr(
transport, "handle_async_request"
)
self.assertTrue(
isinstance(handler, ObjectProxy)
and getattr(handler, "__wrapped__")
)
def test_custom_tracer_provider(self): def test_custom_tracer_provider(self):
resource = resources.Resource.create({}) resource = resources.Resource.create({})
@ -778,7 +793,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span(exporter=exporter) span = self.assert_span(exporter=exporter)
self.assertIs(span.resource, resource) self.assertIs(span.resource, resource)
HTTPXClientInstrumentor().uninstrument()
def test_response_hook(self): def test_response_hook(self):
response_hook_key = ( response_hook_key = (
@ -797,7 +811,7 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span() span = self.assert_span()
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL, SpanAttributes.HTTP_URL: self.URL,
@ -805,7 +819,6 @@ class BaseTestCases:
HTTP_RESPONSE_BODY: "Hello!", HTTP_RESPONSE_BODY: "Hello!",
}, },
) )
HTTPXClientInstrumentor().uninstrument()
def test_response_hook_sync_async_kwargs(self): def test_response_hook_sync_async_kwargs(self):
HTTPXClientInstrumentor().instrument( HTTPXClientInstrumentor().instrument(
@ -819,7 +832,7 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span() span = self.assert_span()
self.assertEqual( self.assertEqual(
span.attributes, dict(span.attributes),
{ {
SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL, SpanAttributes.HTTP_URL: self.URL,
@ -827,7 +840,6 @@ class BaseTestCases:
HTTP_RESPONSE_BODY: "Hello!", HTTP_RESPONSE_BODY: "Hello!",
}, },
) )
HTTPXClientInstrumentor().uninstrument()
def test_request_hook(self): def test_request_hook(self):
request_hook_key = ( request_hook_key = (
@ -846,7 +858,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span() span = self.assert_span()
self.assertEqual(span.name, "GET" + self.URL) self.assertEqual(span.name, "GET" + self.URL)
HTTPXClientInstrumentor().uninstrument()
def test_request_hook_sync_async_kwargs(self): def test_request_hook_sync_async_kwargs(self):
HTTPXClientInstrumentor().instrument( HTTPXClientInstrumentor().instrument(
@ -860,7 +871,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span() span = self.assert_span()
self.assertEqual(span.name, "GET" + self.URL) self.assertEqual(span.name, "GET" + self.URL)
HTTPXClientInstrumentor().uninstrument()
def test_request_hook_no_span_update(self): def test_request_hook_no_span_update(self):
HTTPXClientInstrumentor().instrument( HTTPXClientInstrumentor().instrument(
@ -873,7 +883,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span() span = self.assert_span()
self.assertEqual(span.name, "GET") self.assertEqual(span.name, "GET")
HTTPXClientInstrumentor().uninstrument()
def test_not_recording(self): def test_not_recording(self):
with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span: with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span:
@ -891,7 +900,6 @@ class BaseTestCases:
self.assertTrue(mock_span.is_recording.called) self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called) self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called) self.assertFalse(mock_span.set_status.called)
HTTPXClientInstrumentor().uninstrument()
def test_suppress_instrumentation_new_client(self): def test_suppress_instrumentation_new_client(self):
HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().instrument()
@ -901,7 +909,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
self.assert_span(num_spans=0) self.assert_span(num_spans=0)
HTTPXClientInstrumentor().uninstrument()
def test_instrument_client(self): def test_instrument_client(self):
client = self.create_client() client = self.create_client()
@ -929,8 +936,6 @@ class BaseTestCases:
self.URL, self.URL,
) )
HTTPXClientInstrumentor().uninstrument()
def test_uninstrument(self): def test_uninstrument(self):
HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().instrument()
HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().uninstrument()
@ -980,9 +985,7 @@ class BaseTestCases:
self.assert_proxy_mounts( self.assert_proxy_mounts(
client._mounts.values(), client._mounts.values(),
2, 2,
(SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport),
) )
HTTPXClientInstrumentor().uninstrument()
def test_instrument_client_with_proxy(self): def test_instrument_client_with_proxy(self):
proxy_mounts = self.create_proxy_mounts() proxy_mounts = self.create_proxy_mounts()
@ -999,7 +1002,6 @@ class BaseTestCases:
self.assert_proxy_mounts( self.assert_proxy_mounts(
client._mounts.values(), client._mounts.values(),
2, 2,
(SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport),
) )
HTTPXClientInstrumentor().uninstrument_client(client) HTTPXClientInstrumentor().uninstrument_client(client)
@ -1010,7 +1012,6 @@ class BaseTestCases:
self.assert_proxy_mounts( self.assert_proxy_mounts(
client._mounts.values(), client._mounts.values(),
2, 2,
(SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport),
) )
HTTPXClientInstrumentor().uninstrument_client(client) HTTPXClientInstrumentor().uninstrument_client(client)
@ -1180,6 +1181,21 @@ class TestSyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
def create_proxy_transport(self, url): def create_proxy_transport(self, url):
return httpx.HTTPTransport(proxy=httpx.Proxy(url)) return httpx.HTTPTransport(proxy=httpx.Proxy(url))
def test_can_instrument_subclassed_client(self):
class CustomClient(httpx.Client):
pass
client = CustomClient()
self.assertFalse(
isinstance(client._transport.handle_request, ObjectProxy)
)
HTTPXClientInstrumentor().instrument()
self.assertTrue(
isinstance(client._transport.handle_request, ObjectProxy)
)
class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
response_hook = staticmethod(_async_response_hook) response_hook = staticmethod(_async_response_hook)
@ -1188,10 +1204,8 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
HTTPXClientInstrumentor().instrument()
self.client = self.create_client()
self.client2 = self.create_client() self.client2 = self.create_client()
HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument_client(self.client2)
def create_client( def create_client(
self, self,
@ -1245,7 +1259,6 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
SpanAttributes.HTTP_STATUS_CODE: 200, SpanAttributes.HTTP_STATUS_CODE: 200,
}, },
) )
HTTPXClientInstrumentor().uninstrument()
def test_async_request_hook_does_nothing_if_not_coroutine(self): def test_async_request_hook_does_nothing_if_not_coroutine(self):
HTTPXClientInstrumentor().instrument( HTTPXClientInstrumentor().instrument(
@ -1258,4 +1271,18 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
self.assertEqual(result.text, "Hello!") self.assertEqual(result.text, "Hello!")
span = self.assert_span() span = self.assert_span()
self.assertEqual(span.name, "GET") self.assertEqual(span.name, "GET")
HTTPXClientInstrumentor().uninstrument()
def test_can_instrument_subclassed_async_client(self):
class CustomAsyncClient(httpx.AsyncClient):
pass
client = CustomAsyncClient()
self.assertFalse(
isinstance(client._transport.handle_async_request, ObjectProxy)
)
HTTPXClientInstrumentor().instrument()
self.assertTrue(
isinstance(client._transport.handle_async_request, ObjectProxy)
)