diff --git a/compose_runner/aws_lambda/cost_check_handler.py b/compose_runner/aws_lambda/cost_check_handler.py new file mode 100644 index 0000000..0fcdf7f --- /dev/null +++ b/compose_runner/aws_lambda/cost_check_handler.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import datetime as _dt +import logging +import os +from decimal import Decimal +from typing import Any, Dict + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +_CE_CLIENT = boto3.client("ce", region_name=os.environ.get("AWS_REGION", "us-east-1")) + +COST_LIMIT_ENV = "COST_LIMIT_USD" + + +def _month_range(today: _dt.date) -> Dict[str, str]: + start = today.replace(day=1) + # Cost Explorer end date is exclusive; add a day to include today. + end = today + _dt.timedelta(days=1) + return {"Start": start.isoformat(), "End": end.isoformat()} + + +def _current_month_cost() -> Dict[str, Any]: + period = _month_range(_dt.date.today()) + response = _CE_CLIENT.get_cost_and_usage( + TimePeriod=period, + Granularity="MONTHLY", + Metrics=["UnblendedCost"], + ) + results = response.get("ResultsByTime", []) + total = results[0]["Total"]["UnblendedCost"] if results else {"Amount": "0", "Unit": "USD"} + amount = float(Decimal(total.get("Amount", "0"))) + currency = total.get("Unit", "USD") + return {"amount": amount, "currency": currency, "time_period": period} + + +def handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + limit_raw = os.environ.get(COST_LIMIT_ENV) + if not limit_raw: + raise RuntimeError(f"{COST_LIMIT_ENV} environment variable must be set.") + + try: + limit = float(limit_raw) + except ValueError as exc: # noqa: PERF203 + raise RuntimeError(f"Invalid {COST_LIMIT_ENV}: {limit_raw}") from exc + + try: + cost = _current_month_cost() + except (ClientError, BotoCoreError) as exc: + logger.error("Failed to query Cost Explorer: %s", exc) + return { + "status": "ERROR", + "allowed": False, + "error": "cost_explorer_unavailable", + "limit": limit, + } + + amount = cost["amount"] + allowed = amount < limit + return { + "status": "OK", + "allowed": allowed, + "current_spend": amount, + "limit": limit, + "currency": cost.get("currency", "USD"), + "time_period": cost.get("time_period"), + } diff --git a/infra/cdk/stacks/compose_runner_stack.py b/infra/cdk/stacks/compose_runner_stack.py index 10a417b..3cc99ee 100644 --- a/infra/cdk/stacks/compose_runner_stack.py +++ b/infra/cdk/stacks/compose_runner_stack.py @@ -36,6 +36,7 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non poll_memory_size = int(self.node.try_get_context("pollMemorySize") or 512) poll_timeout_seconds = int(self.node.try_get_context("pollTimeoutSeconds") or 30) poll_lookback_ms = int(self.node.try_get_context("pollLookbackMs") or 3600000) + monthly_spend_limit_usd = float(self.node.try_get_context("monthlySpendLimit") or 100) task_cpu = int(self.node.try_get_context("taskCpu") or 4096) task_memory_mib = int(self.node.try_get_context("taskMemoryMiB") or 30720) @@ -243,6 +244,31 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non max_attempts=2, ) + cost_check_code = lambda_.DockerImageCode.from_image_asset( + str(project_root), + file="aws_lambda/Dockerfile", + cmd=["compose_runner.aws_lambda.cost_check_handler.handler"], + build_args=build_args, + ) + + cost_check_function = lambda_.DockerImageFunction( + self, + "ComposeRunnerCostCheck", + code=cost_check_code, + memory_size=256, + timeout=Duration.seconds(15), + environment={ + "COST_LIMIT_USD": str(monthly_spend_limit_usd), + }, + description="Blocks executions when monthly spend exceeds the configured limit.", + ) + cost_check_function.add_to_role_policy( + iam.PolicyStatement( + actions=["ce:GetCostAndUsage"], + resources=["*"], + ) + ) + run_output = sfn.Pass( self, "ComposeRunnerOutput", @@ -256,7 +282,7 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non }, ) - definition_chain = sfn.Choice( + task_selection = sfn.Choice( self, "SelectFargateTask", ).when( @@ -266,6 +292,28 @@ def __init__(self, scope: Construct, construct_id: str, **kwargs: object) -> Non run_task_standard.next(run_output) ) + cost_limit_exceeded = sfn.Fail( + self, + "CostLimitExceeded", + cause="Monthly spend limit exceeded.", + error="CostLimitExceeded", + ) + + enforce_cost_limit = sfn.Choice(self, "EnforceMonthlyCostLimit").when( + sfn.Condition.boolean_equals("$.cost_check.Payload.allowed", False), + cost_limit_exceeded, + ).otherwise(task_selection) + + cost_check_step = tasks.LambdaInvoke( + self, + "CheckMonthlyCost", + lambda_function=cost_check_function, + payload=sfn.TaskInput.from_object({"stateInput.$": "$"}), + result_path="$.cost_check", + ) + + definition_chain = cost_check_step.next(enforce_cost_limit) + state_machine = sfn.StateMachine( self, "ComposeRunnerStateMachine",