TESTS : cover cases of missing request: Request parameter or misuse

This commit is contained in:
Thomas LEVEIL
2020-09-27 20:06:50 +02:00
parent 4c4141a7db
commit a66a728e50
2 changed files with 53 additions and 7 deletions

View File

@@ -562,7 +562,8 @@ class Limiter:
async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
# get the request object from the decorated endpoint function
request = kwargs.get("request", args[idx] if args else None)
assert isinstance(request, Request)
if not isinstance(request, Request):
raise Exception("parameter `request` must be an instance of starlette.requests.Request")
if self._auto_check and not getattr(
request.state, "_rate_limiting_complete", False
@@ -581,7 +582,8 @@ class Limiter:
def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
# get the request object from the decorated endpoint function
request = kwargs.get("request", args[idx] if args else None)
assert isinstance(request, Request)
if not isinstance(request, Request):
raise Exception("parameter `request` must be an instance of starlette.requests.Request")
if self._auto_check and not getattr(
request.state, "_rate_limiting_complete", False

View File

@@ -1,11 +1,9 @@
import hiro
from fastapi import FastAPI
import pytest
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from slowapi.extension import Limiter
from slowapi.util import get_ipaddr
from tests import TestSlowapi
@@ -45,6 +43,52 @@ class TestDecorators(TestSlowapi):
assert cli.get("/t1").status_code == 429
assert (
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
== 429
cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code
== 429
)
def test_endpoint_missing_request_param(self):
app, limiter = self.build_fastapi_app(key_func=get_ipaddr)
with pytest.raises(Exception) as exc_info:
@app.get("/t3")
@limiter.limit("5/minute")
async def t3():
return PlainTextResponse("test")
assert exc_info.match(r"""^No "request" or "websocket" argument on function .*""")
def test_endpoint_missing_request_param_sync(self):
app, limiter = self.build_fastapi_app(key_func=get_ipaddr)
with pytest.raises(Exception) as exc_info:
@app.get("/t3_sync")
@limiter.limit("5/minute")
def t3():
return PlainTextResponse("test")
assert exc_info.match(r"""^No "request" or "websocket" argument on function .*""")
def test_endpoint_request_param_invalid(self):
app, limiter = self.build_fastapi_app(key_func=get_ipaddr)
@app.get("/t4")
@limiter.limit("5/minute")
async def t4(request: str = None):
return PlainTextResponse("test")
with pytest.raises(Exception) as exc_info:
client = TestClient(app)
client.get("/t4")
assert exc_info.match(r"""parameter `request` must be an instance of starlette.requests.Request""")
def test_endpoint_request_param_invalid_sync(self):
app, limiter = self.build_fastapi_app(key_func=get_ipaddr)
@app.get("/t5")
@limiter.limit("5/minute")
def t5(request: str = None):
return PlainTextResponse("test")
with pytest.raises(Exception) as exc_info:
client = TestClient(app)
client.get("/t5")
assert exc_info.match(r"""parameter `request` must be an instance of starlette.requests.Request""")