Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion promptlens/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
"""LLM provider implementations."""

from promptlens.providers.base import BaseProvider
from promptlens.providers.factory import get_provider


def get_provider(*args, **kwargs):
"""Lazily import and dispatch provider factory.

Keeps package import lightweight for environments/tests that don't have every
optional provider dependency installed.
"""
from promptlens.providers.factory import get_provider as _get_provider

return _get_provider(*args, **kwargs)


__all__ = ["BaseProvider", "get_provider"]
6 changes: 4 additions & 2 deletions promptlens/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ async def _make_request() -> ModelResponse:
stop_reason=response.stop_reason,
)

max_attempts, initial_delay = self.get_retry_settings(kwargs)

try:
return await retry_with_exponential_backoff(
func=_make_request,
max_attempts=3,
initial_delay=1.0,
max_attempts=max_attempts,
initial_delay=initial_delay,
)
except Exception as e:
logger.error(f"Anthropic request failed: {e}")
Expand Down
34 changes: 33 additions & 1 deletion promptlens/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base provider interface for LLM providers."""

from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Tuple

from promptlens.models.config import ProviderConfig
from promptlens.models.result import ModelResponse
Expand Down Expand Up @@ -91,3 +91,35 @@ def supports_tools(self) -> bool:
True if the provider supports tools, False otherwise
"""
return False # Default: most providers don't support tools yet

def get_retry_settings(self, kwargs: dict) -> Tuple[int, float]:
"""Resolve retry settings from runtime kwargs.

Args:
kwargs: Runtime kwargs passed to ``generate``.

Returns:
Tuple of (max_attempts, initial_delay_seconds)
"""
max_attempts = int(kwargs.get("retry_attempts", 3))
initial_delay = float(kwargs.get("retry_delay_seconds", 1.0))

# Guard rails to prevent invalid runtime values from breaking retries
if max_attempts < 1:
max_attempts = 1
if initial_delay < 0:
initial_delay = 0.0

return max_attempts, initial_delay

def get_timeout_seconds(self, kwargs: dict) -> int:
"""Resolve request timeout in seconds from runtime kwargs.

Args:
kwargs: Runtime kwargs passed to ``generate``.

Returns:
Request timeout in seconds.
"""
timeout_seconds = int(kwargs.get("timeout_seconds", self.config.timeout))
return timeout_seconds if timeout_seconds > 0 else self.config.timeout
6 changes: 4 additions & 2 deletions promptlens/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ async def _make_request() -> ModelResponse:
timestamp=datetime.utcnow(),
)

max_attempts, initial_delay = self.get_retry_settings(kwargs)

try:
return await retry_with_exponential_backoff(
func=_make_request,
max_attempts=3,
initial_delay=1.0,
max_attempts=max_attempts,
initial_delay=initial_delay,
)
except Exception as e:
logger.error(f"Google request failed: {e}")
Expand Down
10 changes: 7 additions & 3 deletions promptlens/providers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ async def _make_request() -> ModelResponse:
if "max_tokens" not in payload and "num_predict" not in payload:
payload["num_predict"] = kwargs.get("max_tokens", self.config.max_tokens)

timeout_seconds = self.get_timeout_seconds(kwargs)

async with aiohttp.ClientSession() as session:
async with session.post(
self.endpoint,
json=payload,
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
timeout=aiohttp.ClientTimeout(total=timeout_seconds),
) as response:
response.raise_for_status()
data = await response.json()
Expand Down Expand Up @@ -109,11 +111,13 @@ async def _make_request() -> ModelResponse:
timestamp=datetime.utcnow(),
)

max_attempts, initial_delay = self.get_retry_settings(kwargs)

try:
return await retry_with_exponential_backoff(
func=_make_request,
max_attempts=3,
initial_delay=1.0,
max_attempts=max_attempts,
initial_delay=initial_delay,
)
except Exception as e:
logger.error(f"HTTP request failed: {e}")
Expand Down
6 changes: 4 additions & 2 deletions promptlens/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,13 @@ async def _make_request() -> ModelResponse:
stop_reason=finish_reason,
)

max_attempts, initial_delay = self.get_retry_settings(kwargs)

try:
return await retry_with_exponential_backoff(
func=_make_request,
max_attempts=3,
initial_delay=1.0,
max_attempts=max_attempts,
initial_delay=initial_delay,
)
except Exception as e:
logger.error(f"OpenAI request failed: {e}")
Expand Down
10 changes: 7 additions & 3 deletions promptlens/providers/you.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ async def _make_request() -> ModelResponse:
if "max_tokens" not in payload:
payload["max_tokens"] = kwargs.get("max_tokens", self.config.max_tokens)

timeout_seconds = self.get_timeout_seconds(kwargs)

async with aiohttp.ClientSession() as session:
async with session.post(
self.base_url,
headers=headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
timeout=aiohttp.ClientTimeout(total=timeout_seconds),
) as response:
response.raise_for_status()
data = await response.json()
Expand Down Expand Up @@ -126,11 +128,13 @@ async def _make_request() -> ModelResponse:
timestamp=datetime.utcnow(),
)

max_attempts, initial_delay = self.get_retry_settings(kwargs)

try:
return await retry_with_exponential_backoff(
func=_make_request,
max_attempts=3,
initial_delay=1.0,
max_attempts=max_attempts,
initial_delay=initial_delay,
)
except Exception as e:
logger.error(f"You.com request failed: {e}")
Expand Down
5 changes: 4 additions & 1 deletion promptlens/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ async def _evaluate_single(
# Generate response (pass tools if provided)
model_response = await provider.generate(
test_case.query,
tools=test_case.tools if test_case.tools else None
tools=test_case.tools if test_case.tools else None,
retry_attempts=self.config.execution.retry_attempts,
retry_delay_seconds=self.config.execution.retry_delay_seconds,
timeout_seconds=self.config.execution.timeout_seconds,
)

# Judge the response (only if generation succeeded)
Expand Down
63 changes: 63 additions & 0 deletions tests/test_provider_runtime_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for runtime retry/timeout settings resolution on providers."""

from promptlens.models.config import ProviderConfig
from promptlens.models.result import ModelResponse
from promptlens.providers.base import BaseProvider


class DummyProvider(BaseProvider):
"""Minimal provider used to test BaseProvider helper methods."""

async def generate(self, prompt: str, tools=None, **kwargs): # pragma: no cover
return ModelResponse(
content=prompt,
model=self.config.model,
provider=self.provider_name,
latency_ms=0.0,
)

def estimate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
return 0.0

@property
def provider_name(self) -> str:
return "dummy"


def test_get_retry_settings_uses_runtime_values() -> None:
provider = DummyProvider(ProviderConfig(name="dummy", model="dummy-model"))

attempts, delay = provider.get_retry_settings(
{"retry_attempts": 5, "retry_delay_seconds": 2.5}
)

assert attempts == 5
assert delay == 2.5


def test_get_retry_settings_clamps_invalid_values() -> None:
provider = DummyProvider(ProviderConfig(name="dummy", model="dummy-model"))

attempts, delay = provider.get_retry_settings(
{"retry_attempts": 0, "retry_delay_seconds": -1}
)

assert attempts == 1
assert delay == 0.0


def test_get_timeout_seconds_prefers_runtime_override() -> None:
provider = DummyProvider(
ProviderConfig(name="dummy", model="dummy-model", timeout=42)
)

assert provider.get_timeout_seconds({"timeout_seconds": 12}) == 12


def test_get_timeout_seconds_falls_back_to_provider_config() -> None:
provider = DummyProvider(
ProviderConfig(name="dummy", model="dummy-model", timeout=42)
)

assert provider.get_timeout_seconds({"timeout_seconds": 0}) == 42
assert provider.get_timeout_seconds({}) == 42