Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion cads_processing_api_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class RateLimitsRouteParamConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="allow")


class RateLimitsConfig(pydantic.BaseModel):
class RateLimitsUserConfig(pydantic.BaseModel):
default: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(), validate_default=True
)
Expand Down Expand Up @@ -151,6 +151,18 @@ class RateLimitsConfig(pydantic.BaseModel):
)


class RateLimitsConfig(pydantic.BaseModel):
"""Rate limits configuration."""

auth: RateLimitsUserConfig = pydantic.Field(
default=RateLimitsUserConfig(),
description="Rate limits for authenticated users.",
)
anon: RateLimitsUserConfig = pydantic.Field(
default=RateLimitsUserConfig(), description="Rate limits for anonymous users."
)


def load_rate_limits(rate_limits_file: str | None) -> RateLimitsConfig:
rate_limits = RateLimitsConfig()
if rate_limits_file is not None:
Expand Down
14 changes: 9 additions & 5 deletions cads_processing_api_service/limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@

def get_rate_limits(
rate_limits_config: config.RateLimitsConfig,
user_type: str,
route: str,
method: str,
request_origin: str,
route_param: str | None = None,
) -> list[str]:
"""Get the rate limits for a specific route and method."""
rate_limits = rate_limits_config.model_dump()
route_rate_limits: dict[str, Any] = rate_limits.get(route, {})
user_type_rate_limits: dict[str, Any] = rate_limits.get(user_type, {})
route_rate_limits: dict[str, Any] = user_type_rate_limits.get(route, {})
if route_param is not None:
route_param_rate_limits: dict[str, Any] = route_rate_limits.get(route_param, {})
else:
Expand All @@ -50,22 +52,23 @@ def get_rate_limits(

def get_rate_limits_defaulted(
rate_limits_config: config.RateLimitsConfig,
user_type: str,
route: str,
method: str,
request_origin: str,
route_param: str | None = None,
) -> list[str]:
"""Get the rate limits for a specific route and method, with defaults."""
rate_limits = get_rate_limits(
rate_limits_config, route, method, request_origin, route_param
rate_limits_config, user_type, route, method, request_origin, route_param
)
if not rate_limits:
rate_limits = get_rate_limits(
rate_limits_config, route, method, request_origin, "default"
rate_limits_config, user_type, route, method, request_origin, "default"
)
if not rate_limits:
rate_limits = get_rate_limits(
rate_limits_config, "default", method, request_origin
rate_limits_config, user_type, "default", method, request_origin
)
return rate_limits

Expand Down Expand Up @@ -104,8 +107,9 @@ def check_rate_limits(
"""Check if the rate limits are exceeded."""
request_origin = auth_info.request_origin
user_uid = auth_info.user_uid
user_type = "anon" if auth_info.user_uid == "unauthenticated" else "auth"
rate_limits = get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin, route_param
rate_limits_config, user_type, route, method, request_origin, route_param
)
rate_limits_parsed = [limits.parse(rate_limit) for rate_limit in rate_limits]
check_rate_limits_for_user(user_uid, rate_limits_parsed)
Expand Down
21 changes: 13 additions & 8 deletions tests/test_10_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:

rate_limits_file = str(tmp_path / "rate-limits.yaml")
rate_limits = {
"/jobs/{job_id}": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
"/processes/{process_id}/constraints": {
"default": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
"process-id": {"post": {"api": ["1/second"], "ui": ["2/second"]}},
},
"auth": {
"/jobs/{job_id}": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
"/processes/{process_id}/constraints": {
"default": {"get": {"api": ["1/second"], "ui": ["2/second"]}},
"process-id": {"post": {"api": ["1/second"], "ui": ["2/second"]}},
},
}
}
with open(rate_limits_file, "w") as file:
yaml.dump(rate_limits, file)
Expand All @@ -87,7 +89,8 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:
"post": {"api": [], "ui": []},
"delete": {"api": [], "ui": []},
}
assert loaded_rate_limits["jobs_jobsid"] == expected_jobs_limits
assert "auth" in loaded_rate_limits
assert loaded_rate_limits["auth"]["jobs_jobsid"] == expected_jobs_limits
expected_process_constraints_limits = {
"default": {
"get": {"api": ["1/second"], "ui": ["2/second"]},
Expand All @@ -101,13 +104,15 @@ def test_load_rate_limits(tmp_path: pathlib.Path, caplog) -> None:
},
}
assert (
loaded_rate_limits["processes_processid_constraints"]
loaded_rate_limits["auth"]["processes_processid_constraints"]
== expected_process_constraints_limits
)

rate_limits_file = str(tmp_path / "invalid-rate-limits.yaml")
rate_limits = {
"/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}},
"auth": {
"/jobs/{job_id}": {"get": {"api": ["invalid_limit"]}},
}
}
with open(rate_limits_file, "w") as file:
yaml.dump(rate_limits, file)
Expand Down
82 changes: 47 additions & 35 deletions tests/test_30_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,68 +22,76 @@


def test_get_rate_limits() -> None:
rate_limits = {"/jobs/{job_id}": {"get": {"api": ["2/second"]}}}
rate_limits = {"auth": {"/jobs/{job_id}": {"get": {"api": ["2/second"]}}}}
rate_limits_config = config.RateLimitsConfig(**rate_limits)

user_type = "auth"
route = "jobs_jobsid"
method = "get"
request_origin = "api"
rate_limits = cads_processing_api_service.limits.get_rate_limits(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = ["2/second"]
assert rate_limits == exp_rate_limits


def test_get_rate_limits_route_param() -> None:
rate_limits = {
"/processes/{process_id}/execution": {
"process_id": {"post": {"api": ["2/second"]}}
"auth": {
"/processes/{process_id}/execution": {
"process_id": {"post": {"api": ["2/second"]}}
}
}
}
rate_limits_config = config.RateLimitsConfig(**rate_limits)

user_type = "auth"
route = "processes_processid_execution"
route_param = "process_id"
method = "post"
request_origin = "api"
rate_limits = cads_processing_api_service.limits.get_rate_limits(
rate_limits_config, route, method, request_origin, route_param
rate_limits_config, user_type, route, method, request_origin, route_param
)
exp_rate_limits = ["2/second"]
assert rate_limits == exp_rate_limits


def test_get_rate_limits_defaulted_actual_value() -> None:
rate_limits = {
"/jobs/{job_id}": {"get": {"api": ["2/second"]}},
"default": {"get": {"api": ["1/second"]}},
"auth": {
"/jobs/{job_id}": {"get": {"api": ["2/second"]}},
"default": {"get": {"api": ["1/second"]}},
}
}
rate_limits_config = config.RateLimitsConfig(**rate_limits)

user_type = "auth"
route = "jobs_jobsid"
method = "get"
request_origin = "api"
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = ["2/second"]
assert rate_limits == exp_rate_limits


def test_get_rate_limits_defaulted_default_value() -> None:
rate_limits = {
"/jobs/{job_id}": {"post": {"api": ["2/second"]}},
"/jobs": {"get": {"api": ["2/second"]}},
"default": {"post": {"ui": ["1/second"]}},
"auth": {
"/jobs/{job_id}": {"post": {"api": ["2/second"]}},
"/jobs": {"get": {"api": ["2/second"]}},
"default": {"post": {"ui": ["1/second"]}},
}
}
rate_limits_config = config.RateLimitsConfig(**rate_limits)

user_type = "auth"
route = "jobs_jobsid"
method = "post"
request_origin = "ui"
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = ["1/second"]
assert rate_limits == exp_rate_limits
Expand All @@ -92,7 +100,7 @@ def test_get_rate_limits_defaulted_default_value() -> None:
method = "post"
request_origin = "ui"
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = ["1/second"]
assert rate_limits == exp_rate_limits
Expand All @@ -101,48 +109,52 @@ def test_get_rate_limits_defaulted_default_value() -> None:
method = "post"
request_origin = "ui"
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = ["1/second"]
assert rate_limits == exp_rate_limits


def test_get_rate_limits_defaulted_route_param_actual_value() -> None:
rate_limits = {
"/processes/{process_id}/execution": {
"test_process_id": {"post": {"api": ["2/second"]}}
},
"default": {"post": {"ui": ["1/second"]}},
"auth": {
"/processes/{process_id}/execution": {
"test_process_id": {"post": {"api": ["2/second"]}}
},
"default": {"post": {"ui": ["1/second"]}},
}
}
rate_limits_config = config.RateLimitsConfig(**rate_limits)

user_type = "auth"
route = "processes_processid_execution"
method = "post"
request_origin = "api"
route_param = "test_process_id"
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin, route_param
rate_limits_config, user_type, route, method, request_origin, route_param
)
exp_rate_limits = ["2/second"]
assert rate_limits == exp_rate_limits


def test_get_rate_limits_defaulted_route_param_default_value() -> None:
rate_limits = {
"/processes/{process_id}/execution": {
"test_process_id": {"post": {"api": ["2/second"]}},
"default": {"post": {"api": ["1/second"]}},
},
"default": {"post": {"ui": ["1/minute"]}},
"auth": {
"/processes/{process_id}/execution": {
"test_process_id": {"post": {"api": ["2/second"]}},
"default": {"post": {"api": ["1/second"]}},
},
"default": {"post": {"ui": ["1/minute"]}},
}
}
rate_limits_config = config.RateLimitsConfig(**rate_limits)

user_type = "auth"
route = "processes_processid_execution"
method = "post"
request_origin = "api"
route_param = "missing_test_process_id"
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin, route_param
rate_limits_config, user_type, route, method, request_origin, route_param
)
exp_rate_limits = ["1/second"]
assert rate_limits == exp_rate_limits
Expand All @@ -152,21 +164,21 @@ def test_get_rate_limits_defaulted_route_param_default_value() -> None:
request_origin = "ui"
route_param = "missing_test_process_id"
rate_limits = cads_processing_api_service.limits.get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin, route_param
rate_limits_config, user_type, route, method, request_origin, route_param
)
exp_rate_limits = ["1/minute"]
assert rate_limits == exp_rate_limits


def test_get_rate_limits_undefined() -> None:
rate_limits = {"/jobs": {"get": {"api": ["2/second"]}}}
rate_limits = {"auth": {"/jobs": {"get": {"api": ["2/second"]}}}}
rate_limits_config = config.RateLimitsConfig.model_validate(rate_limits)

user_type = "auth"
route = "jobs"
method = "get"
request_origin = "ui"
rate_limits = cads_processing_api_service.limits.get_rate_limits(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = []
assert rate_limits == exp_rate_limits
Expand All @@ -175,7 +187,7 @@ def test_get_rate_limits_undefined() -> None:
method = "post"
request_origin = "ui"
rate_limits = cads_processing_api_service.limits.get_rate_limits(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = []
assert rate_limits == exp_rate_limits
Expand All @@ -184,7 +196,7 @@ def test_get_rate_limits_undefined() -> None:
method = "get"
request_origin = "ui"
rate_limits = cads_processing_api_service.limits.get_rate_limits(
rate_limits_config, route, method, request_origin
rate_limits_config, user_type, route, method, request_origin
)
exp_rate_limits = []
assert rate_limits == exp_rate_limits
Expand Down
Loading