From 2eb34641fb54ddeb10f4a65e9aa2d8829de4fdf7 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 1 Oct 2025 14:17:03 -0400 Subject: [PATCH 1/3] feat(models): use tool for structured_output when supports_response_schema=false --- src/strands/models/litellm.py | 71 ++++++++++++++++++------ tests/strands/models/test_litellm.py | 22 ++++++-- tests_integ/models/test_model_litellm.py | 50 +++++++++++++++++ 3 files changed, 121 insertions(+), 22 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..ab0b569f2 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -12,6 +12,8 @@ from pydantic import BaseModel from typing_extensions import Unpack, override +from ..event_loop import streaming +from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -196,6 +198,10 @@ async def structured_output( ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. + Some models do not support native structured output via response_format. + In cases of proxies, we may not have a way to determine support, so we + fallback to using tool calling to achieve structured output. + Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. @@ -205,33 +211,64 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - if not supports_response_schema(self.get_config()["model_id"]): - raise ValueError("Model does not support response_format") + if supports_response_schema(self.get_config()["model_id"]): + result = await self._structured_output_using_response_schema(output_model, prompt, system_prompt) + else: + result = await self._structured_output_using_tool(output_model, prompt, system_prompt) + + yield {"output": result} + async def _structured_output_using_response_schema( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using native response_format support.""" response = await litellm.acompletion( **self.client_args, model=self.get_config()["model_id"], messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], response_format=output_model, ) - + if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") + + choice = response.choices[0] + try: + # Parse the message content as JSON + tool_call_data = json.loads(choice.message.content) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + + async def _structured_output_using_tool( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + ) -> T: + """Get structured output using tool calling fallback.""" + tool_spec = convert_pydantic_to_tool_spec(output_model) + request = self.format_request(prompt, [tool_spec], system_prompt, cast(ToolChoice, {"any": {}})) + args = {**self.client_args, **request, 'stream': False} + response = await litellm.acompletion(**args) + + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the response.") + if not response.choices or response.choices[0].finish_reason != "tool_calls": + raise ValueError("No tool_calls found in response") + + choice = response.choices[0] + try: + # Parse the tool call content as JSON + tool_call = choice.message.tool_calls[0] + tool_call_data = json.loads(tool_call.function.arguments) + # Instantiate the output model with the parsed data + return output_model(**tool_call_data) + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError(f"Failed to parse or load content into model: {e}") from e + - # Find the first choice with tool_calls - for choice in response.choices: - if choice.finish_reason == "tool_calls": - try: - # Parse the tool call content as JSON - tool_call_data = json.loads(choice.message.content) - # Instantiate the output model with the parsed data - yield {"output": output_model(**tool_call_data)} - return - except (json.JSONDecodeError, TypeError, ValueError) as e: - raise ValueError(f"Failed to parse or load content into model: {e}") from e - - # If no tool_calls found, raise an error - raise ValueError("No tool_calls found in response") + def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bc81fc819..86da7a023 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -290,15 +290,27 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c @pytest.mark.asyncio -async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls): +async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.arguments = '{"name": "John", "age": 30}' + + mock_choice = unittest.mock.Mock() + mock_choice.finish_reason = "tool_calls" + mock_choice.message.tool_calls = [mock_tool_call] + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + litellm_acompletion.return_value = mock_response + with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False): - with pytest.raises(ValueError, match="Model does not support response_format"): - stream = model.structured_output(test_output_model_cls, messages) - await stream.__anext__() + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + tru_result = events[-1] - litellm_acompletion.assert_not_called() + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings): diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 6cfdd3038..53ee65b5c 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,5 +1,6 @@ import pydantic import pytest +import unittest.mock import strands from strands import Agent @@ -40,6 +41,33 @@ class Weather(pydantic.BaseModel): return Weather(time="12:00", weather="sunny") +class Location(pydantic.BaseModel): + """Location information.""" + city: str = pydantic.Field(description="The city name") + country: str = pydantic.Field(description="The country name") + +class WeatherCondition(pydantic.BaseModel): + """Weather condition details.""" + condition: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')") + temperature: int = pydantic.Field(description="Temperature in Celsius") + +class NestedWeather(pydantic.BaseModel): + """Weather report with nested location and condition information.""" + time: str = pydantic.Field(description="The time in HH:MM format") + location: Location = pydantic.Field(description="Location information") + weather: WeatherCondition = pydantic.Field(description="Weather condition details") + + +@pytest.fixture +def nested_weather(): + + return NestedWeather( + time="12:00", + location=Location(city="New York", country="USA"), + weather=WeatherCondition(condition="sunny", temperature=25) + ) + + @pytest.fixture def yellow_color(): class Color(pydantic.BaseModel): @@ -134,3 +162,25 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): tru_color = agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + +def test_structured_output_unsupported_model(model, nested_weather): + # Mock supports_response_schema to return False to test fallback mechanism + with unittest.mock.patch.multiple( + 'strands.models.litellm', + supports_response_schema=unittest.mock.DEFAULT, + ) as mocks, \ + unittest.mock.patch.object(model, '_structured_output_using_tool', wraps=model._structured_output_using_tool) as mock_tool, \ + unittest.mock.patch.object(model, '_structured_output_using_response_schema', wraps=model._structured_output_using_response_schema) as mock_schema: + + mocks['supports_response_schema'].return_value = False + + # Test that structured output still works via tool calling fallback + agent = Agent(model=model) + prompt = "The time is 12:00 in New York, USA and the weather is sunny with temperature 25 degrees Celsius" + tru_weather = agent.structured_output(NestedWeather, prompt) + exp_weather = nested_weather + assert tru_weather == exp_weather + + # Verify that the tool method was called and schema method was not + mock_tool.assert_called_once() + mock_schema.assert_not_called() \ No newline at end of file From 70d2dfd81770e2a43d995a2a95b39b14ccd0db02 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 1 Oct 2025 14:27:51 -0400 Subject: [PATCH 2/3] fix: linting --- src/strands/models/litellm.py | 10 ++---- tests/strands/models/test_litellm.py | 2 +- tests_integ/models/test_model_litellm.py | 39 +++++++++++++++--------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index ab0b569f2..7c213d4d6 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -12,7 +12,6 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent @@ -228,7 +227,7 @@ async def _structured_output_using_response_schema( messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], response_format=output_model, ) - + if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") if not response.choices or response.choices[0].finish_reason != "tool_calls": @@ -249,9 +248,9 @@ async def _structured_output_using_tool( """Get structured output using tool calling fallback.""" tool_spec = convert_pydantic_to_tool_spec(output_model) request = self.format_request(prompt, [tool_spec], system_prompt, cast(ToolChoice, {"any": {}})) - args = {**self.client_args, **request, 'stream': False} + args = {**self.client_args, **request, "stream": False} response = await litellm.acompletion(**args) - + if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") if not response.choices or response.choices[0].finish_reason != "tool_calls": @@ -267,9 +266,6 @@ async def _structured_output_using_tool( except (json.JSONDecodeError, TypeError, ValueError) as e: raise ValueError(f"Failed to parse or load content into model: {e}") from e - - - def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 86da7a023..a7c352993 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -295,7 +295,7 @@ async def test_structured_output_unsupported_model(litellm_acompletion, model, t mock_tool_call = unittest.mock.Mock() mock_tool_call.function.arguments = '{"name": "John", "age": 30}' - + mock_choice = unittest.mock.Mock() mock_choice.finish_reason = "tool_calls" mock_choice.message.tool_calls = [mock_tool_call] diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 53ee65b5c..c5a09e3e9 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,6 +1,7 @@ +import unittest.mock + import pydantic import pytest -import unittest.mock import strands from strands import Agent @@ -43,16 +44,21 @@ class Weather(pydantic.BaseModel): class Location(pydantic.BaseModel): """Location information.""" + city: str = pydantic.Field(description="The city name") country: str = pydantic.Field(description="The country name") + class WeatherCondition(pydantic.BaseModel): """Weather condition details.""" + condition: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')") temperature: int = pydantic.Field(description="Temperature in Celsius") + class NestedWeather(pydantic.BaseModel): """Weather report with nested location and condition information.""" + time: str = pydantic.Field(description="The time in HH:MM format") location: Location = pydantic.Field(description="Location information") weather: WeatherCondition = pydantic.Field(description="Weather condition details") @@ -60,11 +66,10 @@ class NestedWeather(pydantic.BaseModel): @pytest.fixture def nested_weather(): - return NestedWeather( time="12:00", location=Location(city="New York", country="USA"), - weather=WeatherCondition(condition="sunny", temperature=25) + weather=WeatherCondition(condition="sunny", temperature=25), ) @@ -163,24 +168,30 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): exp_color = yellow_color assert tru_color == exp_color + def test_structured_output_unsupported_model(model, nested_weather): # Mock supports_response_schema to return False to test fallback mechanism - with unittest.mock.patch.multiple( - 'strands.models.litellm', - supports_response_schema=unittest.mock.DEFAULT, - ) as mocks, \ - unittest.mock.patch.object(model, '_structured_output_using_tool', wraps=model._structured_output_using_tool) as mock_tool, \ - unittest.mock.patch.object(model, '_structured_output_using_response_schema', wraps=model._structured_output_using_response_schema) as mock_schema: - - mocks['supports_response_schema'].return_value = False - + with ( + unittest.mock.patch.multiple( + "strands.models.litellm", + supports_response_schema=unittest.mock.DEFAULT, + ) as mocks, + unittest.mock.patch.object( + model, "_structured_output_using_tool", wraps=model._structured_output_using_tool + ) as mock_tool, + unittest.mock.patch.object( + model, "_structured_output_using_response_schema", wraps=model._structured_output_using_response_schema + ) as mock_schema, + ): + mocks["supports_response_schema"].return_value = False + # Test that structured output still works via tool calling fallback agent = Agent(model=model) prompt = "The time is 12:00 in New York, USA and the weather is sunny with temperature 25 degrees Celsius" tru_weather = agent.structured_output(NestedWeather, prompt) exp_weather = nested_weather assert tru_weather == exp_weather - + # Verify that the tool method was called and schema method was not mock_tool.assert_called_once() - mock_schema.assert_not_called() \ No newline at end of file + mock_schema.assert_not_called() From 3ea515b8dda76fcafd88f8a0941578c33f31ff23 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 9 Oct 2025 09:30:04 -0400 Subject: [PATCH 3/3] Update litellm.py --- src/strands/models/litellm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 7c213d4d6..98a4b9f6a 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -211,8 +211,10 @@ async def structured_output( Model events with the last being the structured output. """ if supports_response_schema(self.get_config()["model_id"]): + logger.debug("structuring output using response schema") result = await self._structured_output_using_response_schema(output_model, prompt, system_prompt) else: + logger.debug("model does not support response schema, structuring output using tool approach") result = await self._structured_output_using_tool(output_model, prompt, system_prompt) yield {"output": result}