Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion promptlens/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator


class ProviderConfig(BaseModel):
Expand Down Expand Up @@ -48,6 +48,22 @@ class ModelConfig(BaseModel):
max_tokens: int = 1024
additional_params: Dict[str, Any] = Field(default_factory=dict)

@field_validator("temperature")
@classmethod
def validate_temperature(cls, value: float) -> float:
"""Ensure temperature stays within common model bounds."""
if not 0 <= value <= 2:
raise ValueError("temperature must be between 0 and 2")
return value

@field_validator("max_tokens")
@classmethod
def validate_max_tokens(cls, value: int) -> int:
"""Ensure max_tokens is positive."""
if value <= 0:
raise ValueError("max_tokens must be greater than 0")
return value


class JudgeConfig(BaseModel):
"""Configuration for the judge.
Expand Down Expand Up @@ -82,6 +98,38 @@ class ExecutionConfig(BaseModel):
retry_delay_seconds: float = 1.0
timeout_seconds: int = 60

@field_validator("parallel_requests")
@classmethod
def validate_parallel_requests(cls, value: int) -> int:
"""Ensure at least one parallel request worker."""
if value <= 0:
raise ValueError("parallel_requests must be greater than 0")
return value

@field_validator("retry_attempts")
@classmethod
def validate_retry_attempts(cls, value: int) -> int:
"""Ensure retry attempts are non-negative."""
if value < 0:
raise ValueError("retry_attempts cannot be negative")
return value

@field_validator("retry_delay_seconds")
@classmethod
def validate_retry_delay(cls, value: float) -> float:
"""Ensure retry delay is non-negative."""
if value < 0:
raise ValueError("retry_delay_seconds cannot be negative")
return value

@field_validator("timeout_seconds")
@classmethod
def validate_timeout_seconds(cls, value: int) -> int:
"""Ensure timeout is positive."""
if value <= 0:
raise ValueError("timeout_seconds must be greater than 0")
return value


class OutputConfig(BaseModel):
"""Configuration for output settings.
Expand All @@ -96,6 +144,25 @@ class OutputConfig(BaseModel):
formats: List[str] = Field(default_factory=lambda: ["html", "json"])
run_name: Optional[str] = None

@field_validator("formats")
@classmethod
def validate_formats(cls, formats: List[str]) -> List[str]:
"""Normalize and validate export formats."""
if not formats:
raise ValueError("output.formats must include at least one format")

allowed_formats = {"json", "csv", "md", "html"}
normalized: List[str] = []
for format_name in formats:
normalized_name = format_name.lower()
if normalized_name not in allowed_formats:
allowed = ", ".join(sorted(allowed_formats))
raise ValueError(f"unsupported output format '{format_name}'. Allowed: {allowed}")
if normalized_name not in normalized:
normalized.append(normalized_name)

return normalized


class RunConfig(BaseModel):
"""Complete run configuration.
Expand All @@ -114,6 +181,15 @@ class RunConfig(BaseModel):
execution: ExecutionConfig = Field(default_factory=ExecutionConfig)
output: OutputConfig = Field(default_factory=OutputConfig)

@field_validator("models")
@classmethod
def validate_models(cls, models: List[ModelConfig]) -> List[ModelConfig]:
"""Require at least one model configuration."""
if not models:
raise ValueError("models must include at least one model configuration")
return models


class Config:
"""Pydantic config."""
json_schema_extra = {
Expand Down
42 changes: 42 additions & 0 deletions tests/test_config_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for configuration hardening and validation."""

import pytest
from pydantic import ValidationError

from promptlens.models.config import ModelConfig, OutputConfig, RunConfig


def test_model_config_rejects_invalid_temperature() -> None:
with pytest.raises(ValidationError, match="temperature must be between 0 and 2"):
ModelConfig(
name="Bad Temp",
provider="openai",
model="gpt-4o-mini",
temperature=2.5,
)


def test_model_config_rejects_non_positive_max_tokens() -> None:
with pytest.raises(ValidationError, match="max_tokens must be greater than 0"):
ModelConfig(
name="Bad Tokens",
provider="openai",
model="gpt-4o-mini",
max_tokens=0,
)


def test_output_formats_are_normalized_and_deduplicated() -> None:
output = OutputConfig(formats=["JSON", "html", "json"])

assert output.formats == ["json", "html"]


def test_output_formats_reject_unknown_values() -> None:
with pytest.raises(ValidationError, match="unsupported output format"):
OutputConfig(formats=["pdf"])


def test_run_config_requires_at_least_one_model() -> None:
with pytest.raises(ValidationError, match="models must include at least one model configuration"):
RunConfig(golden_set="tests.yaml", models=[])