diff --git a/slowapi/extension.py b/slowapi/extension.py index e35d3cc..488f0d9 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -1,6 +1,7 @@ """ The starlette extension to rate-limit requests """ + import asyncio import functools import inspect @@ -734,7 +735,8 @@ class Limiter: if not isinstance(response, Response): # get the response object from the decorated endpoint function self._inject_headers( - kwargs.get("response"), request.state.view_rate_limit # type: ignore + kwargs.get("response"), + request.state.view_rate_limit, # type: ignore ) else: self._inject_headers( @@ -766,7 +768,8 @@ class Limiter: if not isinstance(response, Response): # get the response object from the decorated endpoint function self._inject_headers( - kwargs.get("response"), request.state.view_rate_limit # type: ignore + kwargs.get("response"), + request.state.view_rate_limit, # type: ignore ) else: self._inject_headers( @@ -803,7 +806,7 @@ class Limiter: * **error_message**: string (or callable that returns one) to override the error message used in the response. * **exempt_when**: function returning a boolean indicating whether to exempt - the route from the limit + the route from the limit. This function can optionally use a Request object. * **cost**: integer (or callable that returns one) which is the cost of a hit * **override_defaults**: whether to override the default limits (default: True) """ diff --git a/tests/test_starlette_extension.py b/tests/test_starlette_extension.py index 7f21c1d..0e26baa 100644 --- a/tests/test_starlette_extension.py +++ b/tests/test_starlette_extension.py @@ -43,6 +43,61 @@ class TestDecorators(TestSlowapi): if i < 5: assert response.text == "test" + def test_exempt_when_argument(self, build_starlette_app): + app, limiter = build_starlette_app(key_func=get_ipaddr) + + def return_true(): + return True + + def return_false(): + return False + + def dynamic(request: Request): + user_agent = request.headers.get("User-Agent") + if user_agent is None: + return False + return user_agent == "exempt" + + @limiter.limit("1/minute", exempt_when=return_true) + def always_true(request: Request): + return PlainTextResponse("test") + + @limiter.limit("1/minute", exempt_when=return_false) + def always_false(request: Request): + return PlainTextResponse("test") + + @limiter.limit("1/minute", exempt_when=dynamic) + def always_dynamic(request: Request): + return PlainTextResponse("test") + + app.add_route("/true", always_true) + app.add_route("/false", always_false) + app.add_route("/dynamic", always_dynamic) + + client = TestClient(app) + # Test always true always exempting + for i in range(0, 2): + response = client.get("/true") + assert response.status_code == 200 + assert response.text == "test" + # Test always false hitting the limit after one hit + for i in range(0, 2): + response = client.get("/false") + assert response.status_code == 200 if i < 1 else 429 + if i < 1: + assert response.text == "test" + # Test dynamic not exempting with the correct header + for i in range(0, 2): + response = client.get("/dynamic", headers={"User-Agent": "exempt"}) + assert response.status_code == 200 + assert response.text == "test" + # Test dynamic exempting with the incorrect header + for i in range(0, 2): + response = client.get("/dynamic") + assert response.status_code == 200 if i < 1 else 429 + if i < 1: + assert response.text == "test" + def test_shared_decorator(self, build_starlette_app): app, limiter = build_starlette_app(key_func=get_ipaddr)