diff --git a/promptlens/cli.py b/promptlens/cli.py index 235963a..8f9087a 100644 --- a/promptlens/cli.py +++ b/promptlens/cli.py @@ -6,6 +6,7 @@ import sys from pathlib import Path from typing import Optional +from collections.abc import Mapping import click import yaml @@ -35,6 +36,19 @@ def _remove_path_if_exists(path: Path) -> None: shutil.rmtree(path) +def _load_config_data(config_path: str) -> dict: + """Load YAML config and enforce a top-level mapping object.""" + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + if not isinstance(config_data, Mapping): + raise ValueError( + "Configuration file must contain a top-level mapping/object" + ) + + return dict(config_data) + + def setup_logging(level: str = "INFO") -> None: """Set up logging configuration. @@ -99,8 +113,7 @@ def run( try: # Load config console.print(f"\n[cyan]Loading configuration from {config}...[/cyan]") - with open(config, "r") as f: - config_data = yaml.safe_load(f) + config_data = _load_config_data(config) # Override with CLI options if golden_set: diff --git a/tests/test_cli_hardening.py b/tests/test_cli_hardening.py index cab3896..b93f640 100644 --- a/tests/test_cli_hardening.py +++ b/tests/test_cli_hardening.py @@ -1,6 +1,8 @@ from pathlib import Path -from promptlens.cli import _remove_path_if_exists +import pytest + +from promptlens.cli import _load_config_data, _remove_path_if_exists def test_remove_path_if_exists_removes_file(tmp_path: Path) -> None: @@ -33,3 +35,23 @@ def test_remove_path_if_exists_removes_symlink(tmp_path: Path) -> None: assert not link.exists() assert target_dir.exists() + + +def test_load_config_data_rejects_empty_yaml(tmp_path: Path) -> None: + config = tmp_path / "config.yaml" + config.write_text("", encoding="utf-8") + + with pytest.raises( + ValueError, match="top-level mapping/object" + ): + _load_config_data(str(config)) + + +def test_load_config_data_rejects_non_mapping_yaml(tmp_path: Path) -> None: + config = tmp_path / "config.yaml" + config.write_text("- item1\n- item2\n", encoding="utf-8") + + with pytest.raises( + ValueError, match="top-level mapping/object" + ): + _load_config_data(str(config))