From 6580bb06243e7911d6d7e855b109f99bf8ec7d19 Mon Sep 17 00:00:00 2001 From: Reuben Thomas-Davis Date: Sat, 4 Jul 2020 03:16:46 +0100 Subject: [PATCH] add middleware and exempt decorator --- slowapi/extension.py | 22 ++++++++++++--- slowapi/middleware.py | 46 +++++++++++++++++++++++++++++++ tests/__init__.py | 5 ++++ tests/test_starlette_extension.py | 26 +++++++++++++++++ 4 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 slowapi/middleware.py diff --git a/slowapi/extension.py b/slowapi/extension.py index f76efe5..5f3bd02 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -410,7 +410,7 @@ class Limiter: if failed_limit: raise RateLimitExceeded(failed_limit) - def __check_request_limit( + def _check_request_limit( self, request: Request, endpoint_func: Callable[..., Any], @@ -486,7 +486,7 @@ class Limiter: " in-memory storage" ) self._storage_dead = True - self.__check_request_limit(request, endpoint_func, in_middleware) + self._check_request_limit(request, endpoint_func, in_middleware) else: if self._swallow_errors: self.logger.exception("Failed to rate limit. Swallowing error") @@ -567,7 +567,7 @@ class Limiter: if self._auto_check and not getattr( request.state, "_rate_limiting_complete", False ): - self.__check_request_limit(request, func, False) + self._check_request_limit(request, func, False) request.state._rate_limiting_complete = True response = await func(*args, **kwargs) self._inject_headers(response, request.state.view_rate_limit) @@ -586,7 +586,7 @@ class Limiter: if self._auto_check and not getattr( request.state, "_rate_limiting_complete", False ): - self.__check_request_limit(request, func, False) + self._check_request_limit(request, func, False) request.state._rate_limiting_complete = True response = func(*args, **kwargs) self._inject_headers(response, request.state.view_rate_limit) @@ -659,3 +659,17 @@ class Limiter: error_message=error_message, exempt_when=exempt_when, ) + + def exempt(self, obj): + """ + decorator to mark a view as exempt from + rate limits. + """ + name = "%s.%s" % (obj.__module__, obj.__name__) + + @wraps(obj) + def __inner(*a, **k): + return obj(*a, **k) + + self._exempt_routes.add(name) + return __inner \ No newline at end of file diff --git a/slowapi/middleware.py b/slowapi/middleware.py new file mode 100644 index 0000000..48d85f2 --- /dev/null +++ b/slowapi/middleware.py @@ -0,0 +1,46 @@ +from starlette.applications import Starlette +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint, DispatchFunction +from starlette.requests import Request +from starlette.responses import Response + +from slowapi import Limiter, _rate_limit_exceeded_handler + + +class SlowAPIMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + app: Starlette = request.app + limiter: Limiter = app.state.limiter + handler = None + for route in app.routes: + match, _ = route.matches(request.scope) + if match.FULL: + handler = route.endpoint + # if we can't find the route handler + if handler is None: + return await call_next(request) + + name = "%s.%s" % (handler.__module__, handler.__name__) + # if exempt no need to check + if name in limiter._exempt_routes: + return await call_next(request) + + # there is a decorator for this route we let the decorator handle it + if name in limiter._route_limits: + return await call_next(request) + + # let the decorator handle if already in + if limiter._auto_check and not getattr( + request.state, "_rate_limiting_complete", False + ): + try: + limiter._check_request_limit(request, handler, True) + except Exception as e: + # handle the exception since the global exception handler won't pick it up if we call_next + exception_handler = app.exception_handlers.get(type(e), _rate_limit_exceeded_handler) + return exception_handler(request, e) + # request.state._rate_limiting_complete = True + response = await call_next(request) + response = limiter._inject_headers(response, request.state.view_rate_limit) + return response diff --git a/tests/__init__.py b/tests/__init__.py index 54d71ff..de43d99 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,6 +11,7 @@ from starlette.applications import Starlette from slowapi.errors import RateLimitExceeded from slowapi.extension import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address +from slowapi.middleware import SlowAPIMiddleware class TestSlowapi: @@ -20,6 +21,7 @@ class TestSlowapi: app = Starlette(debug=True) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.add_middleware(SlowAPIMiddleware) mock_handler = mock.Mock() mock_handler.level = logging.INFO @@ -30,6 +32,9 @@ class TestSlowapi: limiter_args.setdefault("key_func", get_remote_address) limiter = Limiter(**limiter_args) app = FastAPI() + app.state.limiter = limiter + app.add_middleware(SlowAPIMiddleware) + mock_handler = mock.Mock() mock_handler.level = logging.INFO limiter.logger.addHandler(mock_handler) diff --git a/tests/test_starlette_extension.py b/tests/test_starlette_extension.py index 14db222..aab87f0 100644 --- a/tests/test_starlette_extension.py +++ b/tests/test_starlette_extension.py @@ -168,3 +168,29 @@ class TestDecorators(TestSlowapi): timeline.forward(retry_after) resp = cli.get("/t1") assert resp.status_code == 200 + + def test_exempt_decorator(self): + app, limiter = self.build_starlette_app( + headers_enabled=True, key_func=get_remote_address, default_limits=["1/minute"] + + ) + @app.route("/t1") + def t(request: Request): + return PlainTextResponse("test") + + with TestClient(app) as cli: + resp = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"}) + assert resp.status_code == 200 + resp2 = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.10"}) + assert resp2.status_code == 429 + + @app.route("/t2") + @limiter.exempt + def t(request: Request): + return PlainTextResponse("test") + + with TestClient(app) as cli: + resp = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.11"}) + assert resp.status_code == 200 + resp2 = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.11"}) + assert resp2.status_code == 200 \ No newline at end of file