diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..39c44cc --- /dev/null +++ b/.envrc @@ -0,0 +1,3 @@ + + +source .venv/bin/activate diff --git a/docs/templating.md b/docs/templating.md new file mode 100644 index 0000000..4e533fb --- /dev/null +++ b/docs/templating.md @@ -0,0 +1,231 @@ +# Template Variable Substitution + +Rompy supports template variable substitution in YAML configuration files using `${VAR}` syntax. This allows you to: + +- Use environment variables in configs +- Set default values for missing variables +- Process datetime values with filters +- Share configs across different environments + +## Syntax + +### Basic Substitution + +```yaml +output_dir: "${OUTPUT_ROOT}/my_run" +run_id: "${RUN_ID}" +``` + +### Default Values + +Provide fallback values when variables are not set: + +```yaml +output_dir: "${OUTPUT_ROOT:-./output}/my_run" +timeout: "${JOB_TIMEOUT:-3600}" +threads: "${NUM_THREADS:-4}" +``` + +### Type Conversion + +When a YAML value is **exactly** one template expression, type conversion is automatic: + +```yaml +timeout: "${TIMEOUT}" # "3600" → 3600 (int) +debug: "${DEBUG}" # "true" → True (bool) +pi: "${PI}" # "3.14" → 3.14 (float) +``` + +Embedded templates always produce strings: + +```yaml +path: "/data/${USER}/file" # Always string +``` + +## Datetime Filters + +### Available Filters + +| Filter | Description | Example | +|--------|-------------|---------| +| `as_datetime` | Parse ISO-8601 datetime | `${CYCLE\|as_datetime}` | +| `strftime` | Format datetime | `${CYCLE\|strftime:%Y%m%d}` | +| `shift` | Add/subtract time | `${CYCLE\|shift:-1d}` | + +### Filter Chaining + +Combine filters with `|`: + +```yaml +previous_day: "${CYCLE|as_datetime|shift:-1d|strftime:%Y-%m-%d}" +``` + +### Datetime Examples + +```yaml +cycle_date: "${CYCLE|as_datetime}" +filename: "wind_${CYCLE|strftime:%Y%m%d}.nc" +prev_cycle: "${CYCLE|as_datetime|shift:-1d}" +end_time: "${CYCLE|as_datetime|shift:+24h}" +``` + +### Shift Syntax + +Time deltas use: `[+|-]` + +Units: +- `d` = days +- `h` = hours +- `m` = minutes +- `s` = seconds + +Examples: +- `+1d` = add 1 day +- `-6h` = subtract 6 hours +- `+30m` = add 30 minutes + +## Complete Examples + +### Basic Config + +```yaml +run_id: "cycle_${CYCLE|strftime:%Y%m%d}" + +period: + start: "${CYCLE}" + end: "${CYCLE|as_datetime|shift:+1d}" + interval: "1H" + +output_dir: "${OUTPUT_ROOT:-./output}/cycle_${CYCLE|strftime:%Y%m%d}" + +input_files: + wind: "${DATA_ROOT}/wind/wind_${CYCLE|strftime:%Y%m%d}.nc" + wave: "${DATA_ROOT}/wave/wave_${CYCLE|strftime:%Y%m%d}.nc" +``` + +Usage: +```bash +export CYCLE=2023-01-01T00:00:00 +export DATA_ROOT=/scratch/data +rompy generate config.yml +``` + +### Backend Config + +```yaml +type: local +timeout: "${JOB_TIMEOUT:-3600}" +command: "python run_model.py" + +env_vars: + OMP_NUM_THREADS: "${NUM_THREADS:-4}" + WORK_DIR: "${WORK_DIR}" +``` + +Usage: +```bash +export WORK_DIR=/scratch/my_job +export NUM_THREADS=8 +rompy run config.yml --backend-config backend.yml +``` + +### Lookback Pattern + +Access previous time periods: + +```yaml +input_files: + current: "${DATA_ROOT}/data_${CYCLE|strftime:%Y%m%d}.nc" + previous: "${DATA_ROOT}/data_${CYCLE|as_datetime|shift:-1d|strftime:%Y%m%d}.nc" + week_ago: "${DATA_ROOT}/data_${CYCLE|as_datetime|shift:-7d|strftime:%Y%m%d}.nc" +``` + +### Nested Directory Structures + +```yaml +output_dir: "${DATA_ROOT}/output/${CYCLE|strftime:%Y/%m/%d}" +``` + +With `CYCLE=2023-01-15T00:00:00` → `/data/output/2023/01/15` + +## How It Works + +Template rendering happens **after YAML parsing** but **before Pydantic validation**: + +``` +1. Load YAML file → dict +2. Render templates: ${VAR} → actual values +3. Pydantic validation: dict → ModelRun object +``` + +This ensures: +- Type safety (Pydantic sees resolved values) +- Clear error messages (template errors before validation errors) +- Datetime objects work with Pydantic models + +## Error Handling + +### Missing Variables + +By default, missing variables cause an error: + +```yaml +path: "${MISSING_VAR}" # Error: Variable 'MISSING_VAR' not found +``` + +Use defaults to make variables optional: + +```yaml +path: "${OPTIONAL_VAR:-/default/path}" # OK if OPTIONAL_VAR not set +``` + +### Invalid Filters + +Unknown filters produce clear errors: + +```yaml +date: "${CYCLE|unknown_filter}" # Error: Unknown filter 'unknown_filter' +``` + +### Datetime Parsing + +Invalid datetime strings fail early: + +```yaml +date: "${CYCLE|as_datetime}" # Error if CYCLE is not ISO-8601 format +``` + +## Tips + +### Quote Values with `:-` + +YAML interprets `:` as mapping syntax. Quote defaults containing colons: + +```yaml +path: "${VAR:-/path/with:colon}" # GOOD - quoted +path: ${VAR:-/path/with:colon} # BAD - YAML parse error +``` + +### Environment Variables + +Templates use `os.environ` by default: + +```bash +export MY_VAR=value +rompy generate config.yml # ${MY_VAR} resolved automatically +``` + +### Separation from Jinja2 + +Don't confuse with rompy's existing Jinja2 templates (used for model control files): + +- `${VAR}` = **Config templating** (pre-load, env vars) +- `{{runtime.var}}` = **File templating** (post-load, Python objects) + +They serve different purposes and run at different times. + +## See Also + +- Example configs: `examples/configs/templated_*.yml` +- Tests: `tests/test_templating.py` +- Implementation: `src/rompy/templating.py` diff --git a/examples/configs/templated_advanced.yml b/examples/configs/templated_advanced.yml new file mode 100644 index 0000000..78fcfed --- /dev/null +++ b/examples/configs/templated_advanced.yml @@ -0,0 +1,35 @@ +# Example: Advanced Templated Configuration +# Shows datetime processing and complex filter chains + +# Environment variables: +# - CYCLE: ISO datetime (e.g., 2023-01-15T00:00:00) +# - DATA_ROOT: Base data directory +# - FORECAST_HOURS: Hours to forecast (defaults to 24) + +run_id: "forecast_${CYCLE|strftime:%Y%m%d_%H%M}" + +period: + start: "${CYCLE}" + end: "${CYCLE|as_datetime|shift:+${FORECAST_HOURS:-24}h}" + interval: "1H" + +output_dir: "${DATA_ROOT}/output/${CYCLE|strftime:%Y/%m/%d}" + +input_files: + wind: + current: "${DATA_ROOT}/wind/wind_${CYCLE|strftime:%Y%m%d}.nc" + previous: "${DATA_ROOT}/wind/wind_${CYCLE|as_datetime|shift:-1d|strftime:%Y%m%d}.nc" + + wave: + current: "${DATA_ROOT}/wave/wave_${CYCLE|strftime:%Y%m%d}.nc" + forecast: "${DATA_ROOT}/wave/wave_${CYCLE|as_datetime|shift:+1d|strftime:%Y%m%d}.nc" + +config: + lookback_days: 3 + lookback_start: "${CYCLE|as_datetime|shift:-3d}" + +# Usage with datetime arithmetic: +# export CYCLE=2023-01-15T12:00:00 +# export DATA_ROOT=/data/ocean +# export FORECAST_HOURS=48 +# rompy generate examples/configs/templated_advanced.yml diff --git a/examples/configs/templated_backend.yml b/examples/configs/templated_backend.yml new file mode 100644 index 0000000..8f32ecb --- /dev/null +++ b/examples/configs/templated_backend.yml @@ -0,0 +1,23 @@ +# Example: Templated Backend Configuration +# Demonstrates template variable substitution in backend configs + +# Environment variables: +# - NUM_THREADS: Number of CPU threads (defaults to 4) +# - JOB_TIMEOUT: Timeout in seconds (defaults to 3600) +# - WORK_DIR: Working directory path + +type: local + +timeout: "${JOB_TIMEOUT:-3600}" + +command: "python run_model.py" + +env_vars: + OMP_NUM_THREADS: "${NUM_THREADS:-4}" + MODEL_TYPE: "production" + WORK_DIR: "${WORK_DIR}" + +# Usage: +# export WORK_DIR=/scratch/my_job +# export NUM_THREADS=8 +# rompy run config.yml --backend-config templated_backend.yml diff --git a/examples/configs/templated_modelrun.yml b/examples/configs/templated_modelrun.yml new file mode 100644 index 0000000..8e173c8 --- /dev/null +++ b/examples/configs/templated_modelrun.yml @@ -0,0 +1,28 @@ +# Example: Templated ModelRun Configuration +# This demonstrates template variable substitution in rompy configs + +# Environment variables required: +# - CYCLE: ISO datetime string (e.g., 2023-01-01T00:00:00) +# - DATA_ROOT: Path to input data directory +# - OUTPUT_ROOT: (optional) Path to output directory (defaults to ./output) + +run_id: "cycle_${CYCLE|strftime:%Y%m%d}" + +period: + start: "${CYCLE}" + end: "${CYCLE|as_datetime|shift:+1d}" + interval: "1H" + +output_dir: "${OUTPUT_ROOT:-./output}/cycle_${CYCLE|strftime:%Y%m%d}" +delete_existing: true + +# Example input files with templated paths +input_files: + wind: "${DATA_ROOT}/wind/wind_${CYCLE|strftime:%Y%m%d}.nc" + wave: "${DATA_ROOT}/wave/wave_${CYCLE|strftime:%Y%m%d}.nc" + +# Usage: +# export CYCLE=2023-01-01T00:00:00 +# export DATA_ROOT=/scratch/data +# export OUTPUT_ROOT=/scratch/output # Optional +# rompy generate examples/configs/templated_modelrun.yml -v diff --git a/src/rompy/cli.py b/src/rompy/cli.py index f7ea4ab..76bf896 100644 --- a/src/rompy/cli.py +++ b/src/rompy/cli.py @@ -21,6 +21,7 @@ from rompy.backends import DockerConfig, LocalConfig, SlurmConfig from rompy.logging import LogFormat, LoggingConfig, LogLevel, get_logger from rompy.model import PIPELINE_BACKENDS, POSTPROCESSORS, RUN_BACKENDS, ModelRun +from rompy.templating import render_templates # Initialize the logger logger = get_logger(__name__) @@ -165,11 +166,20 @@ def load_config( try: config = yaml.safe_load(content) logger.info("Parsed config as YAML") - return config except yaml.YAMLError as e: logger.error(f"Failed to parse config as JSON or YAML: {e}") raise click.UsageError("Config file is not valid JSON or YAML") + # Render template variables in config + try: + config = render_templates(config, context=dict(os.environ), strict=True) + logger.debug("Template variables rendered successfully") + except Exception as e: + logger.error(f"Failed to render template variables: {e}") + raise click.UsageError(f"Template rendering error: {e}") + + return config + def print_version(ctx, param, value): """Callback to print version and exit.""" diff --git a/src/rompy/core/config.py b/src/rompy/core/config.py index e9cfd88..438b172 100644 --- a/src/rompy/core/config.py +++ b/src/rompy/core/config.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Literal, Optional, Any +from typing import Literal, Optional from pydantic import Field diff --git a/src/rompy/core/source.py b/src/rompy/core/source.py index fbc589b..4e72985 100644 --- a/src/rompy/core/source.py +++ b/src/rompy/core/source.py @@ -1,6 +1,5 @@ """Rompy source objects.""" -import logging from abc import ABC, abstractmethod from functools import cached_property from pathlib import Path diff --git a/src/rompy/run/docker.py b/src/rompy/run/docker.py index c9df7b0..ff08b6b 100644 --- a/src/rompy/run/docker.py +++ b/src/rompy/run/docker.py @@ -8,7 +8,6 @@ import json import logging import pathlib -import subprocess import time from typing import TYPE_CHECKING, Dict, List, Optional @@ -289,7 +288,7 @@ def _run_container( # Note: When remove=True, client.containers.run() returns None # If you need to capture output, you'd need to set remove=False and manually remove client.containers.run(**container_config) - + logger.info("Model run completed successfully") return True diff --git a/src/rompy/run/slurm.py b/src/rompy/run/slurm.py index dea6de4..e3ed36a 100644 --- a/src/rompy/run/slurm.py +++ b/src/rompy/run/slurm.py @@ -9,8 +9,7 @@ import subprocess import tempfile import time -from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from rompy.backends import SlurmConfig @@ -86,7 +85,7 @@ def _create_job_script( """ # Determine the working directory for the job work_dir = config.working_dir if config.working_dir else staging_dir - + # Create the job script content script_lines = [ "#!/bin/bash", @@ -96,13 +95,13 @@ def _create_job_script( # Add SBATCH directives from configuration if config.job_name: script_lines.append(f"#SBATCH --job-name={config.job_name}") - + if config.output_file: script_lines.append(f"#SBATCH --output={config.output_file}") else: # Default output file with job ID script_lines.append(f"#SBATCH --output={work_dir}/slurm-%j.out") - + if config.error_file: script_lines.append(f"#SBATCH --error={config.error_file}") else: @@ -111,21 +110,21 @@ def _create_job_script( if config.queue: script_lines.append(f"#SBATCH --partition={config.queue}") - + script_lines.append(f"#SBATCH --nodes={config.nodes}") script_lines.append(f"#SBATCH --ntasks={config.ntasks}") script_lines.append(f"#SBATCH --cpus-per-task={config.cpus_per_task}") script_lines.append(f"#SBATCH --time={config.time_limit}") - + if config.account: script_lines.append(f"#SBATCH --account={config.account}") - + if config.qos: script_lines.append(f"#SBATCH --qos={config.qos}") - + if config.reservation: script_lines.append(f"#SBATCH --reservation={config.reservation}") - + if config.mail_type and config.mail_user: script_lines.append(f"#SBATCH --mail-type={config.mail_type}") script_lines.append(f"#SBATCH --mail-user={config.mail_user}") @@ -134,28 +133,32 @@ def _create_job_script( for option in config.additional_options: script_lines.append(f"#SBATCH {option}") - script_lines.extend([ - "", - "# Change to working directory", - f"cd {work_dir}", - "", - "# Set environment variables", - ]) + script_lines.extend( + [ + "", + "# Change to working directory", + f"cd {work_dir}", + "", + "# Set environment variables", + ] + ) # Add environment variables for key, value in config.env_vars.items(): script_lines.append(f"export {key}={value}") # Add the actual command to run the model - script_lines.extend([ - "", - "# Execute command in the workspace", - config.command, - ]) + script_lines.extend( + [ + "", + "# Execute command in the workspace", + config.command, + ] + ) # Create temporary job script file - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: - f.write('\n'.join(script_lines)) + with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f: + f.write("\n".join(script_lines)) script_path = f.name logger.debug(f"SLURM job script created at: {script_path}") @@ -174,13 +177,11 @@ def _submit_job(self, job_script: str) -> Optional[str]: """ try: # Check if sbatch command is available - result = subprocess.run( - ["which", "sbatch"], - capture_output=True, - text=True - ) + result = subprocess.run(["which", "sbatch"], capture_output=True, text=True) if result.returncode != 0 or not result.stdout.strip(): - logger.error("sbatch command not found. SLURM may not be installed or in PATH.") + logger.error( + "sbatch command not found. SLURM may not be installed or in PATH." + ) return None # Check if SLURM controller is responsive @@ -188,18 +189,17 @@ def _submit_job(self, job_script: str) -> Optional[str]: ["scontrol", "--help"], capture_output=True, text=True, - timeout=10 # Don't wait too long + timeout=10, # Don't wait too long ) if result.returncode != 0: - logger.error("SLURM controller is not responsive. scontrol command failed.") + logger.error( + "SLURM controller is not responsive. scontrol command failed." + ) return None # Submit the job using sbatch result = subprocess.run( - ["sbatch", job_script], - capture_output=True, - text=True, - check=True + ["sbatch", job_script], capture_output=True, text=True, check=True ) # Extract job ID from sbatch output (format: "Submitted batch job ") @@ -213,7 +213,9 @@ def _submit_job(self, job_script: str) -> Optional[str]: return None except subprocess.TimeoutExpired: - logger.error("SLURM controller check timed out. SLURM may not be properly configured.") + logger.error( + "SLURM controller check timed out. SLURM may not be properly configured." + ) return None except subprocess.CalledProcessError as e: logger.error(f"Failed to submit SLURM job: {e.stderr}") @@ -243,8 +245,17 @@ def _wait_for_completion(self, job_id: str, config: "SlurmConfig") -> bool: # Terminal states that indicate job completion (successful or failed) # Using SLURM job states: https://slurm.schedmd.com/squeue.html#SECTION_JOB-STATE-CODES - terminal_states = {'BOOT_FAIL', 'CANCELLED', 'COMPLETED', 'DEADLINE', 'FAILED', - 'NODE_FAIL', 'OUT_OF_MEMORY', 'PREEMPTED', 'TIMEOUT'} + terminal_states = { + "BOOT_FAIL", + "CANCELLED", + "COMPLETED", + "DEADLINE", + "FAILED", + "NODE_FAIL", + "OUT_OF_MEMORY", + "PREEMPTED", + "TIMEOUT", + } # Start time for timeout check start_time = time.time() @@ -253,57 +264,68 @@ def _wait_for_completion(self, job_id: str, config: "SlurmConfig") -> bool: # Check if we've exceeded the timeout elapsed_time = time.time() - start_time if elapsed_time > config.timeout: - logger.error(f"Timeout waiting for job {job_id} after {config.timeout} seconds") - + logger.error( + f"Timeout waiting for job {job_id} after {config.timeout} seconds" + ) + # Let SLURM handle job cancellation according to its configured policies return False # Get job status using scontrol for more reliable detection try: result = subprocess.run( - ['scontrol', 'show', 'job', job_id], + ["scontrol", "show", "job", job_id], capture_output=True, text=True, - check=True + check=True, ) - + # Parse the output to get the job state output = result.stdout - if 'JobState=' in output: - state = output.split('JobState=')[1].split()[0].split('_')[0] # Extract state like 'RUNNING', 'COMPLETED', etc. + if "JobState=" in output: + state = ( + output.split("JobState=")[1].split()[0].split("_")[0] + ) # Extract state like 'RUNNING', 'COMPLETED', etc. else: # If JobState is not found, we might have an issue with parsing - logger.warning(f"Could not determine job state from output for job {job_id}") + logger.warning( + f"Could not determine job state from output for job {job_id}" + ) state = None - - if state is None: # If job state can't be determined, check if job is not found - if 'slurm_load_jobs error' in output or 'Invalid job id' in output.lower(): + + if ( + state is None + ): # If job state can't be determined, check if job is not found + if ( + "slurm_load_jobs error" in output + or "Invalid job id" in output.lower() + ): logger.info(f"Job {job_id} not found - likely completed") return True # Assume successful completion if job ID is invalid - + if state in terminal_states: - if state == 'COMPLETED': # Completed successfully + if state == "COMPLETED": # Completed successfully logger.info(f"SLURM job {job_id} completed successfully") return True - elif state == 'CANCELLED': # Cancelled + elif state == "CANCELLED": # Cancelled logger.warning(f"SLURM job {job_id} was cancelled") return False - elif state == 'FAILED': # Failed + elif state == "FAILED": # Failed logger.error(f"SLURM job {job_id} failed") return False - elif state == 'TIMEOUT': # Timeout + elif state == "TIMEOUT": # Timeout logger.error(f"SLURM job {job_id} timed out") return False - elif state == 'BOOT_FAIL': # Boot failure + elif state == "BOOT_FAIL": # Boot failure logger.error(f"SLURM job {job_id} failed to boot") return False - elif state == 'NODE_FAIL': # Node failure + elif state == "NODE_FAIL": # Node failure logger.error(f"SLURM job {job_id} failed due to node failure") return False - elif state == 'OUT_OF_MEMORY': # Out of memory + elif state == "OUT_OF_MEMORY": # Out of memory logger.error(f"SLURM job {job_id} ran out of memory") return False - elif state == 'PREEMPTED': # Preempted + elif state == "PREEMPTED": # Preempted logger.error(f"SLURM job {job_id} was preempted") return False else: diff --git a/src/rompy/templating.py b/src/rompy/templating.py new file mode 100644 index 0000000..e7bd69f --- /dev/null +++ b/src/rompy/templating.py @@ -0,0 +1,405 @@ +""" +Template variable substitution for YAML configurations. + +This module provides ${VAR} style variable substitution for rompy config files, +supporting environment variables, defaults, and datetime processing filters. + +Examples: + Basic substitution: + filename: "/path/to/wind_input_${CYCLE}.nc" + + With defaults: + output_dir: "${OUTPUT_ROOT:-./output}/runs" + + Datetime filters: + run_id: "cycle_${CYCLE|strftime:%Y%m%d}" + prev_cycle: "${CYCLE|as_datetime|shift:-1d}" +""" + +import os +import re +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +from rompy.logging import get_logger + +logger = get_logger(__name__) + +# Regex pattern for ${VAR}, ${VAR:-default}, ${VAR|filter:arg|filter2} +# Captures: variable name, default value (optional), filter chain (optional) +TEMPLATE_PATTERN = re.compile( + r"\$\{([A-Za-z_][A-Za-z0-9_]*)" # Variable name (group 1) + r"(?::-((?:[^|}]|\\[|}])*?))?" # Optional default :- (group 2) + r"(?:\|([^}]+))?" # Optional filter chain | (group 3) + r"\}" +) + + +class TemplateError(Exception): + """Raised when template rendering fails.""" + + pass + + +class TemplateContext: + """Context for template variable resolution. + + Wraps a dict (typically os.environ) and provides variable lookup with + type conversion and default handling. + """ + + def __init__(self, context: Optional[Dict[str, Any]] = None): + """Initialize template context. + + Args: + context: Dict of variables (defaults to os.environ) + """ + self.context = context if context is not None else dict(os.environ) + + def get(self, name: str, default: Optional[str] = None) -> Any: + """Get variable value with optional default. + + Args: + name: Variable name + default: Default value if variable not found + + Returns: + Variable value (typed if possible) + + Raises: + TemplateError: If variable not found and no default + """ + if name in self.context: + return self.context[name] + + if default is not None: + return default + + raise TemplateError( + f"Variable '${{{name}}}' not found in context and no default provided" + ) + + def set(self, name: str, value: Any): + """Set variable in context (for nested resolution).""" + self.context[name] = value + + +def parse_datetime(value: Any, fmt: Optional[str] = None) -> datetime: + """Parse datetime from string. + + Args: + value: String or datetime object + fmt: strptime format (if None, tries ISO-8601) + + Returns: + datetime object + + Raises: + TemplateError: If parsing fails + """ + if isinstance(value, datetime): + return value + + if not isinstance(value, str): + raise TemplateError( + f"Cannot parse datetime from {type(value).__name__}: {value}" + ) + + # Try ISO-8601 first + if fmt is None: + try: + # Handle various ISO formats + for iso_fmt in [ + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M:%S.%f", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d", + ]: + try: + return datetime.strptime(value, iso_fmt) + except ValueError: + continue + # Try fromisoformat as fallback + return datetime.fromisoformat(value.replace("Z", "+00:00")) + except (ValueError, AttributeError) as e: + raise TemplateError(f"Cannot parse ISO datetime from '{value}': {e}") + else: + # Use provided format + try: + return datetime.strptime(value, fmt) + except ValueError as e: + raise TemplateError( + f"Cannot parse datetime from '{value}' with format '{fmt}': {e}" + ) + + +def shift_datetime(dt: datetime, delta_str: str) -> datetime: + """Shift datetime by delta string. + + Args: + dt: datetime object + delta_str: Delta string like "+1d", "-6h", "+30m" + + Returns: + Shifted datetime + + Raises: + TemplateError: If delta format invalid + """ + match = re.match(r"^([+-]?)(\d+)([dhms])$", delta_str) + if not match: + raise TemplateError( + f"Invalid shift format '{delta_str}'. Expected format: [+|-] " + f"(e.g., '+1d', '-6h', '+30m')" + ) + + sign, amount, unit = match.groups() + amount = int(amount) + if sign == "-": + amount = -amount + + if unit == "d": + delta = timedelta(days=amount) + elif unit == "h": + delta = timedelta(hours=amount) + elif unit == "m": + delta = timedelta(minutes=amount) + elif unit == "s": + delta = timedelta(seconds=amount) + else: + raise TemplateError(f"Unknown time unit '{unit}' (expected d, h, m, s)") + + return dt + delta + + +def apply_filter(value: Any, filter_spec: str) -> Any: + """Apply a single filter to a value. + + Args: + value: Input value + filter_spec: Filter specification (e.g., "strftime:%Y%m%d", "shift:-1d") + + Returns: + Filtered value + + Raises: + TemplateError: If filter unknown or fails + """ + # Parse filter name and argument + if ":" in filter_spec: + filter_name, filter_arg = filter_spec.split(":", 1) + else: + filter_name = filter_spec + filter_arg = None + + filter_name = filter_name.strip() + + # Apply filter + if filter_name == "as_datetime": + # Parse datetime with optional format + return parse_datetime(value, filter_arg) + + elif filter_name == "strftime": + # Format datetime + if not filter_arg: + raise TemplateError( + "Filter 'strftime' requires format argument (e.g., strftime:%Y%m%d)" + ) + + # Ensure value is datetime + if not isinstance(value, datetime): + value = parse_datetime(value) + + return value.strftime(filter_arg) + + elif filter_name == "shift": + # Shift datetime by delta + if not filter_arg: + raise TemplateError( + "Filter 'shift' requires delta argument (e.g., shift:-1d)" + ) + + # Ensure value is datetime + if not isinstance(value, datetime): + value = parse_datetime(value) + + return shift_datetime(value, filter_arg) + + else: + raise TemplateError( + f"Unknown filter '{filter_name}'. Available filters: as_datetime, strftime, shift" + ) + + +def apply_filters(value: Any, filter_chain: str) -> Any: + """Apply a chain of filters to a value. + + Args: + value: Input value + filter_chain: Pipe-separated filters (e.g., "as_datetime|shift:-1d|strftime:%Y%m%d") + + Returns: + Filtered value + """ + filters = filter_chain.split("|") + for filter_spec in filters: + filter_spec = filter_spec.strip() + if filter_spec: + value = apply_filter(value, filter_spec) + + return value + + +def render_string(template: str, context: TemplateContext, strict: bool = True) -> Any: + """Render a single template string. + + Args: + template: Template string (may contain multiple ${...} expressions) + context: Template context + strict: Raise error on unresolved variables + + Returns: + Rendered value (preserves type if template is exactly one expression) + """ + # Check if template is exactly one expression (for type preservation) + exact_match = TEMPLATE_PATTERN.fullmatch(template) + + if exact_match: + # Single expression - preserve type + var_name, default, filter_chain = exact_match.groups() + + # Get variable value + try: + value = context.get(var_name, default) + except TemplateError: + if strict: + raise + return template # Keep unresolved + + # Apply filters if present + if filter_chain: + try: + value = apply_filters(value, filter_chain) + except TemplateError as e: + if strict: + raise TemplateError( + f"Filter error in '${{{var_name}|{filter_chain}}}': {e}" + ) + return template # Keep unresolved + # Return filtered value as-is (no type conversion for filtered values) + return value + + # Type conversion for string values (mimics bash-like behavior) + if isinstance(value, str): + # Try bool conversion + if value.lower() in ("true", "yes", "1"): + return True + elif value.lower() in ("false", "no", "0"): + return False + + # Try int conversion + try: + return int(value) + except ValueError: + pass + + # Try float conversion + try: + return float(value) + except ValueError: + pass + + return value + + else: + # Multiple expressions or embedded - always stringify + def replace_match(match): + var_name = match.group(1) + default = match.group(2) + filter_chain = match.group(3) + + # Get variable value + try: + value = context.get(var_name, default) + except TemplateError: + if strict: + raise + return match.group(0) # Keep original ${...} + + # Apply filters if present + if filter_chain: + try: + value = apply_filters(value, filter_chain) + except TemplateError as e: + if strict: + raise TemplateError( + f"Filter error in '${{{var_name}|{filter_chain}}}': {e}" + ) + return match.group(0) # Keep original ${...} + + # Always stringify in embedded context + return str(value) + + return TEMPLATE_PATTERN.sub(replace_match, template) + + +def render_templates( + data: Any, + context: Optional[Dict[str, Any]] = None, + strict: bool = True, +) -> Any: + """Recursively render templates in data structure. + + Walks dict/list/str recursively and replaces ${VAR} expressions. + + Args: + data: Data structure (dict, list, str, or primitive) + context: Variable context (defaults to os.environ) + strict: Raise error on unresolved variables + + Returns: + Rendered data structure + + Examples: + >>> render_templates({"path": "/data/${USER}/file"}) + {"path": "/data/john/file"} + + >>> render_templates({"cycle": "${CYCLE|as_datetime|shift:-1d}"}) + {"cycle": datetime(2023, 1, 1)} + """ + ctx = TemplateContext(context) + return _render_recursive(data, ctx, strict) + + +def _render_recursive(data: Any, context: TemplateContext, strict: bool) -> Any: + """Recursively render templates (internal).""" + + if isinstance(data, dict): + # Render dict recursively + result = {} + for key, value in data.items(): + # Render key if it contains templates (rare but possible) + if isinstance(key, str) and "${" in key: + try: + key = render_string(key, context, strict) + except TemplateError as e: + logger.warning(f"Failed to render dict key '{key}': {e}") + + # Render value + result[key] = _render_recursive(value, context, strict) + + return result + + elif isinstance(data, list): + # Render list recursively + return [_render_recursive(item, context, strict) for item in data] + + elif isinstance(data, str): + # Render string if it contains templates + if "${" in data: + return render_string(data, context, strict) + return data + + else: + # Primitives (int, float, bool, None) pass through + return data diff --git a/tests/backends/test_pydantic_backends.py b/tests/backends/test_pydantic_backends.py index 42ffa13..9a24894 100644 --- a/tests/backends/test_pydantic_backends.py +++ b/tests/backends/test_pydantic_backends.py @@ -406,7 +406,6 @@ def test_local_config_integration(self, mock_model_run): def test_docker_config_integration(self, mock_model_run): """Test DockerConfig integration with DockerRunBackend.""" import tempfile - import docker config = DockerConfig( image="test:latest", diff --git a/tests/backends/test_slurm_backend.py b/tests/backends/test_slurm_backend.py index 40ab471..ee0be3f 100644 --- a/tests/backends/test_slurm_backend.py +++ b/tests/backends/test_slurm_backend.py @@ -5,12 +5,10 @@ provides proper validation, and integrates with the SLURM execution backend. """ -import shutil import subprocess -import sys from pathlib import Path from tempfile import TemporaryDirectory -from unittest.mock import MagicMock, mock_open, patch +from unittest.mock import MagicMock, patch import os import tempfile import pytest @@ -23,10 +21,7 @@ def is_slurm_available(): """Check if SLURM is available on the system.""" try: result = subprocess.run( - ["which", "sbatch"], - capture_output=True, - text=True, - timeout=5 + ["which", "sbatch"], capture_output=True, text=True, timeout=5 ) return result.returncode == 0 and bool(result.stdout.strip()) except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError): @@ -35,8 +30,7 @@ def is_slurm_available(): # Skip tests that require SLURM if it's not available requires_slurm = pytest.mark.skipif( - not is_slurm_available(), - reason="SLURM is not available on this system" + not is_slurm_available(), reason="SLURM is not available on this system" ) @@ -122,24 +116,28 @@ def test_time_limit_validation(self): ] for time_limit in valid_time_limits: - config = SlurmConfig(queue="test", command="python run_model.py", time_limit=time_limit) + config = SlurmConfig( + queue="test", command="python run_model.py", time_limit=time_limit + ) assert config.time_limit == time_limit # Invalid time limits (format-based validation) invalid_time_limits = [ - "00:00", # Missing seconds - "invalid", # Not matching format - "1:1:1", # Not in HH:MM:SS format (only 1 digit for each part) - "25-00-00", # Wrong separator - "12345:00:00", # Too many digits for hours (5 digits instead of max 4) - "23:5", # Missing seconds part - ":23:59", # Missing hours - "23::59", # Missing minutes + "00:00", # Missing seconds + "invalid", # Not matching format + "1:1:1", # Not in HH:MM:SS format (only 1 digit for each part) + "25-00-00", # Wrong separator + "12345:00:00", # Too many digits for hours (5 digits instead of max 4) + "23:5", # Missing seconds part + ":23:59", # Missing hours + "23::59", # Missing minutes ] for time_limit in invalid_time_limits: with pytest.raises(ValidationError): - SlurmConfig(queue="test", command="python run_model.py", time_limit=time_limit) + SlurmConfig( + queue="test", command="python run_model.py", time_limit=time_limit + ) def test_additional_options_validation(self): """Test additional options validation.""" @@ -147,12 +145,18 @@ def test_additional_options_validation(self): config = SlurmConfig( queue="test", command="python run_model.py", - additional_options=["--gres=gpu:1", "--exclusive", "--mem-per-cpu=2048"] + additional_options=["--gres=gpu:1", "--exclusive", "--mem-per-cpu=2048"], ) - assert config.additional_options == ["--gres=gpu:1", "--exclusive", "--mem-per-cpu=2048"] + assert config.additional_options == [ + "--gres=gpu:1", + "--exclusive", + "--mem-per-cpu=2048", + ] # Empty list should be valid - config = SlurmConfig(queue="test", command="python run_model.py", additional_options=[]) + config = SlurmConfig( + queue="test", command="python run_model.py", additional_options=[] + ) assert config.additional_options == [] def test_get_backend_class(self): @@ -209,16 +213,24 @@ def test_field_boundaries(self): # Test out of bounds with pytest.raises(ValidationError): - SlurmConfig(queue="test", command="python run_model.py", nodes=0) # Min nodes is 1 + SlurmConfig( + queue="test", command="python run_model.py", nodes=0 + ) # Min nodes is 1 with pytest.raises(ValidationError): - SlurmConfig(queue="test", command="python run_model.py", nodes=101) # Max nodes is 100 + SlurmConfig( + queue="test", command="python run_model.py", nodes=101 + ) # Max nodes is 100 with pytest.raises(ValidationError): - SlurmConfig(queue="test", command="python run_model.py", cpus_per_task=0) # Min cpus_per_task is 1 + SlurmConfig( + queue="test", command="python run_model.py", cpus_per_task=0 + ) # Min cpus_per_task is 1 with pytest.raises(ValidationError): - SlurmConfig(queue="test", command="python run_model.py", cpus_per_task=129) # Max cpus_per_task is 128 + SlurmConfig( + queue="test", command="python run_model.py", cpus_per_task=129 + ) # Max cpus_per_task is 128 def test_command_field(self): """Test the command field validation and functionality.""" @@ -270,20 +282,22 @@ def basic_config(self): def test_create_job_script(self, mock_model_run, basic_config): """Test the _create_job_script method.""" from rompy.run.slurm import SlurmRunBackend - + backend = SlurmRunBackend() - + with TemporaryDirectory() as staging_dir: # Create the job script - script_path = backend._create_job_script(mock_model_run, basic_config, staging_dir) - + script_path = backend._create_job_script( + mock_model_run, basic_config, staging_dir + ) + # Verify the file was created assert os.path.exists(script_path) - + # Read and check the contents - with open(script_path, 'r') as f: + with open(script_path, "r") as f: content = f.read() - + # Check for SLURM directives assert "#!/bin/bash" in content assert "#SBATCH --partition=general" in content @@ -291,7 +305,7 @@ def test_create_job_script(self, mock_model_run, basic_config): assert "#SBATCH --ntasks=1" in content assert "#SBATCH --cpus-per-task=2" in content assert "#SBATCH --time=01:00:00" in content - + # Clean up if os.path.exists(script_path): os.remove(script_path) @@ -299,7 +313,7 @@ def test_create_job_script(self, mock_model_run, basic_config): def test_create_job_script_with_all_options(self, mock_model_run): """Test the _create_job_script method with all options.""" from rompy.run.slurm import SlurmRunBackend - + config = SlurmConfig( queue="priority", nodes=2, @@ -319,15 +333,17 @@ def test_create_job_script_with_all_options(self, mock_model_run): timeout=86400, env_vars={"OMP_NUM_THREADS": "8", "MY_VAR": "value"}, ) - + backend = SlurmRunBackend() - + with TemporaryDirectory() as staging_dir: - script_path = backend._create_job_script(mock_model_run, config, staging_dir) - - with open(script_path, 'r') as f: + script_path = backend._create_job_script( + mock_model_run, config, staging_dir + ) + + with open(script_path, "r") as f: content = f.read() - + # Check for all SBATCH directives assert "#SBATCH --partition=priority" in content assert "#SBATCH --nodes=2" in content @@ -344,11 +360,11 @@ def test_create_job_script_with_all_options(self, mock_model_run): assert "#SBATCH --mail-user=test@example.com" in content assert "#SBATCH --gres=gpu:1" in content assert "#SBATCH --exclusive" in content - + # Check for environment variables assert "export OMP_NUM_THREADS=8" in content assert "export MY_VAR=value" in content - + # Clean up if os.path.exists(script_path): os.remove(script_path) @@ -370,9 +386,11 @@ def test_create_job_script_with_command(self, mock_model_run): backend = SlurmRunBackend() with TemporaryDirectory() as staging_dir: - script_path = backend._create_job_script(mock_model_run, config, staging_dir) + script_path = backend._create_job_script( + mock_model_run, config, staging_dir + ) - with open(script_path, 'r') as f: + with open(script_path, "r") as f: content = f.read() # Check that the command is in the script @@ -386,7 +404,6 @@ def test_create_job_script_with_command(self, mock_model_run): if os.path.exists(script_path): os.remove(script_path) - def test_submit_job(self, basic_config): """Test the _submit_job method.""" from rompy.run.slurm import SlurmRunBackend @@ -394,7 +411,7 @@ def test_submit_job(self, basic_config): backend = SlurmRunBackend() # Create a simple job script - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f: f.write("#!/bin/bash\n#SBATCH --job-name=test\n") script_path = f.name @@ -409,7 +426,9 @@ def test_submit_job(self, basic_config): # Second call: scontrol --help - return success MagicMock(returncode=0, stdout="scontrol help text"), # Third call: sbatch command - return success - MagicMock(returncode=0, stdout="Submitted batch job 12345", stderr="") + MagicMock( + returncode=0, stdout="Submitted batch job 12345", stderr="" + ), ] job_id = backend._submit_job(script_path) @@ -430,7 +449,7 @@ def test_submit_job_failure(self, basic_config): backend = SlurmRunBackend() # Create a simple job script - with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f: f.write("#!/bin/bash\n#SBATCH --job-name=test\n") script_path = f.name @@ -444,7 +463,9 @@ def test_submit_job_failure(self, basic_config): # Second call: scontrol --help - return success MagicMock(returncode=0, stdout="scontrol help text"), # Third call: sbatch command - return failure - subprocess.CalledProcessError(1, "sbatch", stderr="SLURM submission failed") + subprocess.CalledProcessError( + 1, "sbatch", stderr="SLURM submission failed" + ), ] job_id = backend._submit_job(script_path) @@ -469,16 +490,12 @@ def test_wait_for_completion_completed(self, basic_config): mock_run.side_effect = [ # Running state from scontrol MagicMock( - stdout="JobState=RUNNING\nOtherInfo=...", - stderr="", - returncode=0 + stdout="JobState=RUNNING\nOtherInfo=...", stderr="", returncode=0 ), # Completed state from scontrol MagicMock( - stdout="JobState=COMPLETED\nOtherInfo=...", - stderr="", - returncode=0 - ) + stdout="JobState=COMPLETED\nOtherInfo=...", stderr="", returncode=0 + ), ] result = backend._wait_for_completion("12345", basic_config) @@ -495,9 +512,7 @@ def test_wait_for_completion_failed(self, basic_config): # Mock subprocess.run for scontrol to return failed state with patch("subprocess.run") as mock_run: mock_result = MagicMock( - stdout="JobState=FAILED\nOtherInfo=...", - stderr="", - returncode=0 + stdout="JobState=FAILED\nOtherInfo=...", stderr="", returncode=0 ) mock_run.return_value = mock_result @@ -512,7 +527,7 @@ def test_wait_for_completion_timeout(self): config = SlurmConfig( queue="test", - command="python run_model.py", # Added required command field + command="python run_model.py", # Added required command field timeout=60, # Minimum valid timeout value nodes=1, ntasks=1, @@ -540,7 +555,7 @@ def scontrol_side_effect(*args, **kwargs): return MagicMock( stdout="JobState=RUNNING\nOtherInfo=...", stderr="", - returncode=0 + returncode=0, ) mock_run.side_effect = scontrol_side_effect @@ -561,9 +576,11 @@ def test_run_method_success(self, mock_model_run, basic_config): with TemporaryDirectory() as staging_dir: # Mock the internal methods - with patch.object(backend, '_create_job_script') as mock_create_script, \ - patch.object(backend, '_submit_job') as mock_submit, \ - patch.object(backend, '_wait_for_completion') as mock_wait: + with ( + patch.object(backend, "_create_job_script") as mock_create_script, + patch.object(backend, "_submit_job") as mock_submit, + patch.object(backend, "_wait_for_completion") as mock_wait, + ): # Mock the methods to return expected values mock_create_script.return_value = "/tmp/job_script.sh" @@ -588,8 +605,10 @@ def test_run_method_job_submit_failure(self, mock_model_run, basic_config): with TemporaryDirectory() as staging_dir: # Mock the internal methods - with patch.object(backend, '_create_job_script') as mock_create_script, \ - patch.object(backend, '_submit_job') as mock_submit: + with ( + patch.object(backend, "_create_job_script") as mock_create_script, + patch.object(backend, "_submit_job") as mock_submit, + ): # Mock the methods mock_create_script.return_value = "/tmp/job_script.sh" diff --git a/tests/test_templating.py b/tests/test_templating.py new file mode 100644 index 0000000..484d1f5 --- /dev/null +++ b/tests/test_templating.py @@ -0,0 +1,414 @@ +import os +from datetime import datetime + +import pytest + +from rompy.templating import ( + TemplateContext, + TemplateError, + apply_filter, + apply_filters, + parse_datetime, + render_string, + render_templates, + shift_datetime, +) + + +class TestTemplateContext: + def test_get_from_context(self): + ctx = TemplateContext({"USER": "john", "HOME": "/home/john"}) + assert ctx.get("USER") == "john" + assert ctx.get("HOME") == "/home/john" + + def test_get_with_default(self): + ctx = TemplateContext({}) + assert ctx.get("MISSING", "default_value") == "default_value" + + def test_get_missing_strict(self): + ctx = TemplateContext({}) + with pytest.raises(TemplateError, match="Variable.*not found"): + ctx.get("MISSING") + + def test_set_variable(self): + ctx = TemplateContext({}) + ctx.set("NEW_VAR", "new_value") + assert ctx.get("NEW_VAR") == "new_value" + + def test_defaults_to_environ(self): + ctx = TemplateContext() + assert isinstance(ctx.context, dict) + + +class TestParseDatetime: + def test_parse_iso_datetime(self): + dt = parse_datetime("2023-01-01T12:00:00") + assert dt == datetime(2023, 1, 1, 12, 0, 0) + + def test_parse_iso_date(self): + dt = parse_datetime("2023-01-01") + assert dt == datetime(2023, 1, 1) + + def test_parse_with_microseconds(self): + dt = parse_datetime("2023-01-01T12:00:00.123456") + assert dt == datetime(2023, 1, 1, 12, 0, 0, 123456) + + def test_parse_with_custom_format(self): + dt = parse_datetime("01/15/2023", fmt="%m/%d/%Y") + assert dt == datetime(2023, 1, 15) + + def test_parse_invalid_format(self): + with pytest.raises(TemplateError, match="Cannot parse.*datetime"): + parse_datetime("not-a-date") + + def test_parse_already_datetime(self): + original = datetime(2023, 1, 1) + result = parse_datetime(original) + assert result == original + + +class TestShiftDatetime: + def test_shift_days_positive(self): + dt = datetime(2023, 1, 1) + shifted = shift_datetime(dt, "+1d") + assert shifted == datetime(2023, 1, 2) + + def test_shift_days_negative(self): + dt = datetime(2023, 1, 1) + shifted = shift_datetime(dt, "-1d") + assert shifted == datetime(2022, 12, 31) + + def test_shift_hours(self): + dt = datetime(2023, 1, 1, 12, 0) + shifted = shift_datetime(dt, "+6h") + assert shifted == datetime(2023, 1, 1, 18, 0) + + def test_shift_minutes(self): + dt = datetime(2023, 1, 1, 12, 0) + shifted = shift_datetime(dt, "+30m") + assert shifted == datetime(2023, 1, 1, 12, 30) + + def test_shift_seconds(self): + dt = datetime(2023, 1, 1, 12, 0, 0) + shifted = shift_datetime(dt, "+90s") + assert shifted == datetime(2023, 1, 1, 12, 1, 30) + + def test_shift_no_sign(self): + dt = datetime(2023, 1, 1) + shifted = shift_datetime(dt, "1d") + assert shifted == datetime(2023, 1, 2) + + def test_shift_invalid_format(self): + dt = datetime(2023, 1, 1) + with pytest.raises(TemplateError, match="Invalid shift format"): + shift_datetime(dt, "bad") + + def test_shift_invalid_unit(self): + dt = datetime(2023, 1, 1) + with pytest.raises(TemplateError, match="Invalid shift format"): + shift_datetime(dt, "1x") + + +class TestApplyFilter: + def test_as_datetime_filter(self): + result = apply_filter("2023-01-01T12:00:00", "as_datetime") + assert result == datetime(2023, 1, 1, 12, 0, 0) + + def test_as_datetime_with_format(self): + result = apply_filter("01/15/2023", "as_datetime:%m/%d/%Y") + assert result == datetime(2023, 1, 15) + + def test_strftime_filter(self): + dt = datetime(2023, 1, 15, 12, 30) + result = apply_filter(dt, "strftime:%Y%m%d") + assert result == "20230115" + + def test_strftime_parses_string(self): + result = apply_filter("2023-01-15", "strftime:%Y%m%d") + assert result == "20230115" + + def test_strftime_missing_arg(self): + with pytest.raises(TemplateError, match="requires format argument"): + apply_filter(datetime(2023, 1, 1), "strftime") + + def test_shift_filter(self): + dt = datetime(2023, 1, 1) + result = apply_filter(dt, "shift:-1d") + assert result == datetime(2022, 12, 31) + + def test_shift_parses_string(self): + result = apply_filter("2023-01-01", "shift:+1d") + assert result == datetime(2023, 1, 2) + + def test_shift_missing_arg(self): + with pytest.raises(TemplateError, match="requires delta argument"): + apply_filter(datetime(2023, 1, 1), "shift") + + def test_unknown_filter(self): + with pytest.raises(TemplateError, match="Unknown filter"): + apply_filter("value", "unknown_filter") + + +class TestApplyFilters: + def test_filter_chain(self): + result = apply_filters("2023-01-01", "as_datetime|shift:-1d|strftime:%Y%m%d") + assert result == "20221231" + + def test_single_filter(self): + result = apply_filters("2023-01-01", "as_datetime") + assert result == datetime(2023, 1, 1) + + def test_empty_filter(self): + result = apply_filters("value", "") + assert result == "value" + + +class TestRenderString: + def test_simple_substitution(self): + ctx = TemplateContext({"USER": "john"}) + result = render_string("Hello ${USER}", ctx) + assert result == "Hello john" + + def test_multiple_substitutions(self): + ctx = TemplateContext({"USER": "john", "HOME": "/home/john"}) + result = render_string("${USER} lives in ${HOME}", ctx) + assert result == "john lives in /home/john" + + def test_exact_match_preserves_type_int(self): + ctx = TemplateContext({"TIMEOUT": "3600"}) + result = render_string("${TIMEOUT}", ctx) + assert result == 3600 + assert isinstance(result, int) + + def test_exact_match_preserves_type_bool_true(self): + ctx = TemplateContext({"DEBUG": "true"}) + result = render_string("${DEBUG}", ctx) + assert result is True + + def test_exact_match_preserves_type_bool_false(self): + ctx = TemplateContext({"DEBUG": "false"}) + result = render_string("${DEBUG}", ctx) + assert result is False + + def test_exact_match_preserves_type_float(self): + ctx = TemplateContext({"PI": "3.14"}) + result = render_string("${PI}", ctx) + assert result == 3.14 + assert isinstance(result, float) + + def test_embedded_always_string(self): + ctx = TemplateContext({"NUM": "42"}) + result = render_string("value_${NUM}", ctx) + assert result == "value_42" + assert isinstance(result, str) + + def test_default_value(self): + ctx = TemplateContext({}) + result = render_string("${MISSING:-default}", ctx) + assert result == "default" + + def test_filter_in_exact_match(self): + ctx = TemplateContext({"CYCLE": "2023-01-01"}) + result = render_string("${CYCLE|as_datetime}", ctx) + assert result == datetime(2023, 1, 1) + + def test_filter_chain_in_exact_match(self): + ctx = TemplateContext({"CYCLE": "2023-01-01"}) + result = render_string("${CYCLE|as_datetime|shift:-1d|strftime:%Y%m%d}", ctx) + assert result == "20221231" + + def test_filter_in_embedded(self): + ctx = TemplateContext({"CYCLE": "2023-01-01T12:00:00"}) + result = render_string("wind_${CYCLE|strftime:%Y%m%d}.nc", ctx) + assert result == "wind_20230101.nc" + + def test_missing_variable_strict(self): + ctx = TemplateContext({}) + with pytest.raises(TemplateError, match="not found"): + render_string("${MISSING}", ctx, strict=True) + + def test_missing_variable_non_strict(self): + ctx = TemplateContext({}) + result = render_string("${MISSING}", ctx, strict=False) + assert result == "${MISSING}" + + def test_filter_error_strict(self): + ctx = TemplateContext({"VAR": "value"}) + with pytest.raises(TemplateError, match="Filter error"): + render_string("${VAR|unknown_filter}", ctx, strict=True) + + def test_filter_error_non_strict(self): + ctx = TemplateContext({"VAR": "value"}) + result = render_string("${VAR|unknown_filter}", ctx, strict=False) + assert result == "${VAR|unknown_filter}" + + +class TestRenderTemplates: + def test_render_dict(self): + data = {"user": "${USER}", "home": "${HOME}"} + result = render_templates(data, {"USER": "john", "HOME": "/home/john"}) + assert result == {"user": "john", "home": "/home/john"} + + def test_render_nested_dict(self): + data = {"outer": {"inner": "${VAR}"}} + result = render_templates(data, {"VAR": "value"}) + assert result == {"outer": {"inner": "value"}} + + def test_render_list(self): + data = ["${VAR1}", "${VAR2}"] + result = render_templates(data, {"VAR1": "a", "VAR2": "b"}) + assert result == ["a", "b"] + + def test_render_mixed_structure(self): + data = { + "files": ["${DIR}/file1", "${DIR}/file2"], + "config": {"timeout": "${TIMEOUT}"}, + } + result = render_templates(data, {"DIR": "/data", "TIMEOUT": "3600"}) + assert result == { + "files": ["/data/file1", "/data/file2"], + "config": {"timeout": 3600}, + } + + def test_primitives_pass_through(self): + data = {"int": 42, "float": 3.14, "bool": True, "none": None} + result = render_templates(data, {}) + assert result == data + + def test_no_templates(self): + data = {"key": "value", "number": 42} + result = render_templates(data, {}) + assert result == data + + def test_defaults_to_environ(self): + os.environ["TEST_VAR"] = "test_value" + try: + data = {"var": "${TEST_VAR}"} + result = render_templates(data) + assert result == {"var": "test_value"} + finally: + del os.environ["TEST_VAR"] + + def test_datetime_filters_in_config(self): + data = { + "run_id": "cycle_${CYCLE|strftime:%Y%m%d}", + "start": "${CYCLE|as_datetime}", + "end": "${CYCLE|as_datetime|shift:+1d}", + } + result = render_templates(data, {"CYCLE": "2023-01-01T00:00:00"}) + assert result == { + "run_id": "cycle_20230101", + "start": datetime(2023, 1, 1), + "end": datetime(2023, 1, 2), + } + + def test_realistic_config(self): + data = { + "run_id": "cycle_${CYCLE|strftime:%Y%m%d}", + "period": { + "start": "${CYCLE}", + "end": "${CYCLE|as_datetime|shift:+1d}", + "interval": "1H", + }, + "output_dir": "${OUTPUT_ROOT:-./output}/cycle_${CYCLE|strftime:%Y%m%d}", + "input_files": { + "wind": "${DATA_ROOT}/wind/wind_${CYCLE|strftime:%Y%m%d}.nc", + "wave": "${DATA_ROOT}/wave/wave_${CYCLE|strftime:%Y%m%d}.nc", + }, + } + context = { + "CYCLE": "2023-01-01T00:00:00", + "DATA_ROOT": "/scratch/data", + } + result = render_templates(data, context) + assert result["run_id"] == "cycle_20230101" + assert result["period"]["start"] == "2023-01-01T00:00:00" + assert result["period"]["end"] == datetime(2023, 1, 2) + assert result["output_dir"] == "./output/cycle_20230101" + assert result["input_files"]["wind"] == "/scratch/data/wind/wind_20230101.nc" + + def test_render_dict_keys(self): + data = {"prefix_${VAR}": "value"} + result = render_templates(data, {"VAR": "key"}) + assert result == {"prefix_key": "value"} + + def test_strict_mode_raises(self): + data = {"key": "${MISSING}"} + with pytest.raises(TemplateError, match="not found"): + render_templates(data, {}, strict=True) + + def test_non_strict_keeps_unresolved(self): + data = {"key": "${MISSING}"} + result = render_templates(data, {}, strict=False) + assert result == {"key": "${MISSING}"} + + +class TestEdgeCases: + def test_empty_dict(self): + result = render_templates({}, {}) + assert result == {} + + def test_empty_list(self): + result = render_templates([], {}) + assert result == [] + + def test_empty_string(self): + result = render_templates("", {}) + assert result == "" + + def test_nested_empty_structures(self): + data = {"empty_dict": {}, "empty_list": [], "empty_str": ""} + result = render_templates(data, {}) + assert result == data + + def test_special_characters_in_values(self): + ctx = TemplateContext({"VAR": "value/with/slashes"}) + result = render_string("${VAR}", ctx) + assert result == "value/with/slashes" + + def test_default_with_special_chars(self): + ctx = TemplateContext({}) + result = render_string("${MISSING:-/default/path}", ctx) + assert result == "/default/path" + + def test_unicode_in_values(self): + ctx = TemplateContext({"VAR": "unicode_日本語"}) + result = render_string("${VAR}", ctx) + assert result == "unicode_日本語" + + def test_bool_strings_case_insensitive(self): + for val in ["true", "True", "TRUE", "yes", "Yes", "YES"]: + ctx = TemplateContext({"BOOL": val}) + assert render_string("${BOOL}", ctx) is True + + for val in ["false", "False", "FALSE", "no", "No", "NO"]: + ctx = TemplateContext({"BOOL": val}) + assert render_string("${BOOL}", ctx) is False + + +class TestIntegration: + def test_rompy_use_case_wind_input(self): + ctx = TemplateContext({"CYCLE": "2023-01-15T00:00:00"}) + result = render_string("/path/to/wind_input_${CYCLE|strftime:%Y%m%d}.nc", ctx) + assert result == "/path/to/wind_input_20230115.nc" + + def test_rompy_use_case_lookback(self): + ctx = TemplateContext({"CYCLE": "2023-01-15T00:00:00"}) + result = render_string("${CYCLE|as_datetime|shift:-1d|strftime:%Y-%m-%d}", ctx) + assert result == "2023-01-14" + + def test_backend_config_with_env_vars(self): + data = { + "type": "local", + "timeout": "${TIMEOUT:-3600}", + "command": "python run.py", + "env_vars": { + "OMP_NUM_THREADS": "${NUM_THREADS:-4}", + "DATA_DIR": "${DATA_ROOT}/inputs", + }, + } + context = {"DATA_ROOT": "/scratch"} + result = render_templates(data, context) + assert result["timeout"] == 3600 + assert result["env_vars"]["OMP_NUM_THREADS"] == 4 + assert result["env_vars"]["DATA_DIR"] == "/scratch/inputs"