fastapi: fix wrapping of middlewares (#3012)

* fastapi: fix wrapping of middlewares

* fix import, super

* add test

* changelog

* lint

* lint

* fix

* ci

* fix wip

* fix

* fix

* lint

* lint

* Exit?

* Update test_fastapi_instrumentation.py

Co-authored-by: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com>

* remove break

* fix

* remove dunders

* add test

* lint

* add endpoint to class

* fmt

* pr feedback

* move type ignores

* fix sphinx?

* Update CHANGELOG.md

* update fastapi versions

* fix?

* generate

* stop passing on user-supplied error handler

This prevents potential side effects, such as logging, to be executed
more than once per request handler exception.

* fix ci

Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com>

* fix ruff

Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com>

* remove unused funcs

Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com>

* fix lint,ruff

Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com>

* fix changelog

Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com>

* add changelog note

Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com>

* fix conflicts with main

Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com>

---------

Signed-off-by: emdneto <9735060+emdneto@users.noreply.github.com>
Co-authored-by: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com>
Co-authored-by: Alexander Dorn <ad@not.one>
Co-authored-by: Emídio Neto <9735060+emdneto@users.noreply.github.com>
This commit is contained in:
Adrian Garcia Badaracco
2025-05-20 07:00:56 -07:00
committed by GitHub
parent dbdff31220
commit 4d6893e8fa
11 changed files with 181 additions and 109 deletions

View File

@ -35,7 +35,7 @@ dependencies = [
[project.optional-dependencies]
instruments = [
"fastapi ~= 0.58",
"fastapi ~= 0.92",
]
[project.entry-points.opentelemetry_instrumentor]

View File

@ -182,11 +182,16 @@ API
from __future__ import annotations
import functools
import logging
import types
from typing import Collection, Literal
import fastapi
from starlette.applications import Starlette
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.routing import Match
from starlette.types import ASGIApp
from opentelemetry.instrumentation._semconv import (
_get_schema_url,
@ -203,9 +208,9 @@ from opentelemetry.instrumentation.asgi.types import (
from opentelemetry.instrumentation.fastapi.package import _instruments
from opentelemetry.instrumentation.fastapi.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.metrics import get_meter
from opentelemetry.metrics import MeterProvider, get_meter
from opentelemetry.semconv.attributes.http_attributes import HTTP_ROUTE
from opentelemetry.trace import get_tracer
from opentelemetry.trace import TracerProvider, get_tracer
from opentelemetry.util.http import (
get_excluded_urls,
parse_excluded_urls,
@ -226,13 +231,13 @@ class FastAPIInstrumentor(BaseInstrumentor):
@staticmethod
def instrument_app(
app,
app: fastapi.FastAPI,
server_request_hook: ServerRequestHook = None,
client_request_hook: ClientRequestHook = None,
client_response_hook: ClientResponseHook = None,
tracer_provider=None,
meter_provider=None,
excluded_urls=None,
tracer_provider: TracerProvider | None = None,
meter_provider: MeterProvider | None = None,
excluded_urls: str | None = None,
http_capture_headers_server_request: list[str] | None = None,
http_capture_headers_server_response: list[str] | None = None,
http_capture_headers_sanitize_fields: list[str] | None = None,
@ -284,21 +289,56 @@ class FastAPIInstrumentor(BaseInstrumentor):
schema_url=_get_schema_url(sem_conv_opt_in_mode),
)
app.add_middleware(
OpenTelemetryMiddleware,
excluded_urls=excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=server_request_hook,
client_request_hook=client_request_hook,
client_response_hook=client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=http_capture_headers_server_request,
http_capture_headers_server_response=http_capture_headers_server_response,
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
exclude_spans=exclude_spans,
# Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware
# as the outermost middleware.
# Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able
# to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do.
def build_middleware_stack(self: Starlette) -> ASGIApp:
inner_server_error_middleware: ASGIApp = ( # type: ignore
self._original_build_middleware_stack() # type: ignore
)
otel_middleware = OpenTelemetryMiddleware(
inner_server_error_middleware,
excluded_urls=excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=server_request_hook,
client_request_hook=client_request_hook,
client_response_hook=client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=http_capture_headers_server_request,
http_capture_headers_server_response=http_capture_headers_server_response,
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields,
exclude_spans=exclude_spans,
)
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware
# are handled.
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that
# to impact the user's application just because we wrapped the middlewares in this order.
if isinstance(
inner_server_error_middleware, ServerErrorMiddleware
): # usually true
outer_server_error_middleware = ServerErrorMiddleware(
app=otel_middleware,
)
else:
# Something else seems to have patched things, or maybe Starlette changed.
# Just create a default ServerErrorMiddleware.
outer_server_error_middleware = ServerErrorMiddleware(
app=otel_middleware
)
return outer_server_error_middleware
app._original_build_middleware_stack = app.build_middleware_stack
app.build_middleware_stack = types.MethodType(
functools.wraps(app.build_middleware_stack)(
build_middleware_stack
),
app,
)
app._is_instrumented_by_opentelemetry = True
if app not in _InstrumentedFastAPI._instrumented_fastapi_apps:
_InstrumentedFastAPI._instrumented_fastapi_apps.add(app)
@ -309,11 +349,12 @@ class FastAPIInstrumentor(BaseInstrumentor):
@staticmethod
def uninstrument_app(app: fastapi.FastAPI):
app.user_middleware = [
x
for x in app.user_middleware
if x.cls is not OpenTelemetryMiddleware
]
original_build_middleware_stack = getattr(
app, "_original_build_middleware_stack", None
)
if original_build_middleware_stack:
app.build_middleware_stack = original_build_middleware_stack
del app._original_build_middleware_stack
app.middleware_stack = app.build_middleware_stack()
app._is_instrumented_by_opentelemetry = False
@ -341,12 +382,7 @@ class FastAPIInstrumentor(BaseInstrumentor):
_InstrumentedFastAPI._http_capture_headers_sanitize_fields = (
kwargs.get("http_capture_headers_sanitize_fields")
)
_excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFastAPI._excluded_urls = (
_excluded_urls_from_env
if _excluded_urls is None
else parse_excluded_urls(_excluded_urls)
)
_InstrumentedFastAPI._excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFastAPI._meter_provider = kwargs.get("meter_provider")
_InstrumentedFastAPI._exclude_spans = kwargs.get("exclude_spans")
fastapi.FastAPI = _InstrumentedFastAPI
@ -365,43 +401,29 @@ class _InstrumentedFastAPI(fastapi.FastAPI):
_server_request_hook: ServerRequestHook = None
_client_request_hook: ClientRequestHook = None
_client_response_hook: ClientResponseHook = None
_http_capture_headers_server_request: list[str] | None = None
_http_capture_headers_server_response: list[str] | None = None
_http_capture_headers_sanitize_fields: list[str] | None = None
_exclude_spans: list[Literal["receive", "send"]] | None = None
_instrumented_fastapi_apps = set()
_sem_conv_opt_in_mode = _StabilityMode.DEFAULT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tracer = get_tracer(
__name__,
__version__,
_InstrumentedFastAPI._tracer_provider,
schema_url=_get_schema_url(
_InstrumentedFastAPI._sem_conv_opt_in_mode
),
FastAPIInstrumentor.instrument_app(
self,
server_request_hook=self._server_request_hook,
client_request_hook=self._client_request_hook,
client_response_hook=self._client_response_hook,
tracer_provider=self._tracer_provider,
meter_provider=self._meter_provider,
excluded_urls=self._excluded_urls,
http_capture_headers_server_request=self._http_capture_headers_server_request,
http_capture_headers_server_response=self._http_capture_headers_server_response,
http_capture_headers_sanitize_fields=self._http_capture_headers_sanitize_fields,
exclude_spans=self._exclude_spans,
)
meter = get_meter(
__name__,
__version__,
_InstrumentedFastAPI._meter_provider,
schema_url=_get_schema_url(
_InstrumentedFastAPI._sem_conv_opt_in_mode
),
)
self.add_middleware(
OpenTelemetryMiddleware,
excluded_urls=_InstrumentedFastAPI._excluded_urls,
default_span_details=_get_default_span_details,
server_request_hook=_InstrumentedFastAPI._server_request_hook,
client_request_hook=_InstrumentedFastAPI._client_request_hook,
client_response_hook=_InstrumentedFastAPI._client_response_hook,
# Pass in tracer/meter to get __name__and __version__ of fastapi instrumentation
tracer=tracer,
meter=meter,
http_capture_headers_server_request=_InstrumentedFastAPI._http_capture_headers_server_request,
http_capture_headers_server_response=_InstrumentedFastAPI._http_capture_headers_server_response,
http_capture_headers_sanitize_fields=_InstrumentedFastAPI._http_capture_headers_sanitize_fields,
exclude_spans=_InstrumentedFastAPI._exclude_spans,
)
self._is_instrumented_by_opentelemetry = True
_InstrumentedFastAPI._instrumented_fastapi_apps.add(self)
def __del__(self):

View File

@ -13,7 +13,7 @@
# limitations under the License.
_instruments = ("fastapi ~= 0.58",)
_instruments = ("fastapi ~= 0.92",)
_supports_metrics = True

View File

@ -15,6 +15,7 @@
# pylint: disable=too-many-lines
import unittest
from contextlib import ExitStack
from timeit import default_timer
from unittest.mock import Mock, call, patch
@ -183,9 +184,14 @@ class TestBaseFastAPI(TestBase):
self._instrumentor = otel_fastapi.FastAPIInstrumentor()
self._app = self._create_app()
self._app.add_middleware(HTTPSRedirectMiddleware)
self._client = TestClient(self._app)
self._client = TestClient(self._app, base_url="https://testserver:443")
# run the lifespan, initialize the middleware stack
# this is more in-line with what happens in a real application when the server starts up
self._exit_stack = ExitStack()
self._exit_stack.enter_context(self._client)
def tearDown(self):
self._exit_stack.close()
super().tearDown()
self.env_patch.stop()
self.exclude_patch.stop()
@ -218,11 +224,19 @@ class TestBaseFastAPI(TestBase):
async def _():
return {"message": "ok"}
@app.get("/error")
async def _():
raise UnhandledException("This is an unhandled exception")
app.mount("/sub", app=sub_app)
return app
class UnhandledException(Exception):
pass
class TestBaseManualFastAPI(TestBaseFastAPI):
@classmethod
def setUpClass(cls):
@ -233,6 +247,27 @@ class TestBaseManualFastAPI(TestBaseFastAPI):
super(TestBaseManualFastAPI, cls).setUpClass()
def test_fastapi_unhandled_exception(self):
"""If the application has an unhandled error the instrumentation should capture that a 500 response is returned."""
try:
resp = self._client.get("/error")
assert (
resp.status_code == 500
), resp.content # pragma: no cover, for debugging this test if an exception is _not_ raised
except UnhandledException:
pass
else:
self.fail("Expected UnhandledException")
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 3)
span = spans[0]
assert span.name == "GET /error http send"
assert span.attributes[HTTP_STATUS_CODE] == 500
span = spans[2]
assert span.name == "GET /error"
assert span.attributes[HTTP_TARGET] == "/error"
def test_sub_app_fastapi_call(self):
"""
This test is to ensure that a span in case of a sub app targeted contains the correct server url
@ -975,6 +1010,10 @@ class TestFastAPIManualInstrumentation(TestBaseManualFastAPI):
async def _():
return {"message": "ok"}
@app.get("/error")
async def _():
raise UnhandledException("This is an unhandled exception")
app.mount("/sub", app=sub_app)
return app
@ -1137,9 +1176,11 @@ class TestAutoInstrumentation(TestBaseAutoFastAPI):
def test_mulitple_way_instrumentation(self):
self._instrumentor.instrument_app(self._app)
count = 0
for middleware in self._app.user_middleware:
if middleware.cls is OpenTelemetryMiddleware:
app = self._app.middleware_stack
while app is not None:
if isinstance(app, OpenTelemetryMiddleware):
count += 1
app = getattr(app, "app", None)
self.assertEqual(count, 1)
def test_uninstrument_after_instrument(self):