Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions promptlens/providers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,61 @@ def __init__(self, config: ProviderConfig) -> None:

self.endpoint = config.endpoint

@staticmethod
def _extract_content(data: Dict[str, Any]) -> str:
"""Extract text content from common HTTP provider response shapes.

Supports Ollama-style, OpenAI-compatible text/chat completions, and
chunked content arrays.
"""
if "response" in data and isinstance(data["response"], str):
return data["response"]

if "text" in data and isinstance(data["text"], str):
return data["text"]

if "content" in data:
content = data["content"]
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, dict) and isinstance(item.get("text"), str):
parts.append(item["text"])
elif isinstance(item, str):
parts.append(item)
if parts:
return "".join(parts)

choices = data.get("choices")
if isinstance(choices, list) and choices:
choice = choices[0] if isinstance(choices[0], dict) else {}

if isinstance(choice.get("text"), str):
return choice["text"]

message = choice.get("message")
if isinstance(message, dict):
message_content = message.get("content")
if isinstance(message_content, str):
return message_content
if isinstance(message_content, list):
parts = []
for item in message_content:
if isinstance(item, dict) and isinstance(item.get("text"), str):
parts.append(item["text"])
elif isinstance(item, str):
parts.append(item)
if parts:
return "".join(parts)

delta = choice.get("delta")
if isinstance(delta, dict) and isinstance(delta.get("content"), str):
return delta["content"]

return ""

async def generate(
self,
prompt: str,
Expand Down Expand Up @@ -84,17 +139,7 @@ async def _make_request() -> ModelResponse:
response.raise_for_status()
data = await response.json()

# Extract content (try common response formats)
content = ""
if "response" in data:
content = data["response"] # Ollama format
elif "text" in data:
content = data["text"]
elif "content" in data:
content = data["content"]
elif "choices" in data and len(data["choices"]) > 0:
# OpenAI-compatible format
content = data["choices"][0].get("text", "")
content = self._extract_content(data)

# Local models typically don't provide token counts or cost
return ModelResponse(
Expand Down
48 changes: 48 additions & 0 deletions tests/test_http_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from promptlens.models.config import ProviderConfig
from promptlens.providers.http import HTTPProvider


def _provider() -> HTTPProvider:
return HTTPProvider(
ProviderConfig(name="http", model="test-model", endpoint="http://localhost:11434/api/generate")
)


def test_extract_content_ollama_shape() -> None:
provider = _provider()
assert provider._extract_content({"response": "hello"}) == "hello"


def test_extract_content_openai_text_completion_shape() -> None:
provider = _provider()
payload = {"choices": [{"text": "hello from choices"}]}
assert provider._extract_content(payload) == "hello from choices"


def test_extract_content_openai_chat_message_string_shape() -> None:
provider = _provider()
payload = {"choices": [{"message": {"content": "chat reply"}}]}
assert provider._extract_content(payload) == "chat reply"


def test_extract_content_openai_chat_message_parts_shape() -> None:
provider = _provider()
payload = {
"choices": [
{
"message": {
"content": [
{"type": "text", "text": "Part A"},
{"type": "text", "text": " + Part B"},
]
}
}
]
}
assert provider._extract_content(payload) == "Part A + Part B"


def test_extract_content_top_level_content_parts_shape() -> None:
provider = _provider()
payload = {"content": [{"type": "text", "text": "foo"}, {"type": "text", "text": "bar"}]}
assert provider._extract_content(payload) == "foobar"