Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cads_processing_api_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class RateLimitsRouteParamConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="allow")


class RateLimitsConfig(pydantic.BaseModel):
class RateLimitsUserConfig(pydantic.BaseModel):
default: RateLimitsRouteConfig = pydantic.Field(
default=RateLimitsRouteConfig(), validate_default=True
)
Expand Down Expand Up @@ -151,6 +151,14 @@ class RateLimitsConfig(pydantic.BaseModel):
)


class RateLimitsConfig(RateLimitsUserConfig):
"""Rate limits configuration for the service."""

unauthenticated: RateLimitsUserConfig = pydantic.Field(
default=RateLimitsUserConfig(), validate_default=True
)


def load_rate_limits(rate_limits_file: str | None) -> RateLimitsConfig:
rate_limits = RateLimitsConfig()
if rate_limits_file is not None:
Expand Down
32 changes: 28 additions & 4 deletions cads_processing_api_service/limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def get_rate_limits(
rate_limits_config: config.RateLimitsConfig,
rate_limits_config: config.RateLimitsUserConfig,
route: str,
method: str,
request_origin: str,
Expand All @@ -49,7 +49,7 @@ def get_rate_limits(


def get_rate_limits_defaulted(
rate_limits_config: config.RateLimitsConfig,
rate_limits_config: config.RateLimitsUserConfig,
route: str,
method: str,
request_origin: str,
Expand All @@ -70,6 +70,30 @@ def get_rate_limits_defaulted(
return rate_limits


def get_rate_limits_for_user(
rate_limits_config: config.RateLimitsConfig,
user_uid: str,
route: str,
method: str,
request_origin: str,
route_param: str | None = None,
) -> list[str]:
rate_limits = []
if user_uid == "unauthenticated":
rate_limits = get_rate_limits_defaulted(
rate_limits_config.unauthenticated,
route,
method,
request_origin,
route_param,
)
if not rate_limits:
rate_limits = get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin, route_param
)
return rate_limits


def check_rate_limits_for_user(
user_uid: str, rate_limits: list[limits.RateLimitItem]
) -> None:
Expand Down Expand Up @@ -104,8 +128,8 @@ def check_rate_limits(
"""Check if the rate limits are exceeded."""
request_origin = auth_info.request_origin
user_uid = auth_info.user_uid
rate_limits = get_rate_limits_defaulted(
rate_limits_config, route, method, request_origin, route_param
rate_limits = get_rate_limits_for_user(
rate_limits_config, user_uid, route, method, request_origin, route_param
)
rate_limits_parsed = [limits.parse(rate_limit) for rate_limit in rate_limits]
check_rate_limits_for_user(user_uid, rate_limits_parsed)
Expand Down
90 changes: 90 additions & 0 deletions tests/test_30_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,96 @@ def test_get_rate_limits_undefined() -> None:
assert rate_limits == exp_rate_limits


def test_get_rate_limits_for_user_unauthenticated() -> None:
rate_limits = {
"default": {
"get": {"api": ["5/second"]},
"post": {"api": ["10/second"]},
},
"/jobs/{job_id}": {"delete": {"api": ["1/second"]}},
"unauthenticated": {
"default": {"post": {"api": ["2/second"]}},
"/jobs/{job_id}": {"get": {"api": ["3/second"]}},
},
}
rate_limits_config = config.RateLimitsConfig.model_validate(rate_limits)

route = "jobs_jobsid"
method = "get"
request_origin = "api"
user_uid = "unauthenticated"
rate_limits = cads_processing_api_service.limits.get_rate_limits_for_user(
rate_limits_config, user_uid, route, method, request_origin
)
exp_rate_limits = ["3/second"]
assert rate_limits == exp_rate_limits

route = "jobs_jobsid"
method = "post"
request_origin = "api"
user_uid = "unauthenticated"
rate_limits = cads_processing_api_service.limits.get_rate_limits_for_user(
rate_limits_config, user_uid, route, method, request_origin
)
exp_rate_limits = ["2/second"]
assert rate_limits == exp_rate_limits

route = "jobs_jobsid"
method = "delete"
request_origin = "api"
user_uid = "unauthenticated"
rate_limits = cads_processing_api_service.limits.get_rate_limits_for_user(
rate_limits_config, user_uid, route, method, request_origin
)
exp_rate_limits = ["1/second"]
assert rate_limits == exp_rate_limits


def test_get_rate_limits_for_user_authenticated() -> None:
rate_limits = {
"default": {
"get": {"api": ["5/second"]},
"post": {"api": ["10/second"]},
},
"/jobs/{job_id}": {"delete": {"api": ["1/second"]}},
"unauthenticated": {
"default": {"post": {"api": ["2/second"]}},
"/jobs/{job_id}": {"get": {"api": ["3/second"]}},
},
}
rate_limits_config = config.RateLimitsConfig.model_validate(rate_limits)

route = "jobs_jobsid"
method = "get"
request_origin = "api"
user_uid = "user_uid"
rate_limits = cads_processing_api_service.limits.get_rate_limits_for_user(
rate_limits_config, user_uid, route, method, request_origin
)
exp_rate_limits = ["5/second"]
assert rate_limits == exp_rate_limits

route = "jobs_jobsid"
method = "post"
request_origin = "api"
user_uid = "user_uid"
rate_limits = cads_processing_api_service.limits.get_rate_limits_for_user(
rate_limits_config, user_uid, route, method, request_origin
)
exp_rate_limits = ["10/second"]
assert rate_limits == exp_rate_limits

route = "jobs_jobsid"
method = "delete"
request_origin = "api"
user_uid = "user_uid"
rate_limits = cads_processing_api_service.limits.get_rate_limits_for_user(
rate_limits_config, user_uid, route, method, request_origin
)
exp_rate_limits = ["1/second"]
assert rate_limits == exp_rate_limits


def test_check_rate_limits_for_user() -> None:
rate_limit_ids = ["1/second"]
rate_limits = [limits.parse(rate_limit_id) for rate_limit_id in rate_limit_ids]
Expand Down
Loading