mirror of
https://github.com/laurentS/slowapi.git
synced 2026-02-04 11:54:29 +08:00
add middleware and exempt decorator
This commit is contained in:
@@ -410,7 +410,7 @@ class Limiter:
|
||||
if failed_limit:
|
||||
raise RateLimitExceeded(failed_limit)
|
||||
|
||||
def __check_request_limit(
|
||||
def _check_request_limit(
|
||||
self,
|
||||
request: Request,
|
||||
endpoint_func: Callable[..., Any],
|
||||
@@ -486,7 +486,7 @@ class Limiter:
|
||||
" in-memory storage"
|
||||
)
|
||||
self._storage_dead = True
|
||||
self.__check_request_limit(request, endpoint_func, in_middleware)
|
||||
self._check_request_limit(request, endpoint_func, in_middleware)
|
||||
else:
|
||||
if self._swallow_errors:
|
||||
self.logger.exception("Failed to rate limit. Swallowing error")
|
||||
@@ -567,7 +567,7 @@ class Limiter:
|
||||
if self._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
self.__check_request_limit(request, func, False)
|
||||
self._check_request_limit(request, func, False)
|
||||
request.state._rate_limiting_complete = True
|
||||
response = await func(*args, **kwargs)
|
||||
self._inject_headers(response, request.state.view_rate_limit)
|
||||
@@ -586,7 +586,7 @@ class Limiter:
|
||||
if self._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
self.__check_request_limit(request, func, False)
|
||||
self._check_request_limit(request, func, False)
|
||||
request.state._rate_limiting_complete = True
|
||||
response = func(*args, **kwargs)
|
||||
self._inject_headers(response, request.state.view_rate_limit)
|
||||
@@ -659,3 +659,17 @@ class Limiter:
|
||||
error_message=error_message,
|
||||
exempt_when=exempt_when,
|
||||
)
|
||||
|
||||
def exempt(self, obj):
|
||||
"""
|
||||
decorator to mark a view as exempt from
|
||||
rate limits.
|
||||
"""
|
||||
name = "%s.%s" % (obj.__module__, obj.__name__)
|
||||
|
||||
@wraps(obj)
|
||||
def __inner(*a, **k):
|
||||
return obj(*a, **k)
|
||||
|
||||
self._exempt_routes.add(name)
|
||||
return __inner
|
||||
46
slowapi/middleware.py
Normal file
46
slowapi/middleware.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint, DispatchFunction
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
|
||||
|
||||
class SlowAPIMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
app: Starlette = request.app
|
||||
limiter: Limiter = app.state.limiter
|
||||
handler = None
|
||||
for route in app.routes:
|
||||
match, _ = route.matches(request.scope)
|
||||
if match.FULL:
|
||||
handler = route.endpoint
|
||||
# if we can't find the route handler
|
||||
if handler is None:
|
||||
return await call_next(request)
|
||||
|
||||
name = "%s.%s" % (handler.__module__, handler.__name__)
|
||||
# if exempt no need to check
|
||||
if name in limiter._exempt_routes:
|
||||
return await call_next(request)
|
||||
|
||||
# there is a decorator for this route we let the decorator handle it
|
||||
if name in limiter._route_limits:
|
||||
return await call_next(request)
|
||||
|
||||
# let the decorator handle if already in
|
||||
if limiter._auto_check and not getattr(
|
||||
request.state, "_rate_limiting_complete", False
|
||||
):
|
||||
try:
|
||||
limiter._check_request_limit(request, handler, True)
|
||||
except Exception as e:
|
||||
# handle the exception since the global exception handler won't pick it up if we call_next
|
||||
exception_handler = app.exception_handlers.get(type(e), _rate_limit_exceeded_handler)
|
||||
return exception_handler(request, e)
|
||||
# request.state._rate_limiting_complete = True
|
||||
response = await call_next(request)
|
||||
response = limiter._inject_headers(response, request.state.view_rate_limit)
|
||||
return response
|
||||
@@ -11,6 +11,7 @@ from starlette.applications import Starlette
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.extension import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
||||
|
||||
class TestSlowapi:
|
||||
@@ -20,6 +21,7 @@ class TestSlowapi:
|
||||
app = Starlette(debug=True)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
||||
mock_handler = mock.Mock()
|
||||
mock_handler.level = logging.INFO
|
||||
@@ -30,6 +32,9 @@ class TestSlowapi:
|
||||
limiter_args.setdefault("key_func", get_remote_address)
|
||||
limiter = Limiter(**limiter_args)
|
||||
app = FastAPI()
|
||||
app.state.limiter = limiter
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
|
||||
mock_handler = mock.Mock()
|
||||
mock_handler.level = logging.INFO
|
||||
limiter.logger.addHandler(mock_handler)
|
||||
|
||||
@@ -168,3 +168,29 @@ class TestDecorators(TestSlowapi):
|
||||
timeline.forward(retry_after)
|
||||
resp = cli.get("/t1")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_exempt_decorator(self):
|
||||
app, limiter = self.build_starlette_app(
|
||||
headers_enabled=True, key_func=get_remote_address, default_limits=["1/minute"]
|
||||
|
||||
)
|
||||
@app.route("/t1")
|
||||
def t(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
with TestClient(app) as cli:
|
||||
resp = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"})
|
||||
assert resp.status_code == 200
|
||||
resp2 = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"})
|
||||
assert resp2.status_code == 429
|
||||
|
||||
@app.route("/t2")
|
||||
@limiter.exempt
|
||||
def t(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
with TestClient(app) as cli:
|
||||
resp = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.11"})
|
||||
assert resp.status_code == 200
|
||||
resp2 = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.11"})
|
||||
assert resp2.status_code == 200
|
||||
Reference in New Issue
Block a user