From 88befd03015448bca8141984c873de5c58a1ab96 Mon Sep 17 00:00:00 2001 From: Danny Rehl <15615045+colidyre@users.noreply.github.com> Date: Sun, 24 Aug 2025 13:16:56 +0200 Subject: [PATCH] fix: if/else scope in assert statements A statement like `assert x == 1 if False else "foobar"` will be evaluated to assert "foobar", which is always True. This means that checking for a response status code like `assert status_code == 200 if i < 5 else 429` will be evaluated to `assert status_code == 200` for i < 5, but also to `assert 429` for i >= 5 and not `assert status_code == 429`. The fix is to get the right scope for the if/else condition by using parentheses, so that the right response status code will be checked if i >= 5. --- tests/test_fastapi_extension.py | 26 +++++++++++++------------- tests/test_starlette_extension.py | 24 ++++++++++++------------ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/test_fastapi_extension.py b/tests/test_fastapi_extension.py index 42e6322..aec7c34 100644 --- a/tests/test_fastapi_extension.py +++ b/tests/test_fastapi_extension.py @@ -20,7 +20,7 @@ async def t1(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) def test_single_decorator_with_headers(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True) @@ -33,7 +33,7 @@ async def t1(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) assert ( response.headers.get("X-RateLimit-Limit") is not None if i < 5 else True ) @@ -50,7 +50,7 @@ async def t1(request: Request, response: Response): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) def test_single_decorator_not_response_with_headers(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True) @@ -63,7 +63,7 @@ async def t1(request: Request, response: Response): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) assert ( response.headers.get("X-RateLimit-Limit") is not None if i < 5 else True ) @@ -84,7 +84,7 @@ async def t1(request: Request): cli = TestClient(app) for i in range(0, 100): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) - assert response.status_code == 200 if i < 50 else 429 + assert response.status_code == (200 if i < 50 else 429) for i in range(50): assert cli.get("/t1").status_code == 200 @@ -109,7 +109,7 @@ async def t1(request: Request, response: Response): cli = TestClient(app) for i in range(0, 100): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) - assert response.status_code == 200 if i < 50 else 429 + assert response.status_code == (200 if i < 50 else 429) for i in range(50): assert cli.get("/t1").status_code == 200 @@ -134,7 +134,7 @@ async def t1(request: Request, response: Response): cli = TestClient(app) for i in range(0, 100): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) - assert response.status_code == 200 if i < 50 else 429 + assert response.status_code == (200 if i < 50 else 429) for i in range(50): assert cli.get("/t1").status_code == 200 @@ -253,11 +253,11 @@ async def t1(request: Request, response: Response): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) for i in range(0, 20): response = client.get("/t1", headers={"TOKEN": "secret"}) - assert response.status_code == 200 if i < 10 else 429 + assert response.status_code == (200 if i < 10 else 429) def test_disabled_limiter(self, build_fastapi_app): """ @@ -308,10 +308,10 @@ async def t2(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) response = client.get("/t2") - assert response.status_code == 200 if i < 3 else 429 + assert response.status_code == (200 if i < 3 else 429) def test_callable_cost(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr) @@ -331,10 +331,10 @@ async def t2(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1", headers={"foo": "10"}) - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) response = client.get("/t2", headers={"foo": "5"}) - assert response.status_code == 200 if i < 6 else 429 + assert response.status_code == (200 if i < 6 else 429) @pytest.mark.parametrize( "key_style", diff --git a/tests/test_starlette_extension.py b/tests/test_starlette_extension.py index 0e26baa..3c40154 100644 --- a/tests/test_starlette_extension.py +++ b/tests/test_starlette_extension.py @@ -23,7 +23,7 @@ async def t1(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) if i < 5: assert response.text == "test" @@ -39,7 +39,7 @@ def t1(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) if i < 5: assert response.text == "test" @@ -83,7 +83,7 @@ def always_dynamic(request: Request): # Test always false hitting the limit after one hit for i in range(0, 2): response = client.get("/false") - assert response.status_code == 200 if i < 1 else 429 + assert response.status_code == (200 if i < 1 else 429) if i < 1: assert response.text == "test" # Test dynamic not exempting with the correct header @@ -94,7 +94,7 @@ def always_dynamic(request: Request): # Test dynamic exempting with the incorrect header for i in range(0, 2): response = client.get("/dynamic") - assert response.status_code == 200 if i < 1 else 429 + assert response.status_code == (200 if i < 1 else 429) if i < 1: assert response.text == "test" @@ -117,7 +117,7 @@ def t2(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) # the shared limit has already been hit via t1 assert client.get("/t2").status_code == 429 @@ -135,7 +135,7 @@ async def t1(request: Request): cli = TestClient(app) for i in range(0, 10): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) for i in range(5): assert cli.get("/t1").status_code == 200 @@ -159,7 +159,7 @@ async def t1(request: Request): cli = TestClient(app) for i in range(0, 10): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) assert response.headers.get("Retry-After") if i < 5 else True for i in range(5): assert cli.get("/t1").status_code == 200 @@ -304,7 +304,7 @@ async def t1(request: Request): cli = TestClient(app) for i in range(0, 10): response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}) - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) for i in range(5): assert cli.get("/t1").status_code == 200 @@ -332,14 +332,14 @@ async def t2(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1") - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) if i < 5: assert response.text == "test" else: assert "error" in response.json() response = client.get("/t2") - assert response.status_code == 200 if i < 3 else 429 + assert response.status_code == (200 if i < 3 else 429) if i < 3: assert response.text == "test" else: @@ -365,14 +365,14 @@ async def t2(request: Request): client = TestClient(app) for i in range(0, 10): response = client.get("/t1", headers={"foo": "10"}) - assert response.status_code == 200 if i < 5 else 429 + assert response.status_code == (200 if i < 5 else 429) if i < 5: assert response.text == "test" else: assert "error" in response.json() response = client.get("/t2", headers={"foo": "5"}) - assert response.status_code == 200 if i < 6 else 429 + assert response.status_code == (200 if i < 6 else 429) if i < 6: assert response.text == "test" else: