mirror of
https://github.com/laurentS/slowapi.git
synced 2026-03-13 09:10:20 +08:00
🎨 format with black
This commit is contained in:
@@ -186,15 +186,27 @@ class Limiter:
|
||||
|
||||
for limit in set(default_limits):
|
||||
self._default_limits.extend(
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None, False)]
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, None, False, None, None, None, False
|
||||
)
|
||||
]
|
||||
)
|
||||
for limit in application_limits:
|
||||
self._application_limits.extend(
|
||||
[LimitGroup(limit, self._key_func, "global", False, None, None, None, False)]
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, "global", False, None, None, None, False
|
||||
)
|
||||
]
|
||||
)
|
||||
for limit in in_memory_fallback:
|
||||
self._in_memory_fallback.extend(
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None, False)]
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, None, False, None, None, None, False
|
||||
)
|
||||
]
|
||||
)
|
||||
self._route_limits: Dict[str, List[Limit]] = {}
|
||||
self._dynamic_route_limits: Dict[str, List[LimitGroup]] = {}
|
||||
@@ -496,9 +508,11 @@ class Limiter:
|
||||
combined_defaults = all(
|
||||
not limit.override_defaults for limit in route_limits
|
||||
)
|
||||
if not route_limits and not (
|
||||
in_middleware and name in self.__marked_for_limiting
|
||||
) or combined_defaults:
|
||||
if (
|
||||
not route_limits
|
||||
and not (in_middleware and name in self.__marked_for_limiting)
|
||||
or combined_defaults
|
||||
):
|
||||
all_limits += list(itertools.chain(*self._default_limits))
|
||||
# actually check the limits, so far we've only computed the list of limits to check
|
||||
self.__evaluate_limits(request, endpoint, all_limits)
|
||||
@@ -528,7 +542,7 @@ class Limiter:
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True
|
||||
override_defaults: bool = True,
|
||||
) -> Callable[..., Any]:
|
||||
|
||||
_scope = scope if shared else None
|
||||
@@ -638,7 +652,7 @@ class Limiter:
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True
|
||||
override_defaults: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to be used for rate limiting individual routes.
|
||||
@@ -664,7 +678,7 @@ class Limiter:
|
||||
methods=methods,
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
override_defaults=override_defaults
|
||||
override_defaults=override_defaults,
|
||||
)
|
||||
|
||||
def shared_limit(
|
||||
@@ -716,4 +730,4 @@ class Limiter:
|
||||
return obj(*a, **k)
|
||||
|
||||
self._exempt_routes.add(name)
|
||||
return __inner
|
||||
return __inner
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint, DispatchFunction
|
||||
from starlette.middleware.base import (
|
||||
BaseHTTPMiddleware,
|
||||
RequestResponseEndpoint,
|
||||
DispatchFunction,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
@@ -8,7 +12,7 @@ from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
|
||||
class SlowAPIMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
app: Starlette = request.app
|
||||
limiter: Limiter = app.state.limiter
|
||||
@@ -32,13 +36,15 @@ class SlowAPIMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# let the decorator handle if already in
|
||||
if limiter._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
try:
|
||||
limiter._check_request_limit(request, handler, True)
|
||||
except Exception as e:
|
||||
# handle the exception since the global exception handler won't pick it up if we call_next
|
||||
exception_handler = app.exception_handlers.get(type(e), _rate_limit_exceeded_handler)
|
||||
exception_handler = app.exception_handlers.get(
|
||||
type(e), _rate_limit_exceeded_handler
|
||||
)
|
||||
return exception_handler(request, e)
|
||||
# request.state._rate_limiting_complete = True
|
||||
response = await call_next(request)
|
||||
|
||||
@@ -17,7 +17,7 @@ class Limit(object):
|
||||
methods: Optional[List[str]],
|
||||
error_message: Optional[Union[str, Callable[..., str]]],
|
||||
exempt_when: Optional[Callable[..., bool]],
|
||||
override_defaults: bool
|
||||
override_defaults: bool,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.key_func = key_func
|
||||
|
||||
@@ -171,8 +171,9 @@ class TestDecorators(TestSlowapi):
|
||||
|
||||
def test_exempt_decorator(self):
|
||||
app, limiter = self.build_starlette_app(
|
||||
headers_enabled=True, key_func=get_remote_address, default_limits=["1/minute"]
|
||||
|
||||
headers_enabled=True,
|
||||
key_func=get_remote_address,
|
||||
default_limits=["1/minute"],
|
||||
)
|
||||
|
||||
@app.route("/t1")
|
||||
@@ -199,7 +200,9 @@ class TestDecorators(TestSlowapi):
|
||||
# todo: more tests - see https://github.com/alisaifee/flask-limiter/blob/55df08f14143a7e918fc033067a494248ab6b0c5/tests/test_decorators.py#L187
|
||||
def test_default_and_decorator_limit_merging(self):
|
||||
# test pool has 100 reqs left
|
||||
app, limiter = self.build_starlette_app(key_func=lambda: "test", default_limits=["10/minute"])
|
||||
app, limiter = self.build_starlette_app(
|
||||
key_func=lambda: "test", default_limits=["10/minute"]
|
||||
)
|
||||
|
||||
# ip pool has 50 reqs for 127.0.0.14
|
||||
@limiter.limit("5 per minute", key_func=get_ipaddr, override_defaults=False)
|
||||
@@ -220,5 +223,6 @@ class TestDecorators(TestSlowapi):
|
||||
|
||||
assert cli.get("/t1").status_code == 429
|
||||
assert (
|
||||
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
|
||||
== 429)
|
||||
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
|
||||
== 429
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user