diff --git a/slowapi/extension.py b/slowapi/extension.py index e3e33af..242b2c0 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -503,7 +503,7 @@ class Limiter: if name in self._dynamic_route_limits: for lim in self._dynamic_route_limits[name]: try: - dynamic_limits.extend(list(lim)) + dynamic_limits.extend(list(lim.with_request(request))) except ValueError as e: self.logger.error( "failed to load ratelimit for view function %s (%s)", diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index fd9b61d..b87d108 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -1,3 +1,4 @@ +import inspect from typing import Callable, Iterator, List, Optional, Union from limits import RateLimitItem, parse_many # type: ignore @@ -74,13 +75,22 @@ class LimitGroup(object): self.error_message = error_message self.exempt_when = exempt_when self.override_defaults = override_defaults + self.request = None def __iter__(self) -> Iterator[Limit]: - limit_items: List[RateLimitItem] = parse_many( - self.__limit_provider() - if callable(self.__limit_provider) - else self.__limit_provider - ) + if callable(self.__limit_provider): + if "key" in inspect.signature(self.__limit_provider).parameters.keys(): + assert ( + "request" in inspect.signature(self.key_function).parameters.keys() + ), f"Limit provider function {self.key_function.__name__} needs a `request` argument" + if self.request is None: + raise Exception("`request` object can't be None") + limit_raw = self.__limit_provider(self.key_function(self.request)) + else: + limit_raw = self.__limit_provider() + else: + limit_raw = self.__limit_provider + limit_items: List[RateLimitItem] = parse_many(limit_raw) for limit in limit_items: yield Limit( limit, @@ -92,3 +102,7 @@ class LimitGroup(object): self.exempt_when, self.override_defaults, ) + + def with_request(self, request): + self.request = request + return self diff --git a/tests/test_fastapi_extension.py b/tests/test_fastapi_extension.py index 60bb72d..5c837a0 100644 --- a/tests/test_fastapi_extension.py +++ b/tests/test_fastapi_extension.py @@ -232,6 +232,33 @@ class TestDecorators(TestSlowapi): r"""parameter `response` must be an instance of starlette.responses.Response""" ) + def test_dynamic_limit_provider_depending_on_key(self): + def custom_key_func(request: Request): + if request.headers.get("TOKEN") == "secret": + return "admin" + return "user" + + def dynamic_limit_provider(key: str): + if key == "admin": + return "10/minute" + return "5/minute" + + app, limiter = self.build_fastapi_app(key_func=custom_key_func) + + @app.get("/t1") + @limiter.limit(dynamic_limit_provider) + async def t1(request: Request, response: Response): + return {"key": "value"} + + client = TestClient(app) + for i in range(0, 10): + response = client.get("/t1") + assert response.status_code == 200 if i < 5 else 429 + + for i in range(0, 20): + response = client.get("/t1", headers={"TOKEN": "secret"}) + assert response.status_code == 200 if i < 10 else 429 + def test_disabled_limiter(self): """ Check that the limiter does nothing if disabled (both sync and async)