From a66a728e50e7566ef031c5087ea6bbb444fece96 Mon Sep 17 00:00:00 2001 From: Thomas LEVEIL Date: Sun, 27 Sep 2020 20:06:50 +0200 Subject: [PATCH] TESTS : cover cases of missing `request: Request` parameter or misuse --- slowapi/extension.py | 6 ++-- tests/test_fastapi_extension.py | 54 ++++++++++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/slowapi/extension.py b/slowapi/extension.py index f76efe5..29e5eed 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -562,7 +562,8 @@ class Limiter: async def async_wrapper(*args: Any, **kwargs: Any) -> Response: # get the request object from the decorated endpoint function request = kwargs.get("request", args[idx] if args else None) - assert isinstance(request, Request) + if not isinstance(request, Request): + raise Exception("parameter `request` must be an instance of starlette.requests.Request") if self._auto_check and not getattr( request.state, "_rate_limiting_complete", False @@ -581,7 +582,8 @@ class Limiter: def sync_wrapper(*args: Any, **kwargs: Any) -> Response: # get the request object from the decorated endpoint function request = kwargs.get("request", args[idx] if args else None) - assert isinstance(request, Request) + if not isinstance(request, Request): + raise Exception("parameter `request` must be an instance of starlette.requests.Request") if self._auto_check and not getattr( request.state, "_rate_limiting_complete", False diff --git a/tests/test_fastapi_extension.py b/tests/test_fastapi_extension.py index d943f61..436e871 100644 --- a/tests/test_fastapi_extension.py +++ b/tests/test_fastapi_extension.py @@ -1,11 +1,9 @@ import hiro -from fastapi import FastAPI +import pytest from starlette.requests import Request from starlette.responses import PlainTextResponse -from starlette.routing import Route from starlette.testclient import TestClient -from slowapi.extension import Limiter from slowapi.util import get_ipaddr from tests import TestSlowapi @@ -45,6 +43,52 @@ class TestDecorators(TestSlowapi): assert cli.get("/t1").status_code == 429 assert ( - cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code - == 429 + cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code + == 429 ) + + def test_endpoint_missing_request_param(self): + app, limiter = self.build_fastapi_app(key_func=get_ipaddr) + + with pytest.raises(Exception) as exc_info: + @app.get("/t3") + @limiter.limit("5/minute") + async def t3(): + return PlainTextResponse("test") + assert exc_info.match(r"""^No "request" or "websocket" argument on function .*""") + + def test_endpoint_missing_request_param_sync(self): + app, limiter = self.build_fastapi_app(key_func=get_ipaddr) + + with pytest.raises(Exception) as exc_info: + @app.get("/t3_sync") + @limiter.limit("5/minute") + def t3(): + return PlainTextResponse("test") + assert exc_info.match(r"""^No "request" or "websocket" argument on function .*""") + + def test_endpoint_request_param_invalid(self): + app, limiter = self.build_fastapi_app(key_func=get_ipaddr) + + @app.get("/t4") + @limiter.limit("5/minute") + async def t4(request: str = None): + return PlainTextResponse("test") + + with pytest.raises(Exception) as exc_info: + client = TestClient(app) + client.get("/t4") + assert exc_info.match(r"""parameter `request` must be an instance of starlette.requests.Request""") + + def test_endpoint_request_param_invalid_sync(self): + app, limiter = self.build_fastapi_app(key_func=get_ipaddr) + + @app.get("/t5") + @limiter.limit("5/minute") + def t5(request: str = None): + return PlainTextResponse("test") + + with pytest.raises(Exception) as exc_info: + client = TestClient(app) + client.get("/t5") + assert exc_info.match(r"""parameter `request` must be an instance of starlette.requests.Request""")