diff --git a/slowapi/extension.py b/slowapi/extension.py index 5f3bd02..b9ba27a 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -164,15 +164,15 @@ class Limiter: for limit in set(default_limits): self._default_limits.extend( - [LimitGroup(limit, self._key_func, None, False, None, None, None)] + [LimitGroup(limit, self._key_func, None, False, None, None, None, False)] ) for limit in application_limits: self._application_limits.extend( - [LimitGroup(limit, self._key_func, "global", False, None, None, None)] + [LimitGroup(limit, self._key_func, "global", False, None, None, None, False)] ) for limit in in_memory_fallback: self._in_memory_fallback.extend( - [LimitGroup(limit, self._key_func, None, False, None, None, None)] + [LimitGroup(limit, self._key_func, None, False, None, None, None, False)] ) self._route_limits: Dict[str, List[Limit]] = {} self._dynamic_route_limits: Dict[str, List[LimitGroup]] = {} @@ -471,9 +471,12 @@ class Limiter: else [] ) all_limits += route_limits + combined_defaults = all( + not limit.override_defaults for limit in route_limits + ) if not route_limits and not ( in_middleware and name in self.__marked_for_limiting - ): + ) or combined_defaults: all_limits += list(itertools.chain(*self._default_limits)) # actually check the limits, so far we've only computed the list of limits to check self.__evaluate_limits(request, endpoint, all_limits) @@ -503,6 +506,7 @@ class Limiter: methods: Optional[List[str]] = None, error_message: Optional[str] = None, exempt_when: Optional[Callable[..., bool]] = None, + override_defaults: bool = True ) -> Callable[..., Any]: _scope = scope if shared else None @@ -521,6 +525,7 @@ class Limiter: methods, error_message, exempt_when, + override_defaults, ) else: try: @@ -533,6 +538,7 @@ class Limiter: methods, error_message, exempt_when, + override_defaults, ) ) except ValueError as e: @@ -604,6 +610,7 @@ class Limiter: methods: Optional[List[str]] = None, error_message: Optional[str] = None, exempt_when: Optional[Callable[..., bool]] = None, + override_defaults: bool = True ) -> Callable: """ decorator to be used for rate limiting individual routes. @@ -619,6 +626,7 @@ class Limiter: :param error_message: string (or callable that returns one) to override the error message used in the response. :param exempt_when: + :param override_defaults: whether to override the default limits (default: True) :return: """ return self.__limit_decorator( @@ -628,6 +636,7 @@ class Limiter: methods=methods, error_message=error_message, exempt_when=exempt_when, + override_defaults=override_defaults ) def shared_limit( @@ -637,6 +646,7 @@ class Limiter: key_func: Optional[Callable[..., str]] = None, error_message: Optional[str] = None, exempt_when: Optional[Callable[..., bool]] = None, + override_defaults: bool = True, ) -> Callable: """ decorator to be applied to multiple routes sharing the same rate limit. @@ -650,6 +660,7 @@ class Limiter: :param error_message: string (or callable that returns one) to override the error message used in the response. :param exempt_when: + :param override_defaults: whether to override the default limits (default: True) """ return self.__limit_decorator( limit_value, @@ -658,6 +669,7 @@ class Limiter: scope, error_message=error_message, exempt_when=exempt_when, + override_defaults=override_defaults, ) def exempt(self, obj): diff --git a/slowapi/wrappers.py b/slowapi/wrappers.py index 5dda55d..4fe02c4 100644 --- a/slowapi/wrappers.py +++ b/slowapi/wrappers.py @@ -17,6 +17,7 @@ class Limit(object): methods: Optional[List[str]], error_message: Optional[Union[str, Callable[..., str]]], exempt_when: Optional[Callable[..., bool]], + override_defaults: bool ) -> None: self.limit = limit self.key_func = key_func @@ -25,6 +26,7 @@ class Limit(object): self.methods = methods self.error_message = error_message self.exempt_when = exempt_when + self.override_defaults = override_defaults @property def is_exempt(self) -> bool: @@ -62,6 +64,7 @@ class LimitGroup(object): methods: Optional[List[str]], error_message: Optional[Union[str, Callable[..., str]]], exempt_when: Optional[Callable[..., bool]], + override_defaults: bool, ): self.__limit_provider = limit_provider self.__scope = scope @@ -70,6 +73,7 @@ class LimitGroup(object): self.methods = methods and [m.lower() for m in methods] or methods self.error_message = error_message self.exempt_when = exempt_when + self.override_defaults = override_defaults def __iter__(self) -> Iterator[Limit]: limit_items: List[RateLimitItem] = parse_many( @@ -86,4 +90,5 @@ class LimitGroup(object): self.methods, self.error_message, self.exempt_when, + self.override_defaults, ) diff --git a/tests/test_starlette_extension.py b/tests/test_starlette_extension.py index 3146db8..572cf30 100644 --- a/tests/test_starlette_extension.py +++ b/tests/test_starlette_extension.py @@ -70,8 +70,8 @@ class TestDecorators(TestSlowapi): def test_multiple_decorators(self): app, limiter = self.build_starlette_app(key_func=get_ipaddr) - @limiter.limit("100 per minute", lambda: "test") - @limiter.limit("50/minute") # per ip as per default key_func + @limiter.limit("10 per minute", lambda: "test") + @limiter.limit("5/minute") # per ip as per default key_func async def t1(request: Request): return PlainTextResponse("test") @@ -79,16 +79,16 @@ class TestDecorators(TestSlowapi): with hiro.Timeline().freeze() as timeline: cli = TestClient(app) - for i in range(0, 100): + for i in range(0, 10): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) - assert response.status_code == 200 if i < 50 else 429 - for i in range(50): + assert response.status_code == 200 if i < 5 else 429 + for i in range(5): assert cli.get("/t1").status_code == 200 assert cli.get("/t1").status_code == 429 assert ( - cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code - == 429 + cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code + == 429 ) def test_headers_no_breach(self): @@ -196,10 +196,13 @@ class TestDecorators(TestSlowapi): resp2 = cli.get("/t2", headers={"X_FORWARDED_FOR": "127.0.0.10"}) assert resp2.status_code == 200 + # todo: more tests - see https://github.com/alisaifee/flask-limiter/blob/55df08f14143a7e918fc033067a494248ab6b0c5/tests/test_decorators.py#L187 def test_default_and_decorator_limit_merging(self): - app, limiter = self.build_starlette_app(key_func=get_ipaddr, default_limits=["50/minute"]) + # test pool has 100 reqs left + app, limiter = self.build_starlette_app(key_func=lambda: "test", default_limits=["10/minute"]) - @limiter.limit("100 per minute", key_func=lambda: "lest") + # ip pool has 50 reqs for 127.0.0.14 + @limiter.limit("5 per minute", key_func=get_ipaddr, override_defaults=False) async def t1(request: Request): return PlainTextResponse("test") @@ -207,13 +210,15 @@ class TestDecorators(TestSlowapi): with hiro.Timeline().freeze() as timeline: cli = TestClient(app) - for i in range(0, 100): - response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.14"}) - assert response.status_code == 200 if i < 50 else 429 - for i in range(50): + for i in range(0, 10): + response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) + assert response.status_code == 200 if i < 5 else 429 + # now ip pool for 127.0.0.14 has 0 reqs left + for i in range(5): assert cli.get("/t1").status_code == 200 + # now test pool has 0 reqs left assert cli.get("/t1").status_code == 429 assert ( - cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.15"}).status_code + cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code == 429)