From 0c5671051f4899db57c6975229ba61aaa351898a Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Fri, 17 Apr 2026 23:25:08 -0500 Subject: [PATCH 1/9] refactor(v2-review): extract auth module, clean up credential loading and dead code - Extract `_load_credentials` from both orchestrators into `core/auth.py` (`load_credentials`) - Remove dead `_filter_images_by_dockerfile_context` method from RunOrchestrator - Raise `ConfigurationError` instead of `SystemExit` in BuildOrchestrator config loading - Add CLAUDE.md with codebase guidance and architecture docs - Update tests to cover auth module and refactored error handling Co-authored-by: Claude Sonnet 4 --- CLAUDE.md | 129 ++++++++++++++++++ pyproject.toml | 2 - src/madengine/cli/app.py | 8 +- src/madengine/cli/commands/run.py | 3 + src/madengine/cli/constants.py | 6 +- src/madengine/cli/validators.py | 2 + src/madengine/core/auth.py | 77 +++++++++++ src/madengine/core/context.py | 5 - src/madengine/core/dataprovider.py | 2 - src/madengine/core/docker.py | 16 ++- src/madengine/core/errors.py | 32 +++-- src/madengine/deployment/base.py | 8 +- src/madengine/deployment/common.py | 2 + src/madengine/deployment/factory.py | 9 +- src/madengine/deployment/slurm.py | 20 ++- src/madengine/execution/container_runner.py | 5 +- src/madengine/execution/docker_builder.py | 22 +-- .../orchestration/build_orchestrator.py | 54 +------- .../orchestration/run_orchestrator.py | 98 +------------ src/madengine/reporting/csv_to_email.py | 4 +- src/madengine/reporting/update_perf_csv.py | 10 +- src/madengine/utils/config_parser.py | 2 +- src/madengine/utils/discover_models.py | 12 +- src/madengine/utils/gpu_config.py | 23 ++-- src/madengine/utils/log_formatting.py | 2 +- src/madengine/utils/rocm_tool_manager.py | 11 +- src/madengine/utils/session_tracker.py | 8 +- tests/integration/test_docker_integration.py | 11 +- tests/integration/test_errors.py | 20 +-- tests/unit/test_auth.py | 117 ++++++++++++++++ tests/unit/test_deployment.py | 8 ++ tests/unit/test_error_handling.py | 13 +- 32 files changed, 488 insertions(+), 253 deletions(-) create mode 100644 CLAUDE.md create mode 100644 src/madengine/core/auth.py create mode 100644 tests/unit/test_auth.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..085b8997 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,129 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Setup + +```bash +# Install in development mode with all dependencies +pip install -e ".[dev]" + +# Optional: install Kubernetes support +pip install -e ".[all]" + +# Setup pre-commit hooks +pre-commit install +``` + +## Commands + +```bash +# Run all tests +pytest + +# Run specific test file +pytest tests/unit/test_error_handling.py -v + +# Run specific test class or function +pytest tests/unit/test_error_handling.py::TestErrorPatternMatching -v + +# Run tests with coverage +pytest --cov=src/madengine --cov-report=html + +# Skip slow tests +pytest -m "not slow" + +# Format code +black src/ tests/ +isort src/ tests/ + +# Lint +flake8 src/ tests/ + +# Type check +mypy src/madengine + +# Run all pre-commit checks +pre-commit run --all-files +``` + +## Architecture + +madengine is a CLI tool for running AI/ML models in local Docker, Kubernetes, and SLURM environments. The entry point is `madengine.cli.app:cli_main` (registered as the `madengine` console script). + +### Layer Structure + +**CLI Layer** (`src/madengine/cli/`) +- `app.py` — Typer app wiring, registers 5 commands: `discover`, `build`, `run`, `report`, `database` +- `commands/` — One file per command (build, run, discover, report, database) +- `constants.py` — `ExitCode` enum (`SUCCESS=0`, `FAILURE=1`, `BUILD_FAILURE=2`, `RUN_FAILURE=3`, `INVALID_ARGS=4`) + +**Orchestration Layer** (`src/madengine/orchestration/`) +- `build_orchestrator.py` — `BuildOrchestrator`: discovers models, builds Docker images, writes `build_manifest.json` +- `run_orchestrator.py` — `RunOrchestrator`: reads or triggers builds, infers deployment target, delegates to local or distributed execution + +**Core Layer** (`src/madengine/core/`) +- `context.py` — `Context` class: merges `additional_context` with system detection (GPU vendor, architecture, OS, ROCm path). Uses `ast.literal_eval()` to parse additional_context strings (not `json.loads` — pass Python dict repr, not JSON) +- `console.py` — `Console`: shell execution wrapper with live output support +- `docker.py` — Docker command wrapper + +**Execution Layer** (`src/madengine/execution/`) +- `container_runner.py` — `ContainerRunner`: runs models from manifest via `docker run`, writes results to `perf.csv` +- `docker_builder.py` — `DockerBuilder`: builds images from Dockerfiles +- `container_runner_helpers.py` — Log error pattern scanning, timeout resolution + +**Deployment Layer** (`src/madengine/deployment/`) +- `factory.py` — `DeploymentFactory`: Factory pattern, registers `SlurmDeployment` and `KubernetesDeployment` +- `base.py` — `BaseDeployment` abstract class, `DeploymentConfig` dataclass +- `kubernetes.py` / `slurm.py` — Concrete deployments; target is inferred by Convention over Configuration: presence of `"k8s"` or `"kubernetes"` key → K8s; `"slurm"` key → SLURM; neither → local +- `presets/` — JSON preset files for K8s/SLURM default configurations; auto-merged with minimal user configs +- `config_loader.py` — Loads and merges preset JSON with user-supplied config + +**Utils** (`src/madengine/utils/`) +- `discover_models.py` — `DiscoverModels`: three discovery methods: root `models.json`, `scripts/{dir}/models.json`, or `scripts/{dir}/get_models_json.py` (dynamic) +- `gpu_tool_factory.py` / `gpu_tool_manager.py` — GPU vendor abstraction (AMD/NVIDIA) +- `gpu_validator.py` — ROCm installation detection, GPU vendor detection +- `config_parser.py` — `ConfigParser`: parses `--additional-context` and tools config + +**Reporting** (`src/madengine/reporting/`) +- `update_perf_csv.py` — Writes/appends to `perf.csv` and `perf_entry.csv` +- `csv_to_html.py` / `csv_to_email.py` — Report generation + +### Key Data Flows + +1. **Build flow**: CLI → `BuildOrchestrator` → `DiscoverModels` (finds models by tags) → `DockerBuilder` (builds images) → writes `build_manifest.json` + +2. **Run flow**: CLI → `RunOrchestrator` → loads/generates `build_manifest.json` → infers target → `ContainerRunner` (local) or `DeploymentFactory` (K8s/SLURM) → writes `perf.csv` + +3. **`additional_context`**: User JSON/Python-dict string merged into `Context.ctx`. Context is parsed with `ast.literal_eval()`, so values can use Python dict syntax. Keys like `k8s`, `slurm`, `distributed`, `tools`, `pre_scripts`, `post_scripts` drive behavior. + +4. **Model definition**: Models defined in `models.json` with fields: `name`, `tags`, `dockerfile`, `scripts`, `n_gpus`, `args`, `timeout`, `skip_gpu_arch`, etc. + +5. **Script isolation**: During run, `scripts/common/` is populated from the madengine package (pre_scripts, post_scripts, tools) and cleaned up afterwards. The MAD project's own `scripts/` and `docker/` directories are preserved. + +### Deployment Target Inference + +No explicit `"deploy"` field is needed. Target is inferred from config structure: +- `"k8s"` or `"kubernetes"` key present → Kubernetes deployment +- `"slurm"` key present → SLURM deployment +- Neither → local Docker execution + +### Test Structure + +``` +tests/ +├── unit/ # Fast isolated tests with mocking +├── integration/ # End-to-end with real Docker/system calls +├── e2e/ # Full workflow tests +└── fixtures/ # Dummy models, scripts, and data for testing +``` + +Pytest config is in `pyproject.toml` under `[tool.pytest.ini_options]`. Test markers: `slow`, `integration`. + +### Code Style + +- Black formatting, 88-character line length +- isort with `profile = "black"` +- Google-style docstrings +- Type hints required for public functions +- Conventional commits: `feat:`, `fix:`, `docs:`, `test:`, `refactor:`, `style:`, `perf:`, `chore:` diff --git a/pyproject.toml b/pyproject.toml index 81fded5e..0c83f30a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,8 @@ dependencies = [ "GitPython", "jsondiff", "sqlalchemy", - "setuptools-rust", "paramiko", "tqdm", - "pytest", "typing-extensions", "pymongo", "toml", diff --git a/src/madengine/cli/app.py b/src/madengine/cli/app.py index 66d3256b..2e761f49 100644 --- a/src/madengine/cli/app.py +++ b/src/madengine/cli/app.py @@ -8,6 +8,7 @@ """ import sys +from importlib.metadata import PackageNotFoundError, version as pkg_version import typer from rich.traceback import install @@ -55,9 +56,12 @@ def main( Built with Typer and Rich for a beautiful, production-ready experience. """ if version: - # You might want to get the actual version from your package + try: + _version = pkg_version("madengine") + except PackageNotFoundError: + _version = "unknown" console.print( - "🚀 [bold cyan]madengine[/bold cyan] version [green]2.0.0[/green]" + f"🚀 [bold cyan]madengine[/bold cyan] version [green]{_version}[/green]" ) raise typer.Exit() diff --git a/src/madengine/cli/commands/run.py b/src/madengine/cli/commands/run.py index a684973d..09d90772 100644 --- a/src/madengine/cli/commands/run.py +++ b/src/madengine/cli/commands/run.py @@ -194,6 +194,9 @@ def run( # Convert -1 (default) to actual default timeout value (7200 seconds = 2 hours) if timeout == -1: timeout = 7200 + # 0 means "no timeout" per the help text — map to None so subprocess never expires + elif timeout == 0: + timeout = None try: # Check if we're doing execution-only or full workflow diff --git a/src/madengine/cli/constants.py b/src/madengine/cli/constants.py index f32eb024..b437fa30 100644 --- a/src/madengine/cli/constants.py +++ b/src/madengine/cli/constants.py @@ -5,11 +5,13 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ +from enum import IntEnum + # Exit codes -class ExitCode: +class ExitCode(IntEnum): """Exit codes for CLI commands.""" - + SUCCESS = 0 FAILURE = 1 BUILD_FAILURE = 2 diff --git a/src/madengine/cli/validators.py b/src/madengine/cli/validators.py index d99e87f7..b4e08e8b 100644 --- a/src/madengine/cli/validators.py +++ b/src/madengine/cli/validators.py @@ -395,6 +395,8 @@ def process_batch_manifest_entries( # If the model was not built (build_new=false), create an entry for it if not build_new: + # Initialize with a safe fallback so the except block can always reference it + dockerfile_matched = "unknown" # Find the model configuration by discovering models with this tag try: # Create a temporary args object to discover the model diff --git a/src/madengine/core/auth.py b/src/madengine/core/auth.py new file mode 100644 index 00000000..e592dd60 --- /dev/null +++ b/src/madengine/core/auth.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +Shared authentication utilities for madengine. + +Centralises credential loading logic used by both BuildOrchestrator and +RunOrchestrator so that fixes and improvements only need to be made once. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import json +import os +from typing import Dict, Optional + +from madengine.core.errors import ( + ConfigurationError, + create_error_context, + handle_error, +) + + +def load_credentials() -> Optional[Dict]: + """Load credentials from credential.json and environment variables. + + Precedence (highest wins): + 1. ``MAD_DOCKERHUB_USER`` / ``MAD_DOCKERHUB_PASSWORD`` environment vars + (merged into the ``dockerhub`` key of the returned dict) + 2. ``credential.json`` in the current working directory + + Returns: + Credentials dict (keyed by registry name), or ``None`` if no + credentials are found. + """ + credentials: Optional[Dict] = None + + credential_file = "credential.json" + if os.path.exists(credential_file): + try: + with open(credential_file) as f: + credentials = json.load(f) + print( + f"Loaded credentials from {credential_file}: " + f"{list(credentials.keys())}" + ) + except Exception as e: + context = create_error_context( + operation="load_credentials", + component="auth", + file_path=credential_file, + ) + handle_error( + ConfigurationError( + f"Could not load credentials: {e}", + context=context, + suggestions=[ + "Check if credential.json exists and has valid JSON format" + ], + ) + ) + + # Environment variables override / supplement file credentials + docker_hub_user = os.environ.get("MAD_DOCKERHUB_USER") + docker_hub_password = os.environ.get("MAD_DOCKERHUB_PASSWORD") + docker_hub_repo = os.environ.get("MAD_DOCKERHUB_REPO") + + if docker_hub_user and docker_hub_password: + print("Found Docker Hub credentials in environment variables") + if credentials is None: + credentials = {} + credentials["dockerhub"] = { + "username": docker_hub_user, + "password": docker_hub_password, + } + if docker_hub_repo: + credentials["dockerhub"]["repository"] = docker_hub_repo + + return credentials diff --git a/src/madengine/core/context.py b/src/madengine/core/context.py index 24763588..57e160c1 100644 --- a/src/madengine/core/context.py +++ b/src/madengine/core/context.py @@ -395,11 +395,8 @@ def get_gpu_vendor(self) -> str: for amd_smi_path in amd_smi_paths: if os.path.exists(amd_smi_path): try: - # Debug: log to stderr so SLURM node .err captures where we are if killed - print(f"[DEBUG] get_gpu_vendor: trying amd-smi at {amd_smi_path}", file=sys.stderr, flush=True) # Verify amd-smi actually works (180s timeout for slow GPU initialization) result = self.console.sh(f"{amd_smi_path} list > /dev/null 2>&1 && echo 'AMD' || echo ''", timeout=180) - print(f"[DEBUG] get_gpu_vendor: amd-smi returned", file=sys.stderr, flush=True) if result and result.strip() == "AMD": return "AMD" except Exception as e: @@ -409,9 +406,7 @@ def get_gpu_vendor(self) -> str: rocm_smi_path = os.path.join(self._rocm_path, "bin", "rocm-smi") if os.path.exists(rocm_smi_path): try: - print(f"[DEBUG] get_gpu_vendor: trying rocm-smi at {rocm_smi_path}", file=sys.stderr, flush=True) result = self.console.sh(f"{rocm_smi_path} --showid > /dev/null 2>&1 && echo 'AMD' || echo ''", timeout=180) - print(f"[DEBUG] get_gpu_vendor: rocm-smi returned", file=sys.stderr, flush=True) if result and result.strip() == "AMD": return "AMD" except Exception as e: diff --git a/src/madengine/core/dataprovider.py b/src/madengine/core/dataprovider.py index c0df24a5..809c4425 100644 --- a/src/madengine/core/dataprovider.py +++ b/src/madengine/core/dataprovider.py @@ -164,8 +164,6 @@ def check_source(self, config: typing.Dict) -> bool: # get the base directory of the current file. BASE_DIR = os.path.dirname(os.path.realpath(__file__)) - print("DEBUG - BASE_DIR::", BASE_DIR) - print("DEBUG - self.config[path]::", self.config["path"]) # check if the path exists in the base directory. # if os.path.exists(BASE_DIR + "/../" + self.config["path"]): diff --git a/src/madengine/core/docker.py b/src/madengine/core/docker.py index 42f88263..f9b5c6c9 100644 --- a/src/madengine/core/docker.py +++ b/src/madengine/core/docker.py @@ -7,6 +7,7 @@ """ # built-in modules import os +import shlex import typing # user-defined modules @@ -32,7 +33,7 @@ def __init__( mounts: typing.Optional[typing.List] = None, envVars: typing.Optional[typing.Dict] = None, keep_alive: bool = False, - console: Console = Console(), + console: Console = None, ) -> None: """Constructor of the Docker class. @@ -52,13 +53,14 @@ def __init__( self.docker_sha = None self.keep_alive = keep_alive cwd = os.getcwd() - self.console = console + self.console = console if console is not None else Console() self.userid = self.console.sh("id -u") self.groupid = self.console.sh("id -g") # check if container name exists + container_name_quoted = shlex.quote(container_name) container_name_exists = self.console.sh( - "docker container ps -a | grep " + container_name + " | wc -l" + "docker container ps -a | grep " + container_name_quoted + " | wc -l" ) # if container name exists, clean it up automatically if container_name_exists != "0": @@ -67,11 +69,11 @@ def __init__( ) # Stop the container (with timeout) self.console.sh( - f"docker stop -t 1 {container_name} 2>/dev/null || true" + f"docker stop -t 1 {container_name_quoted} 2>/dev/null || true" ) # Remove the container self.console.sh( - f"docker rm -f {container_name} 2>/dev/null || true" + f"docker rm -f {container_name_quoted} 2>/dev/null || true" ) print(f"✓ Cleaned up existing container '{container_name}'") @@ -93,7 +95,7 @@ def __init__( # add envVars if envVars is not None: for evar in envVars.keys(): - command += "-e " + evar + "=" + envVars[evar] + " " + command += "-e " + evar + "=" + shlex.quote(str(envVars[evar])) + " " command += "--workdir /myworkspace/ " command += "--name " + container_name + " " @@ -123,7 +125,7 @@ def sh(self, command: str, timeout: int = 60, secret: bool = False) -> str: """ # run as root! return self.console.sh( - "docker exec " + self.docker_sha + ' bash -c "' + command + '"', + "docker exec " + self.docker_sha + " bash -c " + shlex.quote(command), timeout=timeout, secret=secret, ) diff --git a/src/madengine/core/errors.py b/src/madengine/core/errors.py index 18ba92f8..2aaf43d0 100644 --- a/src/madengine/core/errors.py +++ b/src/madengine/core/errors.py @@ -83,14 +83,14 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg ) -class ConnectionError(MADEngineError): +class NetworkError(MADEngineError): """Connection and network errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.CONNECTION, - context, + message, + ErrorCategory.CONNECTION, + context, recoverable=True, **kwargs ) @@ -122,10 +122,6 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg ) -# Backward compatibility alias -RuntimeError = ExecutionError - - class BuildError(MADEngineError): """Build and compilation errors.""" @@ -191,14 +187,14 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg ) -class TimeoutError(MADEngineError): +class DeploymentTimeoutError(MADEngineError): """Timeout and duration errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.TIMEOUT, - context, + message, + ErrorCategory.TIMEOUT, + context, recoverable=True, **kwargs ) @@ -387,4 +383,10 @@ def create_error_context( phase=phase, component=component, **kwargs - ) \ No newline at end of file + ) + + +# Backward-compatible aliases for renamed error classes. +# These avoid shadowing builtins.ConnectionError and builtins.TimeoutError. +ConnectionError = NetworkError # noqa: A001 +TimeoutError = DeploymentTimeoutError # noqa: A001 \ No newline at end of file diff --git a/src/madengine/deployment/base.py b/src/madengine/deployment/base.py index 52bbe02f..d4beefeb 100644 --- a/src/madengine/deployment/base.py +++ b/src/madengine/deployment/base.py @@ -631,11 +631,13 @@ def _write_to_perf_csv(self, perf_data: Dict[str, Any]) -> None: row_to_write = perf_data with open(perf_csv_path, "a", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=headers, extrasaction="ignore") - if not file_exists: - writer.writeheader() if file_exists and existing_header: + # File already has a header — write a plain row using csv.writer + # to preserve the exact column order captured in row_to_write csv.writer(f).writerow(row_to_write) else: + # New file — write header then the data row via DictWriter + writer = csv.DictWriter(f, fieldnames=headers, extrasaction="ignore") + writer.writeheader() writer.writerow(row_to_write) diff --git a/src/madengine/deployment/common.py b/src/madengine/deployment/common.py index 93ae1881..5b898960 100644 --- a/src/madengine/deployment/common.py +++ b/src/madengine/deployment/common.py @@ -8,6 +8,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ +import functools import subprocess from typing import Any, Dict, List, Optional @@ -84,6 +85,7 @@ def normalize_launcher(launcher_type: Optional[str], deployment_type: str) -> st return "docker" +@functools.lru_cache(maxsize=None) def is_rocprofv3_available() -> bool: """ Check if rocprofv3 is available on the system. diff --git a/src/madengine/deployment/factory.py b/src/madengine/deployment/factory.py index 9391d3a3..1988259e 100644 --- a/src/madengine/deployment/factory.py +++ b/src/madengine/deployment/factory.py @@ -88,8 +88,13 @@ def register_default_deployments(): DeploymentFactory.register("k8s", KubernetesDeployment) DeploymentFactory.register("kubernetes", KubernetesDeployment) except ImportError: - # Kubernetes library not installed, skip registration - pass + import warnings + warnings.warn( + "Kubernetes deployment target is unavailable: the 'kubernetes' library is not " + "installed. Install it with: pip install madengine[all]", + ImportWarning, + stacklevel=2, + ) # Auto-register on module import diff --git a/src/madengine/deployment/slurm.py b/src/madengine/deployment/slurm.py index a45f83d3..5550a4ec 100644 --- a/src/madengine/deployment/slurm.py +++ b/src/madengine/deployment/slurm.py @@ -1019,20 +1019,26 @@ def _check_job_completion(self, job_id: str) -> DeploymentResult: message=f"Job {job_id} failed: {status}", ) - # Fallback - assume completed - self.console.print(f"[dim yellow]Warning: Could not get status for job {job_id}, assuming success[/dim yellow]") + # sacct returned non-zero — status unknown, do not assume success + self.console.print( + f"[yellow]Warning: sacct returned non-zero for job {job_id} " + f"(exit code {result.returncode}). Status cannot be verified.[/yellow]" + ) return DeploymentResult( - status=DeploymentStatus.SUCCESS, + status=DeploymentStatus.FAILED, deployment_id=job_id, - message=f"Job {job_id} completed (assumed)", + message=f"Job {job_id} status unknown: sacct exited with code {result.returncode}", ) except Exception as e: - self.console.print(f"[dim yellow]Warning: Exception checking job {job_id}: {e}[/dim yellow]") + self.console.print( + f"[yellow]Warning: Exception checking job {job_id} status: {e}. " + f"Status cannot be verified.[/yellow]" + ) return DeploymentResult( - status=DeploymentStatus.SUCCESS, + status=DeploymentStatus.FAILED, deployment_id=job_id, - message=f"Job {job_id} completed (status unavailable)", + message=f"Job {job_id} status unknown: {e}", ) def _build_perf_entry_from_aggregated( diff --git a/src/madengine/execution/container_runner.py b/src/madengine/execution/container_runner.py index d5f27cf0..1eca35bb 100644 --- a/src/madengine/execution/container_runner.py +++ b/src/madengine/execution/container_runner.py @@ -441,8 +441,8 @@ def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> N username = str(creds["username"]) password = str(creds["password"]) - # Perform docker login - login_command = f"echo '{password}' | docker login" + # Perform docker login — shlex.quote handles passwords with special chars + login_command = f"echo {shlex.quote(password)} | docker login" if registry and registry.lower() not in ["docker.io", "dockerhub"]: login_command += f" {registry}" @@ -604,7 +604,6 @@ def get_env_arg(self, run_env: typing.Dict) -> str: for env_arg in self.context.ctx["docker_env_vars"].keys(): env_args += f"--env {env_arg}='{str(self.context.ctx['docker_env_vars'][env_arg])}' " - print(f"Env arguments: {env_args}") return env_args def get_mount_arg(self, mount_datapaths: typing.List) -> str: diff --git a/src/madengine/execution/docker_builder.py b/src/madengine/execution/docker_builder.py index f769b85f..5d4f4abd 100644 --- a/src/madengine/execution/docker_builder.py +++ b/src/madengine/execution/docker_builder.py @@ -7,7 +7,9 @@ and then distributed to remote nodes for execution. """ +import glob import os +import shlex import time import json import re @@ -67,7 +69,7 @@ def get_context_path(self, info: typing.Dict) -> str: return "." return "./docker" - def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: + def get_build_arg(self, run_build_arg: typing.Optional[typing.Dict] = None) -> str: """Get the build arguments. Args: @@ -76,6 +78,8 @@ def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: Returns: str: The build arguments. """ + if run_build_arg is None: + run_build_arg = {} if not run_build_arg and "docker_build_arg" not in self.context.ctx: return "" @@ -84,14 +88,14 @@ def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: build_args += ( "--build-arg " + build_arg - + "='" - + self.context.ctx["docker_build_arg"][build_arg] - + "' " + + "=" + + shlex.quote(self.context.ctx["docker_build_arg"][build_arg]) + + " " ) if run_build_arg: for key, value in run_build_arg.items(): - build_args += "--build-arg " + key + "='" + value + "' " + build_args += "--build-arg " + key + "=" + shlex.quote(value) + " " return build_args @@ -300,8 +304,8 @@ def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> N username = str(creds["username"]) password = str(creds["password"]) - # Perform docker login - login_command = f"echo '{password}' | docker login" + # Perform docker login — shlex.quote handles passwords with special chars + login_command = f"echo {shlex.quote(password)} | docker login" if registry and registry.lower() not in ["docker.io", "dockerhub"]: login_command += f" {registry}" @@ -604,8 +608,10 @@ def _check_dockerfile_has_gpu_variables(self, model_info: typing.Dict) -> typing def _get_dockerfiles_for_model(self, model_info: typing.Dict) -> typing.List[str]: """Get dockerfiles for a model.""" try: + # Quote the dockerfile path to prevent shell injection + dockerfile_quoted = shlex.quote(model_info["dockerfile"]) all_dockerfiles = self.console.sh( - f"ls {model_info['dockerfile']}.*" + f"ls {dockerfile_quoted}.*" ).split("\n") dockerfiles = {} diff --git a/src/madengine/orchestration/build_orchestrator.py b/src/madengine/orchestration/build_orchestrator.py index da06f91f..61c1f3da 100644 --- a/src/madengine/orchestration/build_orchestrator.py +++ b/src/madengine/orchestration/build_orchestrator.py @@ -20,6 +20,7 @@ from madengine.core.console import Console from madengine.core.context import Context from madengine.core.additional_context_defaults import apply_build_context_defaults +from madengine.core.auth import load_credentials from madengine.core.errors import ( BuildError, ConfigurationError, @@ -104,9 +105,8 @@ def __init__(self, args, additional_context: Optional[Dict] = None): # 4. Add 'deploy' field for internal use self.additional_context = ConfigLoader.load_config(self.additional_context) except ValueError as e: - # Configuration validation error - fail fast - self.rich_console.print(f"[red]Configuration Error: {e}[/red]") - raise SystemExit(1) + # Re-raise as ConfigurationError so the CLI layer handles the exit code + raise ConfigurationError(str(e)) except Exception as e: # Other errors during config loading - warn but continue self.rich_console.print(f"[yellow]Warning: Could not apply config defaults: {e}[/yellow]") @@ -131,53 +131,7 @@ def __init__(self, args, additional_context: Optional[Dict] = None): ) # Load credentials if available - self.credentials = self._load_credentials() - - def _load_credentials(self) -> Optional[Dict]: - """Load credentials from credential.json and environment variables.""" - credentials = None - - # Try loading from file - credential_file = "credential.json" - if os.path.exists(credential_file): - try: - with open(credential_file) as f: - credentials = json.load(f) - print(f"Loaded credentials from {credential_file}: {list(credentials.keys())}") - except Exception as e: - context = create_error_context( - operation="load_credentials", - component="BuildOrchestrator", - file_path=credential_file, - ) - handle_error( - ConfigurationError( - f"Could not load credentials: {e}", - context=context, - suggestions=[ - "Check if credential.json exists and has valid JSON format" - ], - ) - ) - - # Override with environment variables if present - docker_hub_user = os.environ.get("MAD_DOCKERHUB_USER") - docker_hub_password = os.environ.get("MAD_DOCKERHUB_PASSWORD") - docker_hub_repo = os.environ.get("MAD_DOCKERHUB_REPO") - - if docker_hub_user and docker_hub_password: - print("Found Docker Hub credentials in environment variables") - if credentials is None: - credentials = {} - - credentials["dockerhub"] = { - "username": docker_hub_user, - "password": docker_hub_password, - } - if docker_hub_repo: - credentials["dockerhub"]["repository"] = docker_hub_repo - - return credentials + self.credentials = load_credentials() def _copy_scripts(self): """[DEPRECATED] Copy common scripts to model directories. diff --git a/src/madengine/orchestration/run_orchestrator.py b/src/madengine/orchestration/run_orchestrator.py index 6725a457..4c2c0ec8 100644 --- a/src/madengine/orchestration/run_orchestrator.py +++ b/src/madengine/orchestration/run_orchestrator.py @@ -21,6 +21,7 @@ from rich.panel import Panel from madengine.core.console import Console +from madengine.core.auth import load_credentials from madengine.core.context import Context from madengine.core.dataprovider import Data from madengine.core.errors import ( @@ -554,7 +555,7 @@ def _execute_local(self, manifest_file: str, timeout: int) -> Dict: from madengine.execution.container_runner import ContainerRunner # Load credentials - credentials = self._load_credentials() + credentials = load_credentials() # Restore context from manifest if present if "context" in manifest: @@ -992,35 +993,6 @@ def ignore_cache_files(directory, files): # Note: K8s and Slurm deployments have their own script handling mechanisms # and do not rely on this local filesystem operation - def _load_credentials(self) -> Optional[Dict]: - """Load credentials from credential.json and environment.""" - credentials = None - - credential_file = "credential.json" - if os.path.exists(credential_file): - try: - with open(credential_file) as f: - credentials = json.load(f) - except Exception as e: - print(f"Warning: Could not load credentials: {e}") - - # Override with environment variables - docker_hub_user = os.environ.get("MAD_DOCKERHUB_USER") - docker_hub_password = os.environ.get("MAD_DOCKERHUB_PASSWORD") - docker_hub_repo = os.environ.get("MAD_DOCKERHUB_REPO") - - if docker_hub_user and docker_hub_password: - if credentials is None: - credentials = {} - credentials["dockerhub"] = { - "username": docker_hub_user, - "password": docker_hub_password, - } - if docker_hub_repo: - credentials["dockerhub"]["repository"] = docker_hub_repo - - return credentials - def _filter_images_by_gpu_compatibility( self, built_images: Dict, runtime_gpu_vendor: str, runtime_gpu_arch: str ) -> Dict: @@ -1133,70 +1105,4 @@ def _infer_deployment_target(self, config: Dict) -> str: else: return "local" - def _filter_images_by_dockerfile_context(self, built_images: Dict) -> Dict: - """Filter images by dockerfile context matching runtime context. - - This implements the legacy behavior where dockerfiles are filtered - at runtime based on their CONTEXT header matching the current runtime context. - - Args: - built_images: Dictionary of built images from manifest - - Returns: - Dictionary of images that match the runtime context - """ - if not self.context: - return built_images - - compatible_images = {} - - for image_name, image_info in built_images.items(): - dockerfile = image_info.get("dockerfile", "") - - if not dockerfile: - # No dockerfile info, include by default (legacy compatibility) - compatible_images[image_name] = image_info - continue - - # Check if dockerfile exists - if not os.path.exists(dockerfile): - self.rich_console.print( - f"[dim] Warning: Dockerfile {dockerfile} not found. Including by default.[/dim]" - ) - compatible_images[image_name] = image_info - continue - - # Read dockerfile context header - try: - dockerfile_context_str = self.console.sh( - f"head -n5 {dockerfile} | grep '# CONTEXT ' | sed 's/# CONTEXT //g'" - ).strip() - - if not dockerfile_context_str: - # No context header, include by default - compatible_images[image_name] = image_info - continue - - # Create a dict with this dockerfile and its context - dockerfile_dict = {dockerfile: dockerfile_context_str} - - # Use context.filter() to check if this dockerfile matches runtime context - filtered = self.context.filter(dockerfile_dict) - - if filtered: - # Dockerfile matches runtime context - compatible_images[image_name] = image_info - else: - self.rich_console.print( - f"[dim] Skipping {image_name}: dockerfile context doesn't match runtime context[/dim]" - ) - - except Exception as e: - # If we can't read the dockerfile, include it by default - self.rich_console.print( - f"[dim] Warning: Could not read context for {dockerfile}: {e}. Including by default.[/dim]" - ) - compatible_images[image_name] = image_info - - return compatible_images diff --git a/src/madengine/reporting/csv_to_email.py b/src/madengine/reporting/csv_to_email.py index 0902ef00..4b21bc17 100644 --- a/src/madengine/reporting/csv_to_email.py +++ b/src/madengine/reporting/csv_to_email.py @@ -9,7 +9,7 @@ import os import argparse import logging -from typing import List, Tuple +from typing import List, Optional, Tuple import pandas as pd @@ -60,7 +60,7 @@ def csv_to_html_section(file_path: str) -> Tuple[str, str]: def convert_directory_csvs_to_html( directory_path: str, output_file: str = "run_results.html" -) -> str: +) -> Optional[str]: """Convert all CSV files in a directory to a single HTML file. Args: diff --git a/src/madengine/reporting/update_perf_csv.py b/src/madengine/reporting/update_perf_csv.py index 0859c9c0..f298efa2 100644 --- a/src/madengine/reporting/update_perf_csv.py +++ b/src/madengine/reporting/update_perf_csv.py @@ -62,7 +62,7 @@ def flatten_tags(perf_entry: dict): The performance entry with flattened tags. """ # flatten tags to a string, if tags is a list. - if type(perf_entry["tags"]) == list: + if isinstance(perf_entry["tags"], list): perf_entry["tags"] = ",".join(str(item) for item in perf_entry["tags"]) @@ -192,6 +192,9 @@ def handle_single_result(perf_csv_df: pd.DataFrame, single_result: str) -> pd.Da AssertionError: If the number of columns in the performance csv DataFrame is not equal """ single_result_json = read_json(single_result) + # Remove non-scalar fields that are not perf.csv columns (e.g. configs list). + # See handle_exception_result for rationale. + single_result_json.pop("configs", None) perf_entry_dict_to_csv(single_result_json) single_result_df = pd.DataFrame(single_result_json, index=[0]) if perf_csv_df.empty: @@ -226,6 +229,11 @@ def handle_exception_result( AssertionError: If there is already an entry for the model in the performance csv DataFrame. """ exception_result_json = read_json(exception_result) + # Remove non-scalar fields that are not perf.csv columns (e.g. configs list) + # before constructing a single-row DataFrame with index=[0]. + # pd.DataFrame(dict_with_list_value, index=[0]) raises ValueError when any + # dict value is a list whose length != 1. + exception_result_json.pop("configs", None) perf_entry_dict_to_csv(exception_result_json) exception_result_df = pd.DataFrame(exception_result_json, index=[0]) if perf_csv_df.empty: diff --git a/src/madengine/utils/config_parser.py b/src/madengine/utils/config_parser.py index ec988570..04e71f9c 100644 --- a/src/madengine/utils/config_parser.py +++ b/src/madengine/utils/config_parser.py @@ -184,7 +184,7 @@ def _walk_up_between( current = os.path.abspath(start_dir) stop = os.path.abspath(stop_dir) - while current.startswith(stop): + while current == stop or current.startswith(stop + os.sep): parent = os.path.dirname(current) if parent == current: # Reached root break diff --git a/src/madengine/utils/discover_models.py b/src/madengine/utils/discover_models.py index 4c3c9201..0cf0438e 100644 --- a/src/madengine/utils/discover_models.py +++ b/src/madengine/utils/discover_models.py @@ -85,6 +85,7 @@ def _setup_model_dir_if_needed(self) -> None: # Only copy if MODEL_DIR points to a different directory (not current dir) if model_dir_abs != cwd_abs: + import shlex import subprocess from pathlib import Path @@ -121,7 +122,7 @@ def _setup_model_dir_if_needed(self) -> None: copied_count = 0 for src_path, item_name, item_type in items_to_copy: try: - cmd = f"cp -vLR --preserve=all {src_path} {cwd_abs}/" + cmd = f"cp -vLR --preserve=all {shlex.quote(str(src_path))} {shlex.quote(str(cwd_abs))}/" result = subprocess.run( cmd, shell=True, capture_output=True, text=True, check=True ) @@ -216,9 +217,12 @@ def discover_models(self) -> None: custom_model_list = get_models_json.list_models() for custom_model in custom_model_list: - assert isinstance( - custom_model, CustomModel - ), "Please use or subclass madengine.utils.discover_models.CustomModel to define your custom model." + if not isinstance(custom_model, CustomModel): + raise TypeError( + "Please use or subclass " + "madengine.utils.discover_models.CustomModel " + "to define your custom model." + ) # Update model name using backslash-separated path custom_model.name = dirname + "/" + custom_model.name # Defer updating script and dockerfile paths until update_model is called diff --git a/src/madengine/utils/gpu_config.py b/src/madengine/utils/gpu_config.py index ff6aabc8..4b3c4143 100644 --- a/src/madengine/utils/gpu_config.py +++ b/src/madengine/utils/gpu_config.py @@ -14,9 +14,12 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ +import logging import warnings from typing import Dict, Any, Optional, Tuple +logger = logging.getLogger(__name__) + class GPUConfigResolver: """ @@ -157,17 +160,18 @@ def _extract_gpu_count( # Warn if multiple GPU fields found if len(found_fields) > 1: field_list = ", ".join([f"{name}={val}" for name, val in found_fields]) - print( - f"⚠️ Multiple GPU fields in {context}: {field_list}. " - f"Using {found_fields[0][0]}={found_fields[0][1]}" + logger.warning( + "Multiple GPU fields in %s: %s. Using %s=%s", + context, field_list, found_fields[0][0], found_fields[0][1], ) # Convert to int (handle string values like "8") try: return int(found_fields[0][1]) except (ValueError, TypeError): - print( - f"⚠️ Invalid GPU count in {context}: {found_fields[0][1]}. Using default." + logger.warning( + "Invalid GPU count in %s: %s. Using default.", + context, found_fields[0][1], ) return None @@ -231,10 +235,9 @@ def _validate_consistency( if is_deployment_override: # This is normal - deployment config overriding model default - # Use print instead of warnings.warn for cleaner output - print( - f"ℹ️ GPU configuration override: {sources[0][0]}={sources[0][1]} " - f"(overriding model default: {mismatch_details.split(',')[-1].strip()})" + logger.info( + "GPU configuration override: %s=%s (overriding model default: %s)", + sources[0][0], sources[0][1], mismatch_details.split(",")[-1].strip(), ) else: # Potentially unexpected mismatch - use warning for actual errors @@ -302,7 +305,7 @@ def resolve_runtime_gpus( validate=True, ) - print(f"ℹ️ Resolved GPU count: {gpu_count} (from {source})") + logger.info("Resolved GPU count: %s (from %s)", gpu_count, source) return gpu_count diff --git a/src/madengine/utils/log_formatting.py b/src/madengine/utils/log_formatting.py index 31673c93..d7b6c5f5 100644 --- a/src/madengine/utils/log_formatting.py +++ b/src/madengine/utils/log_formatting.py @@ -82,7 +82,7 @@ def format_dataframe_for_log( header += f"📏 Shape: {df.shape[0]} rows × {df.shape[1]} columns\n" if truncated_rows: - header += f"⚠️ Display truncated: showing first {max_rows} rows\n" + header += f"⚠️ Display truncated: showing last {max_rows} rows\n" header += f"{'='*80}\n" diff --git a/src/madengine/utils/rocm_tool_manager.py b/src/madengine/utils/rocm_tool_manager.py index 439f7da2..60870d29 100644 --- a/src/madengine/utils/rocm_tool_manager.py +++ b/src/madengine/utils/rocm_tool_manager.py @@ -199,14 +199,15 @@ def execute_command( self._log_debug(f"Command succeeded: {command[:50]}...") return stdout - # Log primary failure - self._log_warning(f"Primary command failed: {command[:50]}... Error: {stderr}") - + # Capture primary error before attempting fallback (fallback overwrites stderr) + primary_stderr = stderr + self._log_warning(f"Primary command failed: {command[:50]}... Error: {primary_stderr}") + # Try fallback if provided if fallback_command: self._log_info(f"Trying fallback command: {fallback_command[:50]}...") success, stdout, stderr = self._execute_shell_command(fallback_command, timeout) - + if success: self._log_warning("Fallback command succeeded (primary tool may be missing or misconfigured)") return stdout @@ -215,7 +216,7 @@ def execute_command( raise RuntimeError( f"Both primary and fallback commands failed.\n" f"Primary: {command}\n" - f"Primary error: {stderr}\n" + f"Primary error: {primary_stderr}\n" f"Fallback: {fallback_command}\n" f"Fallback error: {stderr}" ) diff --git a/src/madengine/utils/session_tracker.py b/src/madengine/utils/session_tracker.py index 6ddd1d92..d6163d74 100644 --- a/src/madengine/utils/session_tracker.py +++ b/src/madengine/utils/session_tracker.py @@ -47,11 +47,12 @@ def start_session(self) -> int: The starting row number (number of rows in CSV before this session) """ if self.perf_csv_path.exists(): - # Count existing rows (excluding header) + # Count existing data rows (excluding header and blank lines) with open(self.perf_csv_path, 'r') as f: lines = f.readlines() + non_empty = [l for l in lines if l.strip()] # Subtract 1 for header row - self.session_start_row = max(0, len(lines) - 1) + self.session_start_row = max(0, len(non_empty) - 1) else: # No existing file, start at 0 self.session_start_row = 0 @@ -85,7 +86,8 @@ def get_session_row_count(self) -> int: with open(self.perf_csv_path, 'r') as f: lines = f.readlines() - current_row_count = max(0, len(lines) - 1) # Exclude header + non_empty = [l for l in lines if l.strip()] + current_row_count = max(0, len(non_empty) - 1) # Exclude header return current_row_count - self.session_start_row diff --git a/tests/integration/test_docker_integration.py b/tests/integration/test_docker_integration.py index 14041e3e..a7421d6b 100644 --- a/tests/integration/test_docker_integration.py +++ b/tests/integration/test_docker_integration.py @@ -8,6 +8,7 @@ # built-in modules import os import json +import shlex import tempfile import unittest.mock from unittest.mock import patch, MagicMock, mock_open @@ -169,8 +170,8 @@ def test_get_build_arg_with_context_args( result = builder.get_build_arg() - assert "--build-arg ARG1='value1'" in result - assert "--build-arg ARG2='value2'" in result + assert f"--build-arg ARG1={shlex.quote('value1')}" in result + assert f"--build-arg ARG2={shlex.quote('value2')}" in result @patch.object(Context, "get_gpu_vendor", return_value="AMD") @patch.object(Context, "get_system_ngpus", return_value=1) @@ -188,7 +189,7 @@ def test_get_build_arg_with_run_args( run_build_arg = {"RUNTIME_ARG": "runtime_value"} result = builder.get_build_arg(run_build_arg) - assert "--build-arg RUNTIME_ARG='runtime_value'" in result + assert f"--build-arg RUNTIME_ARG={shlex.quote('runtime_value')}" in result @patch.object(Context, "get_gpu_vendor", return_value="AMD") @patch.object(Context, "get_system_ngpus", return_value=1) @@ -207,8 +208,8 @@ def test_get_build_arg_with_both_args( run_build_arg = {"RUNTIME_ARG": "runtime_value"} result = builder.get_build_arg(run_build_arg) - assert "--build-arg CONTEXT_ARG='context_value'" in result - assert "--build-arg RUNTIME_ARG='runtime_value'" in result + assert f"--build-arg CONTEXT_ARG={shlex.quote('context_value')}" in result + assert f"--build-arg RUNTIME_ARG={shlex.quote('runtime_value')}" in result @patch.object(Context, "get_gpu_vendor", return_value="AMD") @patch.object(Context, "get_system_ngpus", return_value=1) diff --git a/tests/integration/test_errors.py b/tests/integration/test_errors.py index e325e6ac..c0a88876 100644 --- a/tests/integration/test_errors.py +++ b/tests/integration/test_errors.py @@ -126,7 +126,7 @@ def test_error_logging_integration(self): def test_error_context_serialization(self): """Error context can be serialized for logging.""" - from madengine.core.errors import RuntimeError + from madengine.core.errors import ExecutionError context = create_error_context( operation="model_execution", @@ -136,7 +136,7 @@ def test_error_context_serialization(self): node_id="worker-node-01", additional_info={"container_id": "abc123", "gpu_count": 2}, ) - error = RuntimeError("Model execution failed", context=context) + error = ExecutionError("Model execution failed", context=context) data = json.dumps(error.context.__dict__, default=str) assert "model_execution" in data and "ContainerRunner" in data and "abc123" in data @@ -184,28 +184,28 @@ def test_error_hierarchy_consistency(self): """All error types inherit MADEngineError and have context/category/recoverable.""" from madengine.core.errors import ( ValidationError, - ConnectionError, + NetworkError, AuthenticationError, - RuntimeError, + ExecutionError, BuildError, DiscoveryError, OrchestrationError, RunnerError, ConfigurationError, - TimeoutError, + DeploymentTimeoutError, ) for error_class in [ ValidationError, - ConnectionError, + NetworkError, AuthenticationError, - RuntimeError, + ExecutionError, BuildError, DiscoveryError, OrchestrationError, RunnerError, ConfigurationError, - TimeoutError, + DeploymentTimeoutError, ]: err = error_class("Test error message") assert isinstance(err, MADEngineError) @@ -245,9 +245,9 @@ def test_error_suggestions_and_recovery(self): def test_nested_error_handling(self): """Nested errors with cause chain are handled.""" - from madengine.core.errors import RuntimeError as MADRuntimeError, OrchestrationError + from madengine.core.errors import ExecutionError as MADRuntimeError, OrchestrationError, NetworkError - orig = ConnectionError("Network timeout") + orig = NetworkError("Network timeout") runtime = MADRuntimeError("Operation failed", cause=orig) final = OrchestrationError("Orchestration failed", cause=runtime) assert final.cause == runtime and runtime.cause == orig diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 00000000..767ee761 --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,117 @@ +"""Unit tests for madengine.core.auth module.""" + +import json +import os +from unittest.mock import mock_open, patch, MagicMock + +import pytest + +from madengine.core.auth import load_credentials + + +class TestLoadCredentials: + """Tests for load_credentials().""" + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"dockerhub": {"username": "user", "password": "pass"}}', + ) + def test_load_credentials_from_file(self, mock_file, mock_exists): + """Valid credential.json is loaded and returned.""" + result = load_credentials() + assert result is not None + assert "dockerhub" in result + assert result["dockerhub"]["username"] == "user" + assert result["dockerhub"]["password"] == "pass" + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict(os.environ, {}, clear=True) + def test_load_credentials_no_file_no_env(self, mock_exists): + """Returns None when no credential file and no env vars.""" + result = load_credentials() + assert result is None + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch("builtins.open", new_callable=mock_open, read_data="not valid json{{{") + def test_load_credentials_malformed_json(self, mock_file, mock_exists): + """Malformed credential.json is handled gracefully (returns None).""" + # The function logs the error via handle_error but does not re-raise + result = load_credentials() + # credentials should be None since the file parse failed and no env vars + assert result is None + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict( + os.environ, + {"MAD_DOCKERHUB_USER": "envuser", "MAD_DOCKERHUB_PASSWORD": "envpass"}, + clear=True, + ) + def test_load_credentials_env_vars_only(self, mock_exists): + """Credentials from env vars when no file exists.""" + result = load_credentials() + assert result is not None + assert "dockerhub" in result + assert result["dockerhub"]["username"] == "envuser" + assert result["dockerhub"]["password"] == "envpass" + assert "repository" not in result["dockerhub"] + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"dockerhub": {"username": "fileuser", "password": "filepass"}}', + ) + @patch.dict( + os.environ, + {"MAD_DOCKERHUB_USER": "envuser", "MAD_DOCKERHUB_PASSWORD": "envpass"}, + clear=True, + ) + def test_load_credentials_env_overrides_file(self, mock_file, mock_exists): + """Env vars override file credentials for dockerhub key.""" + result = load_credentials() + assert result is not None + assert result["dockerhub"]["username"] == "envuser" + assert result["dockerhub"]["password"] == "envpass" + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict( + os.environ, + { + "MAD_DOCKERHUB_USER": "envuser", + "MAD_DOCKERHUB_PASSWORD": "envpass", + "MAD_DOCKERHUB_REPO": "myrepo/images", + }, + clear=True, + ) + def test_load_credentials_env_with_repo(self, mock_exists): + """MAD_DOCKERHUB_REPO is included when set.""" + result = load_credentials() + assert result is not None + assert result["dockerhub"]["repository"] == "myrepo/images" + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict( + os.environ, + {"MAD_DOCKERHUB_USER": "envuser"}, + clear=True, + ) + def test_load_credentials_env_user_only_no_password(self, mock_exists): + """Only MAD_DOCKERHUB_USER without PASSWORD does not create dockerhub entry.""" + result = load_credentials() + # Without both user and password, dockerhub credentials are not created + assert result is None + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"custom_registry": {"token": "abc123"}}', + ) + def test_load_credentials_non_dockerhub_registry(self, mock_file, mock_exists): + """Non-dockerhub registries in credential.json are preserved.""" + result = load_credentials() + assert result is not None + assert "custom_registry" in result + assert result["custom_registry"]["token"] == "abc123" diff --git a/tests/unit/test_deployment.py b/tests/unit/test_deployment.py index a71c75e8..d51d94b9 100644 --- a/tests/unit/test_deployment.py +++ b/tests/unit/test_deployment.py @@ -85,6 +85,14 @@ def test_false_for_rocm_trace_lite(self): class TestIsRocprofv3Available: """is_rocprofv3_available (mocked subprocess).""" + def setup_method(self): + # Clear the lru_cache so each test starts with a fresh result + is_rocprofv3_available.cache_clear() + + def teardown_method(self): + # Restore clean cache state after each test + is_rocprofv3_available.cache_clear() + def test_returns_true_when_help_succeeds(self): with patch("madengine.deployment.common.subprocess.run") as m: m.return_value = MagicMock(returncode=0) diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py index 1fa808e4..45422c34 100644 --- a/tests/unit/test_error_handling.py +++ b/tests/unit/test_error_handling.py @@ -24,16 +24,15 @@ ErrorContext, MADEngineError, ValidationError, - ConnectionError, + NetworkError, AuthenticationError, ExecutionError, - RuntimeError, # Backward compatibility alias BuildError, DiscoveryError, OrchestrationError, RunnerError, ConfigurationError, - TimeoutError, + DeploymentTimeoutError, ErrorHandler, set_error_handler, get_error_handler, @@ -87,7 +86,7 @@ def test_base_madengine_error(self): @pytest.mark.parametrize("error_class,category,recoverable,message", [ (ValidationError, ErrorCategory.VALIDATION, True, "Invalid input"), - (ConnectionError, ErrorCategory.CONNECTION, True, "Connection failed"), + (NetworkError, ErrorCategory.CONNECTION, True, "Connection failed"), (BuildError, ErrorCategory.BUILD, False, "Build failed"), (RunnerError, ErrorCategory.RUNNER, True, "Runner execution failed"), (AuthenticationError, ErrorCategory.AUTHENTICATION, True, "Auth failed"), @@ -110,9 +109,9 @@ def test_error_with_cause(self): assert mad_error.cause == original_error assert str(mad_error) == "Runtime failure" - def test_backward_compatibility_alias(self): - """Test that RuntimeError alias still works.""" - error = RuntimeError("Test error") + def test_execution_error_is_mad_engine_error(self): + """Test that ExecutionError is a MADEngineError.""" + error = ExecutionError("Test error") assert isinstance(error, ExecutionError) assert isinstance(error, MADEngineError) From d40b407469d22f319512eaecbf8360b7871e9901 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Sat, 18 Apr 2026 12:04:45 -0500 Subject: [PATCH 2/9] refactor(v2-review): extract shared login_to_registry and remove dead code - Add login_to_registry() to core/auth.py; DockerBuilder and ContainerRunner delegate to it (raise_on_failure=True/False respectively), eliminating ~120 lines of duplicated logic - Remove unused functions: find_and_replace_pattern, substring_found (ops.py), highlight_log_section (log_formatting.py), SessionTracker.get_session_start and SessionTracker.load_marker (session_tracker.py) - Clean up unused imports across core, deployment, execution, orchestration, and utils modules Co-Authored-By: Claude Sonnet 4 --- src/madengine/core/auth.py | 94 +++++++++++++++++++ src/madengine/core/context.py | 6 +- src/madengine/core/errors.py | 50 +++++----- src/madengine/core/timeout.py | 1 - src/madengine/deployment/base.py | 2 +- src/madengine/deployment/config_loader.py | 1 - src/madengine/deployment/kubernetes.py | 6 +- src/madengine/execution/container_runner.py | 77 +++------------ src/madengine/execution/docker_builder.py | 77 ++------------- .../orchestration/build_orchestrator.py | 2 - .../orchestration/run_orchestrator.py | 1 - .../scripts/common/tools/amd_smi_utils.py | 2 +- .../scripts/common/tools/rocm_smi_utils.py | 2 +- src/madengine/utils/gpu_validator.py | 2 +- src/madengine/utils/log_formatting.py | 32 ------- src/madengine/utils/ops.py | 43 --------- src/madengine/utils/session_tracker.py | 26 ----- 17 files changed, 143 insertions(+), 281 deletions(-) diff --git a/src/madengine/core/auth.py b/src/madengine/core/auth.py index e592dd60..8efceeb1 100644 --- a/src/madengine/core/auth.py +++ b/src/madengine/core/auth.py @@ -10,6 +10,7 @@ import json import os +import shlex from typing import Dict, Optional from madengine.core.errors import ( @@ -75,3 +76,96 @@ def load_credentials() -> Optional[Dict]: credentials["dockerhub"]["repository"] = docker_hub_repo return credentials + + +def login_to_registry( + registry: str, + credentials: Optional[Dict], + console, + rich_console, + raise_on_failure: bool = True, +) -> None: + """Login to a Docker registry. + + This is the single shared implementation used by both DockerBuilder + and ContainerRunner. + + Args: + registry: Registry URL (e.g., "localhost:5000", "docker.io", or empty + for DockerHub). + credentials: Credentials dictionary keyed by registry name. + console: A ``Console`` instance for shell execution. + rich_console: A Rich ``Console`` instance for formatted output. + raise_on_failure: If ``True`` (default), re-raise on login failure. + Set to ``False`` when the caller can fall back to pulling + public images. + """ + if not credentials: + rich_console.print( + "[yellow]No credentials provided for registry login[/yellow]" + ) + return + + registry_key = registry if registry else "dockerhub" + + # Normalise docker.io → dockerhub + if registry and registry.lower() == "docker.io": + registry_key = "dockerhub" + + if registry_key not in credentials: + error_msg = f"No credentials found for registry: {registry_key}" + if registry_key == "dockerhub": + error_msg += ( + f"\nPlease add dockerhub credentials to credential.json:\n" + "{\n" + ' "dockerhub": {\n' + ' "repository": "your-repository",\n' + ' "username": "your-dockerhub-username",\n' + ' "password": "your-dockerhub-password-or-token"\n' + " }\n" + "}" + ) + else: + error_msg += ( + f"\nPlease add {registry_key} credentials to credential.json:\n" + "{\n" + f' "{registry_key}": {{\n' + f' "repository": "your-repository",\n' + f' "username": "your-{registry_key}-username",\n' + f' "password": "your-{registry_key}-password"\n' + " }\n" + "}" + ) + rich_console.print(f"[red]{error_msg}[/red]") + raise RuntimeError(error_msg) + + creds = credentials[registry_key] + + if "username" not in creds or "password" not in creds: + error_msg = ( + f"Invalid credentials format for registry: {registry_key}" + f"\nCredentials must contain 'username' and 'password' fields" + ) + rich_console.print(f"[red]{error_msg}[/red]") + raise RuntimeError(error_msg) + + username = str(creds["username"]) + password = str(creds["password"]) + + login_command = f"echo {shlex.quote(password)} | docker login" + if registry and registry.lower() not in ["docker.io", "dockerhub"]: + login_command += f" {registry}" + login_command += f" --username {username} --password-stdin" + + try: + console.sh(login_command, secret=True) + rich_console.print( + f"[green]Successfully logged in to registry: " + f"{registry or 'DockerHub'}[/green]" + ) + except Exception as e: + rich_console.print( + f"[red]Failed to login to registry {registry}: {e}[/red]" + ) + if raise_on_failure: + raise diff --git a/src/madengine/core/context.py b/src/madengine/core/context.py index 57e160c1..6d8089ff 100644 --- a/src/madengine/core/context.py +++ b/src/madengine/core/context.py @@ -23,7 +23,7 @@ # third-party modules from madengine.core.console import Console from madengine.core.constants import get_rocm_path -from madengine.utils.gpu_validator import validate_rocm_installation, GPUInstallationError, GPUVendor +from madengine.utils.gpu_validator import GPUVendor from madengine.utils.gpu_tool_factory import get_gpu_tool_manager from madengine.utils.gpu_tool_manager import BaseGPUToolManager @@ -434,11 +434,11 @@ def get_host_os(self) -> str: "if [ -f \"$(which apt)\" ]; then echo 'HOST_UBUNTU'; elif [ -f \"$(which yum)\" ]; then echo 'HOST_CENTOS'; elif [ -f \"$(which zypper)\" ]; then echo 'HOST_SLES'; elif [ -f \"$(which tdnf)\" ]; then echo 'HOST_AZURE'; else echo 'Unable to detect Host OS'; fi || true" ) - def get_numa_balancing(self) -> bool: + def get_numa_balancing(self) -> typing.Union[str, bool]: """Get NUMA balancing. Returns: - bool: The output of the shell command. + Union[str, bool]: The shell command output as a string, or False if the path does not exist. Raises: RuntimeError: If the NUMA balancing is not enabled or disabled. diff --git a/src/madengine/core/errors.py b/src/madengine/core/errors.py index 2aaf43d0..6a0757ab 100644 --- a/src/madengine/core/errors.py +++ b/src/madengine/core/errors.py @@ -7,7 +7,6 @@ """ import logging -import traceback from dataclasses import dataclass from typing import Optional, Any, Dict, List from enum import Enum @@ -16,14 +15,13 @@ from rich.console import Console from rich.panel import Panel from rich.text import Text - from rich.table import Table except ImportError: raise ImportError("Rich is required for error handling. Install with: pip install rich") class ErrorCategory(Enum): """Error category enumeration for classification.""" - + VALIDATION = "validation" CONNECTION = "connection" AUTHENTICATION = "authentication" @@ -72,12 +70,12 @@ def __init__( class ValidationError(MADEngineError): """Validation and input errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.VALIDATION, - context, + message, + ErrorCategory.VALIDATION, + context, recoverable=True, **kwargs ) @@ -98,12 +96,12 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class AuthenticationError(MADEngineError): """Authentication and credential errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.AUTHENTICATION, - context, + message, + ErrorCategory.AUTHENTICATION, + context, recoverable=True, **kwargs ) @@ -150,12 +148,12 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class OrchestrationError(MADEngineError): """Distributed orchestration errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.ORCHESTRATION, - context, + message, + ErrorCategory.ORCHESTRATION, + context, recoverable=False, **kwargs ) @@ -163,12 +161,12 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class RunnerError(MADEngineError): """Distributed runner errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.RUNNER, - context, + message, + ErrorCategory.RUNNER, + context, recoverable=True, **kwargs ) @@ -176,12 +174,12 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class ConfigurationError(MADEngineError): """Configuration and setup errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.CONFIGURATION, - context, + message, + ErrorCategory.CONFIGURATION, + context, recoverable=True, **kwargs ) @@ -384,9 +382,3 @@ def create_error_context( component=component, **kwargs ) - - -# Backward-compatible aliases for renamed error classes. -# These avoid shadowing builtins.ConnectionError and builtins.TimeoutError. -ConnectionError = NetworkError # noqa: A001 -TimeoutError = DeploymentTimeoutError # noqa: A001 \ No newline at end of file diff --git a/src/madengine/core/timeout.py b/src/madengine/core/timeout.py index 0f72bd84..7fbdcb2e 100644 --- a/src/madengine/core/timeout.py +++ b/src/madengine/core/timeout.py @@ -7,7 +7,6 @@ """ # built-in modules import signal -import typing class Timeout: diff --git a/src/madengine/deployment/base.py b/src/madengine/deployment/base.py index d4beefeb..b5df7ea8 100644 --- a/src/madengine/deployment/base.py +++ b/src/madengine/deployment/base.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from jinja2 import Environment, FileSystemLoader from rich.console import Console diff --git a/src/madengine/deployment/config_loader.py b/src/madengine/deployment/config_loader.py index fdbf3c94..06d8a1b1 100644 --- a/src/madengine/deployment/config_loader.py +++ b/src/madengine/deployment/config_loader.py @@ -11,7 +11,6 @@ """ import json -import os from pathlib import Path from typing import Any, Callable, Dict, Optional from copy import deepcopy diff --git a/src/madengine/deployment/kubernetes.py b/src/madengine/deployment/kubernetes.py index 3430c9d0..92be5549 100644 --- a/src/madengine/deployment/kubernetes.py +++ b/src/madengine/deployment/kubernetes.py @@ -37,17 +37,13 @@ from .base import BaseDeployment, DeploymentConfig, DeploymentResult, DeploymentStatus, create_jinja_env from .common import ( - VALID_LAUNCHERS, configure_multi_node_profiling, - is_rocprofv3_available, normalize_launcher, ) from .config_loader import ConfigLoader, apply_deployment_config from .k8s_secrets import ( CONFIGMAP_MAX_BYTES, - SECRETS_STRATEGY_EXISTING, SECRETS_STRATEGY_FROM_LOCAL, - SECRETS_STRATEGY_OMIT, create_or_update_secrets_from_credentials, delete_job_secrets_if_exist, estimate_configmap_payload_bytes, @@ -58,7 +54,7 @@ ) from madengine.core.dataprovider import Data from madengine.core.context import Context -from madengine.core.errors import ConfigurationError, create_error_context +from madengine.core.errors import ConfigurationError from madengine.utils.gpu_config import resolve_runtime_gpus from madengine.utils.path_utils import get_madengine_root, scripts_base_dir_from from madengine.utils.run_details import flatten_tags_in_place, get_build_number, get_pipeline diff --git a/src/madengine/execution/container_runner.py b/src/madengine/execution/container_runner.py index 1eca35bb..5195dfdc 100644 --- a/src/madengine/execution/container_runner.py +++ b/src/madengine/execution/container_runner.py @@ -17,6 +17,7 @@ import warnings from rich.console import Console as RichConsole from contextlib import redirect_stdout, redirect_stderr +from madengine.core.auth import login_to_registry from madengine.core.console import Console from madengine.core.context import Context from madengine.core.docker import Docker @@ -389,72 +390,16 @@ def load_build_manifest( def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> None: """Login to a Docker registry for pulling images. - Args: - registry: Registry URL (e.g., "localhost:5000", "docker.io") - credentials: Optional credentials dictionary containing username/password + Delegates to :func:`madengine.core.auth.login_to_registry`. + Does not raise on failure so public images can still be pulled. """ - if not credentials: - self.rich_console.print("[yellow]No credentials provided for registry login[/yellow]") - return - - # Check if registry credentials are available - registry_key = registry if registry else "dockerhub" - - # Handle docker.io as dockerhub - if registry and registry.lower() == "docker.io": - registry_key = "dockerhub" - - if registry_key not in credentials: - error_msg = f"No credentials found for registry: {registry_key}" - if registry_key == "dockerhub": - error_msg += f"\nPlease add dockerhub credentials to credential.json:\n" - error_msg += "{\n" - error_msg += ' "dockerhub": {\n' - error_msg += ' "repository": "your-repository",\n' - error_msg += ' "username": "your-dockerhub-username",\n' - error_msg += ' "password": "your-dockerhub-password-or-token"\n' - error_msg += " }\n" - error_msg += "}" - else: - error_msg += ( - f"\nPlease add {registry_key} credentials to credential.json:\n" - ) - error_msg += "{\n" - error_msg += f' "{registry_key}": {{\n' - error_msg += f' "repository": "your-repository",\n' - error_msg += f' "username": "your-{registry_key}-username",\n' - error_msg += f' "password": "your-{registry_key}-password"\n' - error_msg += " }\n" - error_msg += "}" - print(error_msg) - raise RuntimeError(error_msg) - - creds = credentials[registry_key] - - if "username" not in creds or "password" not in creds: - error_msg = f"Invalid credentials format for registry: {registry_key}" - error_msg += f"\nCredentials must contain 'username' and 'password' fields" - print(error_msg) - raise RuntimeError(error_msg) - - # Ensure credential values are strings - username = str(creds["username"]) - password = str(creds["password"]) - - # Perform docker login — shlex.quote handles passwords with special chars - login_command = f"echo {shlex.quote(password)} | docker login" - - if registry and registry.lower() not in ["docker.io", "dockerhub"]: - login_command += f" {registry}" - - login_command += f" --username {username} --password-stdin" - - try: - self.console.sh(login_command, secret=True) - self.rich_console.print(f"[green]✅ Successfully logged in to registry: {registry or 'DockerHub'}[/green]") - except Exception as e: - self.rich_console.print(f"[red]❌ Failed to login to registry {registry}: {e}[/red]") - # Don't raise exception here, as public images might still be pullable + login_to_registry( + registry, + credentials, + console=self.console, + rich_console=self.rich_console, + raise_on_failure=False, + ) def pull_image( self, @@ -493,7 +438,7 @@ def pull_image( try: self.console.sh(f"docker rmi -f {registry_image} 2>/dev/null || true") print(f"✓ Removed cached image layers") - except: + except Exception: pass # It's okay if image doesn't exist try: diff --git a/src/madengine/execution/docker_builder.py b/src/madengine/execution/docker_builder.py index 5d4f4abd..b900a359 100644 --- a/src/madengine/execution/docker_builder.py +++ b/src/madengine/execution/docker_builder.py @@ -7,7 +7,6 @@ and then distributed to remote nodes for execution. """ -import glob import os import shlex import time @@ -16,14 +15,13 @@ import typing from contextlib import redirect_stdout, redirect_stderr from rich.console import Console as RichConsole +from madengine.core.auth import login_to_registry from madengine.core.console import Console from madengine.core.context import Context from madengine.utils.ops import PythonicTee from madengine.execution.dockerfile_utils import ( - is_compilation_arch_compatible, is_target_arch_compatible_with_variable, parse_dockerfile_gpu_variables, - parse_gpu_variable_value, ) @@ -252,72 +250,15 @@ def build_image( def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> None: """Login to a Docker registry. - Args: - registry: Registry URL (e.g., "localhost:5000", "docker.io", or empty for DockerHub) - credentials: Optional credentials dictionary containing username/password + Delegates to :func:`madengine.core.auth.login_to_registry`. """ - if not credentials: - print("No credentials provided for registry login") - return - - # Check if registry credentials are available - registry_key = registry if registry else "dockerhub" - - # Handle docker.io as dockerhub - if registry and registry.lower() == "docker.io": - registry_key = "dockerhub" - - if registry_key not in credentials: - error_msg = f"No credentials found for registry: {registry_key}" - if registry_key == "dockerhub": - error_msg += f"\nPlease add dockerhub credentials to credential.json:\n" - error_msg += "{\n" - error_msg += ' "dockerhub": {\n' - error_msg += ' "repository": "your-repository",\n' - error_msg += ' "username": "your-dockerhub-username",\n' - error_msg += ' "password": "your-dockerhub-password-or-token"\n' - error_msg += " }\n" - error_msg += "}" - else: - error_msg += ( - f"\nPlease add {registry_key} credentials to credential.json:\n" - ) - error_msg += "{\n" - error_msg += f' "{registry_key}": {{\n' - error_msg += f' "repository": "your-repository",\n' - error_msg += f' "username": "your-{registry_key}-username",\n' - error_msg += f' "password": "your-{registry_key}-password"\n' - error_msg += " }\n" - error_msg += "}" - self.rich_console.print(f"[red]{error_msg}[/red]") - raise RuntimeError(error_msg) - - creds = credentials[registry_key] - - if "username" not in creds or "password" not in creds: - error_msg = f"Invalid credentials format for registry: {registry_key}" - error_msg += f"\nCredentials must contain 'username' and 'password' fields" - self.rich_console.print(f"[red]{error_msg}[/red]") - raise RuntimeError(error_msg) - - # Ensure credential values are strings - username = str(creds["username"]) - password = str(creds["password"]) - - # Perform docker login — shlex.quote handles passwords with special chars - login_command = f"echo {shlex.quote(password)} | docker login" - - if registry and registry.lower() not in ["docker.io", "dockerhub"]: - login_command += f" {registry}" - - login_command += f" --username {username} --password-stdin" - - try: - self.console.sh(login_command, secret=True) - self.rich_console.print(f"[green]✅ Successfully logged in to registry: {registry or 'DockerHub'}[/green]") - except Exception as e: - self.rich_console.print(f"[red]❌ Failed to login to registry {registry}: {e}[/red]") - raise + login_to_registry( + registry, + credentials, + console=self.console, + rich_console=self.rich_console, + raise_on_failure=True, + ) def push_image( self, diff --git a/src/madengine/orchestration/build_orchestrator.py b/src/madengine/orchestration/build_orchestrator.py index 61c1f3da..3f67ff29 100644 --- a/src/madengine/orchestration/build_orchestrator.py +++ b/src/madengine/orchestration/build_orchestrator.py @@ -10,7 +10,6 @@ import json import os -import shutil from pathlib import Path from typing import Dict, List, Optional @@ -26,7 +25,6 @@ ConfigurationError, DiscoveryError, create_error_context, - handle_error, ) from madengine.utils.discover_models import DiscoverModels from madengine.execution.docker_builder import DockerBuilder diff --git a/src/madengine/orchestration/run_orchestrator.py b/src/madengine/orchestration/run_orchestrator.py index 4c2c0ec8..681eb4c9 100644 --- a/src/madengine/orchestration/run_orchestrator.py +++ b/src/madengine/orchestration/run_orchestrator.py @@ -29,7 +29,6 @@ ConfigurationError, ExecutionError, create_error_context, - handle_error, ) from madengine.core.constants import get_rocm_path from madengine.utils.session_tracker import SessionTracker diff --git a/src/madengine/scripts/common/tools/amd_smi_utils.py b/src/madengine/scripts/common/tools/amd_smi_utils.py index 05975257..e0e48096 100644 --- a/src/madengine/scripts/common/tools/amd_smi_utils.py +++ b/src/madengine/scripts/common/tools/amd_smi_utils.py @@ -152,7 +152,7 @@ def check_if_secondary_die(self, device: int) -> bool: avg_power = power_info.get('average_socket_power', -1) if current_power == 0 and avg_power == 0: return True - except: + except Exception: # If we can't get power info, might be secondary die return True diff --git a/src/madengine/scripts/common/tools/rocm_smi_utils.py b/src/madengine/scripts/common/tools/rocm_smi_utils.py index 92dff9f2..dd73219b 100644 --- a/src/madengine/scripts/common/tools/rocm_smi_utils.py +++ b/src/madengine/scripts/common/tools/rocm_smi_utils.py @@ -38,7 +38,7 @@ def __init__(self, mode) -> None: raise ImportError('Driver not initialized (amdgpu not found in modules)') exit(0) self.rocm6 = True - except: + except Exception: rocm_smi.initializeRsmi() def get_power(self, device: int) -> str: diff --git a/src/madengine/utils/gpu_validator.py b/src/madengine/utils/gpu_validator.py index 7014268a..8429891e 100644 --- a/src/madengine/utils/gpu_validator.py +++ b/src/madengine/utils/gpu_validator.py @@ -10,7 +10,7 @@ import subprocess import os -from typing import Dict, List, Tuple, Optional +from typing import List, Tuple, Optional from dataclasses import dataclass from enum import Enum diff --git a/src/madengine/utils/log_formatting.py b/src/madengine/utils/log_formatting.py index d7b6c5f5..14a0eed5 100644 --- a/src/madengine/utils/log_formatting.py +++ b/src/madengine/utils/log_formatting.py @@ -9,10 +9,8 @@ """ import pandas as pd -import typing from rich.table import Table from rich.console import Console as RichConsole -from rich.text import Text def format_dataframe_for_log( @@ -209,33 +207,3 @@ def print_dataframe_beautiful( # Fallback to simple but nice formatting formatted_output = format_dataframe_for_log(df, title) print(formatted_output) - - -def highlight_log_section(title: str, content: str, style: str = "info") -> str: - """ - Create a highlighted log section with borders and styling. - - Args: - title: Section title - content: Section content - style: Style type ('info', 'success', 'warning', 'error') - - Returns: - str: Formatted log section - """ - styles = { - "info": {"emoji": "ℹ️", "border": "-"}, - "success": {"emoji": "✅", "border": "="}, - "warning": {"emoji": "⚠️", "border": "!"}, - "error": {"emoji": "❌", "border": "#"}, - } - - style_config = styles.get(style, styles["info"]) - emoji = style_config["emoji"] - border_char = style_config["border"] - - border = border_char * 80 - header = f"\n{border}\n{emoji} {title.upper()}\n{border}" - footer = f"{border}\n" - - return f"{header}\n{content}\n{footer}" diff --git a/src/madengine/utils/ops.py b/src/madengine/utils/ops.py index 0b8ab077..cd717fec 100644 --- a/src/madengine/utils/ops.py +++ b/src/madengine/utils/ops.py @@ -5,15 +5,12 @@ functions: PythonicTee: Class to both write and display stream, in "live" mode - find_and_replace_pattern: Find and replace a substring in a dictionary - substring_found: Check if a substring is found in the dictionary file_print: Write and flush file Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ # built-in modules import typing -import re import sys @@ -53,46 +50,6 @@ def flush(self) -> None: self.stdio.flush() -def find_and_replace_pattern( - dictionary: typing.Dict, substring: str, replacement: str -) -> typing.Dict: - """Find and replace a substring in a dictionary. - - Args: - dictionary: The dictionary. - substring: The substring to find. - replacement: The replacement string. - - Returns: - The updated dictionary. - """ - updated_dict = {} - # iterate over the dictionary, replace the substring with the replacement string. - for key, value in dictionary.items(): - updated_key = str(key).replace(substring, replacement) - updated_value = str(value).replace(substring, replacement) - updated_dict[updated_key] = updated_value - - return updated_dict - - -def substring_found(dictionary: typing.Dict, substring: str) -> bool: - """Check if a substring is found in the dictionary. - - Args: - dictionary: The dictionary. - substring: The substring to find. - - Returns: - True if the substring is found, False otherwise. - """ - # iterate over the dictionary, check if the substring is found in the key or value. - for key, value in dictionary.items(): - if substring in str(key) or substring in str(value): - return True - return False - - def file_print(write_str: str, filename: str, mode: str = "a") -> None: """Write and flush file. diff --git a/src/madengine/utils/session_tracker.py b/src/madengine/utils/session_tracker.py index d6163d74..4449e496 100644 --- a/src/madengine/utils/session_tracker.py +++ b/src/madengine/utils/session_tracker.py @@ -62,15 +62,6 @@ def start_session(self) -> int: return self.session_start_row - def get_session_start(self) -> Optional[int]: - """ - Get the session start row. - - Returns: - Session start row number, or None if session not started - """ - return self.session_start_row - def get_session_row_count(self) -> int: """ Get the number of rows added during this session. @@ -103,23 +94,6 @@ def _save_marker(self, start_row: int): with open(self.marker_file, 'w') as f: f.write(str(start_row)) - def load_marker(self) -> Optional[int]: - """ - Load session start marker from file. - - Uses the marker file path from this instance's perf_csv_path. - - Returns: - Session start row, or None if file doesn't exist - """ - if self.marker_file.exists(): - try: - with open(self.marker_file, 'r') as f: - return int(f.read().strip()) - except (ValueError, IOError): - return None - return None - def cleanup_marker(self): """ Remove session marker file for this instance. From 70c5cb289db9900d9b6203dceb11cadd6d393f83 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Sun, 19 Apr 2026 20:28:15 -0500 Subject: [PATCH 3/9] Resolved the comments of Github Copilot --- src/madengine/core/auth.py | 23 ++++++++++++++------- src/madengine/deployment/base.py | 3 ++- src/madengine/deployment/factory.py | 3 ++- src/madengine/deployment/slurm.py | 4 ++-- src/madengine/execution/container_runner.py | 5 +++-- tests/unit/test_auth.py | 5 +---- 6 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/madengine/core/auth.py b/src/madengine/core/auth.py index 8efceeb1..f8caf116 100644 --- a/src/madengine/core/auth.py +++ b/src/madengine/core/auth.py @@ -96,9 +96,10 @@ def login_to_registry( credentials: Credentials dictionary keyed by registry name. console: A ``Console`` instance for shell execution. rich_console: A Rich ``Console`` instance for formatted output. - raise_on_failure: If ``True`` (default), re-raise on login failure. - Set to ``False`` when the caller can fall back to pulling - public images. + raise_on_failure: If ``True`` (default), raise ``RuntimeError`` on any + failure (missing key, invalid format, or docker login error). + Set to ``False`` to log and return instead, allowing the caller + to fall back to pulling public images. """ if not credentials: rich_console.print( @@ -137,7 +138,9 @@ def login_to_registry( "}" ) rich_console.print(f"[red]{error_msg}[/red]") - raise RuntimeError(error_msg) + if raise_on_failure: + raise RuntimeError(error_msg) + return creds = credentials[registry_key] @@ -147,15 +150,19 @@ def login_to_registry( f"\nCredentials must contain 'username' and 'password' fields" ) rich_console.print(f"[red]{error_msg}[/red]") - raise RuntimeError(error_msg) + if raise_on_failure: + raise RuntimeError(error_msg) + return username = str(creds["username"]) password = str(creds["password"]) - login_command = f"echo {shlex.quote(password)} | docker login" + quoted_password = shlex.quote(password) + quoted_username = shlex.quote(username) + login_command = f"printf %s {quoted_password} | docker login" if registry and registry.lower() not in ["docker.io", "dockerhub"]: - login_command += f" {registry}" - login_command += f" --username {username} --password-stdin" + login_command += f" {shlex.quote(str(registry))}" + login_command += f" --username {quoted_username} --password-stdin" try: console.sh(login_command, secret=True) diff --git a/src/madengine/deployment/base.py b/src/madengine/deployment/base.py index b5df7ea8..00306505 100644 --- a/src/madengine/deployment/base.py +++ b/src/madengine/deployment/base.py @@ -43,6 +43,7 @@ class DeploymentStatus(Enum): SUCCESS = "success" FAILED = "failed" CANCELLED = "cancelled" + UNKNOWN = "unknown" @dataclass @@ -224,7 +225,7 @@ def _monitor_until_complete(self, deployment_id: str) -> DeploymentResult: while True: status = self.monitor(deployment_id) - if status.status in [DeploymentStatus.SUCCESS, DeploymentStatus.FAILED]: + if status.status in [DeploymentStatus.SUCCESS, DeploymentStatus.FAILED, DeploymentStatus.UNKNOWN]: return status # Still running, wait and check again diff --git a/src/madengine/deployment/factory.py b/src/madengine/deployment/factory.py index 1988259e..e2b503d6 100644 --- a/src/madengine/deployment/factory.py +++ b/src/madengine/deployment/factory.py @@ -91,7 +91,8 @@ def register_default_deployments(): import warnings warnings.warn( "Kubernetes deployment target is unavailable: the 'kubernetes' library is not " - "installed. Install it with: pip install madengine[all]", + "installed. Install it with: pip install madengine[kubernetes] " + "(or pip install madengine[all]).", ImportWarning, stacklevel=2, ) diff --git a/src/madengine/deployment/slurm.py b/src/madengine/deployment/slurm.py index 5550a4ec..82afa0aa 100644 --- a/src/madengine/deployment/slurm.py +++ b/src/madengine/deployment/slurm.py @@ -1025,7 +1025,7 @@ def _check_job_completion(self, job_id: str) -> DeploymentResult: f"(exit code {result.returncode}). Status cannot be verified.[/yellow]" ) return DeploymentResult( - status=DeploymentStatus.FAILED, + status=DeploymentStatus.UNKNOWN, deployment_id=job_id, message=f"Job {job_id} status unknown: sacct exited with code {result.returncode}", ) @@ -1036,7 +1036,7 @@ def _check_job_completion(self, job_id: str) -> DeploymentResult: f"Status cannot be verified.[/yellow]" ) return DeploymentResult( - status=DeploymentStatus.FAILED, + status=DeploymentStatus.UNKNOWN, deployment_id=job_id, message=f"Job {job_id} status unknown: {e}", ) diff --git a/src/madengine/execution/container_runner.py b/src/madengine/execution/container_runner.py index 5195dfdc..2b2c3d46 100644 --- a/src/madengine/execution/container_runner.py +++ b/src/madengine/execution/container_runner.py @@ -542,12 +542,13 @@ def get_env_arg(self, run_env: typing.Dict) -> str: # Add custom environment variables if run_env: for env_arg in run_env: - env_args += f"--env {env_arg}='{str(run_env[env_arg])}' " + env_args += f"--env {env_arg}={shlex.quote(str(run_env[env_arg]))} " # Add context environment variables if "docker_env_vars" in self.context.ctx: for env_arg in self.context.ctx["docker_env_vars"].keys(): - env_args += f"--env {env_arg}='{str(self.context.ctx['docker_env_vars'][env_arg])}' " + value = self.context.ctx["docker_env_vars"][env_arg] + env_args += f"--env {env_arg}={shlex.quote(str(value))} " return env_args diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 767ee761..bd34e14d 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,10 +1,7 @@ """Unit tests for madengine.core.auth module.""" -import json import os -from unittest.mock import mock_open, patch, MagicMock - -import pytest +from unittest.mock import mock_open, patch from madengine.core.auth import load_credentials From 85774a3008d3241e9ff7311337517726f2186d8a Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Mon, 20 Apr 2026 13:12:25 -0500 Subject: [PATCH 4/9] fix(security): resolve GitHub Copilot review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - auth: pass registry password via MAD_REGISTRY_PASSWORD env var instead of printf argv to prevent process-list exposure; wrap all build-arg keys and values with str() before shlex.quote() to handle non-string config values - auth: gate all RuntimeError raises (missing key, invalid format, login failure) on raise_on_failure so ContainerRunner can fall through to public image pulls - auth: add TestLoginToRegistry unit tests covering all failure/success paths and raise_on_failure behaviour - docker: replace grep-based container existence check with docker ps --filter name=^/$ to avoid regex metachar false positives and substring matches - factory: change ImportWarning → UserWarning so the missing-kubernetes message is visible by default - factory: recommend madengine[kubernetes] (plus madengine[all]) in the install hint Co-Authored-By: Claude Sonnet 4 --- src/madengine/core/auth.py | 9 ++- src/madengine/core/docker.py | 8 +- src/madengine/deployment/factory.py | 2 +- src/madengine/execution/docker_builder.py | 6 +- tests/unit/test_auth.py | 93 ++++++++++++++++++++++- 5 files changed, 106 insertions(+), 12 deletions(-) diff --git a/src/madengine/core/auth.py b/src/madengine/core/auth.py index f8caf116..5b7acbc3 100644 --- a/src/madengine/core/auth.py +++ b/src/madengine/core/auth.py @@ -157,15 +157,18 @@ def login_to_registry( username = str(creds["username"]) password = str(creds["password"]) - quoted_password = shlex.quote(password) + # Pass the password via an environment variable so it never appears in + # the process argument list (visible via /proc or ps to other users). quoted_username = shlex.quote(username) - login_command = f"printf %s {quoted_password} | docker login" + login_command = "printf %s \"$MAD_REGISTRY_PASSWORD\" | docker login" if registry and registry.lower() not in ["docker.io", "dockerhub"]: login_command += f" {shlex.quote(str(registry))}" login_command += f" --username {quoted_username} --password-stdin" + login_env = {**os.environ, "MAD_REGISTRY_PASSWORD": password} + try: - console.sh(login_command, secret=True) + console.sh(login_command, secret=True, env=login_env) rich_console.print( f"[green]Successfully logged in to registry: " f"{registry or 'DockerHub'}[/green]" diff --git a/src/madengine/core/docker.py b/src/madengine/core/docker.py index f9b5c6c9..13cf1cb0 100644 --- a/src/madengine/core/docker.py +++ b/src/madengine/core/docker.py @@ -57,13 +57,15 @@ def __init__( self.userid = self.console.sh("id -u") self.groupid = self.console.sh("id -g") - # check if container name exists + # check if container name exists — use an exact-match filter so names + # containing regex metacharacters (e.g. ".", "[") cannot produce false + # positives, and substring matches are avoided entirely. container_name_quoted = shlex.quote(container_name) container_name_exists = self.console.sh( - "docker container ps -a | grep " + container_name_quoted + " | wc -l" + f"docker container ps -aq --filter name=^/{container_name_quoted}$" ) # if container name exists, clean it up automatically - if container_name_exists != "0": + if container_name_exists: print( f"⚠️ Container '{container_name}' already exists. Cleaning up..." ) diff --git a/src/madengine/deployment/factory.py b/src/madengine/deployment/factory.py index e2b503d6..dea54557 100644 --- a/src/madengine/deployment/factory.py +++ b/src/madengine/deployment/factory.py @@ -93,7 +93,7 @@ def register_default_deployments(): "Kubernetes deployment target is unavailable: the 'kubernetes' library is not " "installed. Install it with: pip install madengine[kubernetes] " "(or pip install madengine[all]).", - ImportWarning, + UserWarning, stacklevel=2, ) diff --git a/src/madengine/execution/docker_builder.py b/src/madengine/execution/docker_builder.py index b900a359..56f33d6d 100644 --- a/src/madengine/execution/docker_builder.py +++ b/src/madengine/execution/docker_builder.py @@ -85,15 +85,15 @@ def get_build_arg(self, run_build_arg: typing.Optional[typing.Dict] = None) -> s for build_arg in self.context.ctx["docker_build_arg"].keys(): build_args += ( "--build-arg " - + build_arg + + shlex.quote(str(build_arg)) + "=" - + shlex.quote(self.context.ctx["docker_build_arg"][build_arg]) + + shlex.quote(str(self.context.ctx["docker_build_arg"][build_arg])) + " " ) if run_build_arg: for key, value in run_build_arg.items(): - build_args += "--build-arg " + key + "=" + shlex.quote(value) + " " + build_args += "--build-arg " + shlex.quote(str(key)) + "=" + shlex.quote(str(value)) + " " return build_args diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index bd34e14d..cf6100b9 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,9 +1,9 @@ """Unit tests for madengine.core.auth module.""" import os -from unittest.mock import mock_open, patch +from unittest.mock import MagicMock, mock_open, patch -from madengine.core.auth import load_credentials +from madengine.core.auth import load_credentials, login_to_registry class TestLoadCredentials: @@ -112,3 +112,92 @@ def test_load_credentials_non_dockerhub_registry(self, mock_file, mock_exists): assert result is not None assert "custom_registry" in result assert result["custom_registry"]["token"] == "abc123" + + +class TestLoginToRegistry: + """Tests for login_to_registry().""" + + def _mocks(self): + console = MagicMock() + rich_console = MagicMock() + return console, rich_console + + def test_no_credentials_returns_early(self): + """Passing None credentials logs a warning and returns without error.""" + console, rich_console = self._mocks() + login_to_registry("docker.io", None, console, rich_console) + console.sh.assert_not_called() + + def test_missing_registry_key_raises_when_raise_on_failure(self): + """RuntimeError raised when registry key absent and raise_on_failure=True.""" + console, rich_console = self._mocks() + credentials = {"other_registry": {"username": "u", "password": "p"}} + try: + login_to_registry("myregistry.io", credentials, console, rich_console, raise_on_failure=True) + assert False, "Expected RuntimeError" + except RuntimeError as e: + assert "myregistry.io" in str(e) + console.sh.assert_not_called() + + def test_missing_registry_key_returns_when_not_raise_on_failure(self): + """Returns silently when registry key absent and raise_on_failure=False.""" + console, rich_console = self._mocks() + credentials = {"other_registry": {"username": "u", "password": "p"}} + login_to_registry("myregistry.io", credentials, console, rich_console, raise_on_failure=False) + console.sh.assert_not_called() + + def test_invalid_credentials_format_raises(self): + """RuntimeError raised when username/password fields missing.""" + console, rich_console = self._mocks() + credentials = {"dockerhub": {"token": "abc"}} + try: + login_to_registry("docker.io", credentials, console, rich_console, raise_on_failure=True) + assert False, "Expected RuntimeError" + except RuntimeError as e: + assert "username" in str(e) or "password" in str(e) + console.sh.assert_not_called() + + def test_invalid_credentials_format_returns_when_not_raise_on_failure(self): + """Returns silently when credentials format invalid and raise_on_failure=False.""" + console, rich_console = self._mocks() + credentials = {"dockerhub": {"token": "abc"}} + login_to_registry("docker.io", credentials, console, rich_console, raise_on_failure=False) + console.sh.assert_not_called() + + def test_docker_io_normalised_to_dockerhub(self): + """docker.io registry is looked up under the 'dockerhub' key.""" + console, rich_console = self._mocks() + credentials = {"dockerhub": {"username": "user", "password": "pass"}} + login_to_registry("docker.io", credentials, console, rich_console) + console.sh.assert_called_once() + cmd = console.sh.call_args[0][0] + # docker.io should not appear in the login command (uses default DockerHub endpoint) + assert "docker.io" not in cmd + + def test_custom_registry_included_in_command(self): + """Non-DockerHub registry URL is included in the login command.""" + console, rich_console = self._mocks() + credentials = {"myregistry.io": {"username": "user", "password": "pass"}} + login_to_registry("myregistry.io", credentials, console, rich_console) + console.sh.assert_called_once() + cmd = console.sh.call_args[0][0] + assert "myregistry.io" in cmd + + def test_login_failure_raises_when_raise_on_failure(self): + """docker login error is re-raised when raise_on_failure=True.""" + console, rich_console = self._mocks() + console.sh.side_effect = RuntimeError("auth failed") + credentials = {"dockerhub": {"username": "user", "password": "pass"}} + try: + login_to_registry(None, credentials, console, rich_console, raise_on_failure=True) + assert False, "Expected RuntimeError" + except RuntimeError: + pass + + def test_login_failure_suppressed_when_not_raise_on_failure(self): + """docker login error is suppressed when raise_on_failure=False.""" + console, rich_console = self._mocks() + console.sh.side_effect = RuntimeError("auth failed") + credentials = {"dockerhub": {"username": "user", "password": "pass"}} + login_to_registry(None, credentials, console, rich_console, raise_on_failure=False) + # Should not propagate the exception From c5c03f92744ad83f7519fcbda5d90f9fc84a274c Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Tue, 21 Apr 2026 09:47:19 -0500 Subject: [PATCH 5/9] fix(timeout): handle None/0 sentinel so --timeout 0 disables timeout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Timeout.__enter__ called signal.alarm(None) when --timeout 0 was passed, because the CLI correctly maps 0 → None but Timeout had no guard for it. Add early-return in __enter__/__exit__ when seconds is falsy, and improve small-value formatting in the perf table display. Tests: add TestTimeout (None/0/positive) and a resolve_run_timeout None passthrough case to prevent regression. Co-Authored-By: Claude Sonnet 4 --- src/madengine/cli/utils.py | 4 +++- src/madengine/core/timeout.py | 4 ++++ tests/unit/test_execution.py | 27 +++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/madengine/cli/utils.py b/src/madengine/cli/utils.py index 16610278..75e026b7 100644 --- a/src/madengine/cli/utils.py +++ b/src/madengine/cli/utils.py @@ -425,8 +425,10 @@ def format_performance(perf): return f"{val:,.0f}" elif val >= 10: return f"{val:.1f}" + elif val >= 0.01: + return f"{val:.4f}" else: - return f"{val:.2f}" + return f"{val:.4g}" except (ValueError, TypeError): return str(perf) diff --git a/src/madengine/core/timeout.py b/src/madengine/core/timeout.py index 7fbdcb2e..51aac2c9 100644 --- a/src/madengine/core/timeout.py +++ b/src/madengine/core/timeout.py @@ -41,9 +41,13 @@ def handle_timeout(self, signum, frame) -> None: def __enter__(self) -> None: """Enter the context manager.""" + if not self.seconds: + return signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) def __exit__(self, type, value, traceback) -> None: """Exit the context manager.""" + if not self.seconds: + return signal.alarm(0) diff --git a/tests/unit/test_execution.py b/tests/unit/test_execution.py index da8adb23..dc18121e 100644 --- a/tests/unit/test_execution.py +++ b/tests/unit/test_execution.py @@ -2,6 +2,7 @@ import pytest +from madengine.core.timeout import Timeout from madengine.execution.container_runner_helpers import ( _docker_image_ref_for_log_naming, make_run_log_file_path, @@ -16,6 +17,26 @@ ) +# ---- Timeout ---- + +class TestTimeout: + """Timeout context manager: None/0 must not arm signal.alarm.""" + + def test_none_seconds_does_not_raise(self): + with Timeout(None): + pass # must not crash + + def test_zero_seconds_does_not_raise(self): + with Timeout(0): + pass # must not crash + + def test_positive_seconds_raises_on_expiry(self): + with pytest.raises(TimeoutError): + with Timeout(1): + import time + time.sleep(2) + + # ---- container_runner_helpers ---- class TestResolveRunTimeout: @@ -41,6 +62,12 @@ def test_custom_default_cli(self): assert resolve_run_timeout({"timeout": 100}, 5000, default_cli_timeout=5000) == 100 assert resolve_run_timeout({"timeout": 100}, 7200, default_cli_timeout=5000) == 7200 + def test_no_timeout_sentinel_none_passthrough(self): + # --timeout 0 is converted to None by the CLI; resolve_run_timeout must + # pass None through unchanged (model timeout must NOT override "no timeout"). + assert resolve_run_timeout({"timeout": 3600}, None) is None + assert resolve_run_timeout({}, None) is None + class TestDockerImageRefForLogNaming: """_docker_image_ref_for_log_naming: CI tag extraction vs stable non-ci refs.""" From e23ea689910eeb1ba0b6956afff75cf8eebf61a3 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Tue, 21 Apr 2026 17:54:14 -0500 Subject: [PATCH 6/9] test(v2-review): clean up test suite quality issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove dead imports across 14 test files (sys, csv, json, MagicMock, mock_open, call, tempfile, ConfigurationError, DiscoveryError, pytest where unused, generate_additional_context_for_machine) - Restore subprocess/Mock imports that were incorrectly removed - Replace try/assert False/except antipattern with pytest.raises() in test_auth.py (3 occurrences), adding match= to validate error messages - Narrow 5 bare except: clauses to except Exception: in utils.py and test_gpu_management.py - Delete pass-only dead test test_dockerfile_executed_if_contexts_keys_are_not_common - Remove duplicate tests: test_validate_additional_context_invalid_json (kept in test_validators.py) and test_filter_images_by_gpu_architecture (kept in test_orchestrator_workflows.py) - Rename misnamed @pytest.fixture test_dir → tmp_dir in test_reporting_superset.py to prevent pytest collecting fixtures as test functions - Reclassify test_profiling_tools_config.py: unit/ → integration/ (reads real disk files); test_errors.py: integration/ → unit/ (pure mocks) - Add test_sh_live_output to test_console_integration.py for live_output=True path Co-Authored-By: Claude Sonnet 4 --- tests/e2e/test_build_workflows.py | 1 - tests/e2e/test_data_workflows.py | 1 - tests/e2e/test_execution_features.py | 1 - tests/e2e/test_profiling_workflows.py | 3 - tests/e2e/test_run_workflows.py | 8 -- tests/e2e/test_scripting_workflows.py | 4 - tests/fixtures/utils.py | 8 +- tests/integration/test_console_integration.py | 12 +- tests/integration/test_container_execution.py | 3 - tests/integration/test_docker_integration.py | 2 - tests/integration/test_gpu_management.py | 5 +- .../integration/test_platform_integration.py | 23 +--- .../test_profiling_tools_config.py | 2 +- tests/unit/test_auth.py | 16 +-- tests/unit/test_cli.py | 14 +-- tests/unit/test_constants.py | 2 - tests/unit/test_container_runner.py | 2 - tests/unit/test_context_logic.py | 2 +- tests/unit/test_database_mongodb.py | 3 +- tests/unit/test_error_handling.py | 13 +-- tests/{integration => unit}/test_errors.py | 8 +- tests/unit/test_reporting.py | 3 - tests/unit/test_reporting_superset.py | 110 +++++++++--------- 23 files changed, 79 insertions(+), 167 deletions(-) rename tests/{unit => integration}/test_profiling_tools_config.py (96%) rename tests/{integration => unit}/test_errors.py (97%) diff --git a/tests/e2e/test_build_workflows.py b/tests/e2e/test_build_workflows.py index aec87459..c464d776 100644 --- a/tests/e2e/test_build_workflows.py +++ b/tests/e2e/test_build_workflows.py @@ -12,7 +12,6 @@ # built-in modules import os -import sys import csv import json import pandas as pd diff --git a/tests/e2e/test_data_workflows.py b/tests/e2e/test_data_workflows.py index 42e564e6..b83232d5 100644 --- a/tests/e2e/test_data_workflows.py +++ b/tests/e2e/test_data_workflows.py @@ -5,7 +5,6 @@ # built-in modules import os -import sys import csv import re import json diff --git a/tests/e2e/test_execution_features.py b/tests/e2e/test_execution_features.py index fae4cb70..4d7fd601 100644 --- a/tests/e2e/test_execution_features.py +++ b/tests/e2e/test_execution_features.py @@ -7,7 +7,6 @@ import json import os import re -import csv import time from tests.fixtures.utils import BASE_DIR, MODEL_DIR diff --git a/tests/e2e/test_profiling_workflows.py b/tests/e2e/test_profiling_workflows.py index ef31cb3f..74925ae2 100644 --- a/tests/e2e/test_profiling_workflows.py +++ b/tests/e2e/test_profiling_workflows.py @@ -8,9 +8,6 @@ # built-in modules import os import re -import sys -import csv -import json # third-party modules import pytest diff --git a/tests/e2e/test_run_workflows.py b/tests/e2e/test_run_workflows.py index bf770891..4cf199d5 100644 --- a/tests/e2e/test_run_workflows.py +++ b/tests/e2e/test_run_workflows.py @@ -5,7 +5,6 @@ # built-in modules import os -import sys import csv # third-party modules @@ -104,13 +103,6 @@ def test_all_dockerfiles_matching_context_executed( + " ".join(foundDockerfiles) ) - def test_dockerfile_executed_if_contexts_keys_are_not_common(self): - """ - Dockerfile is executed even if all context keys are not common but common keys match - """ - # already tested in test_dockerfile_picked_on_detected_context_0 - pass - @pytest.mark.parametrize( "clean_test_temp_files", [DEFAULT_CLEAN_FILES], indirect=True ) diff --git a/tests/e2e/test_scripting_workflows.py b/tests/e2e/test_scripting_workflows.py index 0f44d7d3..3c163f1a 100644 --- a/tests/e2e/test_scripting_workflows.py +++ b/tests/e2e/test_scripting_workflows.py @@ -6,12 +6,9 @@ # built-in modules import os import re -import csv -import time # 3rd party modules import pytest -import json # project modules from tests.fixtures.utils import BASE_DIR, MODEL_DIR @@ -19,7 +16,6 @@ from tests.fixtures.utils import clean_test_temp_files from tests.fixtures.utils import DEFAULT_CLEAN_FILES from tests.fixtures.utils import is_nvidia -from tests.fixtures.utils import generate_additional_context_for_machine class TestPrePostScriptsFunctionality: diff --git a/tests/fixtures/utils.py b/tests/fixtures/utils.py index 21198b90..6590a2d7 100644 --- a/tests/fixtures/utils.py +++ b/tests/fixtures/utils.py @@ -101,7 +101,7 @@ def clean_test_temp_files(request): capture_output=True, timeout=30 ) - except: + except Exception: pass # Ignore cleanup errors before test yield @@ -123,7 +123,7 @@ def clean_test_temp_files(request): capture_output=True, timeout=30 ) - except: + except Exception: pass # Ignore cleanup errors after test @@ -235,7 +235,7 @@ def get_gpu_nodeid_map() -> dict: node_id = str(gpu_info["node_id"]) gpu_id = gpu_info["gpu"] gpu_map[node_id] = gpu_id - except: + except Exception: # Fall back to older rocm-smi tools try: rocm_version_str = console.sh("hipconfig --version") @@ -265,7 +265,7 @@ def get_gpu_nodeid_map() -> dict: gpu_id = int(line.split()[0]) node_id = line.split()[1] gpu_map[node_id] = gpu_id - except: + except Exception: # If all else fails, return empty map pass diff --git a/tests/integration/test_console_integration.py b/tests/integration/test_console_integration.py index e6a700a0..3d1951fe 100644 --- a/tests/integration/test_console_integration.py +++ b/tests/integration/test_console_integration.py @@ -5,14 +5,6 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ -# built-in modules -import subprocess -import typing - -# third-party modules -import pytest -import typing_extensions - # project modules from madengine.core import console @@ -58,3 +50,7 @@ def test_sh_env(self): def test_sh_verbose(self): obj = console.Console(shellVerbose=False) assert obj.sh("echo MAD Engine") == "MAD Engine" + + def test_sh_live_output(self): + obj = console.Console(live_output=True) + assert obj.sh("echo MAD Engine") == "MAD Engine" diff --git a/tests/integration/test_container_execution.py b/tests/integration/test_container_execution.py index c11c2755..b0d11b3b 100644 --- a/tests/integration/test_container_execution.py +++ b/tests/integration/test_container_execution.py @@ -11,7 +11,6 @@ # built-in modules import os import json -import tempfile import unittest.mock from unittest.mock import patch, MagicMock, mock_open @@ -22,8 +21,6 @@ from madengine.execution.container_runner import ContainerRunner from madengine.core.context import Context from madengine.core.console import Console -from madengine.core.dataprovider import Data -from tests.fixtures.utils import BASE_DIR, MODEL_DIR class TestContainerRunner: diff --git a/tests/integration/test_docker_integration.py b/tests/integration/test_docker_integration.py index a7421d6b..66455536 100644 --- a/tests/integration/test_docker_integration.py +++ b/tests/integration/test_docker_integration.py @@ -10,7 +10,6 @@ import json import shlex import tempfile -import unittest.mock from unittest.mock import patch, MagicMock, mock_open # third-party modules @@ -20,7 +19,6 @@ from madengine.execution.docker_builder import DockerBuilder from madengine.core.context import Context from madengine.core.console import Console -from tests.fixtures.utils import BASE_DIR, MODEL_DIR class TestDockerBuilder: diff --git a/tests/integration/test_gpu_management.py b/tests/integration/test_gpu_management.py index ef18a810..6d60cdd7 100644 --- a/tests/integration/test_gpu_management.py +++ b/tests/integration/test_gpu_management.py @@ -13,8 +13,7 @@ import json import stat import pytest -import unittest.mock -from unittest.mock import Mock, MagicMock, patch, call, mock_open +from unittest.mock import Mock, patch from madengine.utils.gpu_tool_manager import BaseGPUToolManager from madengine.utils.rocm_tool_manager import ROCmToolManager, ROCM_VERSION_THRESHOLD @@ -37,7 +36,7 @@ def is_amd_gpu(): import subprocess result = subprocess.run(['rocm-smi'], capture_output=True, timeout=5) return result.returncode == 0 - except: + except Exception: return False diff --git a/tests/integration/test_platform_integration.py b/tests/integration/test_platform_integration.py index 4f24d0c7..82519a89 100644 --- a/tests/integration/test_platform_integration.py +++ b/tests/integration/test_platform_integration.py @@ -9,7 +9,6 @@ import json import os -import tempfile from pathlib import Path from unittest.mock import MagicMock, patch, mock_open import pytest @@ -23,7 +22,7 @@ ) from madengine.orchestration.build_orchestrator import BuildOrchestrator from madengine.orchestration.run_orchestrator import RunOrchestrator -from madengine.core.errors import BuildError, ConfigurationError, DiscoveryError +from madengine.core.errors import BuildError # ============================================================================ @@ -662,26 +661,6 @@ def test_is_compilation_arch_compatible(self): assert not is_compilation_arch_compatible("gfx908", "gfx942") assert is_compilation_arch_compatible("foo", "foo") - def test_filter_images_by_gpu_architecture(self): - mock_args = MagicMock() - mock_args.additional_context = '{"gpu_vendor": "AMD", "guest_os": "UBUNTU"}' - mock_args.additional_context_file = None - mock_args.tags = [] - mock_args.live_output = True - mock_args.data_config_file_name = "data.json" - mock_args.force_mirror_local = None - run_orch = RunOrchestrator(mock_args) - built = {"img1": {"gpu_architecture": "gfx908", "gpu_vendor": "AMD"}, "img2": {"gpu_architecture": "gfx90a", "gpu_vendor": "AMD"}} - filtered = run_orch._filter_images_by_gpu_architecture(built, "gfx908") - assert "img1" in filtered and "img2" not in filtered - built_legacy = {"img1": {"gpu_architecture": "gfx908"}, "img2": {"gpu_architecture": "gfx90a", "gpu_vendor": "AMD"}} - filtered = run_orch._filter_images_by_gpu_architecture(built_legacy, "gfx908") - assert "img1" in filtered - built_nomatch = {"img1": {"gpu_architecture": "gfx90a", "gpu_vendor": "AMD"}, "img2": {"gpu_architecture": "gfx942", "gpu_vendor": "AMD"}} - assert len(run_orch._filter_images_by_gpu_architecture(built_nomatch, "gfx908")) == 0 - built_all = {"img1": {"gpu_architecture": "gfx908", "gpu_vendor": "AMD"}, "img2": {"gpu_architecture": "gfx908", "gpu_vendor": "AMD"}} - assert len(run_orch._filter_images_by_gpu_architecture(built_all, "gfx908")) == 2 - if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/unit/test_profiling_tools_config.py b/tests/integration/test_profiling_tools_config.py similarity index 96% rename from tests/unit/test_profiling_tools_config.py rename to tests/integration/test_profiling_tools_config.py index 579f0402..7dc36f94 100644 --- a/tests/unit/test_profiling_tools_config.py +++ b/tests/integration/test_profiling_tools_config.py @@ -1,4 +1,4 @@ -"""Unit tests for rocm_trace_lite: tools.json entry and apply_tools wiring (no Docker).""" +"""Integration tests for rocm_trace_lite: tools.json entry and apply_tools wiring.""" import json from pathlib import Path diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index cf6100b9..47d4a6c7 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,6 +1,7 @@ """Unit tests for madengine.core.auth module.""" import os +import pytest from unittest.mock import MagicMock, mock_open, patch from madengine.core.auth import load_credentials, login_to_registry @@ -132,11 +133,8 @@ def test_missing_registry_key_raises_when_raise_on_failure(self): """RuntimeError raised when registry key absent and raise_on_failure=True.""" console, rich_console = self._mocks() credentials = {"other_registry": {"username": "u", "password": "p"}} - try: + with pytest.raises(RuntimeError, match="myregistry.io"): login_to_registry("myregistry.io", credentials, console, rich_console, raise_on_failure=True) - assert False, "Expected RuntimeError" - except RuntimeError as e: - assert "myregistry.io" in str(e) console.sh.assert_not_called() def test_missing_registry_key_returns_when_not_raise_on_failure(self): @@ -150,11 +148,8 @@ def test_invalid_credentials_format_raises(self): """RuntimeError raised when username/password fields missing.""" console, rich_console = self._mocks() credentials = {"dockerhub": {"token": "abc"}} - try: + with pytest.raises(RuntimeError, match="username|password"): login_to_registry("docker.io", credentials, console, rich_console, raise_on_failure=True) - assert False, "Expected RuntimeError" - except RuntimeError as e: - assert "username" in str(e) or "password" in str(e) console.sh.assert_not_called() def test_invalid_credentials_format_returns_when_not_raise_on_failure(self): @@ -188,11 +183,8 @@ def test_login_failure_raises_when_raise_on_failure(self): console, rich_console = self._mocks() console.sh.side_effect = RuntimeError("auth failed") credentials = {"dockerhub": {"username": "user", "password": "pass"}} - try: + with pytest.raises(RuntimeError, match="auth failed"): login_to_registry(None, credentials, console, rich_console, raise_on_failure=True) - assert False, "Expected RuntimeError" - except RuntimeError: - pass def test_login_failure_suppressed_when_not_raise_on_failure(self): """docker login error is suppressed when raise_on_failure=False.""" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index e3f3e832..164ac7a5 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -16,12 +16,9 @@ import importlib import json import os -import sys from io import StringIO import tempfile -import unittest.mock -from pathlib import Path -from unittest.mock import MagicMock, Mock, patch, mock_open +from unittest.mock import MagicMock, patch # third-party modules import pytest @@ -245,15 +242,6 @@ def test_validate_additional_context_valid_file(self): finally: os.unlink(temp_file) - def test_validate_additional_context_invalid_json(self): - """Test validation with invalid JSON.""" - with patch("madengine.cli.validators.console") as mock_console: - with pytest.raises(typer.Exit) as exc_info: - validate_additional_context("invalid json") - - assert exc_info.value.exit_code == ExitCode.INVALID_ARGS - mock_console.print.assert_called() - def test_validate_additional_context_defaults_fill_partial_fields(self): """Missing gpu_vendor or guest_os is filled from defaults (no error).""" from madengine.core.additional_context_defaults import ( diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py index a7230369..248c8ac9 100644 --- a/tests/unit/test_constants.py +++ b/tests/unit/test_constants.py @@ -4,8 +4,6 @@ import os from unittest.mock import patch -import pytest - from madengine.core.constants import ( NAS_NODES, MAD_AWS_S3, diff --git a/tests/unit/test_container_runner.py b/tests/unit/test_container_runner.py index 10078f39..86a56111 100644 --- a/tests/unit/test_container_runner.py +++ b/tests/unit/test_container_runner.py @@ -6,8 +6,6 @@ import tempfile from unittest.mock import MagicMock, patch -import pytest - from madengine.execution.container_runner import ContainerRunner diff --git a/tests/unit/test_context_logic.py b/tests/unit/test_context_logic.py index 7f50f491..90bb5560 100644 --- a/tests/unit/test_context_logic.py +++ b/tests/unit/test_context_logic.py @@ -7,7 +7,7 @@ """ import pytest -from unittest.mock import Mock, patch +from unittest.mock import patch from madengine.core.context import Context diff --git a/tests/unit/test_database_mongodb.py b/tests/unit/test_database_mongodb.py index de64c1e4..b9f0aa65 100644 --- a/tests/unit/test_database_mongodb.py +++ b/tests/unit/test_database_mongodb.py @@ -11,9 +11,8 @@ import os import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import MagicMock, patch import pytest -import pandas as pd from madengine.database.mongodb import ( MongoDBConfig, diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py index 45422c34..dc210a0b 100644 --- a/tests/unit/test_error_handling.py +++ b/tests/unit/test_error_handling.py @@ -6,18 +6,11 @@ context management, Rich console integration, and error propagation. """ -import pytest -import json -import io import re -from unittest.mock import Mock, patch, MagicMock -from rich.console import Console -from rich.text import Text -# Add src to path for imports -import sys -import os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +import pytest +from unittest.mock import Mock +from rich.console import Console from madengine.core.errors import ( ErrorCategory, diff --git a/tests/integration/test_errors.py b/tests/unit/test_errors.py similarity index 97% rename from tests/integration/test_errors.py rename to tests/unit/test_errors.py index c0a88876..078e2d57 100644 --- a/tests/integration/test_errors.py +++ b/tests/unit/test_errors.py @@ -1,4 +1,4 @@ -"""Integration tests for error handling: CLI integration, workflow, unified system, backward compat. +"""Unit tests for error handling: CLI integration, workflow, unified system, backward compat. Merged from test_cli_error_integration and test_error_system_integration. Deduplicated: single setup_logging/handler test, one context serialization test. @@ -6,14 +6,10 @@ import json import os -import sys -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch import pytest -# Ensure src on path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) - from madengine.core.errors import ( ErrorHandler, MADEngineError, diff --git a/tests/unit/test_reporting.py b/tests/unit/test_reporting.py index 88b1eb80..844137e9 100644 --- a/tests/unit/test_reporting.py +++ b/tests/unit/test_reporting.py @@ -3,10 +3,7 @@ import os import tempfile -import pytest - import json - import pandas as pd from madengine.reporting.update_perf_csv import ( diff --git a/tests/unit/test_reporting_superset.py b/tests/unit/test_reporting_superset.py index a107d3dc..622c9b75 100644 --- a/tests/unit/test_reporting_superset.py +++ b/tests/unit/test_reporting_superset.py @@ -28,9 +28,9 @@ class TestConfigParser: """Test cases for ConfigParser functionality.""" - + @pytest.fixture - def test_dir(self): + def tmp_dir(self): """Create temporary directory for tests.""" temp_dir = tempfile.mkdtemp() yield temp_dir @@ -136,7 +136,7 @@ def test_config_parser_match_config_to_result(self, config_file): assert matched['model'] == 'dummy/model-1' assert matched['benchmark'] == 'throughput' - def test_config_parser_json_file(self, test_dir): + def test_config_parser_json_file(self, tmp_dir): """Test loading JSON config file.""" # Create a JSON config file json_config = { @@ -145,10 +145,10 @@ def test_config_parser_json_file(self, test_dir): "epochs": 10 } - json_path = os.path.join(test_dir, "config.json") + json_path = os.path.join(tmp_dir, "config.json") with open(json_path, 'w') as f: json.dump(json_config, f) - + parser = ConfigParser() configs = parser.load_config_file(json_path) @@ -160,9 +160,9 @@ def test_config_parser_json_file(self, test_dir): class TestPerfEntrySuperGeneration: """Test cases for perf_super.json generation (cumulative).""" - + @pytest.fixture - def test_dir(self): + def tmp_dir(self): """Create temporary directory for tests.""" temp_dir = tempfile.mkdtemp() yield temp_dir @@ -181,7 +181,7 @@ def fixtures_dir(self): 'dummy' ) - def test_perf_entry_super_json_structure(self, test_dir, fixtures_dir): + def test_perf_entry_super_json_structure(self, tmp_dir, fixtures_dir): """Test that perf_super.json has the correct structure.""" # Create mock data common_info = { @@ -211,22 +211,22 @@ def test_perf_entry_super_json_structure(self, test_dir, fixtures_dir): "build_number": "1", "additional_docker_run_options": "", } - + # Create common_info.json - common_info_path = os.path.join(test_dir, "common_info.json") + common_info_path = os.path.join(tmp_dir, "common_info.json") with open(common_info_path, 'w') as f: json.dump(common_info, f) # Create results CSV - results_csv = os.path.join(test_dir, "perf_dummy_super.csv") + results_csv = os.path.join(tmp_dir, "perf_dummy_super.csv") with open(results_csv, 'w') as f: f.write("model,performance,metric,status\n") f.write("dummy/model-1,1234.56,tokens/s,SUCCESS\n") f.write("dummy/model-2,2345.67,requests/s,SUCCESS\n") f.write("dummy/model-3,345.78,ms,SUCCESS\n") - + # Generate perf_super.json (cumulative) - perf_super_path = os.path.join(test_dir, "perf_super.json") + perf_super_path = os.path.join(tmp_dir, "perf_super.json") update_perf_super_json( perf_super_json=perf_super_path, @@ -275,7 +275,7 @@ def test_perf_entry_super_json_structure(self, test_dir, fixtures_dir): assert 'datatype' in configs assert 'max_tokens' in configs - def test_perf_entry_super_config_matching(self, test_dir, fixtures_dir): + def test_perf_entry_super_config_matching(self, tmp_dir, fixtures_dir): """Test that configs are correctly matched for all results.""" # Create mock data common_info = { @@ -306,20 +306,20 @@ def test_perf_entry_super_config_matching(self, test_dir, fixtures_dir): "additional_docker_run_options": "", } - common_info_path = os.path.join(test_dir, "common_info_super.json") + common_info_path = os.path.join(tmp_dir, "common_info_super.json") with open(common_info_path, 'w') as f: json.dump(common_info, f) - + # Create results CSV - results_csv = os.path.join(test_dir, "perf_dummy_super.csv") + results_csv = os.path.join(tmp_dir, "perf_dummy_super.csv") with open(results_csv, 'w') as f: f.write("model,performance,metric,benchmark\n") f.write("dummy/model-1,1234.56,tokens/s,throughput\n") f.write("dummy/model-2,2345.67,requests/s,serving\n") f.write("dummy/model-3,345.78,ms,latency\n") - - perf_super_path = os.path.join(test_dir, "perf_super.json") - + + perf_super_path = os.path.join(tmp_dir, "perf_super.json") + update_perf_super_json( perf_super_json=perf_super_path, multiple_results=results_csv, @@ -327,7 +327,7 @@ def test_perf_entry_super_config_matching(self, test_dir, fixtures_dir): model_name="dummy_perf_super", scripts_base_dir=fixtures_dir ) - + # Load and verify matching with open(perf_super_path, 'r') as f: data = json.load(f) @@ -352,7 +352,7 @@ def test_perf_entry_super_config_matching(self, test_dir, fixtures_dir): assert configs['benchmark'] in ['throughput', 'serving', 'latency'] assert configs['datatype'] in ['float16', 'float32', 'bfloat16'] - def test_perf_entry_super_no_config(self, test_dir, fixtures_dir): + def test_perf_entry_super_no_config(self, tmp_dir, fixtures_dir): """Test handling when no config file is specified.""" # Create mock data without config common_info = { @@ -383,17 +383,17 @@ def test_perf_entry_super_no_config(self, test_dir, fixtures_dir): "additional_docker_run_options": "", } - common_info_path = os.path.join(test_dir, "common_info_super.json") + common_info_path = os.path.join(tmp_dir, "common_info_super.json") with open(common_info_path, 'w') as f: json.dump(common_info, f) - + # Create results CSV - results_csv = os.path.join(test_dir, "perf_dummy_super.csv") + results_csv = os.path.join(tmp_dir, "perf_dummy_super.csv") with open(results_csv, 'w') as f: f.write("model,performance,metric\n") f.write("dummy-no-config,1234.56,tokens/s\n") - - perf_super_path = os.path.join(test_dir, "perf_super.json") + + perf_super_path = os.path.join(tmp_dir, "perf_super.json") update_perf_super_json( perf_super_json=perf_super_path, @@ -413,7 +413,7 @@ def test_perf_entry_super_no_config(self, test_dir, fixtures_dir): assert data[0]['configs'] is None, \ "configs should be None when no config file specified" - def test_perf_entry_super_multi_results(self, test_dir, fixtures_dir): + def test_perf_entry_super_multi_results(self, tmp_dir, fixtures_dir): """Test handling of multiple result metrics.""" common_info = { "pipeline": "dummy_test", @@ -443,19 +443,19 @@ def test_perf_entry_super_multi_results(self, test_dir, fixtures_dir): "additional_docker_run_options": "", } - common_info_path = os.path.join(test_dir, "common_info_super.json") + common_info_path = os.path.join(tmp_dir, "common_info_super.json") with open(common_info_path, 'w') as f: json.dump(common_info, f) - + # Create results CSV with extra metrics - results_csv = os.path.join(test_dir, "perf_multi_metrics.csv") + results_csv = os.path.join(tmp_dir, "perf_multi_metrics.csv") with open(results_csv, 'w') as f: f.write("model,performance,metric,throughput,latency_mean_ms,latency_p50_ms,latency_p90_ms,gpu_memory_used_mb\n") f.write("model-1,1234.56,tokens/s,1234.56,8.1,7.9,12.3,12288\n") f.write("model-2,2345.67,requests/s,2345.67,4.3,4.1,6.8,16384\n") - perf_super_path = os.path.join(test_dir, "perf_super.json") - + perf_super_path = os.path.join(tmp_dir, "perf_super.json") + update_perf_super_json( perf_super_json=perf_super_path, multiple_results=results_csv, @@ -490,7 +490,7 @@ def test_perf_entry_super_multi_results(self, test_dir, fixtures_dir): assert multi_results['latency_mean_ms'] == 8.1 assert multi_results['gpu_memory_used_mb'] == 12288 - def test_perf_entry_super_deployment_fields(self, test_dir, fixtures_dir): + def test_perf_entry_super_deployment_fields(self, tmp_dir, fixtures_dir): """Test that all deployment-related fields are present.""" common_info = { "pipeline": "dummy_test", @@ -520,18 +520,18 @@ def test_perf_entry_super_deployment_fields(self, test_dir, fixtures_dir): "additional_docker_run_options": "", } - common_info_path = os.path.join(test_dir, "common_info_super.json") + common_info_path = os.path.join(tmp_dir, "common_info_super.json") with open(common_info_path, 'w') as f: json.dump(common_info, f) - + # Create results CSV - results_csv = os.path.join(test_dir, "perf_deployment.csv") + results_csv = os.path.join(tmp_dir, "perf_deployment.csv") with open(results_csv, 'w') as f: f.write("model,performance,metric\n") f.write("multi-node-test,5000.0,tokens/s\n") - perf_super_path = os.path.join(test_dir, "perf_super.json") - + perf_super_path = os.path.join(tmp_dir, "perf_super.json") + update_perf_super_json( perf_super_json=perf_super_path, multiple_results=results_csv, @@ -566,16 +566,16 @@ def test_perf_entry_super_deployment_fields(self, test_dir, fixtures_dir): class TestPerfSuperCSVGeneration: """Test cases for CSV generation from perf_super.json.""" - + @pytest.fixture - def test_dir(self): + def tmp_dir(self): """Create temporary directory for tests.""" temp_dir = tempfile.mkdtemp() yield temp_dir if os.path.exists(temp_dir): shutil.rmtree(temp_dir) - - def test_csv_generation_from_json(self, test_dir): + + def test_csv_generation_from_json(self, tmp_dir): """Test CSV generation from perf_super.json.""" # Create a sample perf_super.json data = [ @@ -599,13 +599,13 @@ def test_csv_generation_from_json(self, test_dir): } ] - json_path = os.path.join(test_dir, "perf_super.json") + json_path = os.path.join(tmp_dir, "perf_super.json") with open(json_path, 'w') as f: json.dump(data, f) - + # Change to test directory original_dir = os.getcwd() - os.chdir(test_dir) + os.chdir(tmp_dir) try: # Generate CSVs @@ -642,7 +642,7 @@ def test_csv_generation_from_json(self, test_dir): finally: os.chdir(original_dir) - def test_csv_handles_none_values(self, test_dir): + def test_csv_handles_none_values(self, tmp_dir): """Test that CSV generation handles None values correctly.""" data = [ { @@ -653,13 +653,13 @@ def test_csv_handles_none_values(self, test_dir): "multi_results": None, } ] - - json_path = os.path.join(test_dir, "perf_super.json") + + json_path = os.path.join(tmp_dir, "perf_super.json") with open(json_path, 'w') as f: json.dump(data, f) - + original_dir = os.getcwd() - os.chdir(test_dir) + os.chdir(tmp_dir) try: update_perf_super_csv( @@ -677,7 +677,7 @@ def test_csv_handles_none_values(self, test_dir): finally: os.chdir(original_dir) - def test_csv_multiple_entries_in_entry_file(self, test_dir): + def test_csv_multiple_entries_in_entry_file(self, tmp_dir): """Test that perf_entry_super.csv can contain multiple entries from current run. This tests the fix for the issue where perf_entry.csv and perf_entry.json @@ -735,12 +735,12 @@ def test_csv_multiple_entries_in_entry_file(self, test_dir): } ] - json_path = os.path.join(test_dir, "perf_super.json") + json_path = os.path.join(tmp_dir, "perf_super.json") with open(json_path, 'w') as f: json.dump(data, f) - + original_dir = os.getcwd() - os.chdir(test_dir) + os.chdir(tmp_dir) try: # Generate CSVs with num_entries=4 (simulating 4 entries added in current run) From c9161ea2156d84aaf4f75b23dfe1b2459fd40cee Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Thu, 23 Apr 2026 21:13:21 -0500 Subject: [PATCH 7/9] fix(review): address Copilot review comments - docker.py: apply re.escape() to container name before embedding in Docker --filter regex so metacharacters (e.g. ".") are treated as literals, preventing false-positive container matches - run.py: include 0 in timeout validation error message; introduce timeout_display so panels show "disabled" instead of "Nones" when --timeout 0 is used - auth.py: fix login_to_registry signature from registry: str to registry: Optional[str], matching actual call sites and implementation Co-Authored-By: Claude Sonnet 4 --- src/madengine/cli/commands/run.py | 8 +++++--- src/madengine/core/auth.py | 6 +++--- src/madengine/core/docker.py | 4 +++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/madengine/cli/commands/run.py b/src/madengine/cli/commands/run.py index 09d90772..282d41df 100644 --- a/src/madengine/cli/commands/run.py +++ b/src/madengine/cli/commands/run.py @@ -174,7 +174,7 @@ def run( # Input validation if timeout < -1: console.print( - "❌ [red]Timeout must be -1 (default) or a positive integer[/red]" + "❌ [red]Timeout must be -1 (default), 0 (no timeout), or a positive integer[/red]" ) raise typer.Exit(ExitCode.INVALID_ARGS) @@ -198,6 +198,8 @@ def run( elif timeout == 0: timeout = None + timeout_display = "disabled" if timeout is None else f"{timeout}s" + try: # Check if we're doing execution-only or full workflow manifest_exists = manifest_file and os.path.exists(manifest_file) @@ -214,7 +216,7 @@ def run( f"🚀 [bold cyan]Running Models (Execution Only)[/bold cyan]\n" f"Manifest: [yellow]{manifest_file}[/yellow]\n" f"Registry: [yellow]{registry or 'Auto-detected'}[/yellow]\n" - f"Timeout: [yellow]{timeout if timeout != -1 else 'Default'}[/yellow]s", + f"Timeout: [yellow]{timeout_display}[/yellow]", title="Execution Configuration", border_style="green", ) @@ -317,7 +319,7 @@ def run( f"🔨🚀 [bold cyan]Complete Workflow (Build + Run)[/bold cyan]\n" f"Tags: [yellow]{', '.join(processed_tags) if processed_tags else 'All models'}[/yellow]\n" f"Registry: [yellow]{registry or 'Local only'}[/yellow]\n" - f"Timeout: [yellow]{timeout if timeout != -1 else 'Default'}[/yellow]s" + f"Timeout: [yellow]{timeout_display}[/yellow]" f"{skip_note}", title="Workflow Configuration", border_style="magenta", diff --git a/src/madengine/core/auth.py b/src/madengine/core/auth.py index 5b7acbc3..bf26a128 100644 --- a/src/madengine/core/auth.py +++ b/src/madengine/core/auth.py @@ -79,7 +79,7 @@ def load_credentials() -> Optional[Dict]: def login_to_registry( - registry: str, + registry: Optional[str], credentials: Optional[Dict], console, rich_console, @@ -91,8 +91,8 @@ def login_to_registry( and ContainerRunner. Args: - registry: Registry URL (e.g., "localhost:5000", "docker.io", or empty - for DockerHub). + registry: Registry URL (e.g., "localhost:5000", "docker.io"), or + ``None``/empty string to target DockerHub. credentials: Credentials dictionary keyed by registry name. console: A ``Console`` instance for shell execution. rich_console: A Rich ``Console`` instance for formatted output. diff --git a/src/madengine/core/docker.py b/src/madengine/core/docker.py index 13cf1cb0..6cb1d096 100644 --- a/src/madengine/core/docker.py +++ b/src/madengine/core/docker.py @@ -7,6 +7,7 @@ """ # built-in modules import os +import re import shlex import typing @@ -61,8 +62,9 @@ def __init__( # containing regex metacharacters (e.g. ".", "[") cannot produce false # positives, and substring matches are avoided entirely. container_name_quoted = shlex.quote(container_name) + container_name_regex = shlex.quote(f"^/{re.escape(container_name)}$") container_name_exists = self.console.sh( - f"docker container ps -aq --filter name=^/{container_name_quoted}$" + f"docker container ps -aq --filter name={container_name_regex}" ) # if container name exists, clean it up automatically if container_name_exists: From 7f0c617a3c996765fd4ab6a738b1ab3d967f1963 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Thu, 23 Apr 2026 21:42:35 -0500 Subject: [PATCH 8/9] docs(changelog): update [2.0.1] with all changes from coketaste/v2-review-fix Co-Authored-By: Claude Sonnet 4 --- CHANGELOG.md | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc3e5cb1..ab05ae60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **E2E tests — `test_docker_gpus` pre-script OOM on MI350X**: The `run_rocenv_tool.sh` system-env pre-script was being OOM-killed (exit 137) inside Docker on gfx950 nodes with 6 GPUs bound, failing a test whose purpose is only GPU binding verification. Fixed by correcting the `gen_sys_env_details` condition in `container_runner.py` — the old `or` made the context key a no-op since `generate_sys_env_details` defaults to `True` — and passing `gen_sys_env_details: False` in the test's `additional_context`. +- **`--timeout 0` crashing with `signal.alarm(None)`**: `Timeout.__enter__` called `signal.alarm(None)` when `--timeout 0` was passed because the CLI correctly maps `0 → None` but `Timeout` had no guard for a falsy value. Added early-return in `__enter__`/`__exit__` when `seconds` is `None` or `0`. Also fixed the run command panels printing `Nones` for timeout when `--timeout 0` was used; they now display `disabled`. + +- **Docker container name regex false positives**: The `docker ps --filter name=^/$` exact-match filter embedded the container name directly into the regex without escaping, so names containing metacharacters (e.g. `.`, `[`) could match unintended containers. Applied `re.escape()` to the name before building the filter pattern. + +- **`login_to_registry` type annotation**: The `registry` parameter was typed as `str` but the implementation handled `None` and callers (including tests) passed `None` to mean DockerHub. Corrected to `Optional[str]`. + +- **Registry password process-list exposure**: `docker login` was invoked with the password in the argument list (visible via `/proc` or `ps`). Changed to pass it via a `MAD_REGISTRY_PASSWORD` environment variable consumed through `printf %s "$MAD_REGISTRY_PASSWORD" | docker login --password-stdin`. + +- **`login_to_registry` — `raise_on_failure` not fully honoured**: Missing-key and invalid-format errors in `login_to_registry` always raised `RuntimeError` regardless of `raise_on_failure`. All three failure paths (missing registry key, invalid credential format, docker login error) are now gated on `raise_on_failure`, allowing `ContainerRunner` to fall through to public image pulls. + +- **Kubernetes missing-package warning invisible**: `DeploymentFactory` raised `ImportWarning` when the `kubernetes` package was absent, which Python silences by default. Changed to `UserWarning` so the install hint is always visible. + ### Changed - **Model discovery — scope-based tag selection**: Replaced the `strict` mode flag on `DiscoverModels` with a cleaner scope-based rule that applies uniformly to both `madengine run` and `madengine build`: @@ -33,6 +45,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `--tags all` and `--tags scope/all` continue to select all models globally or within a scope respectively. - Removed `strict_discovery` parameter from `BuildOrchestrator.execute()` and the corresponding call in `RunOrchestrator._build_phase()` as they are no longer needed. +- **Shared `login_to_registry` utility**: Extracted duplicated Docker registry login logic (~120 lines) from `DockerBuilder` and `ContainerRunner` into `core/auth.py::login_to_registry()`. Both classes now delegate to it. `DockerBuilder` uses `raise_on_failure=True`; `ContainerRunner` uses `raise_on_failure=False` to allow fallback to public images. + +- **Centralised credential loading**: Extracted `_load_credentials` from `BuildOrchestrator` and `RunOrchestrator` into `core/auth.py::load_credentials()`. Environment variables (`MAD_DOCKERHUB_USER`, `MAD_DOCKERHUB_PASSWORD`, `MAD_DOCKERHUB_REPO`) take precedence over `credential.json`. + +- **Dead code removal**: Removed unused functions `find_and_replace_pattern` and `substring_found` (`utils/ops.py`), `highlight_log_section` (`utils/log_formatting.py`), `SessionTracker.get_session_start` and `SessionTracker.load_marker` (`utils/session_tracker.py`), and the unused `_filter_images_by_dockerfile_context` method from `RunOrchestrator`. + +- **`ConfigurationError` instead of `SystemExit` in orchestrator config loading**: `BuildOrchestrator` now raises a structured `ConfigurationError` (with suggestions) instead of calling `sys.exit()` directly when configuration loading fails. + +### Security + +- **Registry password no longer in process argument list**: Docker login commands previously passed the password as a CLI argument visible to other users via `/proc` or `ps`. All registry logins now inject the password through a dedicated `MAD_REGISTRY_PASSWORD` environment variable and use `--password-stdin`. + +- **`build-arg` values shell-quoted**: All Docker `--build-arg` key/value pairs are now wrapped with `str()` before `shlex.quote()` to prevent shell injection from non-string config values. + +### Tests + +- **New `TestTimeout` suite**: Covers `None`, `0`, and positive-second cases for `Timeout.__enter__`/`__exit__`, plus a `resolve_run_timeout` passthrough regression test. + +- **New `TestLoginToRegistry` suite**: Covers all success and failure paths of `login_to_registry`, including `raise_on_failure=True/False` behaviour, missing registry key, invalid credential format, and `docker.io` normalisation. + +- **Test suite cleanup**: Removed dead imports across 14 test files; replaced `try/assert False/except` antipattern with `pytest.raises()` (with `match=`); narrowed 5 bare `except:` clauses to `except Exception:`; deleted a pass-only dead test; removed duplicate tests; reclassified `test_profiling_tools_config.py` from unit to integration (reads real disk files) and `test_errors.py` from integration to unit (pure mocks). + ## [2.0.0] - 2026-04-09 ### Overview From 8032bbb28c53be045c2624a64028390dbc7f7597 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Fri, 24 Apr 2026 15:01:14 -0500 Subject: [PATCH 9/9] fix(review): address additional Copilot review comments - Validate env var names against POSIX regex before injecting into docker commands (docker.py, container_runner.py) - Allow None/0 timeout in Timeout class to disable signal-based timeout - Retry sacct up to 3 times with 5s delay to handle transient SLURM accounting DB lag after job completion Co-Authored-By: Claude Sonnet 4 --- src/madengine/core/docker.py | 8 +++-- src/madengine/core/timeout.py | 10 +++--- src/madengine/deployment/slurm.py | 35 ++++++++++++++++----- src/madengine/execution/container_runner.py | 6 ++++ 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/madengine/core/docker.py b/src/madengine/core/docker.py index 6cb1d096..115b9448 100644 --- a/src/madengine/core/docker.py +++ b/src/madengine/core/docker.py @@ -97,8 +97,11 @@ def __init__( command += "-v " + cwd + ":/myworkspace/ " # add envVars + _env_key_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") if envVars is not None: for evar in envVars.keys(): + if not _env_key_re.match(evar): + raise ValueError(f"Invalid environment variable name: {evar!r}") command += "-e " + evar + "=" + shlex.quote(str(envVars[evar])) + " " command += "--workdir /myworkspace/ " @@ -111,9 +114,10 @@ def __init__( command += "cat " self.console.sh(command) - # find container sha + # find container sha — use the same exact-match filter as the existence + # check above to avoid false positives from substring/regex matches. self.docker_sha = self.console.sh( - "docker ps -aqf 'name=" + container_name + "' " + f"docker ps -aqf name={container_name_regex}" ) def sh(self, command: str, timeout: int = 60, secret: bool = False) -> str: diff --git a/src/madengine/core/timeout.py b/src/madengine/core/timeout.py index 51aac2c9..68e83834 100644 --- a/src/madengine/core/timeout.py +++ b/src/madengine/core/timeout.py @@ -7,22 +7,24 @@ """ # built-in modules import signal +from typing import Optional class Timeout: """Class to handle timeouts. Attributes: - seconds (int): The timeout in seconds. + seconds (Optional[int]): The timeout in seconds, or None/0 to disable. """ - def __init__(self, seconds: int = 15) -> None: + def __init__(self, seconds: Optional[int] = 15) -> None: """Constructor of the Timeout class. Args: - seconds (int): The timeout in seconds. + seconds (Optional[int]): The timeout in seconds. None or 0 disables + the timeout. Negative values are treated as no timeout. """ - self.seconds = seconds + self.seconds = seconds if seconds and seconds > 0 else None def handle_timeout(self, signum, frame) -> None: """Handle timeout. diff --git a/src/madengine/deployment/slurm.py b/src/madengine/deployment/slurm.py index 82afa0aa..1679b087 100644 --- a/src/madengine/deployment/slurm.py +++ b/src/madengine/deployment/slurm.py @@ -13,6 +13,7 @@ import os import subprocess +import time from pathlib import Path from typing import Any, Dict, List, Optional @@ -980,14 +981,34 @@ def _show_log_summary(self, job_id: str, success: bool = True): self.console.print(f"[dim yellow]Note: Could not locate log files: {e}[/dim yellow]") def _check_job_completion(self, job_id: str) -> DeploymentResult: - """Check completed job status using sacct (locally).""" + """Check completed job status using sacct (locally). + + sacct can transiently return non-zero immediately after a job leaves + the queue because SLURM's accounting database may not yet be updated. + Retry up to _SACCT_RETRIES times with _SACCT_RETRY_DELAY seconds + between attempts before declaring the status UNKNOWN. + """ + _SACCT_RETRIES = 3 + _SACCT_RETRY_DELAY = 5 # seconds + try: - result = subprocess.run( - ["sacct", "-j", job_id, "-n", "-X", "-o", "State"], - capture_output=True, - text=True, - timeout=10, - ) + result = None + for attempt in range(1, _SACCT_RETRIES + 1): + result = subprocess.run( + ["sacct", "-j", job_id, "-n", "-X", "-o", "State"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + break + if attempt < _SACCT_RETRIES: + self.console.print( + f"[dim yellow]sacct returned non-zero for job {job_id} " + f"(attempt {attempt}/{_SACCT_RETRIES}), retrying in " + f"{_SACCT_RETRY_DELAY}s...[/dim yellow]" + ) + time.sleep(_SACCT_RETRY_DELAY) if result.returncode == 0: status = result.stdout.strip().upper() diff --git a/src/madengine/execution/container_runner.py b/src/madengine/execution/container_runner.py index b4383119..8794fa80 100644 --- a/src/madengine/execution/container_runner.py +++ b/src/madengine/execution/container_runner.py @@ -536,6 +536,8 @@ def get_cpu_arg(self) -> str: cpus = self.context.ctx["docker_cpus"].replace(" ", "") return f"--cpuset-cpus {cpus} " + _ENV_KEY_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + def get_env_arg(self, run_env: typing.Dict) -> str: """Get the environment arguments for docker run.""" env_args = "" @@ -543,11 +545,15 @@ def get_env_arg(self, run_env: typing.Dict) -> str: # Add custom environment variables if run_env: for env_arg in run_env: + if not self._ENV_KEY_RE.match(env_arg): + raise ValueError(f"Invalid environment variable name: {env_arg!r}") env_args += f"--env {env_arg}={shlex.quote(str(run_env[env_arg]))} " # Add context environment variables if "docker_env_vars" in self.context.ctx: for env_arg in self.context.ctx["docker_env_vars"].keys(): + if not self._ENV_KEY_RE.match(env_arg): + raise ValueError(f"Invalid environment variable name: {env_arg!r}") value = self.context.ctx["docker_env_vars"][env_arg] env_args += f"--env {env_arg}={shlex.quote(str(value))} "