mirror of
https://github.com/laurentS/slowapi.git
synced 2026-03-13 09:10:20 +08:00
Merge pull request #80 from maratsarbasov/master
Implement dynamic limits depending on key.
This commit is contained in:
@@ -503,7 +503,7 @@ class Limiter:
|
||||
if name in self._dynamic_route_limits:
|
||||
for lim in self._dynamic_route_limits[name]:
|
||||
try:
|
||||
dynamic_limits.extend(list(lim))
|
||||
dynamic_limits.extend(list(lim.with_request(request)))
|
||||
except ValueError as e:
|
||||
self.logger.error(
|
||||
"failed to load ratelimit for view function %s (%s)",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from typing import Callable, Iterator, List, Optional, Union
|
||||
|
||||
from limits import RateLimitItem, parse_many # type: ignore
|
||||
@@ -74,13 +75,22 @@ class LimitGroup(object):
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self.override_defaults = override_defaults
|
||||
self.request = None
|
||||
|
||||
def __iter__(self) -> Iterator[Limit]:
|
||||
limit_items: List[RateLimitItem] = parse_many(
|
||||
self.__limit_provider()
|
||||
if callable(self.__limit_provider)
|
||||
else self.__limit_provider
|
||||
)
|
||||
if callable(self.__limit_provider):
|
||||
if "key" in inspect.signature(self.__limit_provider).parameters.keys():
|
||||
assert (
|
||||
"request" in inspect.signature(self.key_function).parameters.keys()
|
||||
), f"Limit provider function {self.key_function.__name__} needs a `request` argument"
|
||||
if self.request is None:
|
||||
raise Exception("`request` object can't be None")
|
||||
limit_raw = self.__limit_provider(self.key_function(self.request))
|
||||
else:
|
||||
limit_raw = self.__limit_provider()
|
||||
else:
|
||||
limit_raw = self.__limit_provider
|
||||
limit_items: List[RateLimitItem] = parse_many(limit_raw)
|
||||
for limit in limit_items:
|
||||
yield Limit(
|
||||
limit,
|
||||
@@ -92,3 +102,7 @@ class LimitGroup(object):
|
||||
self.exempt_when,
|
||||
self.override_defaults,
|
||||
)
|
||||
|
||||
def with_request(self, request):
|
||||
self.request = request
|
||||
return self
|
||||
|
||||
@@ -232,6 +232,33 @@ class TestDecorators(TestSlowapi):
|
||||
r"""parameter `response` must be an instance of starlette.responses.Response"""
|
||||
)
|
||||
|
||||
def test_dynamic_limit_provider_depending_on_key(self):
|
||||
def custom_key_func(request: Request):
|
||||
if request.headers.get("TOKEN") == "secret":
|
||||
return "admin"
|
||||
return "user"
|
||||
|
||||
def dynamic_limit_provider(key: str):
|
||||
if key == "admin":
|
||||
return "10/minute"
|
||||
return "5/minute"
|
||||
|
||||
app, limiter = self.build_fastapi_app(key_func=custom_key_func)
|
||||
|
||||
@app.get("/t1")
|
||||
@limiter.limit(dynamic_limit_provider)
|
||||
async def t1(request: Request, response: Response):
|
||||
return {"key": "value"}
|
||||
|
||||
client = TestClient(app)
|
||||
for i in range(0, 10):
|
||||
response = client.get("/t1")
|
||||
assert response.status_code == 200 if i < 5 else 429
|
||||
|
||||
for i in range(0, 20):
|
||||
response = client.get("/t1", headers={"TOKEN": "secret"})
|
||||
assert response.status_code == 200 if i < 10 else 429
|
||||
|
||||
def test_disabled_limiter(self):
|
||||
"""
|
||||
Check that the limiter does nothing if disabled (both sync and async)
|
||||
|
||||
Reference in New Issue
Block a user