httpx: rewrite patching to use wrapt instead of subclassing client (#2909)

httpx: rewrote patching to use wrapt instead of subclassing client

Porting of httpx instrumentation to patch async transport methods instead of substituting the client. That is because the current approach will instrument httpx by instantianting another client with a custom transport class and this will race with code already subclassing. This one uses wrapt to patch the default httpx transport classes.

---------

Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com>
This commit is contained in:
Riccardo Magliocchetti
2024-10-29 21:33:35 +01:00
committed by GitHub
parent 7cbe58691a
commit f6b68d0c02
4 changed files with 352 additions and 132 deletions

View File

@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2871](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2871))
- `opentelemetry-instrumentation` Don't fail distro loading if instrumentor raises ImportError, instead skip them
([#2923](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2923))
- `opentelemetry-instrumentation-httpx` Rewrote instrumentation to use wrapt instead of subclassing
([#2909](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2909))
## Version 1.27.0/0.48b0 (2024-08-28)

View File

@ -29,6 +29,7 @@ dependencies = [
"opentelemetry-instrumentation == 0.49b0.dev",
"opentelemetry-semantic-conventions == 0.49b0.dev",
"opentelemetry-util-http == 0.49b0.dev",
"wrapt >= 1.0.0, < 2.0.0",
]
[project.optional-dependencies]

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=too-many-lines
"""
Usage
-----
@ -194,9 +195,11 @@ API
import logging
import typing
from asyncio import iscoroutinefunction
from functools import partial
from types import TracebackType
import httpx
from wrapt import wrap_function_wrapper
from opentelemetry.instrumentation._semconv import (
_get_schema_url,
@ -217,6 +220,7 @@ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
http_status_to_status_code,
is_http_instrumentation_enabled,
unwrap,
)
from opentelemetry.propagate import inject
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
@ -731,44 +735,211 @@ class HTTPXClientInstrumentor(BaseInstrumentor):
``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient``
``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient``
"""
self._original_client = httpx.Client
self._original_async_client = httpx.AsyncClient
tracer_provider = kwargs.get("tracer_provider")
request_hook = kwargs.get("request_hook")
response_hook = kwargs.get("response_hook")
async_request_hook = kwargs.get("async_request_hook")
async_response_hook = kwargs.get("async_response_hook")
if callable(request_hook):
_InstrumentedClient._request_hook = request_hook
if callable(async_request_hook) and iscoroutinefunction(
async_request_hook = (
async_request_hook
):
_InstrumentedAsyncClient._request_hook = async_request_hook
if callable(response_hook):
_InstrumentedClient._response_hook = response_hook
if callable(async_response_hook) and iscoroutinefunction(
if iscoroutinefunction(async_request_hook)
else None
)
async_response_hook = kwargs.get("async_response_hook")
async_response_hook = (
async_response_hook
):
_InstrumentedAsyncClient._response_hook = async_response_hook
tracer_provider = kwargs.get("tracer_provider")
_InstrumentedClient._tracer_provider = tracer_provider
_InstrumentedAsyncClient._tracer_provider = tracer_provider
# Intentionally using a private attribute here, see:
# https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719
httpx.Client = httpx._api.Client = _InstrumentedClient
httpx.AsyncClient = _InstrumentedAsyncClient
if iscoroutinefunction(async_response_hook)
else None
)
_OpenTelemetrySemanticConventionStability._initialize()
sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
_OpenTelemetryStabilitySignalType.HTTP,
)
tracer = get_tracer(
__name__,
instrumenting_library_version=__version__,
tracer_provider=tracer_provider,
schema_url=_get_schema_url(sem_conv_opt_in_mode),
)
wrap_function_wrapper(
"httpx",
"HTTPTransport.handle_request",
partial(
self._handle_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
request_hook=request_hook,
response_hook=response_hook,
),
)
wrap_function_wrapper(
"httpx",
"AsyncHTTPTransport.handle_async_request",
partial(
self._handle_async_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
async_request_hook=async_request_hook,
async_response_hook=async_response_hook,
),
)
def _uninstrument(self, **kwargs):
httpx.Client = httpx._api.Client = self._original_client
httpx.AsyncClient = self._original_async_client
_InstrumentedClient._tracer_provider = None
_InstrumentedClient._request_hook = None
_InstrumentedClient._response_hook = None
_InstrumentedAsyncClient._tracer_provider = None
_InstrumentedAsyncClient._request_hook = None
_InstrumentedAsyncClient._response_hook = None
unwrap(httpx.HTTPTransport, "handle_request")
unwrap(httpx.AsyncHTTPTransport, "handle_async_request")
@staticmethod
def _handle_request_wrapper( # pylint: disable=too-many-locals
wrapped,
instance,
args,
kwargs,
tracer,
sem_conv_opt_in_mode,
request_hook,
response_hook,
):
if not is_http_instrumentation_enabled():
return wrapped(*args, **kwargs)
method, url, headers, stream, extensions = _extract_parameters(
args, kwargs
)
method_original = method.decode()
span_name = _get_default_span_name(method_original)
span_attributes = {}
# apply http client response attributes according to semconv
_apply_request_client_attributes_to_span(
span_attributes,
url,
method_original,
sem_conv_opt_in_mode,
)
request_info = RequestInfo(method, url, headers, stream, extensions)
with tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(request_hook):
request_hook(span, request_info)
_inject_propagation_headers(headers, args, kwargs)
try:
response = wrapped(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
response = getattr(exc, "response", None)
if isinstance(response, (httpx.Response, tuple)):
status_code, headers, stream, extensions, http_version = (
_extract_response(response)
)
if span.is_recording():
# apply http client response attributes according to semconv
_apply_response_client_attributes_to_span(
span,
status_code,
http_version,
sem_conv_opt_in_mode,
)
if callable(response_hook):
response_hook(
span,
request_info,
ResponseInfo(status_code, headers, stream, extensions),
)
if exception:
if span.is_recording() and _report_new(sem_conv_opt_in_mode):
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)
return response
@staticmethod
async def _handle_async_request_wrapper( # pylint: disable=too-many-locals
wrapped,
instance,
args,
kwargs,
tracer,
sem_conv_opt_in_mode,
async_request_hook,
async_response_hook,
):
if not is_http_instrumentation_enabled():
return await wrapped(*args, **kwargs)
method, url, headers, stream, extensions = _extract_parameters(
args, kwargs
)
method_original = method.decode()
span_name = _get_default_span_name(method_original)
span_attributes = {}
# apply http client response attributes according to semconv
_apply_request_client_attributes_to_span(
span_attributes,
url,
method_original,
sem_conv_opt_in_mode,
)
request_info = RequestInfo(method, url, headers, stream, extensions)
with tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(async_request_hook):
await async_request_hook(span, request_info)
_inject_propagation_headers(headers, args, kwargs)
try:
response = await wrapped(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
response = getattr(exc, "response", None)
if isinstance(response, (httpx.Response, tuple)):
status_code, headers, stream, extensions, http_version = (
_extract_response(response)
)
if span.is_recording():
# apply http client response attributes according to semconv
_apply_response_client_attributes_to_span(
span,
status_code,
http_version,
sem_conv_opt_in_mode,
)
if callable(async_response_hook):
await async_response_hook(
span,
request_info,
ResponseInfo(status_code, headers, stream, extensions),
)
if exception:
if span.is_recording() and _report_new(sem_conv_opt_in_mode):
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)
return response
def instrument_client(
self,
client: typing.Union[httpx.Client, httpx.AsyncClient],
tracer_provider: TracerProvider = None,
request_hook: typing.Union[
@ -788,67 +959,88 @@ class HTTPXClientInstrumentor(BaseInstrumentor):
response_hook: A hook that receives the span, request, and response
that is called right before the span ends
"""
# pylint: disable=protected-access
if not hasattr(client, "_is_instrumented_by_opentelemetry"):
client._is_instrumented_by_opentelemetry = False
if not client._is_instrumented_by_opentelemetry:
if isinstance(client, httpx.Client):
client._original_transport = client._transport
client._original_mounts = client._mounts.copy()
transport = client._transport or httpx.HTTPTransport()
client._transport = SyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
client._is_instrumented_by_opentelemetry = True
client._mounts.update(
{
url_pattern: (
SyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
if transport is not None
else transport
)
for url_pattern, transport in client._original_mounts.items()
}
)
if isinstance(client, httpx.AsyncClient):
transport = client._transport or httpx.AsyncHTTPTransport()
client._original_mounts = client._mounts.copy()
client._transport = AsyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
client._is_instrumented_by_opentelemetry = True
client._mounts.update(
{
url_pattern: (
AsyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
if transport is not None
else transport
)
for url_pattern, transport in client._original_mounts.items()
}
)
else:
if getattr(client, "_is_instrumented_by_opentelemetry", False):
_logger.warning(
"Attempting to instrument Httpx client while already instrumented"
)
return
_OpenTelemetrySemanticConventionStability._initialize()
sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
_OpenTelemetryStabilitySignalType.HTTP,
)
tracer = get_tracer(
__name__,
instrumenting_library_version=__version__,
tracer_provider=tracer_provider,
schema_url=_get_schema_url(sem_conv_opt_in_mode),
)
if iscoroutinefunction(request_hook):
async_request_hook = request_hook
request_hook = None
else:
# request_hook already set
async_request_hook = None
if iscoroutinefunction(response_hook):
async_response_hook = response_hook
response_hook = None
else:
# response_hook already set
async_response_hook = None
if hasattr(client._transport, "handle_request"):
wrap_function_wrapper(
client._transport,
"handle_request",
partial(
self._handle_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
request_hook=request_hook,
response_hook=response_hook,
),
)
for transport in client._mounts.values():
wrap_function_wrapper(
transport,
"handle_request",
partial(
self._handle_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
request_hook=request_hook,
response_hook=response_hook,
),
)
client._is_instrumented_by_opentelemetry = True
if hasattr(client._transport, "handle_async_request"):
wrap_function_wrapper(
client._transport,
"handle_async_request",
partial(
self._handle_async_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
async_request_hook=async_request_hook,
async_response_hook=async_response_hook,
),
)
for transport in client._mounts.values():
wrap_function_wrapper(
transport,
"handle_async_request",
partial(
self._handle_async_request_wrapper,
tracer=tracer,
sem_conv_opt_in_mode=sem_conv_opt_in_mode,
async_request_hook=async_request_hook,
async_response_hook=async_response_hook,
),
)
client._is_instrumented_by_opentelemetry = True
@staticmethod
def uninstrument_client(
@ -859,15 +1051,13 @@ class HTTPXClientInstrumentor(BaseInstrumentor):
Args:
client: The httpx Client or AsyncClient instance
"""
if hasattr(client, "_original_transport"):
client._transport = client._original_transport
del client._original_transport
if hasattr(client._transport, "handle_request"):
unwrap(client._transport, "handle_request")
for transport in client._mounts.values():
unwrap(transport, "handle_request")
client._is_instrumented_by_opentelemetry = False
elif hasattr(client._transport, "handle_async_request"):
unwrap(client._transport, "handle_async_request")
for transport in client._mounts.values():
unwrap(transport, "handle_async_request")
client._is_instrumented_by_opentelemetry = False
if hasattr(client, "_original_mounts"):
client._mounts = client._original_mounts.copy()
del client._original_mounts
else:
_logger.warning(
"Attempting to uninstrument Httpx "
"client while already uninstrumented"
)

View File

@ -21,6 +21,7 @@ from unittest import mock
import httpx
import respx
from wrapt import ObjectProxy
import opentelemetry.instrumentation.httpx
from opentelemetry import trace
@ -171,6 +172,7 @@ class BaseTestCases:
super().tearDown()
self.env_patch.stop()
respx.stop()
HTTPXClientInstrumentor().uninstrument()
def assert_span(
self, exporter: "SpanExporter" = None, num_spans: int = 1
@ -204,7 +206,7 @@ class BaseTestCases:
self.assertEqual(span.name, "GET")
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL,
@ -228,7 +230,7 @@ class BaseTestCases:
self.assertIs(span.kind, trace.SpanKind.CLIENT)
self.assertEqual(span.name, "HTTP")
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "_OTHER",
SpanAttributes.HTTP_URL: self.URL,
@ -252,7 +254,7 @@ class BaseTestCases:
self.assertIs(span.kind, trace.SpanKind.CLIENT)
self.assertEqual(span.name, "HTTP")
self.assertEqual(
span.attributes,
dict(span.attributes),
{
HTTP_REQUEST_METHOD: "_OTHER",
URL_FULL: self.URL,
@ -292,7 +294,7 @@ class BaseTestCases:
SpanAttributes.SCHEMA_URL,
)
self.assertEqual(
span.attributes,
dict(span.attributes),
{
HTTP_REQUEST_METHOD: "GET",
URL_FULL: url,
@ -327,7 +329,7 @@ class BaseTestCases:
)
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "GET",
HTTP_REQUEST_METHOD: "GET",
@ -454,7 +456,7 @@ class BaseTestCases:
span = self.assert_span()
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL,
@ -510,7 +512,7 @@ class BaseTestCases:
span = self.assert_span()
self.assertEqual(
span.attributes,
dict(span.attributes),
{
HTTP_REQUEST_METHOD: "GET",
URL_FULL: url,
@ -531,7 +533,7 @@ class BaseTestCases:
span = self.assert_span()
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "GET",
HTTP_REQUEST_METHOD: "GET",
@ -632,7 +634,7 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL,
@ -741,8 +743,10 @@ class BaseTestCases:
def setUp(self):
super().setUp()
HTTPXClientInstrumentor().instrument()
self.client = self.create_client()
HTTPXClientInstrumentor().instrument_client(self.client)
def tearDown(self):
HTTPXClientInstrumentor().uninstrument()
def create_proxy_mounts(self):
@ -755,14 +759,25 @@ class BaseTestCases:
),
}
def assert_proxy_mounts(self, mounts, num_mounts, transport_type):
def assert_proxy_mounts(self, mounts, num_mounts, transport_type=None):
self.assertEqual(len(mounts), num_mounts)
for transport in mounts:
with self.subTest(transport):
self.assertIsInstance(
transport,
transport_type,
)
if transport_type:
self.assertIsInstance(
transport,
transport_type,
)
else:
handler = getattr(transport, "handle_request", None)
if not handler:
handler = getattr(
transport, "handle_async_request"
)
self.assertTrue(
isinstance(handler, ObjectProxy)
and getattr(handler, "__wrapped__")
)
def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
@ -778,7 +793,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
span = self.assert_span(exporter=exporter)
self.assertIs(span.resource, resource)
HTTPXClientInstrumentor().uninstrument()
def test_response_hook(self):
response_hook_key = (
@ -797,7 +811,7 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL,
@ -805,7 +819,6 @@ class BaseTestCases:
HTTP_RESPONSE_BODY: "Hello!",
},
)
HTTPXClientInstrumentor().uninstrument()
def test_response_hook_sync_async_kwargs(self):
HTTPXClientInstrumentor().instrument(
@ -819,7 +832,7 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(
span.attributes,
dict(span.attributes),
{
SpanAttributes.HTTP_METHOD: "GET",
SpanAttributes.HTTP_URL: self.URL,
@ -827,7 +840,6 @@ class BaseTestCases:
HTTP_RESPONSE_BODY: "Hello!",
},
)
HTTPXClientInstrumentor().uninstrument()
def test_request_hook(self):
request_hook_key = (
@ -846,7 +858,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET" + self.URL)
HTTPXClientInstrumentor().uninstrument()
def test_request_hook_sync_async_kwargs(self):
HTTPXClientInstrumentor().instrument(
@ -860,7 +871,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET" + self.URL)
HTTPXClientInstrumentor().uninstrument()
def test_request_hook_no_span_update(self):
HTTPXClientInstrumentor().instrument(
@ -873,7 +883,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET")
HTTPXClientInstrumentor().uninstrument()
def test_not_recording(self):
with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span:
@ -891,7 +900,6 @@ class BaseTestCases:
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)
HTTPXClientInstrumentor().uninstrument()
def test_suppress_instrumentation_new_client(self):
HTTPXClientInstrumentor().instrument()
@ -901,7 +909,6 @@ class BaseTestCases:
self.assertEqual(result.text, "Hello!")
self.assert_span(num_spans=0)
HTTPXClientInstrumentor().uninstrument()
def test_instrument_client(self):
client = self.create_client()
@ -929,8 +936,6 @@ class BaseTestCases:
self.URL,
)
HTTPXClientInstrumentor().uninstrument()
def test_uninstrument(self):
HTTPXClientInstrumentor().instrument()
HTTPXClientInstrumentor().uninstrument()
@ -980,9 +985,7 @@ class BaseTestCases:
self.assert_proxy_mounts(
client._mounts.values(),
2,
(SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport),
)
HTTPXClientInstrumentor().uninstrument()
def test_instrument_client_with_proxy(self):
proxy_mounts = self.create_proxy_mounts()
@ -999,7 +1002,6 @@ class BaseTestCases:
self.assert_proxy_mounts(
client._mounts.values(),
2,
(SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport),
)
HTTPXClientInstrumentor().uninstrument_client(client)
@ -1010,7 +1012,6 @@ class BaseTestCases:
self.assert_proxy_mounts(
client._mounts.values(),
2,
(SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport),
)
HTTPXClientInstrumentor().uninstrument_client(client)
@ -1180,6 +1181,21 @@ class TestSyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
def create_proxy_transport(self, url):
return httpx.HTTPTransport(proxy=httpx.Proxy(url))
def test_can_instrument_subclassed_client(self):
class CustomClient(httpx.Client):
pass
client = CustomClient()
self.assertFalse(
isinstance(client._transport.handle_request, ObjectProxy)
)
HTTPXClientInstrumentor().instrument()
self.assertTrue(
isinstance(client._transport.handle_request, ObjectProxy)
)
class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
response_hook = staticmethod(_async_response_hook)
@ -1188,10 +1204,8 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
def setUp(self):
super().setUp()
HTTPXClientInstrumentor().instrument()
self.client = self.create_client()
self.client2 = self.create_client()
HTTPXClientInstrumentor().uninstrument()
HTTPXClientInstrumentor().instrument_client(self.client2)
def create_client(
self,
@ -1245,7 +1259,6 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
SpanAttributes.HTTP_STATUS_CODE: 200,
},
)
HTTPXClientInstrumentor().uninstrument()
def test_async_request_hook_does_nothing_if_not_coroutine(self):
HTTPXClientInstrumentor().instrument(
@ -1258,4 +1271,18 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest):
self.assertEqual(result.text, "Hello!")
span = self.assert_span()
self.assertEqual(span.name, "GET")
HTTPXClientInstrumentor().uninstrument()
def test_can_instrument_subclassed_async_client(self):
class CustomAsyncClient(httpx.AsyncClient):
pass
client = CustomAsyncClient()
self.assertFalse(
isinstance(client._transport.handle_async_request, ObjectProxy)
)
HTTPXClientInstrumentor().instrument()
self.assertTrue(
isinstance(client._transport.handle_async_request, ObjectProxy)
)