From ebd6a9574564a96205b10469307e1e666a9e3e4a Mon Sep 17 00:00:00 2001 From: Mouse Date: Thu, 9 Apr 2026 01:33:10 -0700 Subject: [PATCH] fix: harden CLI config loading and override handling --- promptlens/cli.py | 10 ++----- promptlens/config_loading.py | 48 ++++++++++++++++++++++++++++++++ tests/test_cli_config_loading.py | 45 ++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 7 deletions(-) create mode 100644 promptlens/config_loading.py create mode 100644 tests/test_cli_config_loading.py diff --git a/promptlens/cli.py b/promptlens/cli.py index 2d34ce3..d41de0f 100644 --- a/promptlens/cli.py +++ b/promptlens/cli.py @@ -7,12 +7,12 @@ from typing import Optional import click -import yaml from dotenv import load_dotenv from rich.console import Console from rich.logging import RichHandler from promptlens import __version__ +from promptlens.config_loading import apply_run_overrides, load_config_data from promptlens.exporters.csv_exporter import CSVExporter from promptlens.exporters.html_exporter import HTMLExporter from promptlens.exporters.json_exporter import JSONExporter @@ -90,14 +90,10 @@ def run( try: # Load config console.print(f"\n[cyan]Loading configuration from {config}...[/cyan]") - with open(config, "r") as f: - config_data = yaml.safe_load(f) + config_data = load_config_data(config) # Override with CLI options - if golden_set: - config_data["golden_set"] = golden_set - if output_dir: - config_data["output"]["directory"] = output_dir + config_data = apply_run_overrides(config_data, golden_set, output_dir) # Parse config try: diff --git a/promptlens/config_loading.py b/promptlens/config_loading.py new file mode 100644 index 0000000..e7b008b --- /dev/null +++ b/promptlens/config_loading.py @@ -0,0 +1,48 @@ +"""Config loading helpers for CLI commands.""" + +from typing import Any, Dict, Optional + +import yaml + + +def load_config_data(config_path: str) -> Dict[str, Any]: + """Load and validate raw config data from YAML.""" + with open(config_path, "r", encoding="utf-8") as f: + config_data = yaml.safe_load(f) + + if config_data is None: + raise ValueError( + "Configuration file is empty. " + "Provide a YAML mapping with at least 'golden_set' and 'models'." + ) + + if not isinstance(config_data, dict): + raise ValueError( + "Configuration root must be a YAML mapping/object " + f"(got {type(config_data).__name__})." + ) + + return config_data + + +def apply_run_overrides( + config_data: Dict[str, Any], + golden_set: Optional[str], + output_dir: Optional[str], +) -> Dict[str, Any]: + """Apply CLI overrides to loaded config data.""" + merged = dict(config_data) + + if golden_set: + merged["golden_set"] = golden_set + + if output_dir: + output_config = merged.get("output") + if output_config is None: + merged["output"] = {"directory": output_dir} + elif isinstance(output_config, dict): + output_config["directory"] = output_dir + else: + raise ValueError("'output' section must be a mapping/object if provided.") + + return merged diff --git a/tests/test_cli_config_loading.py b/tests/test_cli_config_loading.py new file mode 100644 index 0000000..624e87e --- /dev/null +++ b/tests/test_cli_config_loading.py @@ -0,0 +1,45 @@ +"""Tests for CLI config loading and override hardening.""" + +from pathlib import Path + +import pytest + +from promptlens.config_loading import apply_run_overrides, load_config_data + + +def test_load_config_data_rejects_empty_yaml(tmp_path: Path) -> None: + config_path = tmp_path / "empty.yaml" + config_path.write_text("", encoding="utf-8") + + with pytest.raises(ValueError, match="Configuration file is empty"): + load_config_data(str(config_path)) + + +def test_load_config_data_rejects_non_mapping_root(tmp_path: Path) -> None: + config_path = tmp_path / "list.yaml" + config_path.write_text("- not\n- a\n- mapping\n", encoding="utf-8") + + with pytest.raises(ValueError, match="Configuration root must be a YAML mapping"): + load_config_data(str(config_path)) + + +def test_apply_run_overrides_creates_output_mapping_when_missing() -> None: + config = { + "golden_set": "tests.yaml", + "models": [{"name": "m", "provider": "openai", "model": "gpt-4"}], + } + + merged = apply_run_overrides(config, golden_set=None, output_dir="./out") + + assert merged["output"]["directory"] == "./out" + + +def test_apply_run_overrides_rejects_invalid_output_section_type() -> None: + config = { + "golden_set": "tests.yaml", + "models": [{"name": "m", "provider": "openai", "model": "gpt-4"}], + "output": "bad-type", + } + + with pytest.raises(ValueError, match="'output' section must be a mapping"): + apply_run_overrides(config, golden_set=None, output_dir="./out")