diff --git a/cads_processing_api_service/config.py b/cads_processing_api_service/config.py index 233b423..3ff41aa 100644 --- a/cads_processing_api_service/config.py +++ b/cads_processing_api_service/config.py @@ -116,7 +116,7 @@ class RateLimitsRouteParamConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra="allow") -class RateLimitsConfig(pydantic.BaseModel): +class RateLimitsUserConfig(pydantic.BaseModel): default: RateLimitsRouteConfig = pydantic.Field( default=RateLimitsRouteConfig(), validate_default=True ) @@ -151,6 +151,18 @@ class RateLimitsConfig(pydantic.BaseModel): ) +class RateLimitsConfig(pydantic.BaseModel): + """Rate limits configuration.""" + + auth: RateLimitsUserConfig = pydantic.Field( + default=RateLimitsUserConfig(), + description="Rate limits for authenticated users.", + ) + anon: RateLimitsUserConfig = pydantic.Field( + default=RateLimitsUserConfig(), description="Rate limits for anonymous users." + ) + + def load_rate_limits(rate_limits_file: str | None) -> RateLimitsConfig: rate_limits = RateLimitsConfig() if rate_limits_file is not None: diff --git a/cads_processing_api_service/limits.py b/cads_processing_api_service/limits.py index 20e7b46..cce26bc 100644 --- a/cads_processing_api_service/limits.py +++ b/cads_processing_api_service/limits.py @@ -31,6 +31,7 @@ def get_rate_limits( rate_limits_config: config.RateLimitsConfig, + user_type: str, route: str, method: str, request_origin: str, @@ -38,7 +39,8 @@ def get_rate_limits( ) -> list[str]: """Get the rate limits for a specific route and method.""" rate_limits = rate_limits_config.model_dump() - route_rate_limits: dict[str, Any] = rate_limits.get(route, {}) + user_type_rate_limits: dict[str, Any] = rate_limits.get(user_type, {}) + route_rate_limits: dict[str, Any] = user_type_rate_limits.get(route, {}) if route_param is not None: route_param_rate_limits: dict[str, Any] = route_rate_limits.get(route_param, {}) else: @@ -50,6 +52,7 @@ def get_rate_limits( def get_rate_limits_defaulted( rate_limits_config: config.RateLimitsConfig, + user_type: str, route: str, method: str, request_origin: str, @@ -57,15 +60,15 @@ def get_rate_limits_defaulted( ) -> list[str]: """Get the rate limits for a specific route and method, with defaults.""" rate_limits = get_rate_limits( - rate_limits_config, route, method, request_origin, route_param + rate_limits_config, user_type, route, method, request_origin, route_param ) if not rate_limits: rate_limits = get_rate_limits( - rate_limits_config, route, method, request_origin, "default" + rate_limits_config, user_type, route, method, request_origin, "default" ) if not rate_limits: rate_limits = get_rate_limits( - rate_limits_config, "default", method, request_origin + rate_limits_config, user_type, "default", method, request_origin ) return rate_limits @@ -104,8 +107,9 @@ def check_rate_limits( """Check if the rate limits are exceeded.""" request_origin = auth_info.request_origin user_uid = auth_info.user_uid + user_type = "anon" if auth_info.user_uid == "unauthenticated" else "auth" rate_limits = get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin, route_param + rate_limits_config, user_type, route, method, request_origin, route_param ) rate_limits_parsed = [limits.parse(rate_limit) for rate_limit in rate_limits] check_rate_limits_for_user(user_uid, rate_limits_parsed) diff --git a/tests/test_10_config.py b/tests/test_10_config.py index 3ded98d..1f3294d 100644 --- a/tests/test_10_config.py +++ b/tests/test_10_config.py @@ -73,11 +73,13 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None: rate_limits_file = str(tmp_path / "rate-limits.yaml") rate_limits = { - "/jobs/{job_id}": {"get": {"api": ["1/second"], "ui": ["2/second"]}}, - "/processes/{process_id}/constraints": { - "default": {"get": {"api": ["1/second"], "ui": ["2/second"]}}, - "process-id": {"post": {"api": ["1/second"], "ui": ["2/second"]}}, - }, + "auth": { + "/jobs/{job_id}": {"get": {"api": ["1/second"], "ui": ["2/second"]}}, + "/processes/{process_id}/constraints": { + "default": {"get": {"api": ["1/second"], "ui": ["2/second"]}}, + "process-id": {"post": {"api": ["1/second"], "ui": ["2/second"]}}, + }, + } } with open(rate_limits_file, "w") as file: yaml.dump(rate_limits, file) @@ -87,7 +89,8 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None: "post": {"api": [], "ui": []}, "delete": {"api": [], "ui": []}, } - assert loaded_rate_limits["jobs_jobsid"] == expected_jobs_limits + assert "auth" in loaded_rate_limits + assert loaded_rate_limits["auth"]["jobs_jobsid"] == expected_jobs_limits expected_process_constraints_limits = { "default": { "get": {"api": ["1/second"], "ui": ["2/second"]}, @@ -101,13 +104,15 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None: }, } assert ( - loaded_rate_limits["processes_processid_constraints"] + loaded_rate_limits["auth"]["processes_processid_constraints"] == expected_process_constraints_limits ) rate_limits_file = str(tmp_path / "invalid-rate-limits.yaml") rate_limits = { - "/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}}, + "auth": { + "/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}}, + } } with open(rate_limits_file, "w") as file: yaml.dump(rate_limits, file) diff --git a/tests/test_30_limits.py b/tests/test_30_limits.py index 43a5d35..67764e6 100644 --- a/tests/test_30_limits.py +++ b/tests/test_30_limits.py @@ -22,14 +22,15 @@ def test_get_rate_limits() -> None: - rate_limits = {"/jobs/{job_id}": {"get": {"api": ["2/second"]}}} + rate_limits = {"auth": {"/jobs/{job_id}": {"get": {"api": ["2/second"]}}}} rate_limits_config = config.RateLimitsConfig(**rate_limits) + user_type = "auth" route = "jobs_jobsid" method = "get" request_origin = "api" rate_limits = cads_processing_api_service.limits.get_rate_limits( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = ["2/second"] assert rate_limits == exp_rate_limits @@ -37,18 +38,21 @@ def test_get_rate_limits() -> None: def test_get_rate_limits_route_param() -> None: rate_limits = { - "/processes/{process_id}/execution": { - "process_id": {"post": {"api": ["2/second"]}} + "auth": { + "/processes/{process_id}/execution": { + "process_id": {"post": {"api": ["2/second"]}} + } } } rate_limits_config = config.RateLimitsConfig(**rate_limits) + user_type = "auth" route = "processes_processid_execution" route_param = "process_id" method = "post" request_origin = "api" rate_limits = cads_processing_api_service.limits.get_rate_limits( - rate_limits_config, route, method, request_origin, route_param + rate_limits_config, user_type, route, method, request_origin, route_param ) exp_rate_limits = ["2/second"] assert rate_limits == exp_rate_limits @@ -56,16 +60,18 @@ def test_get_rate_limits_route_param() -> None: def test_get_rate_limits_defaulted_actual_value() -> None: rate_limits = { - "/jobs/{job_id}": {"get": {"api": ["2/second"]}}, - "default": {"get": {"api": ["1/second"]}}, + "auth": { + "/jobs/{job_id}": {"get": {"api": ["2/second"]}}, + "default": {"get": {"api": ["1/second"]}}, + } } rate_limits_config = config.RateLimitsConfig(**rate_limits) - + user_type = "auth" route = "jobs_jobsid" method = "get" request_origin = "api" rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = ["2/second"] assert rate_limits == exp_rate_limits @@ -73,17 +79,19 @@ def test_get_rate_limits_defaulted_actual_value() -> None: def test_get_rate_limits_defaulted_default_value() -> None: rate_limits = { - "/jobs/{job_id}": {"post": {"api": ["2/second"]}}, - "/jobs": {"get": {"api": ["2/second"]}}, - "default": {"post": {"ui": ["1/second"]}}, + "auth": { + "/jobs/{job_id}": {"post": {"api": ["2/second"]}}, + "/jobs": {"get": {"api": ["2/second"]}}, + "default": {"post": {"ui": ["1/second"]}}, + } } rate_limits_config = config.RateLimitsConfig(**rate_limits) - + user_type = "auth" route = "jobs_jobsid" method = "post" request_origin = "ui" rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = ["1/second"] assert rate_limits == exp_rate_limits @@ -92,7 +100,7 @@ def test_get_rate_limits_defaulted_default_value() -> None: method = "post" request_origin = "ui" rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = ["1/second"] assert rate_limits == exp_rate_limits @@ -101,7 +109,7 @@ def test_get_rate_limits_defaulted_default_value() -> None: method = "post" request_origin = "ui" rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = ["1/second"] assert rate_limits == exp_rate_limits @@ -109,19 +117,21 @@ def test_get_rate_limits_defaulted_default_value() -> None: def test_get_rate_limits_defaulted_route_param_actual_value() -> None: rate_limits = { - "/processes/{process_id}/execution": { - "test_process_id": {"post": {"api": ["2/second"]}} - }, - "default": {"post": {"ui": ["1/second"]}}, + "auth": { + "/processes/{process_id}/execution": { + "test_process_id": {"post": {"api": ["2/second"]}} + }, + "default": {"post": {"ui": ["1/second"]}}, + } } rate_limits_config = config.RateLimitsConfig(**rate_limits) - + user_type = "auth" route = "processes_processid_execution" method = "post" request_origin = "api" route_param = "test_process_id" rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin, route_param + rate_limits_config, user_type, route, method, request_origin, route_param ) exp_rate_limits = ["2/second"] assert rate_limits == exp_rate_limits @@ -129,20 +139,22 @@ def test_get_rate_limits_defaulted_route_param_actual_value() -> None: def test_get_rate_limits_defaulted_route_param_default_value() -> None: rate_limits = { - "/processes/{process_id}/execution": { - "test_process_id": {"post": {"api": ["2/second"]}}, - "default": {"post": {"api": ["1/second"]}}, - }, - "default": {"post": {"ui": ["1/minute"]}}, + "auth": { + "/processes/{process_id}/execution": { + "test_process_id": {"post": {"api": ["2/second"]}}, + "default": {"post": {"api": ["1/second"]}}, + }, + "default": {"post": {"ui": ["1/minute"]}}, + } } rate_limits_config = config.RateLimitsConfig(**rate_limits) - + user_type = "auth" route = "processes_processid_execution" method = "post" request_origin = "api" route_param = "missing_test_process_id" rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin, route_param + rate_limits_config, user_type, route, method, request_origin, route_param ) exp_rate_limits = ["1/second"] assert rate_limits == exp_rate_limits @@ -152,21 +164,21 @@ def test_get_rate_limits_defaulted_route_param_default_value() -> None: request_origin = "ui" route_param = "missing_test_process_id" rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted( - rate_limits_config, route, method, request_origin, route_param + rate_limits_config, user_type, route, method, request_origin, route_param ) exp_rate_limits = ["1/minute"] assert rate_limits == exp_rate_limits def test_get_rate_limits_undefined() -> None: - rate_limits = {"/jobs": {"get": {"api": ["2/second"]}}} + rate_limits = {"auth": {"/jobs": {"get": {"api": ["2/second"]}}}} rate_limits_config = config.RateLimitsConfig.model_validate(rate_limits) - + user_type = "auth" route = "jobs" method = "get" request_origin = "ui" rate_limits = cads_processing_api_service.limits.get_rate_limits( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = [] assert rate_limits == exp_rate_limits @@ -175,7 +187,7 @@ def test_get_rate_limits_undefined() -> None: method = "post" request_origin = "ui" rate_limits = cads_processing_api_service.limits.get_rate_limits( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = [] assert rate_limits == exp_rate_limits @@ -184,7 +196,7 @@ def test_get_rate_limits_undefined() -> None: method = "get" request_origin = "ui" rate_limits = cads_processing_api_service.limits.get_rate_limits( - rate_limits_config, route, method, request_origin + rate_limits_config, user_type, route, method, request_origin ) exp_rate_limits = [] assert rate_limits == exp_rate_limits