Merge pull request #80 from maratsarbasov/master

Implement dynamic limits depending on key.
This commit is contained in:
Trevor Currie
2022-01-18 16:09:29 -08:00
committed by GitHub
3 changed files with 47 additions and 6 deletions

View File

@@ -503,7 +503,7 @@ class Limiter:
if name in self._dynamic_route_limits:
for lim in self._dynamic_route_limits[name]:
try:
dynamic_limits.extend(list(lim))
dynamic_limits.extend(list(lim.with_request(request)))
except ValueError as e:
self.logger.error(
"failed to load ratelimit for view function %s (%s)",

View File

@@ -1,3 +1,4 @@
import inspect
from typing import Callable, Iterator, List, Optional, Union
from limits import RateLimitItem, parse_many # type: ignore
@@ -74,13 +75,22 @@ class LimitGroup(object):
self.error_message = error_message
self.exempt_when = exempt_when
self.override_defaults = override_defaults
self.request = None
def __iter__(self) -> Iterator[Limit]:
limit_items: List[RateLimitItem] = parse_many(
self.__limit_provider()
if callable(self.__limit_provider)
else self.__limit_provider
)
if callable(self.__limit_provider):
if "key" in inspect.signature(self.__limit_provider).parameters.keys():
assert (
"request" in inspect.signature(self.key_function).parameters.keys()
), f"Limit provider function {self.key_function.__name__} needs a `request` argument"
if self.request is None:
raise Exception("`request` object can't be None")
limit_raw = self.__limit_provider(self.key_function(self.request))
else:
limit_raw = self.__limit_provider()
else:
limit_raw = self.__limit_provider
limit_items: List[RateLimitItem] = parse_many(limit_raw)
for limit in limit_items:
yield Limit(
limit,
@@ -92,3 +102,7 @@ class LimitGroup(object):
self.exempt_when,
self.override_defaults,
)
def with_request(self, request):
self.request = request
return self

View File

@@ -232,6 +232,33 @@ class TestDecorators(TestSlowapi):
r"""parameter `response` must be an instance of starlette.responses.Response"""
)
def test_dynamic_limit_provider_depending_on_key(self):
def custom_key_func(request: Request):
if request.headers.get("TOKEN") == "secret":
return "admin"
return "user"
def dynamic_limit_provider(key: str):
if key == "admin":
return "10/minute"
return "5/minute"
app, limiter = self.build_fastapi_app(key_func=custom_key_func)
@app.get("/t1")
@limiter.limit(dynamic_limit_provider)
async def t1(request: Request, response: Response):
return {"key": "value"}
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
for i in range(0, 20):
response = client.get("/t1", headers={"TOKEN": "secret"})
assert response.status_code == 200 if i < 10 else 429
def test_disabled_limiter(self):
"""
Check that the limiter does nothing if disabled (both sync and async)