Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions promptlens/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
from pathlib import Path
from typing import Optional
from collections.abc import Mapping

import click
import yaml
Expand Down Expand Up @@ -35,6 +36,19 @@ def _remove_path_if_exists(path: Path) -> None:
shutil.rmtree(path)


def _load_config_data(config_path: str) -> dict:
"""Load YAML config and enforce a top-level mapping object."""
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)

if not isinstance(config_data, Mapping):
raise ValueError(
"Configuration file must contain a top-level mapping/object"
)

return dict(config_data)


def setup_logging(level: str = "INFO") -> None:
"""Set up logging configuration.

Expand Down Expand Up @@ -99,8 +113,7 @@ def run(
try:
# Load config
console.print(f"\n[cyan]Loading configuration from {config}...[/cyan]")
with open(config, "r") as f:
config_data = yaml.safe_load(f)
config_data = _load_config_data(config)

# Override with CLI options
if golden_set:
Expand Down
24 changes: 23 additions & 1 deletion tests/test_cli_hardening.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path

from promptlens.cli import _remove_path_if_exists
import pytest

from promptlens.cli import _load_config_data, _remove_path_if_exists


def test_remove_path_if_exists_removes_file(tmp_path: Path) -> None:
Expand Down Expand Up @@ -33,3 +35,23 @@ def test_remove_path_if_exists_removes_symlink(tmp_path: Path) -> None:

assert not link.exists()
assert target_dir.exists()


def test_load_config_data_rejects_empty_yaml(tmp_path: Path) -> None:
config = tmp_path / "config.yaml"
config.write_text("", encoding="utf-8")

with pytest.raises(
ValueError, match="top-level mapping/object"
):
_load_config_data(str(config))


def test_load_config_data_rejects_non_mapping_yaml(tmp_path: Path) -> None:
config = tmp_path / "config.yaml"
config.write_text("- item1\n- item2\n", encoding="utf-8")

with pytest.raises(
ValueError, match="top-level mapping/object"
):
_load_config_data(str(config))