mirror of
https://github.com/laurentS/slowapi.git
synced 2026-03-13 09:10:20 +08:00
Merge pull request #7 from Rested/master
add middleware and exempt decorator
This commit is contained in:
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# testing
|
||||
.pytest_cache
|
||||
__pycache__
|
||||
|
||||
# typing
|
||||
.mypy_cache
|
||||
|
||||
# editors
|
||||
.idea
|
||||
@@ -186,15 +186,27 @@ class Limiter:
|
||||
|
||||
for limit in set(default_limits):
|
||||
self._default_limits.extend(
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None)]
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, None, False, None, None, None, False
|
||||
)
|
||||
]
|
||||
)
|
||||
for limit in application_limits:
|
||||
self._application_limits.extend(
|
||||
[LimitGroup(limit, self._key_func, "global", False, None, None, None)]
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, "global", False, None, None, None, False
|
||||
)
|
||||
]
|
||||
)
|
||||
for limit in in_memory_fallback:
|
||||
self._in_memory_fallback.extend(
|
||||
[LimitGroup(limit, self._key_func, None, False, None, None, None)]
|
||||
[
|
||||
LimitGroup(
|
||||
limit, self._key_func, None, False, None, None, None, False
|
||||
)
|
||||
]
|
||||
)
|
||||
self._route_limits: Dict[str, List[Limit]] = {}
|
||||
self._dynamic_route_limits: Dict[str, List[LimitGroup]] = {}
|
||||
@@ -257,7 +269,7 @@ class Limiter:
|
||||
if not self._application_limits and app_limits:
|
||||
self._application_limits = [
|
||||
LimitGroup(
|
||||
app_limits, self._key_func, "global", False, None, None, None
|
||||
app_limits, self._key_func, "global", False, None, None, None, False
|
||||
)
|
||||
]
|
||||
|
||||
@@ -266,7 +278,9 @@ class Limiter:
|
||||
)
|
||||
if not self._default_limits and conf_limits:
|
||||
self._default_limits = [
|
||||
LimitGroup(conf_limits, self._key_func, None, False, None, None, None)
|
||||
LimitGroup(
|
||||
conf_limits, self._key_func, None, False, None, None, None, False
|
||||
)
|
||||
]
|
||||
fallback_enabled = self.get_app_config(C.IN_MEMORY_FALLBACK_ENABLED, False)
|
||||
fallback_limits: Optional[StrOrCallableStr] = self.get_app_config(
|
||||
@@ -275,7 +289,14 @@ class Limiter:
|
||||
if not self._in_memory_fallback and fallback_limits:
|
||||
self._in_memory_fallback = [
|
||||
LimitGroup(
|
||||
fallback_limits, self._key_func, None, False, None, None, None
|
||||
fallback_limits,
|
||||
self._key_func,
|
||||
None,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
]
|
||||
if not self._in_memory_fallback_enabled:
|
||||
@@ -432,7 +453,7 @@ class Limiter:
|
||||
if failed_limit:
|
||||
raise RateLimitExceeded(failed_limit)
|
||||
|
||||
def __check_request_limit(
|
||||
def _check_request_limit(
|
||||
self,
|
||||
request: Request,
|
||||
endpoint_func: Callable[..., Any],
|
||||
@@ -493,8 +514,13 @@ class Limiter:
|
||||
else []
|
||||
)
|
||||
all_limits += route_limits
|
||||
if not route_limits and not (
|
||||
in_middleware and name in self.__marked_for_limiting
|
||||
combined_defaults = all(
|
||||
not limit.override_defaults for limit in route_limits
|
||||
)
|
||||
if (
|
||||
not route_limits
|
||||
and not (in_middleware and name in self.__marked_for_limiting)
|
||||
or combined_defaults
|
||||
):
|
||||
all_limits += list(itertools.chain(*self._default_limits))
|
||||
# actually check the limits, so far we've only computed the list of limits to check
|
||||
@@ -508,7 +534,7 @@ class Limiter:
|
||||
" in-memory storage"
|
||||
)
|
||||
self._storage_dead = True
|
||||
self.__check_request_limit(request, endpoint_func, in_middleware)
|
||||
self._check_request_limit(request, endpoint_func, in_middleware)
|
||||
else:
|
||||
if self._swallow_errors:
|
||||
self.logger.exception("Failed to rate limit. Swallowing error")
|
||||
@@ -525,6 +551,7 @@ class Limiter:
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True,
|
||||
) -> Callable[..., Any]:
|
||||
|
||||
_scope = scope if shared else None
|
||||
@@ -543,6 +570,7 @@ class Limiter:
|
||||
methods,
|
||||
error_message,
|
||||
exempt_when,
|
||||
override_defaults,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -555,6 +583,7 @@ class Limiter:
|
||||
methods,
|
||||
error_message,
|
||||
exempt_when,
|
||||
override_defaults,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
@@ -592,7 +621,7 @@ class Limiter:
|
||||
if self._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
self.__check_request_limit(request, func, False)
|
||||
self._check_request_limit(request, func, False)
|
||||
request.state._rate_limiting_complete = True
|
||||
response = await func(*args, **kwargs) # type: ignore
|
||||
self._inject_headers(response, request.state.view_rate_limit)
|
||||
@@ -614,7 +643,7 @@ class Limiter:
|
||||
if self._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
self.__check_request_limit(request, func, False)
|
||||
self._check_request_limit(request, func, False)
|
||||
request.state._rate_limiting_complete = True
|
||||
response = func(*args, **kwargs)
|
||||
self._inject_headers(response, request.state.view_rate_limit)
|
||||
@@ -632,6 +661,7 @@ class Limiter:
|
||||
methods: Optional[List[str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to be used for rate limiting individual routes.
|
||||
@@ -648,6 +678,7 @@ class Limiter:
|
||||
error message used in the response.
|
||||
* **exempt_when**: function returning a boolean indicating whether to exempt
|
||||
the route from the limit
|
||||
* **override_defaults**: whether to override the default limits (default: True)
|
||||
"""
|
||||
return self.__limit_decorator(
|
||||
limit_value,
|
||||
@@ -656,6 +687,7 @@ class Limiter:
|
||||
methods=methods,
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
override_defaults=override_defaults,
|
||||
)
|
||||
|
||||
def shared_limit(
|
||||
@@ -665,6 +697,7 @@ class Limiter:
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
error_message: Optional[str] = None,
|
||||
exempt_when: Optional[Callable[..., bool]] = None,
|
||||
override_defaults: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to be applied to multiple routes sharing the same rate limit.
|
||||
@@ -683,6 +716,7 @@ class Limiter:
|
||||
error message used in the response.
|
||||
* **exempt_when**: function returning a boolean indicating whether to exempt
|
||||
the route from the limit
|
||||
* **override_defaults**: whether to override the default limits (default: True)
|
||||
"""
|
||||
return self.__limit_decorator(
|
||||
limit_value,
|
||||
@@ -691,4 +725,18 @@ class Limiter:
|
||||
scope,
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
override_defaults=override_defaults,
|
||||
)
|
||||
|
||||
def exempt(self, obj):
|
||||
"""
|
||||
Decorator to mark a view as exempt from rate limits.
|
||||
"""
|
||||
name = "%s.%s" % (obj.__module__, obj.__name__)
|
||||
|
||||
@wraps(obj)
|
||||
def __inner(*a, **k):
|
||||
return obj(*a, **k)
|
||||
|
||||
self._exempt_routes.add(name)
|
||||
return __inner
|
||||
|
||||
56
slowapi/middleware.py
Normal file
56
slowapi/middleware.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Union
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.base import (
|
||||
BaseHTTPMiddleware,
|
||||
RequestResponseEndpoint,
|
||||
DispatchFunction,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Route, BaseRoute, WebSocketRoute
|
||||
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
|
||||
|
||||
class SlowAPIMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
app: Starlette = request.app
|
||||
limiter: Limiter = app.state.limiter
|
||||
handler = None
|
||||
for route in app.routes:
|
||||
match, _ = route.matches(request.scope)
|
||||
if match.FULL and hasattr(route, "endpoint"):
|
||||
handler = route.endpoint # type: ignore
|
||||
# if we can't find the route handler
|
||||
if handler is None:
|
||||
return await call_next(request)
|
||||
|
||||
name = "%s.%s" % (handler.__module__, handler.__name__)
|
||||
# if exempt no need to check
|
||||
if name in limiter._exempt_routes:
|
||||
return await call_next(request)
|
||||
|
||||
# there is a decorator for this route we let the decorator handle it
|
||||
if name in limiter._route_limits:
|
||||
return await call_next(request)
|
||||
|
||||
# let the decorator handle if already in
|
||||
if limiter._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
try:
|
||||
limiter._check_request_limit(request, handler, True)
|
||||
except Exception as e:
|
||||
# handle the exception since the global exception handler won't pick it up if we call_next
|
||||
exception_handler = app.exception_handlers.get(
|
||||
type(e), _rate_limit_exceeded_handler
|
||||
)
|
||||
return exception_handler(request, e)
|
||||
# request.state._rate_limiting_complete = True
|
||||
response = await call_next(request)
|
||||
response = limiter._inject_headers(response, request.state.view_rate_limit)
|
||||
return response
|
||||
return await call_next(request)
|
||||
@@ -17,6 +17,7 @@ class Limit(object):
|
||||
methods: Optional[List[str]],
|
||||
error_message: Optional[Union[str, Callable[..., str]]],
|
||||
exempt_when: Optional[Callable[..., bool]],
|
||||
override_defaults: bool,
|
||||
) -> None:
|
||||
self.limit = limit
|
||||
self.key_func = key_func
|
||||
@@ -25,6 +26,7 @@ class Limit(object):
|
||||
self.methods = methods
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self.override_defaults = override_defaults
|
||||
|
||||
@property
|
||||
def is_exempt(self) -> bool:
|
||||
@@ -62,6 +64,7 @@ class LimitGroup(object):
|
||||
methods: Optional[List[str]],
|
||||
error_message: Optional[Union[str, Callable[..., str]]],
|
||||
exempt_when: Optional[Callable[..., bool]],
|
||||
override_defaults: bool,
|
||||
):
|
||||
self.__limit_provider = limit_provider
|
||||
self.__scope = scope
|
||||
@@ -70,6 +73,7 @@ class LimitGroup(object):
|
||||
self.methods = methods and [m.lower() for m in methods] or methods
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self.override_defaults = override_defaults
|
||||
|
||||
def __iter__(self) -> Iterator[Limit]:
|
||||
limit_items: List[RateLimitItem] = parse_many(
|
||||
@@ -86,4 +90,5 @@ class LimitGroup(object):
|
||||
self.methods,
|
||||
self.error_message,
|
||||
self.exempt_when,
|
||||
self.override_defaults,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from starlette.applications import Starlette
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.extension import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
||||
|
||||
class TestSlowapi:
|
||||
@@ -20,6 +21,7 @@ class TestSlowapi:
|
||||
app = Starlette(debug=True)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
||||
mock_handler = mock.Mock()
|
||||
mock_handler.level = logging.INFO
|
||||
@@ -30,6 +32,9 @@ class TestSlowapi:
|
||||
limiter_args.setdefault("key_func", get_remote_address)
|
||||
limiter = Limiter(**limiter_args)
|
||||
app = FastAPI()
|
||||
app.state.limiter = limiter
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
||||
mock_handler = mock.Mock()
|
||||
mock_handler.level = logging.INFO
|
||||
limiter.logger.addHandler(mock_handler)
|
||||
|
||||
@@ -70,8 +70,8 @@ class TestDecorators(TestSlowapi):
|
||||
def test_multiple_decorators(self):
|
||||
app, limiter = self.build_starlette_app(key_func=get_ipaddr)
|
||||
|
||||
@limiter.limit("100 per minute", lambda: "test")
|
||||
@limiter.limit("50/minute") # per ip as per default key_func
|
||||
@limiter.limit("10 per minute", lambda: "test")
|
||||
@limiter.limit("5/minute") # per ip as per default key_func
|
||||
async def t1(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
@@ -79,10 +79,10 @@ class TestDecorators(TestSlowapi):
|
||||
|
||||
with hiro.Timeline().freeze() as timeline:
|
||||
cli = TestClient(app)
|
||||
for i in range(0, 100):
|
||||
for i in range(0, 10):
|
||||
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
|
||||
assert response.status_code == 200 if i < 50 else 429
|
||||
for i in range(50):
|
||||
assert response.status_code == 200 if i < 5 else 429
|
||||
for i in range(5):
|
||||
assert cli.get("/t1").status_code == 200
|
||||
|
||||
assert cli.get("/t1").status_code == 429
|
||||
@@ -168,3 +168,57 @@ class TestDecorators(TestSlowapi):
|
||||
timeline.forward(retry_after)
|
||||
resp = cli.get("/t1")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_exempt_decorator(self):
|
||||
app, limiter = self.build_starlette_app(
|
||||
headers_enabled=True,
|
||||
key_func=get_remote_address,
|
||||
default_limits=["1/minute"],
|
||||
)
|
||||
|
||||
@app.route("/t1")
|
||||
def t(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
with TestClient(app) as cli:
|
||||
resp = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"})
|
||||
assert resp.status_code == 200
|
||||
resp2 = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"})
|
||||
assert resp2.status_code == 429
|
||||
|
||||
@app.route("/t2")
|
||||
@limiter.exempt
|
||||
def t(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
with TestClient(app) as cli:
|
||||
resp = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.10"})
|
||||
assert resp.status_code == 200
|
||||
resp2 = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.10"})
|
||||
assert resp2.status_code == 200
|
||||
|
||||
# todo: more tests - see https://github.com/alisaifee/flask-limiter/blob/55df08f14143a7e918fc033067a494248ab6b0c5/tests/test_decorators.py#L187
|
||||
def test_default_and_decorator_limit_merging(self):
|
||||
app, limiter = self.build_starlette_app(
|
||||
key_func=lambda: "test", default_limits=["10/minute"]
|
||||
)
|
||||
|
||||
@limiter.limit("5 per minute", key_func=get_ipaddr, override_defaults=False)
|
||||
async def t1(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
app.add_route("/t1", t1)
|
||||
|
||||
with hiro.Timeline().freeze() as timeline:
|
||||
cli = TestClient(app)
|
||||
for i in range(0, 10):
|
||||
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
|
||||
assert response.status_code == 200 if i < 5 else 429
|
||||
for i in range(5):
|
||||
assert cli.get("/t1").status_code == 200
|
||||
|
||||
assert cli.get("/t1").status_code == 429
|
||||
assert (
|
||||
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
|
||||
== 429
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user