diff --git a/README.md b/README.md index 68d076e..8ca347d 100644 --- a/README.md +++ b/README.md @@ -248,10 +248,9 @@ models: - name: "Local Llama" provider: http model: llama3.1:8b + endpoint: "http://localhost:11434/api/generate" temperature: 0.7 max_tokens: 1024 - additional_params: - endpoint: "http://localhost:11434/api/generate" ``` **Setup:** @@ -270,10 +269,12 @@ models: - name: "Display Name" # Human-readable name provider: anthropic # anthropic, openai, google, http model: model-identifier # Model ID + endpoint: "http://..." # Optional, for HTTP provider temperature: 0.7 # 0.0-1.0 max_tokens: 1024 # Maximum output tokens additional_params: # Provider-specific params - endpoint: "http://..." # For HTTP provider + # endpoint under additional_params is still accepted for backwards compatibility + custom_option: "value" ``` ### Judge diff --git a/examples/configs/local_model.yaml b/examples/configs/local_model.yaml index e642597..448d722 100644 --- a/examples/configs/local_model.yaml +++ b/examples/configs/local_model.yaml @@ -13,10 +13,9 @@ models: - name: "Local Llama 3.1 8B" provider: http model: llama3.1:8b + endpoint: "http://localhost:11434/api/generate" temperature: 0.7 max_tokens: 1024 - additional_params: - endpoint: "http://localhost:11434/api/generate" judge: provider: anthropic diff --git a/promptlens/models/config.py b/promptlens/models/config.py index 7c85798..e14bc4c 100644 --- a/promptlens/models/config.py +++ b/promptlens/models/config.py @@ -36,6 +36,8 @@ class ModelConfig(BaseModel): name: Display name for the model provider: Provider name model: Model identifier + endpoint: Optional endpoint URL (for HTTP/local providers) + timeout: Optional request timeout in seconds temperature: Sampling temperature max_tokens: Maximum tokens to generate additional_params: Provider-specific parameters @@ -44,6 +46,8 @@ class ModelConfig(BaseModel): name: str provider: str model: str + endpoint: Optional[str] = None + timeout: Optional[int] = None temperature: float = 0.7 max_tokens: int = 1024 additional_params: Dict[str, Any] = Field(default_factory=dict) diff --git a/promptlens/providers/factory.py b/promptlens/providers/factory.py index 4f78907..8eed5bf 100644 --- a/promptlens/providers/factory.py +++ b/promptlens/providers/factory.py @@ -41,13 +41,23 @@ def get_provider(model_config: ModelConfig) -> BaseProvider: f"Available providers: {available}" ) + # Copy params so we can normalize legacy keys without mutating input + additional_params = dict(model_config.additional_params) + + # Backwards compatibility: allow endpoint under additional_params for HTTP provider + endpoint = model_config.endpoint + if provider_name == "http" and endpoint is None: + endpoint = additional_params.pop("endpoint", None) + # Convert ModelConfig to ProviderConfig provider_config = ProviderConfig( name=provider_name, model=model_config.model, + endpoint=endpoint, + timeout=model_config.timeout or 60, temperature=model_config.temperature, max_tokens=model_config.max_tokens, - additional_params=model_config.additional_params, + additional_params=additional_params, ) # Get provider class and instantiate diff --git a/tests/test_provider_factory_hardening.py b/tests/test_provider_factory_hardening.py new file mode 100644 index 0000000..61c9baa --- /dev/null +++ b/tests/test_provider_factory_hardening.py @@ -0,0 +1,36 @@ +from promptlens.models.config import ModelConfig +from promptlens.providers.factory import get_provider +from promptlens.providers.http import HTTPProvider + + +def test_http_provider_supports_legacy_endpoint_in_additional_params() -> None: + provider = get_provider( + ModelConfig( + name="Local model", + provider="http", + model="llama3.1:8b", + additional_params={ + "endpoint": "http://localhost:11434/api/generate", + "temperature": 0.2, + }, + ) + ) + + assert isinstance(provider, HTTPProvider) + assert provider.endpoint == "http://localhost:11434/api/generate" + assert "endpoint" not in provider.config.additional_params + assert provider.config.additional_params["temperature"] == 0.2 + + +def test_http_provider_accepts_top_level_endpoint() -> None: + provider = get_provider( + ModelConfig( + name="Local model", + provider="http", + model="llama3.1:8b", + endpoint="http://localhost:11434/api/generate", + ) + ) + + assert isinstance(provider, HTTPProvider) + assert provider.endpoint == "http://localhost:11434/api/generate"