Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions promptlens/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import Optional

import click
import yaml
from dotenv import load_dotenv
from rich.console import Console
from rich.logging import RichHandler

from promptlens import __version__
from promptlens.config_loading import apply_run_overrides, load_config_data
from promptlens.exporters.csv_exporter import CSVExporter
from promptlens.exporters.html_exporter import HTMLExporter
from promptlens.exporters.json_exporter import JSONExporter
Expand Down Expand Up @@ -90,14 +90,10 @@ def run(
try:
# Load config
console.print(f"\n[cyan]Loading configuration from {config}...[/cyan]")
with open(config, "r") as f:
config_data = yaml.safe_load(f)
config_data = load_config_data(config)

# Override with CLI options
if golden_set:
config_data["golden_set"] = golden_set
if output_dir:
config_data["output"]["directory"] = output_dir
config_data = apply_run_overrides(config_data, golden_set, output_dir)

# Parse config
try:
Expand Down
48 changes: 48 additions & 0 deletions promptlens/config_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Config loading helpers for CLI commands."""

from typing import Any, Dict, Optional

import yaml


def load_config_data(config_path: str) -> Dict[str, Any]:
"""Load and validate raw config data from YAML."""
with open(config_path, "r", encoding="utf-8") as f:
config_data = yaml.safe_load(f)

if config_data is None:
raise ValueError(
"Configuration file is empty. "
"Provide a YAML mapping with at least 'golden_set' and 'models'."
)

if not isinstance(config_data, dict):
raise ValueError(
"Configuration root must be a YAML mapping/object "
f"(got {type(config_data).__name__})."
)

return config_data


def apply_run_overrides(
config_data: Dict[str, Any],
golden_set: Optional[str],
output_dir: Optional[str],
) -> Dict[str, Any]:
"""Apply CLI overrides to loaded config data."""
merged = dict(config_data)

if golden_set:
merged["golden_set"] = golden_set

if output_dir:
output_config = merged.get("output")
if output_config is None:
merged["output"] = {"directory": output_dir}
elif isinstance(output_config, dict):
output_config["directory"] = output_dir
else:
raise ValueError("'output' section must be a mapping/object if provided.")

return merged
45 changes: 45 additions & 0 deletions tests/test_cli_config_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests for CLI config loading and override hardening."""

from pathlib import Path

import pytest

from promptlens.config_loading import apply_run_overrides, load_config_data


def test_load_config_data_rejects_empty_yaml(tmp_path: Path) -> None:
config_path = tmp_path / "empty.yaml"
config_path.write_text("", encoding="utf-8")

with pytest.raises(ValueError, match="Configuration file is empty"):
load_config_data(str(config_path))


def test_load_config_data_rejects_non_mapping_root(tmp_path: Path) -> None:
config_path = tmp_path / "list.yaml"
config_path.write_text("- not\n- a\n- mapping\n", encoding="utf-8")

with pytest.raises(ValueError, match="Configuration root must be a YAML mapping"):
load_config_data(str(config_path))


def test_apply_run_overrides_creates_output_mapping_when_missing() -> None:
config = {
"golden_set": "tests.yaml",
"models": [{"name": "m", "provider": "openai", "model": "gpt-4"}],
}

merged = apply_run_overrides(config, golden_set=None, output_dir="./out")

assert merged["output"]["directory"] == "./out"


def test_apply_run_overrides_rejects_invalid_output_section_type() -> None:
config = {
"golden_set": "tests.yaml",
"models": [{"name": "m", "provider": "openai", "model": "gpt-4"}],
"output": "bad-type",
}

with pytest.raises(ValueError, match="'output' section must be a mapping"):
apply_run_overrides(config, golden_set=None, output_dir="./out")