add middleware and exempt decorator

This commit is contained in:
Reuben Thomas-Davis
2020-07-04 03:16:46 +01:00
parent 4c4141a7db
commit 6580bb0624
4 changed files with 95 additions and 4 deletions

View File

@@ -410,7 +410,7 @@ class Limiter:
if failed_limit:
raise RateLimitExceeded(failed_limit)
def __check_request_limit(
def _check_request_limit(
self,
request: Request,
endpoint_func: Callable[..., Any],
@@ -486,7 +486,7 @@ class Limiter:
" in-memory storage"
)
self._storage_dead = True
self.__check_request_limit(request, endpoint_func, in_middleware)
self._check_request_limit(request, endpoint_func, in_middleware)
else:
if self._swallow_errors:
self.logger.exception("Failed to rate limit. Swallowing error")
@@ -567,7 +567,7 @@ class Limiter:
if self._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
):
self.__check_request_limit(request, func, False)
self._check_request_limit(request, func, False)
request.state._rate_limiting_complete = True
response = await func(*args, **kwargs)
self._inject_headers(response, request.state.view_rate_limit)
@@ -586,7 +586,7 @@ class Limiter:
if self._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
):
self.__check_request_limit(request, func, False)
self._check_request_limit(request, func, False)
request.state._rate_limiting_complete = True
response = func(*args, **kwargs)
self._inject_headers(response, request.state.view_rate_limit)
@@ -659,3 +659,17 @@ class Limiter:
error_message=error_message,
exempt_when=exempt_when,
)
def exempt(self, obj):
"""
decorator to mark a view as exempt from
rate limits.
"""
name = "%s.%s" % (obj.__module__, obj.__name__)
@wraps(obj)
def __inner(*a, **k):
return obj(*a, **k)
self._exempt_routes.add(name)
return __inner

46
slowapi/middleware.py Normal file
View File

@@ -0,0 +1,46 @@
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint, DispatchFunction
from starlette.requests import Request
from starlette.responses import Response
from slowapi import Limiter, _rate_limit_exceeded_handler
class SlowAPIMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
app: Starlette = request.app
limiter: Limiter = app.state.limiter
handler = None
for route in app.routes:
match, _ = route.matches(request.scope)
if match.FULL:
handler = route.endpoint
# if we can't find the route handler
if handler is None:
return await call_next(request)
name = "%s.%s" % (handler.__module__, handler.__name__)
# if exempt no need to check
if name in limiter._exempt_routes:
return await call_next(request)
# there is a decorator for this route we let the decorator handle it
if name in limiter._route_limits:
return await call_next(request)
# let the decorator handle if already in
if limiter._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
):
try:
limiter._check_request_limit(request, handler, True)
except Exception as e:
# handle the exception since the global exception handler won't pick it up if we call_next
exception_handler = app.exception_handlers.get(type(e), _rate_limit_exceeded_handler)
return exception_handler(request, e)
# request.state._rate_limiting_complete = True
response = await call_next(request)
response = limiter._inject_headers(response, request.state.view_rate_limit)
return response

View File

@@ -11,6 +11,7 @@ from starlette.applications import Starlette
from slowapi.errors import RateLimitExceeded
from slowapi.extension import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware
class TestSlowapi:
@@ -20,6 +21,7 @@ class TestSlowapi:
app = Starlette(debug=True)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
mock_handler = mock.Mock()
mock_handler.level = logging.INFO
@@ -30,6 +32,9 @@ class TestSlowapi:
limiter_args.setdefault("key_func", get_remote_address)
limiter = Limiter(**limiter_args)
app = FastAPI()
app.state.limiter = limiter
app.add_middleware(SlowAPIMiddleware)
mock_handler = mock.Mock()
mock_handler.level = logging.INFO
limiter.logger.addHandler(mock_handler)

View File

@@ -168,3 +168,29 @@ class TestDecorators(TestSlowapi):
timeline.forward(retry_after)
resp = cli.get("/t1")
assert resp.status_code == 200
def test_exempt_decorator(self):
app, limiter = self.build_starlette_app(
headers_enabled=True, key_func=get_remote_address, default_limits=["1/minute"]
)
@app.route("/t1")
def t(request: Request):
return PlainTextResponse("test")
with TestClient(app) as cli:
resp = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"})
assert resp.status_code == 200
resp2 = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"})
assert resp2.status_code == 429
@app.route("/t2")
@limiter.exempt
def t(request: Request):
return PlainTextResponse("test")
with TestClient(app) as cli:
resp = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.11"})
assert resp.status_code == 200
resp2 = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.11"})
assert resp2.status_code == 200