mirror of
https://github.com/open-telemetry/opentelemetry-python-contrib.git
synced 2025-08-02 19:47:17 +08:00
fastapi: fix wrapping of middlewares (#3012)
* fastapi: fix wrapping of middlewares * fix import, super * add test * changelog * lint * lint * fix * ci * fix wip * fix * fix * lint * lint * Exit? * Update test_fastapi_instrumentation.py Co-authored-by: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com> * remove break * fix * remove dunders * add test * lint * add endpoint to class * fmt * pr feedback * move type ignores * fix sphinx? * Update CHANGELOG.md * update fastapi versions * fix? * generate * stop passing on user-supplied error handler This prevents potential side effects, such as logging, to be executed more than once per request handler exception. * fix ci Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com> * fix ruff Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com> * remove unused funcs Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com> * fix lint,ruff Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com> * fix changelog Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com> * add changelog note Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com> * fix conflicts with main Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com> --------- Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com> Co-authored-by: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com> Co-authored-by: Alexander Dorn <ad@not.one> Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com>
This commit is contained in:

committed by
GitHub

parent
dbdff31220
commit
4d6893e8fa
@ -35,7 +35,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
instruments = [
|
||||
"fastapi ~= 0.58",
|
||||
"fastapi ~= 0.92",
|
||||
]
|
||||
|
||||
[project.entry-points.opentelemetry_instrumentor]
|
||||
|
@ -182,11 +182,16 @@ API
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import types
|
||||
from typing import Collection, Literal
|
||||
|
||||
import fastapi
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.errors import ServerErrorMiddleware
|
||||
from starlette.routing import Match
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from opentelemetry.instrumentation._semconv import (
|
||||
_get_schema_url,
|
||||
@ -203,9 +208,9 @@ from opentelemetry.instrumentation.asgi.types import (
|
||||
from opentelemetry.instrumentation.fastapi.package import _instruments
|
||||
from opentelemetry.instrumentation.fastapi.version import __version__
|
||||
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
|
||||
from opentelemetry.metrics import get_meter
|
||||
from opentelemetry.metrics import MeterProvider, get_meter
|
||||
from opentelemetry.semconv.attributes.http_attributes import HTTP_ROUTE
|
||||
from opentelemetry.trace import get_tracer
|
||||
from opentelemetry.trace import TracerProvider, get_tracer
|
||||
from opentelemetry.util.http import (
|
||||
get_excluded_urls,
|
||||
parse_excluded_urls,
|
||||
@ -226,13 +231,13 @@ class FastAPIInstrumentor(BaseInstrumentor):
|
||||
|
||||
@staticmethod
|
||||
def instrument_app(
|
||||
app,
|
||||
app: fastapi.FastAPI,
|
||||
server_request_hook: ServerRequestHook = None,
|
||||
client_request_hook: ClientRequestHook = None,
|
||||
client_response_hook: ClientResponseHook = None,
|
||||
tracer_provider=None,
|
||||
meter_provider=None,
|
||||
excluded_urls=None,
|
||||
tracer_provider: TracerProvider | None = None,
|
||||
meter_provider: MeterProvider | None = None,
|
||||
excluded_urls: str | None = None,
|
||||
http_capture_headers_server_request: list[str] | None = None,
|
||||
http_capture_headers_server_response: list[str] | None = None,
|
||||
http_capture_headers_sanitize_fields: list[str] | None = None,
|
||||
@ -284,21 +289,56 @@ class FastAPIInstrumentor(BaseInstrumentor):
|
||||
schema_url=_get_schema_url(sem_conv_opt_in_mode),
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
OpenTelemetryMiddleware,
|
||||
excluded_urls=excluded_urls,
|
||||
default_span_details=_get_default_span_details,
|
||||
server_request_hook=server_request_hook,
|
||||
client_request_hook=client_request_hook,
|
||||
client_response_hook=client_response_hook,
|
||||
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
|
||||
tracer=tracer,
|
||||
meter=meter,
|
||||
http_capture_headers_server_request=http_capture_headers_server_request,
|
||||
http_capture_headers_server_response=http_capture_headers_server_response,
|
||||
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
|
||||
exclude_spans=exclude_spans,
|
||||
# Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware
|
||||
# as the outermost middleware.
|
||||
# Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able
|
||||
# to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do.
|
||||
|
||||
def build_middleware_stack(self: Starlette) -> ASGIApp:
|
||||
inner_server_error_middleware: ASGIApp = ( # type: ignore
|
||||
self._original_build_middleware_stack() # type: ignore
|
||||
)
|
||||
otel_middleware = OpenTelemetryMiddleware(
|
||||
inner_server_error_middleware,
|
||||
excluded_urls=excluded_urls,
|
||||
default_span_details=_get_default_span_details,
|
||||
server_request_hook=server_request_hook,
|
||||
client_request_hook=client_request_hook,
|
||||
client_response_hook=client_response_hook,
|
||||
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
|
||||
tracer=tracer,
|
||||
meter=meter,
|
||||
http_capture_headers_server_request=http_capture_headers_server_request,
|
||||
http_capture_headers_server_response=http_capture_headers_server_response,
|
||||
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
|
||||
exclude_spans=exclude_spans,
|
||||
)
|
||||
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware
|
||||
# are handled.
|
||||
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that
|
||||
# to impact the user's application just because we wrapped the middlewares in this order.
|
||||
if isinstance(
|
||||
inner_server_error_middleware, ServerErrorMiddleware
|
||||
): # usually true
|
||||
outer_server_error_middleware = ServerErrorMiddleware(
|
||||
app=otel_middleware,
|
||||
)
|
||||
else:
|
||||
# Something else seems to have patched things, or maybe Starlette changed.
|
||||
# Just create a default ServerErrorMiddleware.
|
||||
outer_server_error_middleware = ServerErrorMiddleware(
|
||||
app=otel_middleware
|
||||
)
|
||||
return outer_server_error_middleware
|
||||
|
||||
app._original_build_middleware_stack = app.build_middleware_stack
|
||||
app.build_middleware_stack = types.MethodType(
|
||||
functools.wraps(app.build_middleware_stack)(
|
||||
build_middleware_stack
|
||||
),
|
||||
app,
|
||||
)
|
||||
|
||||
app._is_instrumented_by_opentelemetry = True
|
||||
if app not in _InstrumentedFastAPI._instrumented_fastapi_apps:
|
||||
_InstrumentedFastAPI._instrumented_fastapi_apps.add(app)
|
||||
@ -309,11 +349,12 @@ class FastAPIInstrumentor(BaseInstrumentor):
|
||||
|
||||
@staticmethod
|
||||
def uninstrument_app(app: fastapi.FastAPI):
|
||||
app.user_middleware = [
|
||||
x
|
||||
for x in app.user_middleware
|
||||
if x.cls is not OpenTelemetryMiddleware
|
||||
]
|
||||
original_build_middleware_stack = getattr(
|
||||
app, "_original_build_middleware_stack", None
|
||||
)
|
||||
if original_build_middleware_stack:
|
||||
app.build_middleware_stack = original_build_middleware_stack
|
||||
del app._original_build_middleware_stack
|
||||
app.middleware_stack = app.build_middleware_stack()
|
||||
app._is_instrumented_by_opentelemetry = False
|
||||
|
||||
@ -341,12 +382,7 @@ class FastAPIInstrumentor(BaseInstrumentor):
|
||||
_InstrumentedFastAPI._http_capture_headers_sanitize_fields = (
|
||||
kwargs.get("http_capture_headers_sanitize_fields")
|
||||
)
|
||||
_excluded_urls = kwargs.get("excluded_urls")
|
||||
_InstrumentedFastAPI._excluded_urls = (
|
||||
_excluded_urls_from_env
|
||||
if _excluded_urls is None
|
||||
else parse_excluded_urls(_excluded_urls)
|
||||
)
|
||||
_InstrumentedFastAPI._excluded_urls = kwargs.get("excluded_urls")
|
||||
_InstrumentedFastAPI._meter_provider = kwargs.get("meter_provider")
|
||||
_InstrumentedFastAPI._exclude_spans = kwargs.get("exclude_spans")
|
||||
fastapi.FastAPI = _InstrumentedFastAPI
|
||||
@ -365,43 +401,29 @@ class _InstrumentedFastAPI(fastapi.FastAPI):
|
||||
_server_request_hook: ServerRequestHook = None
|
||||
_client_request_hook: ClientRequestHook = None
|
||||
_client_response_hook: ClientResponseHook = None
|
||||
_http_capture_headers_server_request: list[str] | None = None
|
||||
_http_capture_headers_server_response: list[str] | None = None
|
||||
_http_capture_headers_sanitize_fields: list[str] | None = None
|
||||
_exclude_spans: list[Literal["receive", "send"]] | None = None
|
||||
|
||||
_instrumented_fastapi_apps = set()
|
||||
_sem_conv_opt_in_mode = _StabilityMode.DEFAULT
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
tracer = get_tracer(
|
||||
__name__,
|
||||
__version__,
|
||||
_InstrumentedFastAPI._tracer_provider,
|
||||
schema_url=_get_schema_url(
|
||||
_InstrumentedFastAPI._sem_conv_opt_in_mode
|
||||
),
|
||||
FastAPIInstrumentor.instrument_app(
|
||||
self,
|
||||
server_request_hook=self._server_request_hook,
|
||||
client_request_hook=self._client_request_hook,
|
||||
client_response_hook=self._client_response_hook,
|
||||
tracer_provider=self._tracer_provider,
|
||||
meter_provider=self._meter_provider,
|
||||
excluded_urls=self._excluded_urls,
|
||||
http_capture_headers_server_request=self._http_capture_headers_server_request,
|
||||
http_capture_headers_server_response=self._http_capture_headers_server_response,
|
||||
http_capture_headers_sanitize_fields=self._http_capture_headers_sanitize_fields,
|
||||
exclude_spans=self._exclude_spans,
|
||||
)
|
||||
meter = get_meter(
|
||||
__name__,
|
||||
__version__,
|
||||
_InstrumentedFastAPI._meter_provider,
|
||||
schema_url=_get_schema_url(
|
||||
_InstrumentedFastAPI._sem_conv_opt_in_mode
|
||||
),
|
||||
)
|
||||
self.add_middleware(
|
||||
OpenTelemetryMiddleware,
|
||||
excluded_urls=_InstrumentedFastAPI._excluded_urls,
|
||||
default_span_details=_get_default_span_details,
|
||||
server_request_hook=_InstrumentedFastAPI._server_request_hook,
|
||||
client_request_hook=_InstrumentedFastAPI._client_request_hook,
|
||||
client_response_hook=_InstrumentedFastAPI._client_response_hook,
|
||||
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
|
||||
tracer=tracer,
|
||||
meter=meter,
|
||||
http_capture_headers_server_request=_InstrumentedFastAPI._http_capture_headers_server_request,
|
||||
http_capture_headers_server_response=_InstrumentedFastAPI._http_capture_headers_server_response,
|
||||
http_capture_headers_sanitize_fields=_InstrumentedFastAPI._http_capture_headers_sanitize_fields,
|
||||
exclude_spans=_InstrumentedFastAPI._exclude_spans,
|
||||
)
|
||||
self._is_instrumented_by_opentelemetry = True
|
||||
_InstrumentedFastAPI._instrumented_fastapi_apps.add(self)
|
||||
|
||||
def __del__(self):
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
_instruments = ("fastapi ~= 0.58",)
|
||||
_instruments = ("fastapi ~= 0.92",)
|
||||
|
||||
_supports_metrics = True
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import unittest
|
||||
from contextlib import ExitStack
|
||||
from timeit import default_timer
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
@ -183,9 +184,14 @@ class TestBaseFastAPI(TestBase):
|
||||
self._instrumentor = otel_fastapi.FastAPIInstrumentor()
|
||||
self._app = self._create_app()
|
||||
self._app.add_middleware(HTTPSRedirectMiddleware)
|
||||
self._client = TestClient(self._app)
|
||||
self._client = TestClient(self._app, base_url="https://testserver:443")
|
||||
# run the lifespan, initialize the middleware stack
|
||||
# this is more in-line with what happens in a real application when the server starts up
|
||||
self._exit_stack = ExitStack()
|
||||
self._exit_stack.enter_context(self._client)
|
||||
|
||||
def tearDown(self):
|
||||
self._exit_stack.close()
|
||||
super().tearDown()
|
||||
self.env_patch.stop()
|
||||
self.exclude_patch.stop()
|
||||
@ -218,11 +224,19 @@ class TestBaseFastAPI(TestBase):
|
||||
async def _():
|
||||
return {"message": "ok"}
|
||||
|
||||
@app.get("/error")
|
||||
async def _():
|
||||
raise UnhandledException("This is an unhandled exception")
|
||||
|
||||
app.mount("/sub", app=sub_app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class UnhandledException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TestBaseManualFastAPI(TestBaseFastAPI):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -233,6 +247,27 @@ class TestBaseManualFastAPI(TestBaseFastAPI):
|
||||
|
||||
super(TestBaseManualFastAPI, cls).setUpClass()
|
||||
|
||||
def test_fastapi_unhandled_exception(self):
|
||||
"""If the application has an unhandled error the instrumentation should capture that a 500 response is returned."""
|
||||
try:
|
||||
resp = self._client.get("/error")
|
||||
assert (
|
||||
resp.status_code == 500
|
||||
), resp.content # pragma: no cover, for debugging this test if an exception is _not_ raised
|
||||
except UnhandledException:
|
||||
pass
|
||||
else:
|
||||
self.fail("Expected UnhandledException")
|
||||
|
||||
spans = self.memory_exporter.get_finished_spans()
|
||||
self.assertEqual(len(spans), 3)
|
||||
span = spans[0]
|
||||
assert span.name == "GET /error http send"
|
||||
assert span.attributes[HTTP_STATUS_CODE] == 500
|
||||
span = spans[2]
|
||||
assert span.name == "GET /error"
|
||||
assert span.attributes[HTTP_TARGET] == "/error"
|
||||
|
||||
def test_sub_app_fastapi_call(self):
|
||||
"""
|
||||
This test is to ensure that a span in case of a sub app targeted contains the correct server url
|
||||
@ -975,6 +1010,10 @@ class TestFastAPIManualInstrumentation(TestBaseManualFastAPI):
|
||||
async def _():
|
||||
return {"message": "ok"}
|
||||
|
||||
@app.get("/error")
|
||||
async def _():
|
||||
raise UnhandledException("This is an unhandled exception")
|
||||
|
||||
app.mount("/sub", app=sub_app)
|
||||
|
||||
return app
|
||||
@ -1137,9 +1176,11 @@ class TestAutoInstrumentation(TestBaseAutoFastAPI):
|
||||
def test_mulitple_way_instrumentation(self):
|
||||
self._instrumentor.instrument_app(self._app)
|
||||
count = 0
|
||||
for middleware in self._app.user_middleware:
|
||||
if middleware.cls is OpenTelemetryMiddleware:
|
||||
app = self._app.middleware_stack
|
||||
while app is not None:
|
||||
if isinstance(app, OpenTelemetryMiddleware):
|
||||
count += 1
|
||||
app = getattr(app, "app", None)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
def test_uninstrument_after_instrument(self):
|
||||
|
Reference in New Issue
Block a user