diff --git a/slowapi/extension.py b/slowapi/extension.py index ac14a06..25b999a 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -33,6 +33,7 @@ from typing_extensions import Literal from .errors import RateLimitExceeded +from .util import add_request_signature from .wrappers import Limit, LimitGroup # used to annotate get_app_config method @@ -656,6 +657,7 @@ def __limit_decorator( _scope = scope if shared else None def decorator(func: Callable[..., Response]): + func = add_request_signature(func) keyfunc = key_func or self._key_func name = f"{func.__module__}.{func.__name__}" dynamic_limit = None diff --git a/slowapi/util.py b/slowapi/util.py index c44faa9..a78a360 100644 --- a/slowapi/util.py +++ b/slowapi/util.py @@ -1,4 +1,9 @@ +from asyncio import iscoroutinefunction +from functools import wraps +from inspect import signature, Parameter + from starlette.requests import Request +from typing import Callable, List def get_ipaddr(request: Request) -> str: @@ -25,3 +30,68 @@ def get_remote_address(request: Request) -> str: return "127.0.0.1" return request.client.host + + +def get_request_param(func: Callable) -> List[Parameter]: + """Retrieve list of parameters that are a Request""" + sig = signature(func) + params = list(sig.parameters.values()) + return [param for param in params if param.annotation == Request] + + +def add_request_signature(func: Callable): + """Adds starlette.Request argument to function's signature so that it'll be accessible to custom decorators""" + + def scrap_req(func: Callable, args, kwargs): + if getattr(func, "scrap_req", False): + req_param = get_request_param(func)[0] + try: + del kwargs[req_param.name] + except KeyError: + # Request is not in kwargs for some reason delete from args + # Deletion index: 0 + del args[0] + return args, kwargs + + if iscoroutinefunction(func): + + @wraps(func) + async def wrapper(*args, **kwargs): + args, kwargs = scrap_req(func, args, kwargs) + return await func(*args, **kwargs) + + else: + + @wraps(func) + def wrapper(*args, **kwargs): + args, kwargs = scrap_req(func, args, kwargs) + return func(*args, **kwargs) + + sig = signature(func) + params = list(sig.parameters.values()) + + rq = get_request_param(func) + if len(rq) >= 1: + if not hasattr(func, "scrap_req"): # Ignore if already set + func.scrap_req = False + else: + func.scrap_req = True + name = "request" # Slowapi should allow for request to be anything <- param name generator + param_names = [pname.name for pname in params] + if name not in param_names: + func.req = name + + req = Parameter( + name=name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=Request + ) + params.insert(0, req) + sig = sig.replace(parameters=params) + func.__signature__ = sig + else: + fname = f"{func.__module__}.{func.__name__}" + raise Exception( + f"Remove 'request' argument from function {fname}" + f" or add [request : starlette.Request] manually." + ) + + return wrapper diff --git a/tests/test_fastapi_extension.py b/tests/test_fastapi_extension.py index 42e6322..734c0cf 100644 --- a/tests/test_fastapi_extension.py +++ b/tests/test_fastapi_extension.py @@ -144,47 +144,18 @@ async def t1(request: Request, response: Response): == 429 ) - def test_endpoint_missing_request_param(self, build_fastapi_app): - app, limiter = build_fastapi_app(key_func=get_ipaddr) - - with pytest.raises(Exception) as exc_info: - - @app.get("/t3") - @limiter.limit("5/minute") - async def t3(): - return PlainTextResponse("test") - - assert exc_info.match( - r"""^No "request" or "websocket" argument on function .*""" - ) - - def test_endpoint_missing_request_param_sync(self, build_fastapi_app): + def test_endpoint_request_param_invalid(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr) with pytest.raises(Exception) as exc_info: - @app.get("/t3_sync") + @app.get("/t4") @limiter.limit("5/minute") - def t3(): + async def t4(request: str = None): return PlainTextResponse("test") assert exc_info.match( - r"""^No "request" or "websocket" argument on function .*""" - ) - - def test_endpoint_request_param_invalid(self, build_fastapi_app): - app, limiter = build_fastapi_app(key_func=get_ipaddr) - - @app.get("/t4") - @limiter.limit("5/minute") - async def t4(request: str = None): - return PlainTextResponse("test") - - with pytest.raises(Exception) as exc_info: - client = TestClient(app) - client.get("/t4") - assert exc_info.match( - r"""parameter `request` must be an instance of starlette.requests.Request""" + r"Remove 'request' argument from function tests.test_fastapi_extension.t4 or add \[request : starlette.Request\] manually" ) def test_endpoint_response_param_invalid(self, build_fastapi_app): @@ -205,16 +176,15 @@ async def t4(request: Request, response: str = None): def test_endpoint_request_param_invalid_sync(self, build_fastapi_app): app, limiter = build_fastapi_app(key_func=get_ipaddr) - @app.get("/t5") - @limiter.limit("5/minute") - def t5(request: str = None): - return PlainTextResponse("test") - with pytest.raises(Exception) as exc_info: - client = TestClient(app) - client.get("/t5") + + @app.get("/t5") + @limiter.limit("5/minute") + def t5(request: str = None): + return PlainTextResponse("test") + assert exc_info.match( - r"""parameter `request` must be an instance of starlette.requests.Request""" + r"Remove 'request' argument from function tests.test_fastapi_extension.t5 or add \[request : starlette.Request\] manually" ) def test_endpoint_response_param_invalid_sync(self, build_fastapi_app):