diff --git a/promptlens/providers/http.py b/promptlens/providers/http.py index 20a130a..9543169 100644 --- a/promptlens/providers/http.py +++ b/promptlens/providers/http.py @@ -35,6 +35,61 @@ def __init__(self, config: ProviderConfig) -> None: self.endpoint = config.endpoint + @staticmethod + def _extract_content(data: Dict[str, Any]) -> str: + """Extract text content from common HTTP provider response shapes. + + Supports Ollama-style, OpenAI-compatible text/chat completions, and + chunked content arrays. + """ + if "response" in data and isinstance(data["response"], str): + return data["response"] + + if "text" in data and isinstance(data["text"], str): + return data["text"] + + if "content" in data: + content = data["content"] + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, dict) and isinstance(item.get("text"), str): + parts.append(item["text"]) + elif isinstance(item, str): + parts.append(item) + if parts: + return "".join(parts) + + choices = data.get("choices") + if isinstance(choices, list) and choices: + choice = choices[0] if isinstance(choices[0], dict) else {} + + if isinstance(choice.get("text"), str): + return choice["text"] + + message = choice.get("message") + if isinstance(message, dict): + message_content = message.get("content") + if isinstance(message_content, str): + return message_content + if isinstance(message_content, list): + parts = [] + for item in message_content: + if isinstance(item, dict) and isinstance(item.get("text"), str): + parts.append(item["text"]) + elif isinstance(item, str): + parts.append(item) + if parts: + return "".join(parts) + + delta = choice.get("delta") + if isinstance(delta, dict) and isinstance(delta.get("content"), str): + return delta["content"] + + return "" + async def generate( self, prompt: str, @@ -84,17 +139,7 @@ async def _make_request() -> ModelResponse: response.raise_for_status() data = await response.json() - # Extract content (try common response formats) - content = "" - if "response" in data: - content = data["response"] # Ollama format - elif "text" in data: - content = data["text"] - elif "content" in data: - content = data["content"] - elif "choices" in data and len(data["choices"]) > 0: - # OpenAI-compatible format - content = data["choices"][0].get("text", "") + content = self._extract_content(data) # Local models typically don't provide token counts or cost return ModelResponse( diff --git a/tests/test_http_provider.py b/tests/test_http_provider.py new file mode 100644 index 0000000..a995ee1 --- /dev/null +++ b/tests/test_http_provider.py @@ -0,0 +1,48 @@ +from promptlens.models.config import ProviderConfig +from promptlens.providers.http import HTTPProvider + + +def _provider() -> HTTPProvider: + return HTTPProvider( + ProviderConfig(name="http", model="test-model", endpoint="http://localhost:11434/api/generate") + ) + + +def test_extract_content_ollama_shape() -> None: + provider = _provider() + assert provider._extract_content({"response": "hello"}) == "hello" + + +def test_extract_content_openai_text_completion_shape() -> None: + provider = _provider() + payload = {"choices": [{"text": "hello from choices"}]} + assert provider._extract_content(payload) == "hello from choices" + + +def test_extract_content_openai_chat_message_string_shape() -> None: + provider = _provider() + payload = {"choices": [{"message": {"content": "chat reply"}}]} + assert provider._extract_content(payload) == "chat reply" + + +def test_extract_content_openai_chat_message_parts_shape() -> None: + provider = _provider() + payload = { + "choices": [ + { + "message": { + "content": [ + {"type": "text", "text": "Part A"}, + {"type": "text", "text": " + Part B"}, + ] + } + } + ] + } + assert provider._extract_content(payload) == "Part A + Part B" + + +def test_extract_content_top_level_content_parts_shape() -> None: + provider = _provider() + payload = {"content": [{"type": "text", "text": "foo"}, {"type": "text", "text": "bar"}]} + assert provider._extract_content(payload) == "foobar"