From d5a10f1c73bae554273c3dc9e7fbee350385d345 Mon Sep 17 00:00:00 2001 From: Nikolai Petukhov Date: Fri, 27 Feb 2026 12:37:50 -0300 Subject: [PATCH 1/3] Add aws container runner and refactor --- agent/worker/container_runner/aws_config.json | 7 + agent/worker/container_runner/aws_runner.py | 306 ++++++++++ agent/worker/container_runner/aws_utils.py | 560 ++++++++++++++++++ .../container_runner/container_runner.py | 82 +++ agent/worker/container_runner/local.py | 182 ++++++ agent/worker/task_app.py | 109 ++-- agent/worker/task_dockerized.py | 34 +- requirements.txt | 2 + 8 files changed, 1190 insertions(+), 92 deletions(-) create mode 100644 agent/worker/container_runner/aws_config.json create mode 100644 agent/worker/container_runner/aws_runner.py create mode 100644 agent/worker/container_runner/aws_utils.py create mode 100644 agent/worker/container_runner/container_runner.py create mode 100644 agent/worker/container_runner/local.py diff --git a/agent/worker/container_runner/aws_config.json b/agent/worker/container_runner/aws_config.json new file mode 100644 index 0000000..29c0d08 --- /dev/null +++ b/agent/worker/container_runner/aws_config.json @@ -0,0 +1,7 @@ +{ + "cluster": "gpu-cloud", + "ecr_host": "952013232994.dkr.ecr.us-east-1.amazonaws.com", + "capacity_provider": "Infra-ECS-Cluster-gpu-cloud-bff8b004-ManagedInstancesCapacityProvider-znOnkWBWcrTv", + "task_definition": "cpu-task-3gb", + "mirroring_image_task_definition": "mirror-image" +} \ No newline at end of file diff --git a/agent/worker/container_runner/aws_runner.py b/agent/worker/container_runner/aws_runner.py new file mode 100644 index 0000000..b81b4ad --- /dev/null +++ b/agent/worker/container_runner/aws_runner.py @@ -0,0 +1,306 @@ +import json +from pathlib import Path +import shlex +import time +from worker.container_runner.container_runner import ( + BaseContainer, + BaseContainerRunner, +) +import os +from typing import Dict, Generator, List, Literal, Optional, Union + +from worker.container_runner.aws_utils import ( + mirror_image_to_ecr, + run_container_ec2, + get_boto3_client, + stream_task_logs, + ECSConfig, +) + + +import json +import uuid + +import construct as c +import websocket + + +def parse_mem_limit_to_bytes(mem_limit) -> int: + if isinstance(mem_limit, (int, float)): + return int(mem_limit) + + if isinstance(mem_limit, str): + if mem_limit == "": + return None + mem_limit = mem_limit.strip().lower() + units = {"b": 1, "k": 1024, "m": 1024**2, "g": 1024**3} + if mem_limit[-1] in units: + return int(float(mem_limit[:-1]) * units[mem_limit[-1]]) + return int(mem_limit) # no unit, assume bytes + + if mem_limit is None: + return None + + raise ValueError(f"Unsupported mem_limit type: {type(mem_limit)}") +class AWSContainerExec: + def __init__(self, session: dict, exec_id: str): + self._session = session + self._exec_id = exec_id + self._exit_code = None + self._connection = self._init_connection() + + def _init_connection(self): + if self._session is None: + raise RuntimeError("No active exec session found") + + connection = websocket.create_connection(self._session["streamUrl"]) + init_payload = { + "MessageSchemaVersion": "1.0", + "RequestId": str(uuid.uuid4()), + "TokenValue": self._session["tokenValue"], + } + connection.send(json.dumps(init_payload)) + return connection + + def stream_logs(self) -> Generator[str, None, None]: + AgentMessageHeader = c.Struct( + "HeaderLength" / c.Int32ub, + "MessageType" / c.PaddedString(32, "ascii"), + ) + AgentMessagePayload = c.Struct( + "PayloadLength" / c.Int32ub, + "Payload" / c.PaddedString(c.this.PayloadLength, "ascii"), + ) + + try: + while True: + response = self._connection.recv() + message = AgentMessageHeader.parse(response) + message_type = message.MessageType.strip() + + if "channel_closed" in message_type: + break + + if "output_stream_data" in message_type: + payload_message = AgentMessagePayload.parse( + response[message.HeaderLength :] + ) + for line in payload_message.Payload.splitlines(): + yield line + + if "exit_code" in message_type: + payload_message = AgentMessagePayload.parse( + response[message.HeaderLength :] + ) + self._exit_code = int(payload_message.Payload.strip()) + + finally: + self.close() + + def get_exit_code(self) -> Optional[int]: + return self._exit_code + + def close(self): + if self._connection is not None: + try: + self._connection.close() + except Exception: + pass + self._connection = None + + def __del__(self): + self.close() + + +class AWSContainer(BaseContainer): + def __init__( + self, + task_arn: str, + container_name: str, + task_definition_arn: str, + ecs_config: ECSConfig, + ): + self._task_arn = task_arn + self._container_name = container_name + self._task_definition_arn = task_definition_arn + self._ecs_config = ecs_config + self._ecs_client = get_boto3_client("ecs", self._ecs_config.region) + self._logs_token = None + self._session = None + + def _describe_task(self) -> dict: + response = self._ecs_client.describe_tasks( + cluster=self._ecs_config.cluster, tasks=[self._task_arn] + ) + if not response["tasks"]: + raise RuntimeError(f"Task {self._task_arn} not found") + return response["tasks"][0] + + def _get_status(self) -> str: + """Returns ECS last status: PROVISIONING, PENDING, RUNNING, DEPROVISIONING, STOPPED.""" + return self._describe_task()["lastStatus"] + + def stop(self, *, timeout: Optional[float] = None): + self._ecs_client.stop_task( + cluster=self._ecs_config.cluster, + task=self._task_arn, + reason="Stopped by AWSContainer.stop()", + ) + + def wait( + self, + *, + timeout: Optional[float] = None, + condition: Literal["not-running", "next-exit", "removed"] = None, + ) -> Dict: + start = time.time() + poll_interval = 1 + + while True: + task = self._describe_task() + status = task["lastStatus"] + + is_done = ( + status == "STOPPED" + if condition in (None, "not-running", "next-exit", "removed") + else False + ) + + if is_done: + containers = task.get("containers", []) + exit_code = containers[0].get("exitCode", 0) if containers else 0 + return {"StatusCode": exit_code} + + if timeout is not None and (time.time() - start) > timeout: + raise TimeoutError( + f"Task {self._task_arn} did not stop within {timeout}s" + ) + + time.sleep(poll_interval) + + def exec(self, command) -> AWSContainerExec: + exec_id = str(uuid.uuid4()).replace("-", "")[:8] + pid_file = f"/tmp/{exec_id}.pid" + # shlex.quote handles all special characters safely + inner = f"{command} & echo $! > {pid_file}" + wrapped = f"bash -c {shlex.quote(inner)}" + exec_resp = self._ecs_client.execute_command( + cluster=self._ecs_config.cluster, + task=self._task_arn, + container=self._container_name, + command=wrapped, + interactive=True, + ) + return AWSContainerExec(session=exec_resp["session"], exec_id=exec_id) + + def exec_kill(self, exec_id: str): + self._ecs_client.execute_command( + cluster=self._ecs_config.cluster, + task=self._task_arn, + container=self._container_name, + command=f'bash -c "kill $(cat /tmp/{exec_id}.pid)"', + interactive=False, + ) + + def remove(self, *, v: bool = False, link: bool = False, force: bool = False): + if force: + try: + self.stop() + except Exception: + pass + + def is_running(self) -> bool: + try: + return self._get_status() == "RUNNING" + except Exception: + return False + + def is_alive(self) -> bool: + try: + return self._get_status() not in ("STOPPED", "DEPROVISIONING") + except Exception: + return False + + def stream_container_logs(self) -> Generator[str, None, None]: + yield from stream_task_logs( + self._ecs_client, + self._ecs_config.region, + self._ecs_config.cluster, + self._task_arn, + self._task_definition_arn, + self._container_name, + self._logs_token, + ) + + def get_exit_code(self) -> Optional[int]: + task = self._describe_task() + if task["lastStatus"] != "STOPPED": + return None + containers = task.get("containers", []) + if not containers: + return None + return containers[0].get("exitCode") + + +class AWSContainerRunner(BaseContainerRunner): + def __init__(self): + aws_config_path = os.environ.get( + "AWS_CONFIG_PATH", Path(__file__).parent / "aws_config.json" + ) + with open(aws_config_path, "r") as f: + aws_config = json.load(f) + self.ecs_config = ECSConfig( + region=aws_config.get("region", "us-east-1"), + cluster=aws_config["cluster"], + capacity_provider=aws_config["capacity_provider"], + task_definition=aws_config["task_definition"], + ecr_host=aws_config["ecr_host"], + mirroring_image_task_definition=aws_config[ + "mirroring_image_task_definition" + ], + ) + + def prepare_image(self, image): + mirror_image_to_ecr(image, self.ecs_config) + + def spawn_container( + self, + image, + *, + runtime: str = None, # not used + entrypoint: List = None, + detach: bool = True, # not used + name: str = None, # not used + remove: bool = False, # not used + volumes: Dict = None, # not used + environment: Dict = None, + labels: Dict = None, + shm_size: int = None, + stdin_open: bool = False, # not used + tty: bool = False, # not used + cpu_limit: int = None, + mem_limit: Union[str, int] = None, + memswap_limit: int = None, # not used + network: str = None, # not used + ipc_mode: str = None, + security_opt: List[str] = None, # not used + ) -> AWSContainer: + memory = parse_mem_limit_to_bytes(mem_limit) + shm_size = parse_mem_limit_to_bytes(shm_size) + if ipc_mode == "": + ipc_mode = None + task_arn, container_name, task_definition_arn = run_container_ec2( + docker_image_name=image, + entrypoint=entrypoint, + command=None, + ecs_config=self.ecs_config, + env_vars=environment, + cpu=cpu_limit, + memory=memory, + tags=[{"key": k, "value": v} for k, v in (labels or {}).items()], + shm_size=shm_size, + ipc_mode=ipc_mode, + ) + return AWSContainer( + task_arn, container_name, task_definition_arn, self.ecs_config + ) diff --git a/agent/worker/container_runner/aws_utils.py b/agent/worker/container_runner/aws_utils.py new file mode 100644 index 0000000..d1dac74 --- /dev/null +++ b/agent/worker/container_runner/aws_utils.py @@ -0,0 +1,560 @@ +import boto3 +import os +import time +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + + +@dataclass +class ECSConfig: + cluster: str + capacity_provider: str + task_definition: str + ecr_host: str + mirroring_image_task_definition: str + region: str = "us-east-1" + + +def get_boto3_client(service: str, region: str): + """Helper to create boto3 clients with consistent credentials.""" + return boto3.client( + service, + region_name=region, + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + ) + + +def get_default_network_config(ec2_client, assign_public_ip: bool = True) -> dict: + vpcs = ec2_client.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}]) + vpc_id = vpcs["Vpcs"][0]["VpcId"] + + subnets = ec2_client.describe_subnets( + Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] + ) + subnet_ids = [s["SubnetId"] for s in subnets["Subnets"]] + + sgs = ec2_client.describe_security_groups( + Filters=[ + {"Name": "vpc-id", "Values": [vpc_id]}, + {"Name": "group-name", "Values": ["default"]}, + ] + ) + sg_id = sgs["SecurityGroups"][0]["GroupId"] + + config = { + "awsvpcConfiguration": { + "subnets": subnet_ids, + "securityGroups": [sg_id], + } + } + + if assign_public_ip: + config["awsvpcConfiguration"]["assignPublicIp"] = "ENABLED" + + return config + + +def _parse_image_to_ecr_path( + docker_image_name: str, ecr_host: str +) -> tuple[str, str, str]: + """ + Parse a docker image name into ECR components. + + supervisely/base-py-sdk-light:6.73.527 + -> repository_name: supervisely/base-py-sdk-light + -> image_tag: 6.73.527 + -> target_image: {ecr_host}/supervisely/base-py-sdk-light:6.73.527 + + Returns (repository_name, image_tag, target_image) + """ + # Strip any existing registry prefix (anything before the first slash that contains a dot or colon) + parts = docker_image_name.split("/") + if "." in parts[0] or ":" in parts[0]: + image_path = "/".join(parts[1:]) + else: + image_path = docker_image_name + + if ":" in image_path: + repository_name, image_tag = image_path.rsplit(":", 1) + else: + repository_name = image_path + image_tag = "latest" + + target_image = f"{ecr_host}/{repository_name}:{image_tag}" + return repository_name, image_tag, target_image + + +def _ensure_ecr_repository(ecr_client, repository_name: str): + """Create ECR repository if it doesn't exist. Handles nested names like org/repo.""" + try: + ecr_client.create_repository(repositoryName=repository_name) + print(f"Created ECR repository: {repository_name}") + except ecr_client.exceptions.RepositoryAlreadyExistsException: + pass + + +def _image_exists_in_ecr(ecr_client, repository_name: str, image_tag: str) -> bool: + try: + ecr_client.describe_images( + repositoryName=repository_name, imageIds=[{"imageTag": image_tag}] + ) + return True + except ecr_client.exceptions.ImageNotFoundException: + return False + except ecr_client.exceptions.RepositoryNotFoundException: + return False + + +def _create_task_definition_revision( + ecs_client, + base_task_definition: str, + new_image: str = None, + entrypoint: list[str] = None, + cpu: int = None, + memory: int = None, + gpu: int = None, + shm_size: int = None, + ipc_mode: str = None, +) -> tuple[str, str]: + task_def_response = ecs_client.describe_task_definition( + taskDefinition=base_task_definition + ) + task_def = task_def_response["taskDefinition"] + + container_name = task_def["containerDefinitions"][0]["name"] + + container_definitions = [] + for container in task_def["containerDefinitions"]: + container_copy = container.copy() + if new_image is not None: + container_copy["image"] = new_image + if entrypoint is not None: + container_copy["entryPoint"] = entrypoint + if cpu is not None: + container_copy["cpu"] = cpu + if memory is not None: + container_copy["memory"] = memory + if gpu is not None: + container_copy["resourceRequirements"] = [ + {"type": "GPU", "value": str(gpu)} + ] + if shm_size is not None: + linux_params = container_copy.get("linuxParameters", {}) + linux_params["sharedMemorySize"] = shm_size + container_copy["linuxParameters"] = linux_params + container_definitions.append(container_copy) + + register_params = { + "family": task_def["family"], + "containerDefinitions": container_definitions, + } + + if ipc_mode is not None: + register_params["ipcMode"] = ipc_mode + + optional_fields = [ + "taskRoleArn", + "executionRoleArn", + "networkMode", + "volumes", + "placementConstraints", + "requiresCompatibilities", + "cpu", + "memory", + "tags", + "pidMode", + "ipcMode", + "proxyConfiguration", + "inferenceAccelerators", + "ephemeralStorage", + "runtimePlatform", + ] + for field in optional_fields: + if field in task_def and field not in register_params: + register_params[field] = task_def[field] + + new_task_def = ecs_client.register_task_definition(**register_params) + arn = new_task_def["taskDefinition"]["taskDefinitionArn"] + print(f"Registered task definition revision: {arn}") + return arn, container_name + + +def _get_task_log_config( + ecs_client, task_definition_arn: str, container_name: str +) -> dict | None: + """ + Extract CloudWatch log configuration for a container from a task definition. + Returns dict with {log_group, log_stream_prefix, region} or None if not configured. + """ + task_def = ecs_client.describe_task_definition(taskDefinition=task_definition_arn) + for container in task_def["taskDefinition"]["containerDefinitions"]: + if container["name"] == container_name: + log_config = container.get("logConfiguration", {}) + if log_config.get("logDriver") == "awslogs": + options = log_config.get("options", {}) + return { + "log_group": options.get("awslogs-group"), + "log_stream_prefix": options.get("awslogs-stream-prefix", "ecs"), + "region": options.get("awslogs-region"), + } + return None + + +def _stream_task_logs( + logs_client, log_group: str, log_stream: str, next_token: str = None +) -> str | None: + """Stream new log events from a CloudWatch log stream since the last token.""" + while True: + kwargs = { + "logGroupName": log_group, + "logStreamName": log_stream, + "startFromHead": True, + } + if next_token: + kwargs["nextToken"] = next_token + del kwargs["startFromHead"] # mutually exclusive with nextToken + + try: + response = logs_client.get_log_events(**kwargs) + except logs_client.exceptions.ResourceNotFoundException: + break + + events = response.get("events", []) + for event in events: + print(event["message"]) + + new_token = response.get("nextForwardToken") + if new_token == next_token or not events: + return new_token + next_token = new_token + + +def _wait_for_task_and_logs( + ecs_client, + region: str, + cluster: str, + task_arn: str, + task_definition_arn: str, + container_name: str, + poll_interval: int = 1, +): + """Wait for ECS task to finish, streaming CloudWatch logs when available.""" + logs_client = get_boto3_client("logs", region) + log_config = _get_task_log_config(ecs_client, task_definition_arn, container_name) + + task_id = task_arn.split("/")[-1] + next_log_token = None + + if log_config: + log_stream = f"{log_config['log_stream_prefix']}/{container_name}/{task_id}" + print(f"Streaming logs from {log_config['log_group']}/{log_stream}") + + while True: + response = ecs_client.describe_tasks(cluster=cluster, tasks=[task_arn]) + task = response["tasks"][0] + status = task["lastStatus"] + print(f"Task status: {status}") + + if log_config: + next_log_token = _stream_task_logs( + logs_client, + log_config["log_group"], + log_stream, + next_token=next_log_token, + ) + + if status == "STOPPED": + stopped_reason = task.get("stoppedReason", "") + if stopped_reason: + print(f"Task stopped reason: {stopped_reason}") + + containers = task.get("containers", []) + for container in containers: + if container.get("exitCode", 0) != 0: + raise RuntimeError( + f"Task {task_arn} failed: container '{container['name']}' " + f"exited with code {container['exitCode']}. " + f"Reason: {container.get('reason', 'unknown')}" + ) + + if stopped_reason and not containers: + raise RuntimeError( + f"Task {task_arn} failed before container started: {stopped_reason}" + ) + + return + + time.sleep(poll_interval) + + +def _collect_log_lines( + logs_client, log_group: str, log_stream: str, next_token: str = None +) -> tuple[str | None, list[str]]: + """Fetch new log lines since last token. Returns (next_token, lines).""" + lines = [] + while True: + kwargs = { + "logGroupName": log_group, + "logStreamName": log_stream, + "startFromHead": True, + } + if next_token: + kwargs["nextToken"] = next_token + del kwargs["startFromHead"] + + try: + response = logs_client.get_log_events(**kwargs) + except logs_client.exceptions.ResourceNotFoundException: + return next_token, lines + + events = response.get("events", []) + lines.extend(event["message"] for event in events) + + new_token = response.get("nextForwardToken") + if new_token == next_token or not events: + return new_token, lines + next_token = new_token + + +def stream_task_logs( + ecs_client, + region: str, + cluster: str, + task_arn: str, + task_definition_arn: str, + container_name: str, + poll_interval: int = 1, + next_log_token: str = None, +): + """Yield log lines from a running ECS task until it stops.""" + logs_client = get_boto3_client("logs", region) + log_config = _get_task_log_config(ecs_client, task_definition_arn, container_name) + + if not log_config: + raise RuntimeError("No CloudWatch log configuration found for container") + + task_id = task_arn.split("/")[-1] + log_stream = f"{log_config['log_stream_prefix']}/{container_name}/{task_id}" + + while True: + status = ecs_client.describe_tasks(cluster=cluster, tasks=[task_arn])["tasks"][ + 0 + ]["lastStatus"] + + next_log_token, lines = _collect_log_lines( + logs_client, log_config["log_group"], log_stream, next_log_token + ) + yield from lines + + if status == "STOPPED": + return + + time.sleep(poll_interval) + + +def run_container_fargate( + ecs_config: ECSConfig, + docker_image_name: str = None, + entrypoint: Union[str, List[str]] = None, + command: Union[str, List[str]] = None, + env_vars: dict = None, + cpu: int = None, + memory: int = None, + tags: List[Dict] = None, + wait: bool = False, +) -> str: + """Run a container on Fargate (FARGATE launch type).""" + if isinstance(command, str): + command = command.split() if command else [] + if isinstance(entrypoint, str): + entrypoint = entrypoint.split() if entrypoint else [] + + ecs_client = get_boto3_client("ecs", ecs_config.region) + ec2_client = get_boto3_client("ec2", ecs_config.region) + + task_definition_arn, container_name = _create_task_definition_revision( + ecs_client, + ecs_config.mirroring_image_task_definition, + new_image=docker_image_name, + entrypoint=entrypoint, + cpu=cpu, + memory=memory, + ) + + container_overrides = { + "name": container_name, + "command": command, + "environment": [{"name": k, "value": v} for k, v in (env_vars or {}).items()], + } + container_overrides = { + k: v for k, v in container_overrides.items() if v + } # Remove empty fields + + response = ecs_client.run_task( + cluster=ecs_config.cluster, + taskDefinition=task_definition_arn, + launchType="FARGATE", + networkConfiguration=get_default_network_config( + ec2_client, assign_public_ip=True + ), + overrides={"containerOverrides": [container_overrides]}, + tags=tags or [], + ) + + if not response["tasks"]: + failures = response.get("failures", []) + raise RuntimeError(f"Failed to start Fargate task: {failures}") + + task_arn = response["tasks"][0]["taskArn"] + print(f"Started Fargate task: {task_arn}") + + if wait: + _wait_for_task_and_logs( + ecs_client, + ecs_config.region, + ecs_config.cluster, + task_arn, + task_definition_arn, + container_name, + ) + + return task_arn + + +def mirror_image_to_ecr( + docker_image_name: str, + ecs_config: ECSConfig, +) -> str: + """ + Ensure a Docker image is mirrored to ECR. + Returns the ECR image URI. + """ + ecr_client = get_boto3_client("ecr", ecs_config.region) + repository_name, image_tag, target_image = _parse_image_to_ecr_path( + docker_image_name, ecs_config.ecr_host + ) + + print(f"Checking ECR for {target_image}...") + _ensure_ecr_repository(ecr_client, repository_name) + + if _image_exists_in_ecr(ecr_client, repository_name, image_tag): + print(f"Image already exists in ECR: {target_image}") + return target_image + + print( + f"Image not found in ECR. Launching mirroring task for {docker_image_name}..." + ) + run_container_fargate( + ecs_config=ecs_config, + docker_image_name=None, + entrypoint=None, + command=None, + env_vars={ + "AWS_REGION": ecs_config.region, + "ECR_HOST": ecs_config.ecr_host, + "SOURCE_IMAGE": docker_image_name, + "TARGET_IMAGE": target_image, + "AWS_ACCESS_KEY_ID": os.environ["AWS_ACCESS_KEY_ID"], + "AWS_SECRET_ACCESS_KEY": os.environ["AWS_SECRET_ACCESS_KEY"], + }, + wait=True, + ) + + print(f"Mirroring complete: {target_image}") + return target_image + + +def run_container_ec2( + docker_image_name: str, + entrypoint: Union[str, List[str]], + command: Union[str, List[str]], + ecs_config: ECSConfig, + env_vars: dict = None, + cpu: int = None, + memory: int = None, + gpu: int = None, + shm_size: int = None, + ipc_mode: str = None, + tags: List[Dict] = None, + wait: bool = False, +) -> Tuple[str, str, str]: + """Run a container using the EC2 capacity provider.""" + if isinstance(command, str): + command = [command] + if isinstance(entrypoint, str): + entrypoint = entrypoint.split() if entrypoint else [] + + ecs_client = get_boto3_client("ecs", ecs_config.region) + ec2_client = get_boto3_client("ec2", ecs_config.region) + + task_definition_arn, container_name = _create_task_definition_revision( + ecs_client, + ecs_config.task_definition, + docker_image_name, + entrypoint=entrypoint, + cpu=cpu, + memory=memory, + gpu=gpu, + shm_size=shm_size, + ipc_mode=ipc_mode, + ) + + container_overrides = { + "name": container_name, + "command": command, + "environment": [{"name": k, "value": v} for k, v in (env_vars or {}).items()], + } + container_overrides = {k: v for k, v in container_overrides.items() if v} + + response = ecs_client.run_task( + cluster=ecs_config.cluster, + enableExecuteCommand=True, + taskDefinition=task_definition_arn, + capacityProviderStrategy=[ + {"capacityProvider": ecs_config.capacity_provider, "weight": 1} + ], + networkConfiguration=get_default_network_config( + ec2_client, assign_public_ip=False + ), + overrides={"containerOverrides": [container_overrides]}, + tags=tags or [], + ) + + if not response["tasks"]: + failures = response.get("failures", []) + raise RuntimeError(f"Failed to start EC2 task: {failures}") + + task_arn = response["tasks"][0]["taskArn"] + print(f"Started EC2 task: {task_arn}") + + if wait: + _wait_for_task_and_logs( + ecs_client, + ecs_config.region, + ecs_config.cluster, + task_arn, + task_definition_arn, + container_name, + ) + + return task_arn, container_name, task_definition_arn + + +def run( + docker_image_name: str, + entrypoint: str, + command: Union[str, List[str]], + ecs_config: ECSConfig, + env_vars: dict = None, +) -> str: + """Full pipeline: mirror image to ECR, create task def revision, run and wait.""" + ecr_image = mirror_image_to_ecr(docker_image_name, ecs_config) + return run_container_ec2( + docker_image_name=ecr_image, + entrypoint=entrypoint, + command=command, + ecs_config=ecs_config, + env_vars=env_vars, + wait=True, + ) diff --git a/agent/worker/container_runner/container_runner.py b/agent/worker/container_runner/container_runner.py new file mode 100644 index 0000000..4c1c9d2 --- /dev/null +++ b/agent/worker/container_runner/container_runner.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from typing import Dict, Generator, List, Literal, Optional + + +class BaseContainerExec(ABC): + @abstractmethod + def stream_logs(self) -> Generator[str, None, None]: + raise NotImplementedError() + + @abstractmethod + def get_exit_code(self) -> Optional[int]: + raise NotImplementedError() + + +class BaseContainer(ABC): + @abstractmethod + def stop(self, *, timeout: Optional[float] = None): + raise NotImplementedError() + + @abstractmethod + def wait( + self, + *, + timeout: Optional[float] = None, + condition: Literal["not-running", "next-exit", "removed"] = None, + ) -> Dict: + raise NotImplementedError() + + @abstractmethod + def remove(self, *, v: bool = False, link: bool = False, force: bool = False): + raise NotImplementedError() + + @abstractmethod + def is_running(self) -> bool: + raise NotImplementedError() + + @abstractmethod + def is_alive(self) -> bool: + raise NotImplementedError() + + @abstractmethod + def exec(self, command) -> BaseContainerExec: + raise NotImplementedError() + + @abstractmethod + def exec_kill(self, exec_id: str): + raise NotImplementedError() + + @abstractmethod + def get_exit_code(self) -> Optional[int]: + raise NotImplementedError() + + +class BaseContainerRunner(ABC): + @abstractmethod + def prepare_image(self, image: str): + raise NotImplementedError() + + @abstractmethod + def spawn_container( + self, + image, + *, + runtime: str = None, + entrypoint: List = None, + detach: bool = True, + name: str = None, + remove: bool = False, + volumes: Dict = None, + environment: Dict = None, + labels: Dict = None, + shm_size: int = None, + stdin_open: bool = False, + tty: bool = False, + cpu_limit: int = None, + mem_limit: int = None, + memswap_limit: int = None, + network: str = None, + ipc_mode: str = None, + security_opt: List[str] = None, + ) -> BaseContainer: + raise NotImplementedError() diff --git a/agent/worker/container_runner/local.py b/agent/worker/container_runner/local.py new file mode 100644 index 0000000..ac64bca --- /dev/null +++ b/agent/worker/container_runner/local.py @@ -0,0 +1,182 @@ +from logging import Logger +from typing import Callable, Dict, Generator, List, Literal, Optional +import docker +from worker import constants, docker_utils +from worker.task_dockerized import ErrorReport +from worker.container_runner.container_runner import ( + BaseContainerExec, + BaseContainerRunner, + BaseContainer, +) +from docker.models.containers import Container + +from worker.agent_utils import ( + TaskDirCleaner, + filter_log_line, + pip_req_satisfied_filter, + post_get_request_filter, + convert_millicores_to_cpu_quota, +) + +import supervisely as sly + +from docker.errors import APIError, NotFound, DockerException + + +class LocalContainerExec(BaseContainerExec): + def __init__(self, docker_client: docker.DockerClient, exec_id: str): + self._docker_client = docker_client + self._exec_id = exec_id + + def stream_logs(self) -> Generator[str, None, None]: + for log_line in self._docker_client.api.exec_start(self._exec_id, stream=True): + yield log_line.decode("utf-8") + + def get_exit_code(self) -> int: + exec_info = self._docker_client.api.exec_inspect(self._exec_id) + exit_code = exec_info["ExitCode"] + return exit_code + + +class LocalContainer(BaseContainer): + def __init__(self, container: Container, docker_client: docker.DockerClient): + self._container = container + self._docker_client = docker_client + + def stop(self, *, timeout: Optional[float] = None): + self._container.stop(timeout=timeout) + + def wait( + self, + *, + timeout: Optional[float] = None, + condition: Literal["not-running", "next-exit", "removed"] = None, + ) -> Dict: + result = self._container.wait(timeout=timeout, condition=condition) + return result + + def remove(self, *, v: bool = False, link: bool = False, force: bool = False): + return self._container.remove(v=v, link=link, force=force) + + def is_running(self) -> bool: + if self._container is None: + return False + try: + self._container.reload() + return self._container.status == "running" + except NotFound: + return False + + def is_alive(self): + return self.is_running() + + def exec(self, command) -> LocalContainerExec: + exec_id = self._docker_client.api.exec_create( + self._container.id, + cmd=command, + ) + return LocalContainerExec(self._docker_client, exec_id) + + def exec_kill(self, exec_id: str): + exec_info = self._docker_client.api.exec_inspect(exec_id) + if exec_info["Running"] == True: + pid = exec_info["Pid"] + self._container.exec_run(cmd="kill {}".format(pid)) + else: + return + + def get_exit_code(self): + self._container.reload() + return self._container.attrs["State"]["ExitCode"] + + +class LocalContainerRunner(BaseContainerRunner): + def __init__(self, docker_client: docker.DockerClient, logger: Logger): + self.docker_client = docker_client + self.logger = logger + + self._container: LocalContainer = None + + def prepare_image(self, image): + docker_utils.docker_pull_if_needed( + self.docker_client, + image, + constants.PULL_POLICY(), + self.logger, + ) + + # self.sync_pip_cache() + + def spawn_container( + self, + image, + *, + runtime: str = None, + entrypoint: List = None, + detach: bool = True, + name: str = None, + remove: bool = False, + volumes: Dict = None, + environment: Dict = None, + labels: Dict = None, + shm_size: int = None, + stdin_open: bool = False, + tty: bool = False, + cpu_limit: int = None, + mem_limit: int = None, + memswap_limit: int = None, + network: str = None, + ipc_mode: str = None, + security_opt: List[str] = None, + ) -> LocalContainer: + print("Spawning container with image: {}".format(image)) + if cpu_limit is None: + cpu_limit = constants.CPU_LIMIT() + if cpu_limit is not None: + cpu_quota = convert_millicores_to_cpu_quota(cpu_limit) + else: + cpu_quota = None + container = self.docker_client.containers.run( + image, + runtime=runtime, + entrypoint=entrypoint, + detach=detach, + name=name, + remove=remove, # TODO: check autoremove + volumes=volumes, + environment=environment, + labels=labels, + shm_size=shm_size, + stdin_open=stdin_open, + tty=tty, + cpu_quota=cpu_quota, + mem_limit=mem_limit, + memswap_limit=memswap_limit, + network=network, + ipc_mode=ipc_mode, + security_opt=security_opt, + ) + container.reload() + self._container = LocalContainer(container, self.docker_client) + self.logger.debug( + "After spawning. Container status: {}".format(str(container.status)) + ) + self.logger.info( + "Docker container is spawned", + extra={ + "container_id": container.id, + "container_name": container.name, + }, + ) + return self._container + + def exec(self, command, environment=None): + self._exec_id = self.docker_client.api.exec_create( + self._container._container.id, + cmd=command, + environment=environment, + ) + + def stream_logs(self) -> Generator[str, None, None]: + for log_line in self._container._container.logs(stream=True): + yield log_line.decode("utf-8") diff --git a/agent/worker/task_app.py b/agent/worker/task_app.py index 0f2a6b1..4295498 100644 --- a/agent/worker/task_app.py +++ b/agent/worker/task_app.py @@ -12,6 +12,9 @@ import copy from docker.errors import APIError, NotFound, DockerException +from worker.container_runner.aws_runner import AWSContainerRunner +from worker.container_runner.container_runner import BaseContainerExec +from worker.container_runner.local import LocalContainer, LocalContainerRunner from slugify import slugify from pathlib import Path from packaging import version @@ -101,7 +104,7 @@ def __init__(self, *args, **kwargs): self.dir_task_src_container = None self.dir_apps_cache_host = None self.dir_apps_cache_container = None - self._exec_id = None + self.exec: BaseContainerExec = None self.app_info = None self._path_cache_host = None self._need_sync_pip_cache = False @@ -466,6 +469,9 @@ def get_requirements_path(self): return requirements_path def sync_pip_cache(self): + if not isinstance(self._container, LocalContainer): + return + version = self.app_info.get("version", "master") module_id = self.app_info.get("moduleId") @@ -495,10 +501,10 @@ def sync_pip_cache(self): "app_session_id": str(self.info["task_id"]), }, ) - self.install_pip_requirements(container_id=self._container.id) + self.install_pip_requirements() # @TODO: handle 404 not found - bits, stat = self._container.get_archive(_LINUX_DEFAULT_PIP_CACHE_DIR) + bits, stat = self._container._container.get_archive(_LINUX_DEFAULT_PIP_CACHE_DIR) self.logger.info( "Download initial pip cache from dockerimage", extra={ @@ -525,14 +531,15 @@ def sync_pip_cache(self): @handle_exceptions def find_or_run_container(self): add_labels = {"sly_app": "1", "app_session_id": str(self.info["task_id"])} - docker_utils.docker_pull_if_needed( - self._docker_api, - self.docker_image_name, - constants.PULL_POLICY(), - self.logger, - ) + if os.environ.get("IS_AWS", "false").lower() == "true": + self.logger.info("AWS environment detected, using AWSContainerRunner") + self.container_runner = AWSContainerRunner() + else: + self.logger.info("Using LocalContainerRunner") + self.container_runner = LocalContainerRunner(self._docker_api, self.logger) + + self.container_runner.prepare_image(self.docker_image_name) - self.sync_pip_cache() if self._container is None: try: self.spawn_container(add_envs=self.main_step_envs(), add_labels=add_labels) @@ -543,7 +550,7 @@ def find_or_run_container(self): orig_runtime = self.docker_runtime if ( - is_runtime_err + is_runtime_err and (self.docker_runtime == "nvidia") and (self._gpu_config is GPUFlag.preferred) ): @@ -580,9 +587,9 @@ def find_or_run_container(self): ) raise api_ex - if constants.OFFLINE_MODE() is False: + if constants.OFFLINE_MODE() is False and isinstance(self._container, LocalContainer): self.logger.info("Double check pip cache for old agents") - self.install_pip_requirements(container_id=self._container.id) + self.install_pip_requirements() self.logger.info("pip second install for old agents is finished") def get_spawn_entrypoint(self): @@ -595,42 +602,11 @@ def get_spawn_entrypoint(self): entrypoint = ["/usr/bin/timeout", "--kill-after", "30s", f"{timeout}s"] + entrypoint return entrypoint - def _exec_command(self, command, add_envs=None, container_id=None): - SINGLE_ENV_VAR_LIMIT = 65536 - add_envs = sly.take_with_default(add_envs, {}) - all_envs = { - "LOG_LEVEL": "DEBUG", - "LANG": "C.UTF-8", - "PYTHONUNBUFFERED": "1", - constants._HTTP_PROXY: constants.HTTP_PROXY(), - constants._HTTPS_PROXY: constants.HTTPS_PROXY(), - constants._NO_PROXY: constants.NO_PROXY(), - "HOST_TASK_DIR": self.dir_task_host, - "TASK_ID": self.info["task_id"], - "SERVER_ADDRESS": self.info["server_address"], - "API_TOKEN": self.info["api_token"], - "AGENT_TOKEN": constants.TOKEN(), - "PIP_ROOT_USER_ACTION": "ignore", - **add_envs, - } - oversized_envs = {} - for key, value in all_envs.items(): - value_bytes = len(str(value).encode("utf-8")) - if value_bytes > SINGLE_ENV_VAR_LIMIT: - oversized_envs[key] = value_bytes - if oversized_envs: - self.logger.warning("Oversized environment variables found!", extra={"envs": oversized_envs}) - for key in oversized_envs.keys(): - all_envs[key] = sly.LARGE_ENV_PLACEHOLDER - self._exec_id = self._docker_api.api.exec_create( - self._container.id if container_id is None else container_id, - cmd=command, - environment=all_envs, - ) - self._logs_output = self._docker_api.api.exec_start(self._exec_id, stream=True) + def _exec_command(self, command): + self.exec = self._container.exec(command=command) + self._logs_output = self.exec.stream_logs() - def exec_command(self, add_envs=None, command=None): - add_envs = sly.take_with_default(add_envs, {}) + def exec_command(self, command=None): main_script_path = os.path.join( self.dir_task_src_container, self.app_config.get("main_script", "src/main.py"), @@ -644,14 +620,14 @@ def exec_command(self, add_envs=None, command=None): f'bash -c "cd {self.dir_task_src_container} && {self.app_config["entrypoint"]}"' ) self.logger.info("command to run", extra={"command": command}) - self._exec_command(command, add_envs) + self._exec_command(command) # change pulling progress to app progress progress_dummy = sly.Progress("Application is started ...", 1, ext_logger=self.logger) progress_dummy.iter_done_report() self.logger.info("command is running", extra={"command": command}) - def install_pip_requirements(self, container_id=None): + def install_pip_requirements(self): if self._need_sync_pip_cache is True: self.logger.info("Installing app requirements") progress_dummy = sly.Progress( @@ -661,7 +637,7 @@ def install_pip_requirements(self, container_id=None): command = "pip3 install --disable-pip-version-check --upgrade setuptools==69.0.0" self.logger.info(f"PIP command: {command}") - self._exec_command(command, add_envs=self.main_step_envs(), container_id=container_id) + self._exec_command(command) self.process_logs() # --root-user-action=ignore @@ -669,10 +645,10 @@ def install_pip_requirements(self, container_id=None): self.dir_task_src_container, self._requirements_path_relative ) self.logger.info(f"PIP command: {command}") - self._exec_command(command, add_envs=self.main_step_envs(), container_id=container_id) + self._exec_command(command) self.process_logs() - pip_install_exec_info = self._docker_api.api.exec_inspect(self._exec_id) + pip_install_exec_info = self.exec.get_exit_code() if pip_install_exec_info["ExitCode"] != 0: raise RuntimeError("Pip install failed") @@ -683,11 +659,7 @@ def is_container_alive(self): if self._container is None: return False - try: - self._container.reload() - return self._container.status == "running" - except NotFound: - return False + return self._container.is_alive() def main_step(self): api = Api(self.info["server_address"], self.info["api_token"]) @@ -707,7 +679,7 @@ def main_step(self): self.find_or_run_container() if self.is_container_alive(): - self.exec_command(add_envs=self.main_step_envs()) + self.exec_command() logs_cnt = self.process_logs() if logs_cnt == 0: @@ -843,7 +815,9 @@ def _decode(bytes: bytes): # @TODO: parse multiline logs correctly (including exceptions) for log_line_arr in self._logs_output: - for log_part in _decode(log_line_arr).splitlines(): + if not isinstance(log_line_arr, str): + log_line_arr = _decode(log_line_arr) + for log_part in log_line_arr.splitlines(): yield log_part def process_logs(self, logs_arr=None): @@ -883,17 +857,13 @@ def _stop_wait_container(self): return self.exec_stop() def exec_stop(self): - exec_info = self._docker_api.api.exec_inspect(self._exec_id) - if exec_info["Running"] == True: - pid = exec_info["Pid"] - self._container.exec_run(cmd="kill {}".format(pid)) - else: - return + self._container.exec_kill(self.exec.exec_id) def get_exit_status(self): - exec_info = self._docker_api.api.exec_inspect(self._exec_id) - exit_code = exec_info["ExitCode"] - return exit_code + if self.exec is not None: + return self.exec.get_exit_code() + if self._container is not None: + return self._container.get_exit_code() def _drop_container(self): if self.is_isolate(): @@ -902,7 +872,6 @@ def _drop_container(self): self.exec_stop() def drop_container_and_check_status(self): - self._container.reload() status = self.get_exit_status() if self.is_isolate(): diff --git a/agent/worker/task_dockerized.py b/agent/worker/task_dockerized.py index 58f6eea..53f2b4d 100644 --- a/agent/worker/task_dockerized.py +++ b/agent/worker/task_dockerized.py @@ -23,6 +23,7 @@ ) from worker.task_sly import TaskSly from docker.models.containers import Container +from worker.container_runner.container_runner import BaseContainerRunner, BaseContainer class TaskStep(Enum): NOTHING = 0 @@ -60,7 +61,8 @@ def __init__(self, *args, **kwargs): self._docker_api: docker.DockerClient = None # must be set by someone - self._container: Container = None + self.container_runner: BaseContainerRunner = None + self._container: BaseContainer = None self._container_lock = Lock() # to drop container from different threads self.docker_image_name = None @@ -214,6 +216,10 @@ def spawn_container(self, add_envs=None, add_labels=None, entrypoint_func=None): constants._HTTPS_PROXY.lower(): constants.HTTPS_PROXY(), constants._NO_PROXY.lower(): constants.NO_PROXY(), "PIP_ROOT_USER_ACTION": "ignore", + "TASK_ID": self.info["task_id"], + "SERVER_ADDRESS": self.info.get("server_address"), + "API_TOKEN": self.info.get("api_token"), + "AGENT_TOKEN": constants.TOKEN(), **add_envs, } if constants.SSL_CERT_FILE() is not None: @@ -228,12 +234,7 @@ def spawn_container(self, add_envs=None, add_labels=None, entrypoint_func=None): self.info["task_id"], constants.TASKS_DOCKER_LABEL() ) - cpu_quota = self.info.get("limits", {}).get("cpu", None) - if cpu_quota is None: - cpu_quota = constants.CPU_LIMIT() - if cpu_quota is not None: - cpu_quota = convert_millicores_to_cpu_quota(cpu_quota) - + cpu_limit = self.info.get("limits", {}).get("cpu", None) mem_limit = self.info.get("limits", {}).get("memory", None) if mem_limit is None: mem_limit = constants.MEM_LIMIT() @@ -248,7 +249,8 @@ def spawn_container(self, add_envs=None, add_labels=None, entrypoint_func=None): self.logger.warning("Oversized environment variables found. Such envs would be removed!", extra={"envs": oversized_envs}) for key in oversized_envs.keys(): all_environments[key] = sly.LARGE_ENV_PLACEHOLDER - self._container = self._docker_api.containers.run( + + self._container = self.container_runner.spawn_container( self.docker_image_name, runtime=self.docker_runtime, entrypoint=entrypoint_func(), @@ -266,24 +268,13 @@ def spawn_container(self, add_envs=None, add_labels=None, entrypoint_func=None): shm_size=constants.SHM_SIZE(), stdin_open=False, tty=False, - cpu_quota=cpu_quota, + cpu_limit=cpu_limit, mem_limit=mem_limit, memswap_limit=mem_limit, network=constants.DOCKER_NET(), ipc_mode=ipc_mode, security_opt=constants.SECURITY_OPT(), ) - self._container.reload() - self.logger.debug( - "After spawning. Container status: {}".format(str(self._container.status)) - ) - self.logger.info( - "Docker container is spawned", - extra={ - "container_id": self._container.id, - "container_name": self._container.name, - }, - ) finally: self._container_lock.release() @@ -342,9 +333,8 @@ def call_event_function(self, jlog): def process_logs(self): logs_found = False - for log_line in self._container.logs(stream=True): + for log_line in self.container_runner.stream_logs(): logs_found = True - log_line = log_line.decode("utf-8") msg, res_log, lvl = self.parse_log_line(log_line) output = self.call_event_function(res_log) self._process_report(msg) diff --git a/requirements.txt b/requirements.txt index c7e0778..6af2ab3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,8 @@ python-slugify==6.1.2 nvidia-ml-py==12.535.77 httpx>=0.26.0 filelock==3.13.1 +boto3>=1.42.47 +construct # grpcio installed from system packages (python3-grpcio) # grpcio-tools removed due to protobuf version conflict with supervisely[agent] From f1309d9ac11d6f223613811fd56d4dfca7466b25 Mon Sep 17 00:00:00 2001 From: Nikolai Petukhov Date: Mon, 2 Mar 2026 08:43:02 -0300 Subject: [PATCH 2/3] wip --- agent/worker/container_runner/aws_runner.py | 33 ++++++++++++++++++--- agent/worker/container_runner/aws_utils.py | 7 ----- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/agent/worker/container_runner/aws_runner.py b/agent/worker/container_runner/aws_runner.py index b81b4ad..11a25b0 100644 --- a/agent/worker/container_runner/aws_runner.py +++ b/agent/worker/container_runner/aws_runner.py @@ -178,12 +178,38 @@ def wait( time.sleep(poll_interval) + def wait_for_execute(self, timeout: float = 300, poll_interval: float = 2): + start = time.time() + while True: + task = self._describe_task() + status = task["lastStatus"] + + if status == "STOPPED": + stopped_reason = task.get("stoppedReason", "unknown") + raise RuntimeError(f"Task stopped before reaching RUNNING state: {stopped_reason}") + if time.time() - start > timeout: + raise TimeoutError(f"Task did not reach RUNNING state within {timeout}s (last status: {status})") + + if status == "RUNNING": + containers = task.get("containers", []) + for container in containers: + if container["name"] == self._container_name: + managed_agents = container.get("managedAgents", []) + for agent in managed_agents: + if agent["name"] == "ExecuteCommandAgent" and agent["lastStatus"] == "RUNNING": + return + + time.sleep(poll_interval) + def exec(self, command) -> AWSContainerExec: exec_id = str(uuid.uuid4()).replace("-", "")[:8] pid_file = f"/tmp/{exec_id}.pid" - # shlex.quote handles all special characters safely inner = f"{command} & echo $! > {pid_file}" wrapped = f"bash -c {shlex.quote(inner)}" + + print("Waiting for container to be ready for exec...") + self.wait_for_execute() + exec_resp = self._ecs_client.execute_command( cluster=self._ecs_config.cluster, task=self._task_arn, @@ -275,7 +301,7 @@ def spawn_container( volumes: Dict = None, # not used environment: Dict = None, labels: Dict = None, - shm_size: int = None, + shm_size: int = None, # not used stdin_open: bool = False, # not used tty: bool = False, # not used cpu_limit: int = None, @@ -286,9 +312,9 @@ def spawn_container( security_opt: List[str] = None, # not used ) -> AWSContainer: memory = parse_mem_limit_to_bytes(mem_limit) - shm_size = parse_mem_limit_to_bytes(shm_size) if ipc_mode == "": ipc_mode = None + environment = {k: str(v) for k, v in (environment or {}).items()} task_arn, container_name, task_definition_arn = run_container_ec2( docker_image_name=image, entrypoint=entrypoint, @@ -298,7 +324,6 @@ def spawn_container( cpu=cpu_limit, memory=memory, tags=[{"key": k, "value": v} for k, v in (labels or {}).items()], - shm_size=shm_size, ipc_mode=ipc_mode, ) return AWSContainer( diff --git a/agent/worker/container_runner/aws_utils.py b/agent/worker/container_runner/aws_utils.py index d1dac74..218ac95 100644 --- a/agent/worker/container_runner/aws_utils.py +++ b/agent/worker/container_runner/aws_utils.py @@ -114,7 +114,6 @@ def _create_task_definition_revision( cpu: int = None, memory: int = None, gpu: int = None, - shm_size: int = None, ipc_mode: str = None, ) -> tuple[str, str]: task_def_response = ecs_client.describe_task_definition( @@ -139,10 +138,6 @@ def _create_task_definition_revision( container_copy["resourceRequirements"] = [ {"type": "GPU", "value": str(gpu)} ] - if shm_size is not None: - linux_params = container_copy.get("linuxParameters", {}) - linux_params["sharedMemorySize"] = shm_size - container_copy["linuxParameters"] = linux_params container_definitions.append(container_copy) register_params = { @@ -474,7 +469,6 @@ def run_container_ec2( cpu: int = None, memory: int = None, gpu: int = None, - shm_size: int = None, ipc_mode: str = None, tags: List[Dict] = None, wait: bool = False, @@ -496,7 +490,6 @@ def run_container_ec2( cpu=cpu, memory=memory, gpu=gpu, - shm_size=shm_size, ipc_mode=ipc_mode, ) From 85577c7562e3269294596283152d47046fe2a072 Mon Sep 17 00:00:00 2001 From: Nikolai Petukhov Date: Mon, 2 Mar 2026 16:13:36 -0300 Subject: [PATCH 3/3] add readme, docstring and demo --- agent/worker/container_runner/aws_runner.py | 308 +++++++++++++++++-- agent/worker/container_runner/aws_utils.py | 319 ++++++++++++++++++-- agent/worker/container_runner/demo.py | 136 +++++++++ agent/worker/container_runner/local.py | 24 +- agent/worker/container_runner/readme.md | 63 ++++ 5 files changed, 790 insertions(+), 60 deletions(-) create mode 100644 agent/worker/container_runner/demo.py create mode 100644 agent/worker/container_runner/readme.md diff --git a/agent/worker/container_runner/aws_runner.py b/agent/worker/container_runner/aws_runner.py index 11a25b0..004fff1 100644 --- a/agent/worker/container_runner/aws_runner.py +++ b/agent/worker/container_runner/aws_runner.py @@ -1,31 +1,52 @@ import json -from pathlib import Path +import os import shlex import time -from worker.container_runner.container_runner import ( - BaseContainer, - BaseContainerRunner, -) -import os +import uuid +from pathlib import Path from typing import Dict, Generator, List, Literal, Optional, Union +import construct as c +import websocket from worker.container_runner.aws_utils import ( + ECSConfig, + get_boto3_client, mirror_image_to_ecr, run_container_ec2, - get_boto3_client, stream_task_logs, - ECSConfig, ) +from worker.container_runner.container_runner import BaseContainer, BaseContainerRunner -import json -import uuid +def parse_mem_limit_to_bytes(mem_limit) -> int: + """Convert a memory limit value to an integer number of bytes. -import construct as c -import websocket + Accepts the same formats used by Docker's ``--memory`` flag. + Args: + mem_limit: Memory limit as one of: -def parse_mem_limit_to_bytes(mem_limit) -> int: + - ``int`` or ``float`` — treated directly as bytes. + - ``str`` — a numeric string optionally suffixed with a unit: + ``b`` (bytes), ``k`` (kibibytes), ``m`` (mebibytes), or + ``g`` (gibibytes). Case-insensitive. An empty string returns + ``None``. + - ``None`` — returns ``None`` (no limit). + + Returns: + The memory limit in bytes as an ``int``, or ``None`` if no limit was + specified. + + Raises: + ValueError: If ``mem_limit`` is of an unsupported type. + + Examples:: + + parse_mem_limit_to_bytes("512m") # -> 536870912 + parse_mem_limit_to_bytes("2g") # -> 2147483648 + parse_mem_limit_to_bytes(1048576) # -> 1048576 + parse_mem_limit_to_bytes(None) # -> None + """ if isinstance(mem_limit, (int, float)): return int(mem_limit) @@ -42,7 +63,22 @@ def parse_mem_limit_to_bytes(mem_limit) -> int: return None raise ValueError(f"Unsupported mem_limit type: {type(mem_limit)}") + + class AWSContainerExec: + """A handle to a command execution session inside a running ECS container. + + Wraps the SSM WebSocket session returned by ``ecs.execute_command``. Use + :meth:`stream_logs` to iterate over output lines and :meth:`get_exit_code` + to retrieve the process exit code after the stream is exhausted. + + Args: + session: The ``session`` dict from the ECS ``execute_command`` response, + containing ``streamUrl`` and ``tokenValue``. + exec_id: Short identifier for this exec (used to locate the PID file + written by :meth:`AWSContainer.exec`). + """ + def __init__(self, session: dict, exec_id: str): self._session = session self._exec_id = exec_id @@ -50,6 +86,14 @@ def __init__(self, session: dict, exec_id: str): self._connection = self._init_connection() def _init_connection(self): + """Open the SSM WebSocket connection and send the authentication token. + + Returns: + An open ``websocket.WebSocket`` connection. + + Raises: + RuntimeError: If ``session`` is ``None``. + """ if self._session is None: raise RuntimeError("No active exec session found") @@ -63,6 +107,17 @@ def _init_connection(self): return connection def stream_logs(self) -> Generator[str, None, None]: + """Yield output lines from the remote command until it finishes. + + Parses the binary SSM agent framing format, extracts ``output_stream_data`` + frames, and splits their payloads into individual lines. Sets + ``_exit_code`` when an ``exit_code`` frame is received. Closes the + WebSocket connection when the stream ends or on any exception. + + Yields: + Individual output lines (strings) produced by the remote command, + in order. + """ AgentMessageHeader = c.Struct( "HeaderLength" / c.Int32ub, "MessageType" / c.PaddedString(32, "ascii"), @@ -98,9 +153,19 @@ def stream_logs(self) -> Generator[str, None, None]: self.close() def get_exit_code(self) -> Optional[int]: + """Return the exit code of the remote command, if available. + + Returns: + The integer exit code set when an ``exit_code`` frame is received, + or ``None`` if the stream has not yet delivered that frame. + """ return self._exit_code def close(self): + """Close the WebSocket connection, ignoring any errors. + + Safe to call multiple times; subsequent calls are no-ops. + """ if self._connection is not None: try: self._connection.close() @@ -113,6 +178,20 @@ def __del__(self): class AWSContainer(BaseContainer): + """A handle to a running ECS task, implementing the ``BaseContainer`` interface. + + Provides lifecycle management (stop, wait, remove), command execution via + ECS Execute Command, log streaming, and status inspection for a single ECS + task. + + Args: + task_arn: ARN of the ECS task this object represents. + container_name: Name of the primary container within the task. + task_definition_arn: ARN of the task definition revision used to launch + the task (needed for log configuration lookup). + ecs_config: ECS/ECR configuration for the cluster and region. + """ + def __init__( self, task_arn: str, @@ -129,6 +208,14 @@ def __init__( self._session = None def _describe_task(self) -> dict: + """Fetch the current ECS task description from the API. + + Returns: + The task description dict as returned by ``ecs.describe_tasks``. + + Raises: + RuntimeError: If the task is not found in the cluster. + """ response = self._ecs_client.describe_tasks( cluster=self._ecs_config.cluster, tasks=[self._task_arn] ) @@ -137,10 +224,23 @@ def _describe_task(self) -> dict: return response["tasks"][0] def _get_status(self) -> str: - """Returns ECS last status: PROVISIONING, PENDING, RUNNING, DEPROVISIONING, STOPPED.""" + """Return the current ECS ``lastStatus`` of the task. + + Returns: + One of ``"PROVISIONING"``, ``"PENDING"``, ``"RUNNING"``, + ``"DEPROVISIONING"``, or ``"STOPPED"``. + """ return self._describe_task()["lastStatus"] def stop(self, *, timeout: Optional[float] = None): + """Send a stop request to the ECS task. + + The task transitions to ``STOPPED`` asynchronously. Use :meth:`wait` + to block until it has fully stopped. + + Args: + timeout: Unused. Present for interface compatibility. + """ self._ecs_client.stop_task( cluster=self._ecs_config.cluster, task=self._task_arn, @@ -153,6 +253,25 @@ def wait( timeout: Optional[float] = None, condition: Literal["not-running", "next-exit", "removed"] = None, ) -> Dict: + """Block until the ECS task reaches ``STOPPED`` status. + + Polls the task status every second until the task stops or the optional + timeout elapses. + + Args: + timeout: Maximum number of seconds to wait. If ``None``, waits + indefinitely. + condition: Accepted values are ``"not-running"``, ``"next-exit"``, + ``"removed"``, or ``None``; all are treated equivalently and + resolve when the task reaches ``STOPPED``. + + Returns: + A dict ``{"StatusCode": exit_code}`` where ``exit_code`` is the + exit code of the first container (``0`` if not available). + + Raises: + TimeoutError: If the task does not stop within ``timeout`` seconds. + """ start = time.time() poll_interval = 1 @@ -179,6 +298,21 @@ def wait( time.sleep(poll_interval) def wait_for_execute(self, timeout: float = 300, poll_interval: float = 2): + """Block until the ECS Execute Command agent is ready inside the container. + + Polls the task description until the ``ExecuteCommandAgent`` managed + agent reports ``RUNNING`` status for the target container. + + Args: + timeout: Maximum seconds to wait before raising. Defaults to + ``300``. + poll_interval: Seconds between status polls. Defaults to ``2``. + + Raises: + RuntimeError: If the task stops before the agent becomes ready. + TimeoutError: If the agent does not become ready within ``timeout`` + seconds. + """ start = time.time() while True: task = self._describe_task() @@ -186,9 +320,13 @@ def wait_for_execute(self, timeout: float = 300, poll_interval: float = 2): if status == "STOPPED": stopped_reason = task.get("stoppedReason", "unknown") - raise RuntimeError(f"Task stopped before reaching RUNNING state: {stopped_reason}") + raise RuntimeError( + f"Task stopped before reaching RUNNING state: {stopped_reason}" + ) if time.time() - start > timeout: - raise TimeoutError(f"Task did not reach RUNNING state within {timeout}s (last status: {status})") + raise TimeoutError( + f"Task did not reach RUNNING state within {timeout}s (last status: {status})" + ) if status == "RUNNING": containers = task.get("containers", []) @@ -196,12 +334,28 @@ def wait_for_execute(self, timeout: float = 300, poll_interval: float = 2): if container["name"] == self._container_name: managed_agents = container.get("managedAgents", []) for agent in managed_agents: - if agent["name"] == "ExecuteCommandAgent" and agent["lastStatus"] == "RUNNING": + if ( + agent["name"] == "ExecuteCommandAgent" + and agent["lastStatus"] == "RUNNING" + ): return time.sleep(poll_interval) def exec(self, command) -> AWSContainerExec: + """Execute a shell command inside the running container. + + Wraps the command so it runs in the background and writes its PID to a + temp file (enabling later cancellation via :meth:`exec_kill`), then + opens an ECS Execute Command interactive session. + + Args: + command: Shell command string to run inside the container. + + Returns: + An :class:`AWSContainerExec` handle whose :meth:`~AWSContainerExec.stream_logs` + method yields the command's output. + """ exec_id = str(uuid.uuid4()).replace("-", "")[:8] pid_file = f"/tmp/{exec_id}.pid" inner = f"{command} & echo $! > {pid_file}" @@ -209,7 +363,7 @@ def exec(self, command) -> AWSContainerExec: print("Waiting for container to be ready for exec...") self.wait_for_execute() - + exec_resp = self._ecs_client.execute_command( cluster=self._ecs_config.cluster, task=self._task_arn, @@ -220,6 +374,16 @@ def exec(self, command) -> AWSContainerExec: return AWSContainerExec(session=exec_resp["session"], exec_id=exec_id) def exec_kill(self, exec_id: str): + """Kill a background command previously started by :meth:`exec`. + + Sends ``kill`` to the PID recorded in ``/tmp/{exec_id}.pid`` inside the + container. + + Args: + exec_id: The exec ID returned implicitly via + :class:`AWSContainerExec` (the 8-character hex string used as + the PID file name). + """ self._ecs_client.execute_command( cluster=self._ecs_config.cluster, task=self._task_arn, @@ -229,6 +393,18 @@ def exec_kill(self, exec_id: str): ) def remove(self, *, v: bool = False, link: bool = False, force: bool = False): + """Remove the container, optionally stopping it first. + + ECS tasks are not explicitly deleted; this method only stops the task + when ``force=True``. Present for interface compatibility with local + container runners. + + Args: + v: Unused (volume removal flag in Docker API). + link: Unused (link removal flag in Docker API). + force: If ``True``, attempt to stop the task before removal. + Errors during stop are silently ignored. + """ if force: try: self.stop() @@ -236,18 +412,38 @@ def remove(self, *, v: bool = False, link: bool = False, force: bool = False): pass def is_running(self) -> bool: + """Check whether the task is currently in ``RUNNING`` status. + + Returns: + ``True`` if the task status is ``"RUNNING"``; ``False`` otherwise + or if the status check raises an exception. + """ try: return self._get_status() == "RUNNING" except Exception: return False def is_alive(self) -> bool: + """Check whether the task has not yet stopped or begun deprovisioning. + + Returns: + ``True`` if the task status is neither ``"STOPPED"`` nor + ``"DEPROVISIONING"``; ``False`` otherwise or on any exception. + """ try: return self._get_status() not in ("STOPPED", "DEPROVISIONING") except Exception: return False def stream_container_logs(self) -> Generator[str, None, None]: + """Yield CloudWatch log lines from the container until the task stops. + + Delegates to :func:`~worker.container_runner.aws_utils.stream_task_logs`, + resuming from the last pagination token if called multiple times. + + Yields: + Individual log message strings in chronological order. + """ yield from stream_task_logs( self._ecs_client, self._ecs_config.region, @@ -259,6 +455,12 @@ def stream_container_logs(self) -> Generator[str, None, None]: ) def get_exit_code(self) -> Optional[int]: + """Return the exit code of the first container, if the task has stopped. + + Returns: + The integer exit code, or ``None`` if the task has not yet stopped + or no container exit code is available. + """ task = self._describe_task() if task["lastStatus"] != "STOPPED": return None @@ -269,6 +471,23 @@ def get_exit_code(self) -> Optional[int]: class AWSContainerRunner(BaseContainerRunner): + """A ``BaseContainerRunner`` implementation that runs containers on AWS ECS (EC2 launch type). + + Reads its configuration from a JSON file (default path: + ``aws_config.json`` in the same directory, overridable via the + ``AWS_CONFIG_PATH`` environment variable) and exposes the standard + :meth:`prepare_image` and :meth:`spawn_container` interface. + + Expected keys in the config file: + + - ``cluster`` — ECS cluster name or ARN. + - ``capacity_provider`` — EC2 capacity provider name. + - ``task_definition`` — base task definition family or ARN. + - ``ecr_host`` — ECR registry host. + - ``mirroring_image_task_definition`` — task definition used for image mirroring. + - ``region`` *(optional)* — AWS region; defaults to ``"us-east-1"``. + """ + def __init__(self): aws_config_path = os.environ.get( "AWS_CONFIG_PATH", Path(__file__).parent / "aws_config.json" @@ -286,12 +505,22 @@ def __init__(self): ], ) - def prepare_image(self, image): + def prepare_image(self, image: str): + """Ensure a Docker image is available in ECR, mirroring it if necessary. + + Delegates to :func:`~worker.container_runner.aws_utils.mirror_image_to_ecr`. + Should be called before :meth:`spawn_container` to avoid cold-start + delays when the image has not been mirrored yet. + + Args: + image: Source Docker image reference (e.g. + ``"supervisely/base-py-sdk-light:6.73.527"``). + """ mirror_image_to_ecr(image, self.ecs_config) def spawn_container( self, - image, + image: str, *, runtime: str = None, # not used entrypoint: List = None, @@ -301,7 +530,7 @@ def spawn_container( volumes: Dict = None, # not used environment: Dict = None, labels: Dict = None, - shm_size: int = None, # not used + shm_size: int = None, # not used stdin_open: bool = False, # not used tty: bool = False, # not used cpu_limit: int = None, @@ -311,6 +540,43 @@ def spawn_container( ipc_mode: str = None, security_opt: List[str] = None, # not used ) -> AWSContainer: + """Launch a new ECS task and return a handle to the running container. + + Converts Docker-style parameters to their ECS equivalents, registers a + new task definition revision, and starts the task via the configured + EC2 capacity provider. Several Docker-specific parameters are accepted + for interface compatibility but are silently ignored. + + Args: + image: Docker image URI to run (should already be mirrored to ECR + via :meth:`prepare_image`). + runtime: Ignored (Docker container runtime flag). + entrypoint: Container entrypoint as a list of strings. + detach: Ignored (tasks always start detached on ECS). + name: Ignored (ECS tasks are identified by ARN, not name). + remove: Ignored. + volumes: Ignored (volume mounts are not supported in this runner). + environment: Environment variables to inject as ``{"KEY": "value"}`` + pairs. All values are coerced to strings. + labels: Resource tags to apply to the ECS task, converted to + ``[{"key": ..., "value": ...}]`` format. + shm_size: Ignored. + stdin_open: Ignored. + tty: Ignored. + cpu_limit: CPU units for the container. Passed directly to the task + definition revision. + mem_limit: Memory limit in Docker format (e.g. ``"512m"``, ``"2g"``, + or an integer number of bytes). Converted via + :func:`parse_mem_limit_to_bytes`. + memswap_limit: Ignored. + network: Ignored. + ipc_mode: IPC mode for the task (e.g. ``"host"``). An empty string + is treated as ``None``. + security_opt: Ignored. + + Returns: + An :class:`AWSContainer` handle to the newly started task. + """ memory = parse_mem_limit_to_bytes(mem_limit) if ipc_mode == "": ipc_mode = None diff --git a/agent/worker/container_runner/aws_utils.py b/agent/worker/container_runner/aws_utils.py index 218ac95..8e690b1 100644 --- a/agent/worker/container_runner/aws_utils.py +++ b/agent/worker/container_runner/aws_utils.py @@ -1,12 +1,25 @@ -import boto3 import os import time from dataclasses import dataclass from typing import Dict, List, Tuple, Union +import boto3 + @dataclass class ECSConfig: + """Configuration for interacting with AWS ECS and ECR. + + Attributes: + cluster: Name or ARN of the ECS cluster to run tasks on. + capacity_provider: Name of the EC2 capacity provider used for EC2 launch type tasks. + task_definition: Base task definition family or ARN used for EC2 tasks. + ecr_host: ECR registry host (e.g. ``123456789.dkr.ecr.us-east-1.amazonaws.com``). + mirroring_image_task_definition: Base task definition used for image mirroring + (Fargate) tasks. + region: AWS region. Defaults to ``"us-east-1"``. + """ + cluster: str capacity_provider: str task_definition: str @@ -16,7 +29,15 @@ class ECSConfig: def get_boto3_client(service: str, region: str): - """Helper to create boto3 clients with consistent credentials.""" + """Create a boto3 client authenticated via environment variables. + + Args: + service: AWS service name (e.g. ``"ecs"``, ``"ecr"``, ``"ec2"``). + region: AWS region name (e.g. ``"us-east-1"``). + + Returns: + A boto3 client for the specified service and region. + """ return boto3.client( service, region_name=region, @@ -26,6 +47,18 @@ def get_boto3_client(service: str, region: str): def get_default_network_config(ec2_client, assign_public_ip: bool = True) -> dict: + """Build an ECS ``awsvpcConfiguration`` dict from the account's default VPC. + + Discovers the default VPC, all of its subnets, and the default security group, + then assembles the network configuration expected by ``ecs_client.run_task``. + + Args: + ec2_client: A boto3 EC2 client. + assign_public_ip: Whether to enable ``assignPublicIp``. Defaults to ``True``. + + Returns: + A dict suitable for passing as ``networkConfiguration`` to ``run_task``. + """ vpcs = ec2_client.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}]) vpc_id = vpcs["Vpcs"][0]["VpcId"] @@ -58,15 +91,27 @@ def get_default_network_config(ec2_client, assign_public_ip: bool = True) -> dic def _parse_image_to_ecr_path( docker_image_name: str, ecr_host: str ) -> tuple[str, str, str]: - """ - Parse a docker image name into ECR components. + """Parse a Docker image name into its ECR repository components. + + Strips any existing registry prefix and splits the image path into a + repository name and tag, then constructs the full ECR URI. + + Example:: - supervisely/base-py-sdk-light:6.73.527 - -> repository_name: supervisely/base-py-sdk-light - -> image_tag: 6.73.527 - -> target_image: {ecr_host}/supervisely/base-py-sdk-light:6.73.527 + _parse_image_to_ecr_path( + "supervisely/base-py-sdk-light:6.73.527", + "123.dkr.ecr.us-east-1.amazonaws.com" + ) + # -> ("supervisely/base-py-sdk-light", "6.73.527", + # "123.dkr.ecr.us-east-1.amazonaws.com/supervisely/base-py-sdk-light:6.73.527") + + Args: + docker_image_name: Docker image reference, optionally including a registry + prefix (e.g. ``"docker.io/library/ubuntu:22.04"``). + ecr_host: ECR registry host to use as the target registry prefix. - Returns (repository_name, image_tag, target_image) + Returns: + A 3-tuple of ``(repository_name, image_tag, target_image)``. """ # Strip any existing registry prefix (anything before the first slash that contains a dot or colon) parts = docker_image_name.split("/") @@ -86,7 +131,16 @@ def _parse_image_to_ecr_path( def _ensure_ecr_repository(ecr_client, repository_name: str): - """Create ECR repository if it doesn't exist. Handles nested names like org/repo.""" + """Create an ECR repository if it does not already exist. + + Silently ignores ``RepositoryAlreadyExistsException`` so this function is + safe to call unconditionally before pushing an image. + + Args: + ecr_client: A boto3 ECR client. + repository_name: Repository name, which may contain slashes for + namespaced repos (e.g. ``"org/repo"``). + """ try: ecr_client.create_repository(repositoryName=repository_name) print(f"Created ECR repository: {repository_name}") @@ -95,6 +149,17 @@ def _ensure_ecr_repository(ecr_client, repository_name: str): def _image_exists_in_ecr(ecr_client, repository_name: str, image_tag: str) -> bool: + """Check whether a specific image tag exists in an ECR repository. + + Args: + ecr_client: A boto3 ECR client. + repository_name: Name of the ECR repository to query. + image_tag: Image tag to look up (e.g. ``"latest"`` or ``"1.2.3"``). + + Returns: + ``True`` if the image tag is found; ``False`` if the repository or + image does not exist. + """ try: ecr_client.describe_images( repositoryName=repository_name, imageIds=[{"imageTag": image_tag}] @@ -116,6 +181,33 @@ def _create_task_definition_revision( gpu: int = None, ipc_mode: str = None, ) -> tuple[str, str]: + """Register a new revision of a task definition with optional overrides. + + Fetches the most recent active revision of ``base_task_definition``, applies + the requested overrides to the first container, copies all supported optional + fields, and registers the result as a new revision. + + Args: + ecs_client: A boto3 ECS client. + base_task_definition: Family name or ARN of the task definition to clone. + new_image: Docker image URI to set on the first container. If ``None``, + the existing image is preserved. + entrypoint: Override for the container ``entryPoint``. If ``None``, the + existing entrypoint is preserved. + cpu: CPU units to assign to the first container. If ``None``, the + existing value is preserved. + memory: Memory (MiB) to assign to the first container. If ``None``, the + existing value is preserved. + gpu: Number of GPUs to request via ``resourceRequirements``. If ``None``, + no GPU requirement is set. + ipc_mode: IPC mode for the task (e.g. ``"host"``). Overrides any value + in the base definition when provided. + + Returns: + A 2-tuple of ``(task_definition_arn, container_name)`` where + ``task_definition_arn`` is the ARN of the newly registered revision and + ``container_name`` is the name of the first container. + """ task_def_response = ecs_client.describe_task_definition( taskDefinition=base_task_definition ) @@ -178,9 +270,19 @@ def _create_task_definition_revision( def _get_task_log_config( ecs_client, task_definition_arn: str, container_name: str ) -> dict | None: - """ - Extract CloudWatch log configuration for a container from a task definition. - Returns dict with {log_group, log_stream_prefix, region} or None if not configured. + """Extract the CloudWatch Logs configuration for a named container. + + Looks up the task definition and returns the ``awslogs`` log driver options + for ``container_name`` if they are present. + + Args: + ecs_client: A boto3 ECS client. + task_definition_arn: Full ARN of the task definition to inspect. + container_name: Name of the container whose log config to retrieve. + + Returns: + A dict with keys ``log_group``, ``log_stream_prefix``, and ``region`` + if the container uses the ``awslogs`` driver; otherwise ``None``. """ task_def = ecs_client.describe_task_definition(taskDefinition=task_definition_arn) for container in task_def["taskDefinition"]["containerDefinitions"]: @@ -199,7 +301,22 @@ def _get_task_log_config( def _stream_task_logs( logs_client, log_group: str, log_stream: str, next_token: str = None ) -> str | None: - """Stream new log events from a CloudWatch log stream since the last token.""" + """Print new log events from a CloudWatch log stream and return the next token. + + Paginates through all available events since ``next_token``, printing each + message to stdout. Stops when no new events are returned. + + Args: + logs_client: A boto3 CloudWatch Logs client. + log_group: Name of the CloudWatch log group. + log_stream: Name of the log stream within the group. + next_token: Pagination token from a previous call. Pass ``None`` to + start from the beginning of the stream. + + Returns: + The ``nextForwardToken`` to pass on the next call, or ``None`` if the + stream was not found. + """ while True: kwargs = { "logGroupName": log_group, @@ -234,7 +351,26 @@ def _wait_for_task_and_logs( container_name: str, poll_interval: int = 1, ): - """Wait for ECS task to finish, streaming CloudWatch logs when available.""" + """Block until an ECS task stops, streaming its CloudWatch logs to stdout. + + Polls the task status every ``poll_interval`` seconds. On each iteration, + any new log events are printed. Raises ``RuntimeError`` if any container + exits with a non-zero code or if the task stops before the container starts. + + Args: + ecs_client: A boto3 ECS client. + region: AWS region of the task and log group. + cluster: Name or ARN of the ECS cluster. + task_arn: ARN of the running task to monitor. + task_definition_arn: ARN of the task definition (used to resolve the + log configuration). + container_name: Name of the container whose logs to stream. + poll_interval: Seconds to wait between status polls. Defaults to ``1``. + + Raises: + RuntimeError: If any container exits with a non-zero exit code, or if + the task stops before a container starts. + """ logs_client = get_boto3_client("logs", region) log_config = _get_task_log_config(ecs_client, task_definition_arn, container_name) @@ -286,7 +422,22 @@ def _wait_for_task_and_logs( def _collect_log_lines( logs_client, log_group: str, log_stream: str, next_token: str = None ) -> tuple[str | None, list[str]]: - """Fetch new log lines since last token. Returns (next_token, lines).""" + """Fetch new log lines from a CloudWatch log stream since the last token. + + Paginates until no new events are available and collects all message strings. + + Args: + logs_client: A boto3 CloudWatch Logs client. + log_group: Name of the CloudWatch log group. + log_stream: Name of the log stream within the group. + next_token: Pagination token from a previous call. Pass ``None`` to + start from the beginning of the stream. + + Returns: + A 2-tuple of ``(next_token, lines)`` where ``next_token`` is the forward + pagination token for subsequent calls and ``lines`` is a list of log + message strings collected in this call. + """ lines = [] while True: kwargs = { @@ -322,7 +473,31 @@ def stream_task_logs( poll_interval: int = 1, next_log_token: str = None, ): - """Yield log lines from a running ECS task until it stops.""" + """Yield log lines from a running ECS task until it stops. + + Polls the task status and CloudWatch log stream in a loop, yielding each + new log line as it becomes available. Returns when the task reaches + ``STOPPED`` status. + + Args: + ecs_client: A boto3 ECS client. + region: AWS region of the task and log group. + cluster: Name or ARN of the ECS cluster. + task_arn: ARN of the running task to tail. + task_definition_arn: ARN of the task definition (used to resolve the + log configuration). + container_name: Name of the container whose logs to stream. + poll_interval: Seconds to wait between log/status polls. Defaults to ``1``. + next_log_token: Optional starting pagination token. Pass ``None`` to + stream from the beginning. + + Yields: + Individual log message strings in chronological order. + + Raises: + RuntimeError: If no CloudWatch log configuration is found for the + specified container. + """ logs_client = get_boto3_client("logs", region) log_config = _get_task_log_config(ecs_client, task_definition_arn, container_name) @@ -359,7 +534,38 @@ def run_container_fargate( tags: List[Dict] = None, wait: bool = False, ) -> str: - """Run a container on Fargate (FARGATE launch type).""" + """Run a container on AWS Fargate and optionally wait for it to finish. + + Creates a new task definition revision from + ``ecs_config.mirroring_image_task_definition``, applies the provided + overrides, and launches it with the ``FARGATE`` launch type using the + account's default VPC network configuration. + + Args: + ecs_config: ECS/ECR configuration including cluster and task definition + details. + docker_image_name: Docker image URI to run. If ``None``, the image from + the base task definition is used. + entrypoint: Container entrypoint. A string is split on whitespace into a + list. Defaults to the base task definition's entrypoint. + command: Container command. A string is split on whitespace into a list. + Defaults to the base task definition's command. + env_vars: Additional environment variables to inject into the container + as ``{"KEY": "value"}`` pairs. + cpu: CPU units to allocate. Defaults to the base task definition value. + memory: Memory in MiB to allocate. Defaults to the base task definition + value. + tags: List of ECS resource tags in ``[{"key": ..., "value": ...}]`` form. + wait: If ``True``, block until the task stops and stream its logs. + Defaults to ``False``. + + Returns: + The ARN of the started Fargate task. + + Raises: + RuntimeError: If ECS reports failures and no task ARN is returned, or + if ``wait=True`` and the task exits with a non-zero status. + """ if isinstance(command, str): command = command.split() if command else [] if isinstance(entrypoint, str): @@ -421,9 +627,25 @@ def mirror_image_to_ecr( docker_image_name: str, ecs_config: ECSConfig, ) -> str: - """ - Ensure a Docker image is mirrored to ECR. - Returns the ECR image URI. + """Ensure a public Docker image is mirrored into ECR, pulling it if needed. + + Checks whether the image already exists in ECR. If it does, returns the ECR + URI immediately. If not, launches a Fargate mirroring task (using + ``ecs_config.mirroring_image_task_definition``) that pulls the source image + and pushes it to ECR, then waits for the task to finish. + + Args: + docker_image_name: Source Docker image reference, e.g. + ``"supervisely/base-py-sdk-light:6.73.527"``. + ecs_config: ECS/ECR configuration including the ECR host and cluster + details. + + Returns: + The full ECR image URI (e.g. + ``"123.dkr.ecr.us-east-1.amazonaws.com/supervisely/base-py-sdk-light:6.73.527"``). + + Raises: + RuntimeError: If the mirroring task fails or exits with a non-zero status. """ ecr_client = get_boto3_client("ecr", ecs_config.region) repository_name, image_tag, target_image = _parse_image_to_ecr_path( @@ -473,7 +695,38 @@ def run_container_ec2( tags: List[Dict] = None, wait: bool = False, ) -> Tuple[str, str, str]: - """Run a container using the EC2 capacity provider.""" + """Run a container via the EC2 capacity provider and optionally wait for it. + + Creates a new task definition revision from ``ecs_config.task_definition``, + applies the provided overrides, and launches the task using + ``ecs_config.capacity_provider`` with ``enableExecuteCommand`` enabled. + + Args: + docker_image_name: Docker image URI to run. + entrypoint: Container entrypoint. A string is split on whitespace into a + list. + command: Container command. A single string is wrapped in a list. + ecs_config: ECS/ECR configuration including the cluster and capacity + provider details. + env_vars: Additional environment variables to inject into the container + as ``{"KEY": "value"}`` pairs. + cpu: CPU units for the first container. Defaults to the base task + definition value. + memory: Memory in MiB for the first container. Defaults to the base task + definition value. + gpu: Number of GPUs to request. Defaults to no GPU requirement. + ipc_mode: IPC mode for the task (e.g. ``"host"``). + tags: List of ECS resource tags in ``[{"key": ..., "value": ...}]`` form. + wait: If ``True``, block until the task stops and stream its logs. + Defaults to ``False``. + + Returns: + A 3-tuple of ``(task_arn, container_name, task_definition_arn)``. + + Raises: + RuntimeError: If ECS reports failures and no task ARN is returned, or + if ``wait=True`` and the task exits with a non-zero status. + """ if isinstance(command, str): command = [command] if isinstance(entrypoint, str): @@ -541,7 +794,27 @@ def run( ecs_config: ECSConfig, env_vars: dict = None, ) -> str: - """Full pipeline: mirror image to ECR, create task def revision, run and wait.""" + """Mirror an image to ECR, then run it on EC2 and wait for completion. + + Convenience wrapper that combines :func:`mirror_image_to_ecr` and + :func:`run_container_ec2` into a single call. The function blocks until the + task finishes and raises on failure. + + Args: + docker_image_name: Source Docker image to mirror and run (e.g. + ``"supervisely/base-py-sdk-light:6.73.527"``). + entrypoint: Container entrypoint string (split on whitespace). + command: Container command as a string or list of strings. + ecs_config: ECS/ECR configuration used for both mirroring and running. + env_vars: Environment variables to inject into the container as + ``{"KEY": "value"}`` pairs. + + Returns: + The ARN of the completed EC2 task. + + Raises: + RuntimeError: If mirroring or the EC2 task fails. + """ ecr_image = mirror_image_to_ecr(docker_image_name, ecs_config) return run_container_ec2( docker_image_name=ecr_image, diff --git a/agent/worker/container_runner/demo.py b/agent/worker/container_runner/demo.py new file mode 100644 index 0000000..6c56d2a --- /dev/null +++ b/agent/worker/container_runner/demo.py @@ -0,0 +1,136 @@ +"""demo.py — Run a customisable Docker container on AWS via AWSContainerRunner. + +Configuration +------------- +Environment variables (required unless noted): + + CONTAINER_IMAGE Docker image to run, e.g. ``ubuntu:22.04``. + CONTAINER_ENTRYPOINT (optional) Entrypoint override as a shell string, + e.g. ``"python -u script.py"``. If omitted, the image's + default entrypoint is used. + +Container environment variables are loaded from a ``run.env`` file in the +current working directory (if it exists). Each non-empty, non-comment line +must be in ``KEY=VALUE`` format. + +Usage +----- + # Minimal + CONTAINER_IMAGE=ubuntu:22.04 python demo.py + + # With entrypoint and env file + CONTAINER_IMAGE=myrepo/myapp:1.0 CONTAINER_ENTRYPOINT="python main.py" python demo.py +""" + +import os +import sys +from pathlib import Path + +from .aws_runner import AWSContainerRunner + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def load_env_file(path: Path) -> dict: + """Parse a ``KEY=VALUE`` env file and return a dict of variables. + + Lines that are empty or start with ``#`` are ignored. Inline comments + (anything after an unquoted ``#``) are *not* stripped — values are taken + verbatim after the first ``=``. + + Args: + path: Path to the env file. + + Returns: + A ``{key: value}`` dict. Returns an empty dict if the file does not + exist. + """ + env: dict = {} + if not path.exists(): + print(f"[demo] No env file found at '{path}', running without extra vars.") + return env + + with path.open() as f: + for lineno, raw in enumerate(f, start=1): + line = raw.rstrip("\n") + if not line or line.lstrip().startswith("#"): + continue + if "=" not in line: + print(f"[demo] Warning: skipping malformed line {lineno} in '{path}': {line!r}") + continue + key, value = line.split("=", 1) + env[key.strip()] = value + return env + + +def parse_entrypoint(raw: str | None) -> list[str] | None: + """Split a shell entrypoint string into a list, or return ``None``. + + Args: + raw: Entrypoint string (e.g. ``"python -u script.py"``), or ``None``. + + Returns: + A list of strings suitable for ``spawn_container(entrypoint=...)``, or + ``None`` if ``raw`` is empty or ``None``. + """ + if not raw: + return None + import shlex + return shlex.split(raw) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + # --- Read configuration -------------------------------------------------- + image = os.environ.get("CONTAINER_IMAGE", "").strip() + if not image: + print("Error: CONTAINER_IMAGE environment variable is required.", file=sys.stderr) + sys.exit(1) + + entrypoint = parse_entrypoint(os.environ.get("CONTAINER_ENTRYPOINT")) + environment = load_env_file(Path("run.env")) + + # --- Summary ------------------------------------------------------------- + print("[demo] Launch configuration:") + print(f" Image : {image}") + print(f" Entrypoint : {entrypoint or '(image default)'}") + print(f" Env vars : {list(environment.keys()) or '(none)'}") + print() + + # --- Run ----------------------------------------------------------------- + runner = AWSContainerRunner() + + print("[demo] Mirroring image to ECR (skipped if already present)...") + runner.prepare_image(image) + + print("[demo] Spawning container...") + container = runner.spawn_container( + image=image, + entrypoint=entrypoint, + environment=environment, + detach=True, + ) + print(f"[demo] Task started: {container._task_arn}") + + # --- Stream logs --------------------------------------------------------- + print("[demo] Streaming container logs (Ctrl-C to detach):\n") + try: + for line in container.stream_container_logs(): + print(line) + except KeyboardInterrupt: + print("\n[demo] Detached from log stream. Task is still running.") + return + + # --- Exit code ----------------------------------------------------------- + exit_code = container.get_exit_code() + print(f"\n[demo] Container finished with exit code: {exit_code}") + sys.exit(exit_code if exit_code is not None else 0) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/agent/worker/container_runner/local.py b/agent/worker/container_runner/local.py index ac64bca..60656f6 100644 --- a/agent/worker/container_runner/local.py +++ b/agent/worker/container_runner/local.py @@ -1,26 +1,18 @@ from logging import Logger -from typing import Callable, Dict, Generator, List, Literal, Optional +from typing import Dict, Generator, List, Literal, Optional + import docker +import supervisely as sly +from docker.errors import APIError, DockerException, NotFound +from docker.models.containers import Container from worker import constants, docker_utils -from worker.task_dockerized import ErrorReport +from worker.agent_utils import convert_millicores_to_cpu_quota from worker.container_runner.container_runner import ( + BaseContainer, BaseContainerExec, BaseContainerRunner, - BaseContainer, ) -from docker.models.containers import Container - -from worker.agent_utils import ( - TaskDirCleaner, - filter_log_line, - pip_req_satisfied_filter, - post_get_request_filter, - convert_millicores_to_cpu_quota, -) - -import supervisely as sly - -from docker.errors import APIError, NotFound, DockerException +from worker.task_dockerized import ErrorReport class LocalContainerExec(BaseContainerExec): diff --git a/agent/worker/container_runner/readme.md b/agent/worker/container_runner/readme.md new file mode 100644 index 0000000..8446b91 --- /dev/null +++ b/agent/worker/container_runner/readme.md @@ -0,0 +1,63 @@ +# AWS Container Runner +A thin wrapper around AWS ECS that lets you launch, monitor, and interact with Docker containers on EC2-backed ECS clusters using the same interface as a local Docker runner. + +## How it works +The wrapper is built around three layers: +1. aws_utils.py — Low-level AWS primitives +Stateless functions that talk directly to the AWS APIs: + +mirror_image_to_ecr — Checks whether a public Docker image already exists in your ECR registry. If not, it launches a short-lived Fargate task that pulls the image and pushes it to ECR. This ensures all images are served from within your AWS account, avoiding external registry rate limits and improving pull latency. +run_container_ec2 — Clones a base ECS task definition, applies overrides (image, entrypoint, CPU, memory, GPU, environment variables), registers a new revision, and starts the task using your EC2 capacity provider. +stream_task_logs — Tails CloudWatch Logs for a running task, yielding lines as they arrive. + +2. aws_container_runner.py — High-level container interface +Two classes that implement the BaseContainer / BaseContainerRunner interface: +AWSContainerRunner is the entry point. It reads cluster configuration from a JSON file and exposes two methods: + +prepare_image(image) — Mirror an image to ECR before running it. +spawn_container(image, ...) — Launch a container and return an AWSContainer handle. + +AWSContainer is the handle returned by spawn_container. It wraps a single ECS task and provides: +MethodDescriptionis_running()Returns True if the task is in RUNNING state.is_alive()Returns True if the task has not yet stopped or begun deprovisioning.wait(timeout, condition)Blocks until the task stops; returns {"StatusCode": exit_code}.stop()Sends a stop request to the task (async).remove(force)Optionally stops the task; present for interface compatibility.stream_container_logs()Yields CloudWatch log lines until the task stops.get_exit_code()Returns the container exit code, or None if still running.exec(command)Runs a shell command inside the container via ECS Execute Command.exec_kill(exec_id)Kills a background command started by exec. +AWSContainerExec is returned by AWSContainer.exec. It opens an SSM WebSocket session and exposes: + +stream_logs() — Yields output lines from the remote command. +get_exit_code() — Returns the command's exit code after the stream ends. + +## Configuration +AWS config file +AWSContainerRunner reads its cluster configuration from a JSON file. The default path is aws_config.json in the same directory as aws_container_runner.py. Override it with the AWS_CONFIG_PATH environment variable. +json{ + "region": "us-east-1", + "cluster": "my-ecs-cluster", + "capacity_provider": "my-ec2-capacity-provider", + "task_definition": "my-base-task-def", + "ecr_host": "123456789012.dkr.ecr.us-east-1.amazonaws.com", + "mirroring_image_task_definition": "my-mirror-task-def" +} +KeyDescriptionregionAWS region. Defaults to us-east-1 if omitted.clusterECS cluster name or ARN.capacity_providerEC2 capacity provider used for all container tasks.task_definitionBase task definition cloned for each container run.ecr_hostECR registry host (the part before the first /).mirroring_image_task_definitionTask definition used by the Fargate image-mirroring task. +AWS credentials +Credentials are read from the standard environment variables: +AWS_ACCESS_KEY_ID=... +AWS_SECRET_ACCESS_KEY=... + +## Demo + +1. Create a run.env file (optional) +Add any environment variables you want injected into the container, one per line: +# run.env +MY_VAR=hello +ANOTHER_VAR=world + +2. Set required environment variables +bashexport AWS_ACCESS_KEY_ID=... +export AWS_SECRET_ACCESS_KEY=... + +# Required: the image to run +export CONTAINER_IMAGE=ubuntu:22.04 + +# Optional: override the image's default entrypoint +export CONTAINER_ENTRYPOINT="bash -c 'echo hello && sleep 5'" + +3. Run +bashpython demo.py \ No newline at end of file