mirror of
https://github.com/laurentS/slowapi.git
synced 2026-02-04 11:54:29 +08:00
✨ add override_defaults parameter and fix merge test
This commit is contained in:
@@ -164,15 +164,15 @@ class Limiter:
|
||||
|
||||
for limit in set(default_limits):
|
||||
self._default_limits.extend(
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None)]
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None, False)]
|
||||
)
|
||||
for limit in application_limits:
|
||||
self._application_limits.extend(
|
||||
[LimitGroup(limit, self._key_func, "global", False, None, None, None)]
|
||||
[LimitGroup(limit, self._key_func, "global", False, None, None, None, False)]
|
||||
)
|
||||
for limit in in_memory_fallback:
|
||||
self._in_memory_fallback.extend(
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None)]
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None, False)]
|
||||
)
|
||||
self._route_limits: Dict[str, List[Limit]] = {}
|
||||
self._dynamic_route_limits: Dict[str, List[LimitGroup]] = {}
|
||||
@@ -471,9 +471,12 @@ class Limiter:
|
||||
else []
|
||||
)
|
||||
all_limits += route_limits
|
||||
combined_defaults = all(
|
||||
not limit.override_defaults for limit in route_limits
|
||||
)
|
||||
if not route_limits and not (
|
||||
in_middleware and name in self.__marked_for_limiting
|
||||
):
|
||||
) or combined_defaults:
|
||||
all_limits += list(itertools.chain(*self._default_limits))
|
||||
# actually check the limits, so far we've only computed the list of limits to check
|
||||
self.__evaluate_limits(request, endpoint, all_limits)
|
||||
@@ -503,6 +506,7 @@ class Limiter:
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True
|
||||
) -> Callable[..., Any]:
|
||||
|
||||
_scope = scope if shared else None
|
||||
@@ -521,6 +525,7 @@ class Limiter:
|
||||
methods,
|
||||
error_message,
|
||||
exempt_when,
|
||||
override_defaults,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -533,6 +538,7 @@ class Limiter:
|
||||
methods,
|
||||
error_message,
|
||||
exempt_when,
|
||||
override_defaults,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
@@ -604,6 +610,7 @@ class Limiter:
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
decorator to be used for rate limiting individual routes.
|
||||
@@ -619,6 +626,7 @@ class Limiter:
|
||||
:param error_message: string (or callable that returns one) to override the
|
||||
error message used in the response.
|
||||
:param exempt_when:
|
||||
:param override_defaults: whether to override the default limits (default: True)
|
||||
:return:
|
||||
"""
|
||||
return self.__limit_decorator(
|
||||
@@ -628,6 +636,7 @@ class Limiter:
|
||||
methods=methods,
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
override_defaults=override_defaults
|
||||
)
|
||||
|
||||
def shared_limit(
|
||||
@@ -637,6 +646,7 @@ class Limiter:
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
decorator to be applied to multiple routes sharing the same rate limit.
|
||||
@@ -650,6 +660,7 @@ class Limiter:
|
||||
:param error_message: string (or callable that returns one) to override the
|
||||
error message used in the response.
|
||||
:param exempt_when:
|
||||
:param override_defaults: whether to override the default limits (default: True)
|
||||
"""
|
||||
return self.__limit_decorator(
|
||||
limit_value,
|
||||
@@ -658,6 +669,7 @@ class Limiter:
|
||||
scope,
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
override_defaults=override_defaults,
|
||||
)
|
||||
|
||||
def exempt(self, obj):
|
||||
|
||||
@@ -17,6 +17,7 @@ class Limit(object):
|
||||
methods: Optional[List[str]],
|
||||
error_message: Optional[Union[str, Callable[..., str]]],
|
||||
exempt_when: Optional[Callable[..., bool]],
|
||||
override_defaults: bool
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.key_func = key_func
|
||||
@@ -25,6 +26,7 @@ class Limit(object):
|
||||
self.methods = methods
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self.override_defaults = override_defaults
|
||||
|
||||
@property
|
||||
def is_exempt(self) -> bool:
|
||||
@@ -62,6 +64,7 @@ class LimitGroup(object):
|
||||
methods: Optional[List[str]],
|
||||
error_message: Optional[Union[str, Callable[..., str]]],
|
||||
exempt_when: Optional[Callable[..., bool]],
|
||||
override_defaults: bool,
|
||||
):
|
||||
self.__limit_provider = limit_provider
|
||||
self.__scope = scope
|
||||
@@ -70,6 +73,7 @@ class LimitGroup(object):
|
||||
self.methods = methods and [m.lower() for m in methods] or methods
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self.override_defaults = override_defaults
|
||||
|
||||
def __iter__(self) -> Iterator[Limit]:
|
||||
limit_items: List[RateLimitItem] = parse_many(
|
||||
@@ -86,4 +90,5 @@ class LimitGroup(object):
|
||||
self.methods,
|
||||
self.error_message,
|
||||
self.exempt_when,
|
||||
self.override_defaults,
|
||||
)
|
||||
|
||||
@@ -70,8 +70,8 @@ class TestDecorators(TestSlowapi):
|
||||
def test_multiple_decorators(self):
|
||||
app, limiter = self.build_starlette_app(key_func=get_ipaddr)
|
||||
|
||||
@limiter.limit("100 per minute", lambda: "test")
|
||||
@limiter.limit("50/minute") # per ip as per default key_func
|
||||
@limiter.limit("10 per minute", lambda: "test")
|
||||
@limiter.limit("5/minute") # per ip as per default key_func
|
||||
async def t1(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
@@ -79,16 +79,16 @@ class TestDecorators(TestSlowapi):
|
||||
|
||||
with hiro.Timeline().freeze() as timeline:
|
||||
cli = TestClient(app)
|
||||
for i in range(0, 100):
|
||||
for i in range(0, 10):
|
||||
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
|
||||
assert response.status_code == 200 if i < 50 else 429
|
||||
for i in range(50):
|
||||
assert response.status_code == 200 if i < 5 else 429
|
||||
for i in range(5):
|
||||
assert cli.get("/t1").status_code == 200
|
||||
|
||||
assert cli.get("/t1").status_code == 429
|
||||
assert (
|
||||
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
|
||||
== 429
|
||||
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
|
||||
== 429
|
||||
)
|
||||
|
||||
def test_headers_no_breach(self):
|
||||
@@ -196,10 +196,13 @@ class TestDecorators(TestSlowapi):
|
||||
resp2 = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.10"})
|
||||
assert resp2.status_code == 200
|
||||
|
||||
# todo: more tests - see https://github.com/alisaifee/flask-limiter/blob/55df08f14143a7e918fc033067a494248ab6b0c5/tests/test_decorators.py#L187
|
||||
def test_default_and_decorator_limit_merging(self):
|
||||
app, limiter = self.build_starlette_app(key_func=get_ipaddr, default_limits=["50/minute"])
|
||||
# test pool has 100 reqs left
|
||||
app, limiter = self.build_starlette_app(key_func=lambda: "test", default_limits=["10/minute"])
|
||||
|
||||
@limiter.limit("100 per minute", key_func=lambda: "lest")
|
||||
# ip pool has 50 reqs for 127.0.0.14
|
||||
@limiter.limit("5 per minute", key_func=get_ipaddr, override_defaults=False)
|
||||
async def t1(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
@@ -207,13 +210,15 @@ class TestDecorators(TestSlowapi):
|
||||
|
||||
with hiro.Timeline().freeze() as timeline:
|
||||
cli = TestClient(app)
|
||||
for i in range(0, 100):
|
||||
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.14"})
|
||||
assert response.status_code == 200 if i < 50 else 429
|
||||
for i in range(50):
|
||||
for i in range(0, 10):
|
||||
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
|
||||
assert response.status_code == 200 if i < 5 else 429
|
||||
# now ip pool for 127.0.0.14 has 0 reqs left
|
||||
for i in range(5):
|
||||
assert cli.get("/t1").status_code == 200
|
||||
# now test pool has 0 reqs left
|
||||
|
||||
assert cli.get("/t1").status_code == 429
|
||||
assert (
|
||||
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.15"}).status_code
|
||||
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
|
||||
== 429)
|
||||
|
||||
Reference in New Issue
Block a user