diff --git a/tests/integration/llm/providers/test_anthropic.py b/tests/integration/llm/providers/test_anthropic.py index b5c6af3..7a46517 100644 --- a/tests/integration/llm/providers/test_anthropic.py +++ b/tests/integration/llm/providers/test_anthropic.py @@ -37,3 +37,35 @@ def test_anthropic_provider_get_response_live_call(): wait_time=1, ) ) + + +@pytest.mark.integration +def test_anthropic_provider_web_search(): + """It invokes web search and returns current information.""" + from utils.llm.model_registry import ( + _get_api_key_for_provider, # type: ignore[import] + ) + from utils.llm.providers.anthropic import AnthropicProvider # type: ignore[import] + + api_key = _get_api_key_for_provider(AnthropicProvider) + assert api_key is not None, "API key should be configured by fixture" + provider = anthropic_module.AnthropicProvider(api_key=api_key) + + # Ask a question with a stable, verifiable answer + prompt = "What is the official website URL for Python? Just respond with the URL." + response = provider.get_response( + ANTHROPIC_MODEL, + prompt, + tools=[ + { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 1, + } + ], + max_tokens=256, + wait_time=1, + ) + + assert isinstance(response, str) + assert "python.org" in response.lower() diff --git a/tests/integration/llm/providers/test_openai.py b/tests/integration/llm/providers/test_openai.py index 7486827..72decba 100644 --- a/tests/integration/llm/providers/test_openai.py +++ b/tests/integration/llm/providers/test_openai.py @@ -35,3 +35,29 @@ def test_openai_provider_get_response_live_call(): wait_time=1, ) ) + + +@pytest.mark.integration +def test_openai_provider_web_search(): + """It invokes web search and returns current information.""" + from utils.llm.model_registry import ( + _get_api_key_for_provider, # type: ignore[import] + ) + from utils.llm.providers.openai import OpenAIProvider # type: ignore[import] + + api_key = _get_api_key_for_provider(OpenAIProvider) + assert api_key is not None, "API key should be configured by fixture" + provider = openai_module.OpenAIProvider(api_key=api_key) + + # Ask a question with a stable, verifiable answer + prompt = "What is the official website URL for Python? Just respond with the URL." + response = provider.get_response( + OPENAI_MODEL, + prompt, + tools=[{"type": "web_search"}], + max_tokens=256, + wait_time=1, + ) + + assert isinstance(response, str) + assert "python.org" in response.lower() diff --git a/utils/llm/providers/anthropic.py b/utils/llm/providers/anthropic.py index afd9b7f..2bb8892 100644 --- a/utils/llm/providers/anthropic.py +++ b/utils/llm/providers/anthropic.py @@ -42,6 +42,7 @@ def _call_model(self, model: "Model", prompt: str, **options: Any) -> str: temperature = options.get("temperature") max_tokens = options.get("max_tokens") assert max_tokens is not None, "max_tokens is required for Anthropic models." + tools = options.get("tools") model_name = model.full_name call_args: dict[str, Any] = { @@ -56,8 +57,15 @@ def _call_model(self, model: "Model", prompt: str, **options: Any) -> str: } if temperature is not None: call_args["temperature"] = temperature + if tools is not None: + call_args["tools"] = tools with self._anthropic_console.messages.stream(**call_args) as stream: stream.until_done() - return stream.get_final_message().content[0].text + message = stream.get_final_message() + # Extract text from the last text block (web search responses have multiple blocks) + text_blocks = [block for block in message.content if block.type == "text"] + if not text_blocks: + return "" + return "".join(block.text for block in text_blocks) diff --git a/utils/llm/providers/openai.py b/utils/llm/providers/openai.py index 44e1b1a..d2a9a40 100644 --- a/utils/llm/providers/openai.py +++ b/utils/llm/providers/openai.py @@ -36,25 +36,25 @@ def __init__(self, *, api_key: str | None = None, default_wait_time: int | None self._openai_client = OpenAI(api_key=api_key) def _call_model(self, model: "Model", prompt: str, **options: Any) -> str: - temperature = options.get("temperature", 0.8) - max_tokens = options.get("max_tokens") model_name = model.full_name - # OpenAI doesn't support temperature for reasoning models - if model.reasoning_model: - request_payload: Dict[str, Any] = { - "model": model_name, - "input": prompt, - } - else: - request_payload = { - "model": model_name, - "input": prompt, - "temperature": temperature, - } - - if max_tokens is not None: - request_payload["max_output_tokens"] = max_tokens + request_payload: Dict[str, Any] = { + "model": model_name, + "input": prompt, + } + + # Add default temperature for non-reasoning models (unless explicitly provided) + if not model.reasoning_model and "temperature" not in options: + request_payload["temperature"] = 0.8 + + # Pass through all options, with special handling for some keys + for key, value in options.items(): + if key == "max_tokens": + request_payload["max_output_tokens"] = value + elif key == "temperature" and model.reasoning_model: + continue # OpenAI doesn't support temperature for reasoning models + else: + request_payload[key] = value response = self._openai_client.responses.create(**request_payload)