From 3cc311712a2fc5f493801dbe788e3ab4c14c843a Mon Sep 17 00:00:00 2001 From: Radu Suciu Date: Fri, 14 Nov 2025 21:45:14 -0800 Subject: [PATCH] fix: issue with resource validation assuming Fargate #34 As described in issue #34, validation assumed a Fargate compute environment. We can detect the compute environment by looking up the queue, and only run the Fargate specific validation when necessary. --- .../batch_client.py | 18 +++ .../batch_job_builder.py | 109 ++++++++++++++++-- 2 files changed, 118 insertions(+), 9 deletions(-) diff --git a/snakemake_executor_plugin_aws_batch/batch_client.py b/snakemake_executor_plugin_aws_batch/batch_client.py index 40a8e5d..7146a16 100644 --- a/snakemake_executor_plugin_aws_batch/batch_client.py +++ b/snakemake_executor_plugin_aws_batch/batch_client.py @@ -72,3 +72,21 @@ def terminate_job(self, **kwargs): :return: The response from the terminate_job method. """ return self.client.terminate_job(**kwargs) + + def describe_job_queues(self, **kwargs): + """ + Describe job queues in AWS Batch. + + :param kwargs: The keyword arguments to pass to the describe_job_queues method. + :return: The response from the describe_job_queues method. + """ + return self.client.describe_job_queues(**kwargs) + + def describe_compute_environments(self, **kwargs): + """ + Describe compute environments in AWS Batch. + + :param kwargs: The keyword arguments to pass to the describe_compute_environments method. + :return: The response from the describe_compute_environments method. + """ + return self.client.describe_compute_environments(**kwargs) diff --git a/snakemake_executor_plugin_aws_batch/batch_job_builder.py b/snakemake_executor_plugin_aws_batch/batch_job_builder.py index cff9f8b..cab0775 100644 --- a/snakemake_executor_plugin_aws_batch/batch_job_builder.py +++ b/snakemake_executor_plugin_aws_batch/batch_job_builder.py @@ -30,6 +30,8 @@ def __init__( self.job_command = job_command self.batch_client = batch_client self.created_job_defs = [] + # Determine platform from job queue + self.platform = self._get_platform_from_queue() def _make_container_command(self, remote_command: str) -> List[str]: """ @@ -37,27 +39,116 @@ def _make_container_command(self, remote_command: str) -> List[str]: """ return ["/bin/bash", "-c", remote_command] - def _validate_resources(self, vcpu: str, mem: str) -> tuple[str, str]: - """Validates vcpu and meme conform to Batch EC2 cpu/mem relationship + def _get_platform_from_queue(self) -> str: + """ + Determine the platform (EC2 or FARGATE) from the job queue's compute environments. - https://docs.aws.amazon.com/batch/latest/APIReference/API_ResourceRequirement.html + :return: Platform capability string (EC2 or FARGATE) """ - vcpu = int(vcpu) - mem = int(mem) + try: + # Query the job queue + queue_response = self.batch_client.describe_job_queues( + jobQueues=[self.settings.job_queue] + ) + + if not queue_response.get("jobQueues"): + self.logger.warning( + f"Job queue {self.settings.job_queue} not found. Defaulting to EC2." + ) + return BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value + + job_queue = queue_response["jobQueues"][0] + compute_env_order = job_queue.get("computeEnvironmentOrder", []) + + if not compute_env_order: + self.logger.warning( + f"No compute environments found for queue {self.settings.job_queue}. " + "Defaulting to EC2." + ) + return BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value + + # Get the first compute environment ARN + compute_env_arn = compute_env_order[0]["computeEnvironment"] + + # Query the compute environment to get its type + env_response = self.batch_client.describe_compute_environments( + computeEnvironments=[compute_env_arn] + ) + + if not env_response.get("computeEnvironments"): + self.logger.warning( + f"Compute environment {compute_env_arn} not found. Defaulting to EC2." + ) + return BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value + + compute_env = env_response["computeEnvironments"][0] + + # Check if it's a Fargate environment + # Fargate environments have computeResources.type == "FARGATE" or "FARGATE_SPOT" + compute_resources = compute_env.get("computeResources", {}) + resource_type = compute_resources.get("type", "") + + if resource_type in ["FARGATE", "FARGATE_SPOT"]: + self.logger.info( + f"Detected FARGATE platform from queue {self.settings.job_queue}" + ) + return BATCH_JOB_PLATFORM_CAPABILITIES.FARGATE.value + else: + self.logger.info( + f"Detected EC2 platform from queue {self.settings.job_queue}" + ) + return BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value + except Exception as e: + self.logger.warning( + f"Failed to determine platform from queue: {e}. Defaulting to EC2." + ) + return BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value + + def _validate_fargate_resources(self, vcpu: int, mem: int) -> tuple[str, str]: + """Validates vcpu and memory conform to Fargate requirements. + + Fargate requires strict memory/vCPU combinations. + https://docs.aws.amazon.com/batch/latest/userguide/fargate.html + """ if mem in VALID_RESOURCES_MAPPING: if vcpu in VALID_RESOURCES_MAPPING[mem]: return str(vcpu), str(mem) else: - raise WorkflowError(f"Invalid vCPU value {vcpu} for memory {mem} MB") + raise WorkflowError(f"Invalid vCPU value {vcpu} for memory {mem} MB on Fargate") else: min_mem = min([m for m, v in VALID_RESOURCES_MAPPING.items() if vcpu in v]) self.logger.warning( - f"Memory value {mem} MB is invalid for vCPU {vcpu}." + f"Memory value {mem} MB is invalid for vCPU {vcpu} on Fargate. " f"Setting memory to minimum allowed value {min_mem} MB." ) return str(vcpu), str(min_mem) + def _validate_ec2_resources(self, vcpu: int, mem: int) -> tuple[str, str]: + """Validates vcpu and memory for EC2 compute environments. + + EC2 allows flexible resource allocation - just basic sanity checks. + https://docs.aws.amazon.com/batch/latest/userguide/compute_environment_parameters.html + """ + if vcpu < 1: + raise WorkflowError(f"vCPU must be at least 1, got {vcpu}") + if mem < 1024: + raise WorkflowError(f"Memory must be at least 1024 MiB, got {mem} MiB") + return str(vcpu), str(mem) + + def _validate_resources(self, vcpu: str, mem: str) -> tuple[str, str]: + """Validates vcpu and memory based on platform requirements. + + https://docs.aws.amazon.com/batch/latest/APIReference/API_ResourceRequirement.html + """ + vcpu_int = int(vcpu) + mem_int = int(mem) + + if self.platform == BATCH_JOB_PLATFORM_CAPABILITIES.FARGATE.value: + return self._validate_fargate_resources(vcpu_int, mem_int) + else: + return self._validate_ec2_resources(vcpu_int, mem_int) + def build_job_definition(self): job_uuid = str(uuid.uuid4()) job_name = f"snakejob-{self.job.name}-{job_uuid}" @@ -66,7 +157,7 @@ def build_job_definition(self): # Validate and convert resources gpu = max(0, int(self.job.resources.get("_gpus", 0))) vcpu = max(1, int(self.job.resources.get("_cores", 1))) # Default to 1 vCPU - mem = max(1, int(self.job.resources.get("mem_mb", 2048))) # Default to 2048 MiB + mem = max(1, int(self.job.resources.get("mem_mb", 1024))) # Default to 1024 MiB vcpu_str, mem_str = self._validate_resources(str(vcpu), str(mem)) gpu_str = str(gpu) @@ -111,7 +202,7 @@ def build_job_definition(self): containerProperties=container_properties, timeout=timeout, tags=tags, - platformCapabilities=[BATCH_JOB_PLATFORM_CAPABILITIES.EC2.value], + platformCapabilities=[self.platform], ) self.created_job_defs.append(job_def) return job_def, job_name