From b0a33823c27f9833b0332b5c1f1ed6a7d6d53a99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Gryta?= Date: Sun, 28 Jan 2024 22:45:35 +0100 Subject: [PATCH 1/3] Make request parameter unnecessary --- slowapi/extension.py | 2 ++ slowapi/util.py | 66 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/slowapi/extension.py b/slowapi/extension.py index ac14a06..25b999a 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -33,6 +33,7 @@ from typing_extensions import Literal from .errors import RateLimitExceeded +from .util import add_request_signature from .wrappers import Limit, LimitGroup # used to annotate get_app_config method @@ -656,6 +657,7 @@ def __limit_decorator( _scope = scope if shared else None def decorator(func: Callable[..., Response]): + func = add_request_signature(func) keyfunc = key_func or self._key_func name = f"{func.__module__}.{func.__name__}" dynamic_limit = None diff --git a/slowapi/util.py b/slowapi/util.py index c44faa9..dfd92be 100644 --- a/slowapi/util.py +++ b/slowapi/util.py @@ -1,4 +1,9 @@ +from asyncio import iscoroutinefunction +from functools import wraps +from inspect import signature, Parameter + from starlette.requests import Request +from typing import Callable, List def get_ipaddr(request: Request) -> str: @@ -25,3 +30,64 @@ def get_remote_address(request: Request) -> str: return "127.0.0.1" return request.client.host + + +def get_request_param(func: Callable) -> List[Parameter]: + """Retrieve list of parameters that are a Request""" + sig = signature(func) + params = list(sig.parameters.values()) + return [param for param in params if param.annotation == Request] + + +def add_request_signature(func: Callable): + """Adds starlette.Request argument to function's signature so that it'll be accessible to custom decorators""" + + def scrap_req(func: Callable, args, kwargs): + if getattr(func, "scrap_req", False): + req_param = get_request_param(func)[0] + try: + del kwargs[req_param.name] + except KeyError: + # Request is not in kwargs for some reason delete from args + # Deletion index: 0 + del args[0] + return args, kwargs + + if iscoroutinefunction(func): + + @wraps(func) + async def wrapper(*args, **kwargs): + args, kwargs = scrap_req(func, args, kwargs) + return await func(*args, **kwargs) + + else: + + @wraps(func) + def wrapper(*args, **kwargs): + args, kwargs = scrap_req(func, args, kwargs) + return func(*args, **kwargs) + + sig = signature(func) + params = list(sig.parameters.values()) + + rq = get_request_param(func) + if len(rq) == 1: + if not hasattr(func, "scrap_req"): # Ignore if already set + func.scrap_req = False + else: + func.scrap_req = True + name = "request" # Slowapi should allow for request to be anything <- param name generator + param_names = [pname.name for pname in params] + if name not in param_names: + func.req = name + + req = Parameter(name=name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=Request) + params.insert(0, req) + sig = sig.replace(parameters=params) + func.__signature__ = sig + else: + fname = f"{func.__module__}.{func.__name__}" + raise Exception(f"Remove 'request' argument from function {fname}" + f" or add [request : starlette.Request] manually.") + + return wrapper \ No newline at end of file From c5aa2d554b9ddeacca1a820cb9fc74be3a830d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Gryta?= Date: Sun, 28 Jan 2024 23:10:55 +0100 Subject: [PATCH 2/3] Black reformatted, removed and fixed unit tests --- slowapi/util.py | 12 +++++--- tests/test_fastapi_extension.py | 52 +++++++-------------------------- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/slowapi/util.py b/slowapi/util.py index dfd92be..28b220e 100644 --- a/slowapi/util.py +++ b/slowapi/util.py @@ -81,13 +81,17 @@ def wrapper(*args, **kwargs): if name not in param_names: func.req = name - req = Parameter(name=name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=Request) + req = Parameter( + name=name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=Request + ) params.insert(0, req) sig = sig.replace(parameters=params) func.__signature__ = sig else: fname = f"{func.__module__}.{func.__name__}" - raise Exception(f"Remove 'request' argument from function {fname}" - f" or add [request : starlette.Request] manually.") + raise Exception( + f"Remove 'request' argument from function {fname}" + f" or add [request : starlette.Request] manually." + ) - return wrapper \ No newline at end of file + return wrapper diff --git a/tests/test_fastapi_extension.py b/tests/test_fastapi_extension.py index 42e6322..734c0cf 100644 --- a/tests/test_fastapi_extension.py +++ b/tests/test_fastapi_extension.py @@ -144,47 +144,18 @@ async def t1(request: Request, response: Response): == 429 ) - def test_endpoint_missing_request_param(self, build_fastapi_app): - app, limiter = build_fastapi_app(key_func=get_ipaddr) - - with pytest.raises(Exception) as exc_info: - - @app.get("/t3") - @limiter.limit("5/minute") - async def t3(): - return PlainTextResponse("test") - - assert exc_info.match( - r"""^No "request" or "websocket" argument on function .*""" - ) - - def test_endpoint_missing_request_param_sync(self, build_fastapi_app): + def test_endpoint_request_param_invalid(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr) with pytest.raises(Exception) as exc_info: - @app.get("/t3_sync") + @app.get("/t4") @limiter.limit("5/minute") - def t3(): + async def t4(request: str = None): return PlainTextResponse("test") assert exc_info.match( - r"""^No "request" or "websocket" argument on function .*""" - ) - - def test_endpoint_request_param_invalid(self, build_fastapi_app): - app, limiter = build_fastapi_app(key_func=get_ipaddr) - - @app.get("/t4") - @limiter.limit("5/minute") - async def t4(request: str = None): - return PlainTextResponse("test") - - with pytest.raises(Exception) as exc_info: - client = TestClient(app) - client.get("/t4") - assert exc_info.match( - r"""parameter `request` must be an instance of starlette.requests.Request""" + r"Remove 'request' argument from function tests.test_fastapi_extension.t4 or add \[request : starlette.Request\] manually" ) def test_endpoint_response_param_invalid(self, build_fastapi_app): @@ -205,16 +176,15 @@ async def t4(request: Request, response: str = None): def test_endpoint_request_param_invalid_sync(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr) - @app.get("/t5") - @limiter.limit("5/minute") - def t5(request: str = None): - return PlainTextResponse("test") - with pytest.raises(Exception) as exc_info: - client = TestClient(app) - client.get("/t5") + + @app.get("/t5") + @limiter.limit("5/minute") + def t5(request: str = None): + return PlainTextResponse("test") + assert exc_info.match( - r"""parameter `request` must be an instance of starlette.requests.Request""" + r"Remove 'request' argument from function tests.test_fastapi_extension.t5 or add \[request : starlette.Request\] manually" ) def test_endpoint_response_param_invalid_sync(self, build_fastapi_app): From a6d4c014d4b61ef6151d8041a07a9afd42b4b3ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Gryta?= Date: Sun, 28 Jan 2024 23:24:20 +0100 Subject: [PATCH 3/3] Update util.py --- slowapi/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slowapi/util.py b/slowapi/util.py index 28b220e..a78a360 100644 --- a/slowapi/util.py +++ b/slowapi/util.py @@ -71,7 +71,7 @@ def wrapper(*args, **kwargs): params = list(sig.parameters.values()) rq = get_request_param(func) - if len(rq) == 1: + if len(rq) >= 1: if not hasattr(func, "scrap_req"): # Ignore if already set func.scrap_req = False else: