Improved logs, added a test

This commit is contained in:
Colin Delahunty
2024-06-27 12:47:12 -04:00
parent c1681754cb
commit 57223b72d7
2 changed files with 61 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
"""
The starlette extension to rate-limit requests
"""
import asyncio
import functools
import inspect
@@ -734,7 +735,8 @@ class Limiter:
if not isinstance(response, Response):
# get the response object from the decorated endpoint function
self._inject_headers(
kwargs.get("response"), request.state.view_rate_limit # type: ignore
kwargs.get("response"),
request.state.view_rate_limit, # type: ignore
)
else:
self._inject_headers(
@@ -766,7 +768,8 @@ class Limiter:
if not isinstance(response, Response):
# get the response object from the decorated endpoint function
self._inject_headers(
kwargs.get("response"), request.state.view_rate_limit # type: ignore
kwargs.get("response"),
request.state.view_rate_limit, # type: ignore
)
else:
self._inject_headers(
@@ -803,7 +806,7 @@ class Limiter:
* **error_message**: string (or callable that returns one) to override the
error message used in the response.
* **exempt_when**: function returning a boolean indicating whether to exempt
the route from the limit
the route from the limit. This function can optionally use a Request object.
* **cost**: integer (or callable that returns one) which is the cost of a hit
* **override_defaults**: whether to override the default limits (default: True)
"""

View File

@@ -43,6 +43,61 @@ class TestDecorators(TestSlowapi):
if i < 5:
assert response.text == "test"
def test_exempt_when_argument(self, build_starlette_app):
app, limiter = build_starlette_app(key_func=get_ipaddr)
def return_true():
return True
def return_false():
return False
def dynamic(request: Request):
user_agent = request.headers.get("User-Agent")
if user_agent is None:
return False
return user_agent == "exempt"
@limiter.limit("1/minute", exempt_when=return_true)
def always_true(request: Request):
return PlainTextResponse("test")
@limiter.limit("1/minute", exempt_when=return_false)
def always_false(request: Request):
return PlainTextResponse("test")
@limiter.limit("1/minute", exempt_when=dynamic)
def always_dynamic(request: Request):
return PlainTextResponse("test")
app.add_route("/true", always_true)
app.add_route("/false", always_false)
app.add_route("/dynamic", always_dynamic)
client = TestClient(app)
# Test always true always exempting
for i in range(0, 2):
response = client.get("/true")
assert response.status_code == 200
assert response.text == "test"
# Test always false hitting the limit after one hit
for i in range(0, 2):
response = client.get("/false")
assert response.status_code == 200 if i < 1 else 429
if i < 1:
assert response.text == "test"
# Test dynamic not exempting with the correct header
for i in range(0, 2):
response = client.get("/dynamic", headers={"User-Agent": "exempt"})
assert response.status_code == 200
assert response.text == "test"
# Test dynamic exempting with the incorrect header
for i in range(0, 2):
response = client.get("/dynamic")
assert response.status_code == 200 if i < 1 else 429
if i < 1:
assert response.text == "test"
def test_shared_decorator(self, build_starlette_app):
app, limiter = build_starlette_app(key_func=get_ipaddr)