From 33899c0ef1ce5f9ee13b51acda6ea83fa7a38038 Mon Sep 17 00:00:00 2001 From: Marat Sarbasov Date: Sat, 15 Jan 2022 16:02:55 +0300 Subject: [PATCH 1/4] Enable dynamic limits dependant on key. --- slowapi/extension.py | 2 +- slowapi/wrappers.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/slowapi/extension.py b/slowapi/extension.py index e3e33af..242b2c0 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -503,7 +503,7 @@ class Limiter: if name in self._dynamic_route_limits: for lim in self._dynamic_route_limits[name]: try: - dynamic_limits.extend(list(lim)) + dynamic_limits.extend(list(lim.with_request(request))) except ValueError as e: self.logger.error( "failed to load ratelimit for view function %s (%s)", diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index fd9b61d..6e78b41 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -1,3 +1,4 @@ +import inspect from typing import Callable, Iterator, List, Optional, Union from limits import RateLimitItem, parse_many # type: ignore @@ -74,13 +75,21 @@ class LimitGroup(object): self.error_message = error_message self.exempt_when = exempt_when self.override_defaults = override_defaults + self.request = None def __iter__(self) -> Iterator[Limit]: - limit_items: List[RateLimitItem] = parse_many( - self.__limit_provider() - if callable(self.__limit_provider) - else self.__limit_provider - ) + if callable(self.__limit_provider): + if "key" in inspect.signature(self.__limit_provider).parameters.keys(): + assert ( + "request" in inspect.signature(self.key_function).parameters.keys() + ) + assert self.request + limit_raw = self.__limit_provider(self.key_function(self.request)) + else: + limit_raw = self.__limit_provider() + else: + limit_raw = self.__limit_provider + limit_items: List[RateLimitItem] = parse_many(limit_raw) for limit in limit_items: yield Limit( limit, @@ -92,3 +101,7 @@ class LimitGroup(object): self.exempt_when, self.override_defaults, ) + + def with_request(self, request): + self.request = request + return self From 59d2cdd9de748289524376f5fae9fe7c3eaa4186 Mon Sep 17 00:00:00 2001 From: Marat Sarbasov Date: Sun, 16 Jan 2022 04:37:10 +0300 Subject: [PATCH 2/4] Add test_dynamic_limit_provider_depending_on_key --- tests/test_fastapi_extension.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_fastapi_extension.py b/tests/test_fastapi_extension.py index 60bb72d..5c837a0 100644 --- a/tests/test_fastapi_extension.py +++ b/tests/test_fastapi_extension.py @@ -232,6 +232,33 @@ class TestDecorators(TestSlowapi): r"""parameter `response` must be an instance of starlette.responses.Response""" ) + def test_dynamic_limit_provider_depending_on_key(self): + def custom_key_func(request: Request): + if request.headers.get("TOKEN") == "secret": + return "admin" + return "user" + + def dynamic_limit_provider(key: str): + if key == "admin": + return "10/minute" + return "5/minute" + + app, limiter = self.build_fastapi_app(key_func=custom_key_func) + + @app.get("/t1") + @limiter.limit(dynamic_limit_provider) + async def t1(request: Request, response: Response): + return {"key": "value"} + + client = TestClient(app) + for i in range(0, 10): + response = client.get("/t1") + assert response.status_code == 200 if i < 5 else 429 + + for i in range(0, 20): + response = client.get("/t1", headers={"TOKEN": "secret"}) + assert response.status_code == 200 if i < 10 else 429 + def test_disabled_limiter(self): """ Check that the limiter does nothing if disabled (both sync and async) From 5e78adde2fbd499c78bb0c558b21d150218faf19 Mon Sep 17 00:00:00 2001 From: Marat Sarbasov Date: Wed, 19 Jan 2022 00:51:52 +0300 Subject: [PATCH 3/4] Replace assert with raise ... --- slowapi/wrappers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index 6e78b41..a67cb8d 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -82,8 +82,9 @@ class LimitGroup(object): if "key" in inspect.signature(self.__limit_provider).parameters.keys(): assert ( "request" in inspect.signature(self.key_function).parameters.keys() - ) - assert self.request + ), "If limit provider function depends on `key` argument, key function must accept `request`" + if self.request is None: + raise Exception("`request` object can't be None") limit_raw = self.__limit_provider(self.key_function(self.request)) else: limit_raw = self.__limit_provider() From e994f06f19cccbb2ef17d277425047a87e47c646 Mon Sep 17 00:00:00 2001 From: Marat Sarbasov Date: Wed, 19 Jan 2022 02:37:18 +0300 Subject: [PATCH 4/4] Clear assert message Co-authored-by: Laurent Savaete --- slowapi/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index a67cb8d..b87d108 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -82,7 +82,7 @@ class LimitGroup(object): if "key" in inspect.signature(self.__limit_provider).parameters.keys(): assert ( "request" in inspect.signature(self.key_function).parameters.keys() - ), "If limit provider function depends on `key` argument, key function must accept `request`" + ), f"Limit provider function {self.key_function.__name__} needs a `request` argument" if self.request is None: raise Exception("`request` object can't be None") limit_raw = self.__limit_provider(self.key_function(self.request))