diff --git a/promptlens/models/config.py b/promptlens/models/config.py index 7c85798..f289389 100644 --- a/promptlens/models/config.py +++ b/promptlens/models/config.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator class ProviderConfig(BaseModel): @@ -48,6 +48,22 @@ class ModelConfig(BaseModel): max_tokens: int = 1024 additional_params: Dict[str, Any] = Field(default_factory=dict) + @field_validator("temperature") + @classmethod + def validate_temperature(cls, value: float) -> float: + """Ensure temperature stays within common model bounds.""" + if not 0 <= value <= 2: + raise ValueError("temperature must be between 0 and 2") + return value + + @field_validator("max_tokens") + @classmethod + def validate_max_tokens(cls, value: int) -> int: + """Ensure max_tokens is positive.""" + if value <= 0: + raise ValueError("max_tokens must be greater than 0") + return value + class JudgeConfig(BaseModel): """Configuration for the judge. @@ -82,6 +98,38 @@ class ExecutionConfig(BaseModel): retry_delay_seconds: float = 1.0 timeout_seconds: int = 60 + @field_validator("parallel_requests") + @classmethod + def validate_parallel_requests(cls, value: int) -> int: + """Ensure at least one parallel request worker.""" + if value <= 0: + raise ValueError("parallel_requests must be greater than 0") + return value + + @field_validator("retry_attempts") + @classmethod + def validate_retry_attempts(cls, value: int) -> int: + """Ensure retry attempts are non-negative.""" + if value < 0: + raise ValueError("retry_attempts cannot be negative") + return value + + @field_validator("retry_delay_seconds") + @classmethod + def validate_retry_delay(cls, value: float) -> float: + """Ensure retry delay is non-negative.""" + if value < 0: + raise ValueError("retry_delay_seconds cannot be negative") + return value + + @field_validator("timeout_seconds") + @classmethod + def validate_timeout_seconds(cls, value: int) -> int: + """Ensure timeout is positive.""" + if value <= 0: + raise ValueError("timeout_seconds must be greater than 0") + return value + class OutputConfig(BaseModel): """Configuration for output settings. @@ -96,6 +144,25 @@ class OutputConfig(BaseModel): formats: List[str] = Field(default_factory=lambda: ["html", "json"]) run_name: Optional[str] = None + @field_validator("formats") + @classmethod + def validate_formats(cls, formats: List[str]) -> List[str]: + """Normalize and validate export formats.""" + if not formats: + raise ValueError("output.formats must include at least one format") + + allowed_formats = {"json", "csv", "md", "html"} + normalized: List[str] = [] + for format_name in formats: + normalized_name = format_name.lower() + if normalized_name not in allowed_formats: + allowed = ", ".join(sorted(allowed_formats)) + raise ValueError(f"unsupported output format '{format_name}'. Allowed: {allowed}") + if normalized_name not in normalized: + normalized.append(normalized_name) + + return normalized + class RunConfig(BaseModel): """Complete run configuration. @@ -114,6 +181,15 @@ class RunConfig(BaseModel): execution: ExecutionConfig = Field(default_factory=ExecutionConfig) output: OutputConfig = Field(default_factory=OutputConfig) + @field_validator("models") + @classmethod + def validate_models(cls, models: List[ModelConfig]) -> List[ModelConfig]: + """Require at least one model configuration.""" + if not models: + raise ValueError("models must include at least one model configuration") + return models + + class Config: """Pydantic config.""" json_schema_extra = { diff --git a/tests/test_config_validation.py b/tests/test_config_validation.py new file mode 100644 index 0000000..89dbfa2 --- /dev/null +++ b/tests/test_config_validation.py @@ -0,0 +1,42 @@ +"""Tests for configuration hardening and validation.""" + +import pytest +from pydantic import ValidationError + +from promptlens.models.config import ModelConfig, OutputConfig, RunConfig + + +def test_model_config_rejects_invalid_temperature() -> None: + with pytest.raises(ValidationError, match="temperature must be between 0 and 2"): + ModelConfig( + name="Bad Temp", + provider="openai", + model="gpt-4o-mini", + temperature=2.5, + ) + + +def test_model_config_rejects_non_positive_max_tokens() -> None: + with pytest.raises(ValidationError, match="max_tokens must be greater than 0"): + ModelConfig( + name="Bad Tokens", + provider="openai", + model="gpt-4o-mini", + max_tokens=0, + ) + + +def test_output_formats_are_normalized_and_deduplicated() -> None: + output = OutputConfig(formats=["JSON", "html", "json"]) + + assert output.formats == ["json", "html"] + + +def test_output_formats_reject_unknown_values() -> None: + with pytest.raises(ValidationError, match="unsupported output format"): + OutputConfig(formats=["pdf"]) + + +def test_run_config_requires_at_least_one_model() -> None: + with pytest.raises(ValidationError, match="models must include at least one model configuration"): + RunConfig(golden_set="tests.yaml", models=[])