mirror of
https://github.com/laurentS/slowapi.git
synced 2026-03-13 09:10:20 +08:00
Merge pull request #160 from colin99d/master
feat: allow Requests to be sent to exempt_when
This commit is contained in:
@@ -1,5 +1,11 @@
|
||||
# Change Log
|
||||
|
||||
## [0.1.10] - 2024-06-04
|
||||
|
||||
### Changed
|
||||
|
||||
- Breaking change: allow usage of the request object in the except_when function (thanks @colin99d)
|
||||
|
||||
## [0.1.9] - 2024-02-05
|
||||
|
||||
### Added
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
The starlette extension to rate-limit requests
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
@@ -486,7 +487,7 @@ class Limiter:
|
||||
limit_for_header = None
|
||||
for lim in limits:
|
||||
limit_scope = lim.scope or endpoint
|
||||
if lim.is_exempt:
|
||||
if lim.is_exempt(request):
|
||||
continue
|
||||
if lim.methods is not None and request.method.lower() not in lim.methods:
|
||||
continue
|
||||
@@ -703,11 +704,9 @@ class Limiter:
|
||||
else:
|
||||
self._route_limits.setdefault(name, []).extend(static_limits)
|
||||
|
||||
connection_type: Optional[str] = None
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request" or parameter.name == "websocket":
|
||||
connection_type = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
@@ -736,7 +735,8 @@ class Limiter:
|
||||
if not isinstance(response, Response):
|
||||
# get the response object from the decorated endpoint function
|
||||
self._inject_headers(
|
||||
kwargs.get("response"), request.state.view_rate_limit # type: ignore
|
||||
kwargs.get("response"), # type: ignore
|
||||
request.state.view_rate_limit,
|
||||
)
|
||||
else:
|
||||
self._inject_headers(
|
||||
@@ -768,7 +768,8 @@ class Limiter:
|
||||
if not isinstance(response, Response):
|
||||
# get the response object from the decorated endpoint function
|
||||
self._inject_headers(
|
||||
kwargs.get("response"), request.state.view_rate_limit # type: ignore
|
||||
kwargs.get("response"),
|
||||
request.state.view_rate_limit, # type: ignore
|
||||
)
|
||||
else:
|
||||
self._inject_headers(
|
||||
@@ -805,7 +806,7 @@ class Limiter:
|
||||
* **error_message**: string (or callable that returns one) to override the
|
||||
error message used in the response.
|
||||
* **exempt_when**: function returning a boolean indicating whether to exempt
|
||||
the route from the limit
|
||||
the route from the limit. This function can optionally use a Request object.
|
||||
* **cost**: integer (or callable that returns one) which is the cost of a hit
|
||||
* **override_defaults**: whether to override the default limits (default: True)
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ import inspect
|
||||
from typing import Callable, Iterator, List, Optional, Union
|
||||
|
||||
from limits import RateLimitItem, parse_many # type: ignore
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
class Limit(object):
|
||||
@@ -28,16 +29,27 @@ class Limit(object):
|
||||
self.methods = methods
|
||||
self.error_message = error_message
|
||||
self.exempt_when = exempt_when
|
||||
self._exempt_when_takes_request = (
|
||||
self.exempt_when
|
||||
and len(inspect.signature(self.exempt_when).parameters) == 1
|
||||
)
|
||||
self.cost = cost
|
||||
self.override_defaults = override_defaults
|
||||
|
||||
@property
|
||||
def is_exempt(self) -> bool:
|
||||
def is_exempt(self, request: Optional[Request] = None) -> bool:
|
||||
"""
|
||||
Check if the limit is exempt.
|
||||
|
||||
** parameter **
|
||||
* **request**: the request object
|
||||
|
||||
Return True to exempt the route from the limit.
|
||||
"""
|
||||
return self.exempt_when() if self.exempt_when is not None else False
|
||||
if self.exempt_when is None:
|
||||
return False
|
||||
if self._exempt_when_takes_request and request:
|
||||
return self.exempt_when(request)
|
||||
return self.exempt_when()
|
||||
|
||||
@property
|
||||
def scope(self) -> str:
|
||||
|
||||
@@ -43,6 +43,61 @@ class TestDecorators(TestSlowapi):
|
||||
if i < 5:
|
||||
assert response.text == "test"
|
||||
|
||||
def test_exempt_when_argument(self, build_starlette_app):
|
||||
app, limiter = build_starlette_app(key_func=get_ipaddr)
|
||||
|
||||
def return_true():
|
||||
return True
|
||||
|
||||
def return_false():
|
||||
return False
|
||||
|
||||
def dynamic(request: Request):
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
if user_agent is None:
|
||||
return False
|
||||
return user_agent == "exempt"
|
||||
|
||||
@limiter.limit("1/minute", exempt_when=return_true)
|
||||
def always_true(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
@limiter.limit("1/minute", exempt_when=return_false)
|
||||
def always_false(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
@limiter.limit("1/minute", exempt_when=dynamic)
|
||||
def always_dynamic(request: Request):
|
||||
return PlainTextResponse("test")
|
||||
|
||||
app.add_route("/true", always_true)
|
||||
app.add_route("/false", always_false)
|
||||
app.add_route("/dynamic", always_dynamic)
|
||||
|
||||
client = TestClient(app)
|
||||
# Test always true always exempting
|
||||
for i in range(0, 2):
|
||||
response = client.get("/true")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "test"
|
||||
# Test always false hitting the limit after one hit
|
||||
for i in range(0, 2):
|
||||
response = client.get("/false")
|
||||
assert response.status_code == 200 if i < 1 else 429
|
||||
if i < 1:
|
||||
assert response.text == "test"
|
||||
# Test dynamic not exempting with the correct header
|
||||
for i in range(0, 2):
|
||||
response = client.get("/dynamic", headers={"User-Agent": "exempt"})
|
||||
assert response.status_code == 200
|
||||
assert response.text == "test"
|
||||
# Test dynamic exempting with the incorrect header
|
||||
for i in range(0, 2):
|
||||
response = client.get("/dynamic")
|
||||
assert response.status_code == 200 if i < 1 else 429
|
||||
if i < 1:
|
||||
assert response.text == "test"
|
||||
|
||||
def test_shared_decorator(self, build_starlette_app):
|
||||
app, limiter = build_starlette_app(key_func=get_ipaddr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user