Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,9 @@ models:
- name: "Local Llama"
provider: http
model: llama3.1:8b
endpoint: "http://localhost:11434/api/generate"
temperature: 0.7
max_tokens: 1024
additional_params:
endpoint: "http://localhost:11434/api/generate"
```

**Setup:**
Expand All @@ -270,10 +269,12 @@ models:
- name: "Display Name" # Human-readable name
provider: anthropic # anthropic, openai, google, http
model: model-identifier # Model ID
endpoint: "http://..." # Optional, for HTTP provider
temperature: 0.7 # 0.0-1.0
max_tokens: 1024 # Maximum output tokens
additional_params: # Provider-specific params
endpoint: "http://..." # For HTTP provider
# endpoint under additional_params is still accepted for backwards compatibility
custom_option: "value"
```

### Judge
Expand Down
3 changes: 1 addition & 2 deletions examples/configs/local_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ models:
- name: "Local Llama 3.1 8B"
provider: http
model: llama3.1:8b
endpoint: "http://localhost:11434/api/generate"
temperature: 0.7
max_tokens: 1024
additional_params:
endpoint: "http://localhost:11434/api/generate"

judge:
provider: anthropic
Expand Down
4 changes: 4 additions & 0 deletions promptlens/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class ModelConfig(BaseModel):
name: Display name for the model
provider: Provider name
model: Model identifier
endpoint: Optional endpoint URL (for HTTP/local providers)
timeout: Optional request timeout in seconds
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
additional_params: Provider-specific parameters
Expand All @@ -44,6 +46,8 @@ class ModelConfig(BaseModel):
name: str
provider: str
model: str
endpoint: Optional[str] = None
timeout: Optional[int] = None
temperature: float = 0.7
max_tokens: int = 1024
additional_params: Dict[str, Any] = Field(default_factory=dict)
Expand Down
12 changes: 11 additions & 1 deletion promptlens/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,23 @@ def get_provider(model_config: ModelConfig) -> BaseProvider:
f"Available providers: {available}"
)

# Copy params so we can normalize legacy keys without mutating input
additional_params = dict(model_config.additional_params)

# Backwards compatibility: allow endpoint under additional_params for HTTP provider
endpoint = model_config.endpoint
if provider_name == "http" and endpoint is None:
endpoint = additional_params.pop("endpoint", None)

# Convert ModelConfig to ProviderConfig
provider_config = ProviderConfig(
name=provider_name,
model=model_config.model,
endpoint=endpoint,
timeout=model_config.timeout or 60,
temperature=model_config.temperature,
max_tokens=model_config.max_tokens,
additional_params=model_config.additional_params,
additional_params=additional_params,
)

# Get provider class and instantiate
Expand Down
36 changes: 36 additions & 0 deletions tests/test_provider_factory_hardening.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from promptlens.models.config import ModelConfig
from promptlens.providers.factory import get_provider
from promptlens.providers.http import HTTPProvider


def test_http_provider_supports_legacy_endpoint_in_additional_params() -> None:
provider = get_provider(
ModelConfig(
name="Local model",
provider="http",
model="llama3.1:8b",
additional_params={
"endpoint": "http://localhost:11434/api/generate",
"temperature": 0.2,
},
)
)

assert isinstance(provider, HTTPProvider)
assert provider.endpoint == "http://localhost:11434/api/generate"
assert "endpoint" not in provider.config.additional_params
assert provider.config.additional_params["temperature"] == 0.2


def test_http_provider_accepts_top_level_endpoint() -> None:
provider = get_provider(
ModelConfig(
name="Local model",
provider="http",
model="llama3.1:8b",
endpoint="http://localhost:11434/api/generate",
)
)

assert isinstance(provider, HTTPProvider)
assert provider.endpoint == "http://localhost:11434/api/generate"