🎨 format with black

This commit is contained in:
Reuben Thomas-Davis
2020-10-01 19:13:50 +01:00
parent 3753de763c
commit 32205c5b25
4 changed files with 44 additions and 20 deletions

View File

@@ -186,15 +186,27 @@ class Limiter:
for limit in set(default_limits):
self._default_limits.extend(
[LimitGroup(limit, self._key_func, None, False, None, None, None, False)]
[
LimitGroup(
limit, self._key_func, None, False, None, None, None, False
)
]
)
for limit in application_limits:
self._application_limits.extend(
[LimitGroup(limit, self._key_func, "global", False, None, None, None, False)]
[
LimitGroup(
limit, self._key_func, "global", False, None, None, None, False
)
]
)
for limit in in_memory_fallback:
self._in_memory_fallback.extend(
[LimitGroup(limit, self._key_func, None, False, None, None, None, False)]
[
LimitGroup(
limit, self._key_func, None, False, None, None, None, False
)
]
)
self._route_limits: Dict[str, List[Limit]] = {}
self._dynamic_route_limits: Dict[str, List[LimitGroup]] = {}
@@ -496,9 +508,11 @@ class Limiter:
combined_defaults = all(
not limit.override_defaults for limit in route_limits
)
if not route_limits and not (
in_middleware and name in self.__marked_for_limiting
) or combined_defaults:
if (
not route_limits
and not (in_middleware and name in self.__marked_for_limiting)
or combined_defaults
):
all_limits += list(itertools.chain(*self._default_limits))
# actually check the limits, so far we've only computed the list of limits to check
self.__evaluate_limits(request, endpoint, all_limits)
@@ -528,7 +542,7 @@ class Limiter:
methods: Optional[List[str]] = None,
error_message: Optional[str] = None,
exempt_when: Optional[Callable[..., bool]] = None,
override_defaults: bool = True
override_defaults: bool = True,
) -> Callable[..., Any]:
_scope = scope if shared else None
@@ -638,7 +652,7 @@ class Limiter:
methods: Optional[List[str]] = None,
error_message: Optional[str] = None,
exempt_when: Optional[Callable[..., bool]] = None,
override_defaults: bool = True
override_defaults: bool = True,
) -> Callable:
"""
Decorator to be used for rate limiting individual routes.
@@ -664,7 +678,7 @@ class Limiter:
methods=methods,
error_message=error_message,
exempt_when=exempt_when,
override_defaults=override_defaults
override_defaults=override_defaults,
)
def shared_limit(
@@ -716,4 +730,4 @@ class Limiter:
return obj(*a, **k)
self._exempt_routes.add(name)
return __inner
return __inner

View File

@@ -1,5 +1,9 @@
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint, DispatchFunction
from starlette.middleware.base import (
BaseHTTPMiddleware,
RequestResponseEndpoint,
DispatchFunction,
)
from starlette.requests import Request
from starlette.responses import Response
@@ -8,7 +12,7 @@ from slowapi import Limiter, _rate_limit_exceeded_handler
class SlowAPIMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
app: Starlette = request.app
limiter: Limiter = app.state.limiter
@@ -32,13 +36,15 @@ class SlowAPIMiddleware(BaseHTTPMiddleware):
# let the decorator handle if already in
if limiter._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
request.state, "_rate_limiting_complete", False
):
try:
limiter._check_request_limit(request, handler, True)
except Exception as e:
# handle the exception since the global exception handler won't pick it up if we call_next
exception_handler = app.exception_handlers.get(type(e), _rate_limit_exceeded_handler)
exception_handler = app.exception_handlers.get(
type(e), _rate_limit_exceeded_handler
)
return exception_handler(request, e)
# request.state._rate_limiting_complete = True
response = await call_next(request)

View File

@@ -17,7 +17,7 @@ class Limit(object):
methods: Optional[List[str]],
error_message: Optional[Union[str, Callable[..., str]]],
exempt_when: Optional[Callable[..., bool]],
override_defaults: bool
override_defaults: bool,
) -> None:
self.limit = limit
self.key_func = key_func

View File

@@ -171,8 +171,9 @@ class TestDecorators(TestSlowapi):
def test_exempt_decorator(self):
app, limiter = self.build_starlette_app(
headers_enabled=True, key_func=get_remote_address, default_limits=["1/minute"]
headers_enabled=True,
key_func=get_remote_address,
default_limits=["1/minute"],
)
@app.route("/t1")
@@ -199,7 +200,9 @@ class TestDecorators(TestSlowapi):
# todo: more tests - see https://github.com/alisaifee/flask-limiter/blob/55df08f14143a7e918fc033067a494248ab6b0c5/tests/test_decorators.py#L187
def test_default_and_decorator_limit_merging(self):
# test pool has 100 reqs left
app, limiter = self.build_starlette_app(key_func=lambda: "test", default_limits=["10/minute"])
app, limiter = self.build_starlette_app(
key_func=lambda: "test", default_limits=["10/minute"]
)
# ip pool has 50 reqs for 127.0.0.14
@limiter.limit("5 per minute", key_func=get_ipaddr, override_defaults=False)
@@ -220,5 +223,6 @@ class TestDecorators(TestSlowapi):
assert cli.get("/t1").status_code == 429
assert (
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
== 429)
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
== 429
)