Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 58 additions & 26 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel
from typing_extensions import Unpack, override

from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.streaming import StreamEvent
Expand Down Expand Up @@ -202,6 +203,10 @@ async def structured_output(
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
"""Get structured output from the model.

Some models do not support native structured output via response_format.
In cases of proxies, we may not have a way to determine support, so we
fallback to using tool calling to achieve structured output.

Args:
output_model: The output model to use for the agent.
prompt: The prompt messages to use for the agent.
Expand All @@ -211,42 +216,69 @@ async def structured_output(
Yields:
Model events with the last being the structured output.
"""
supports_schema = supports_response_schema(self.get_config()["model_id"])
if supports_response_schema(self.get_config()["model_id"]):
logger.debug("structuring output using response schema")
result = await self._structured_output_using_response_schema(output_model, prompt, system_prompt)
else:
logger.debug("model does not support response schema, structuring output using tool approach")
result = await self._structured_output_using_tool(output_model, prompt, system_prompt)

yield {"output": result}

async def _structured_output_using_response_schema(
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None
) -> T:
"""Get structured output using native response_format support."""
response = await litellm.acompletion(
**self.client_args,
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)

# If the provider does not support response schemas, we cannot reliably parse structured output.
# In that case we must not call the provider and must raise the documented ValueError.
if not supports_schema:
raise ValueError("Model does not support response_format")
if len(response.choices) > 1:
raise ValueError("Multiple choices found in the response.")
if not response.choices or response.choices[0].finish_reason != "tool_calls":
raise ValueError("No tool_calls found in response")

# For providers that DO support response schemas, call litellm and map context-window errors.
choice = response.choices[0]
try:
response = await litellm.acompletion(
**self.client_args,
model=self.get_config()["model_id"],
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
response_format=output_model,
)
# Parse the message content as JSON
tool_call_data = json.loads(choice.message.content)
# Instantiate the output model with the parsed data
return output_model(**tool_call_data)
except ContextWindowExceededError as e:
logger.warning("litellm client raised context window overflow in structured_output")
raise ContextWindowOverflowException(e) from e
except (json.JSONDecodeError, TypeError, ValueError) as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e

async def _structured_output_using_tool(
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None
) -> T:
"""Get structured output using tool calling fallback."""
tool_spec = convert_pydantic_to_tool_spec(output_model)
request = self.format_request(prompt, [tool_spec], system_prompt, cast(ToolChoice, {"any": {}}))
args = {**self.client_args, **request, "stream": False}
response = await litellm.acompletion(**args)

if len(response.choices) > 1:
raise ValueError("Multiple choices found in the response.")
if not response.choices or response.choices[0].finish_reason != "tool_calls":
raise ValueError("No tool_calls found in response")

# Find the first choice with tool_calls
for choice in response.choices:
if choice.finish_reason == "tool_calls":
try:
# Parse the tool call content as JSON
tool_call_data = json.loads(choice.message.content)
# Instantiate the output model with the parsed data
yield {"output": output_model(**tool_call_data)}
return
except (json.JSONDecodeError, TypeError, ValueError) as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e

# If no tool_calls found, raise an error
raise ValueError("No tool_calls found in response")
choice = response.choices[0]
try:
# Parse the tool call content as JSON
tool_call = choice.message.tool_calls[0]
tool_call_data = json.loads(tool_call.function.arguments)
# Instantiate the output model with the parsed data
return output_model(**tool_call_data)
except ContextWindowExceededError as e:
logger.warning("litellm client raised context window overflow in structured_output")
raise ContextWindowOverflowException(e) from e
except (json.JSONDecodeError, TypeError, ValueError) as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e

def _apply_proxy_prefix(self) -> None:
"""Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True.
Expand Down
22 changes: 17 additions & 5 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,27 @@ async def test_structured_output(litellm_acompletion, model, test_output_model_c


@pytest.mark.asyncio
async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls):
async def test_structured_output_unsupported_model(litellm_acompletion, model, test_output_model_cls, alist):
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

mock_tool_call = unittest.mock.Mock()
mock_tool_call.function.arguments = '{"name": "John", "age": 30}'

mock_choice = unittest.mock.Mock()
mock_choice.finish_reason = "tool_calls"
mock_choice.message.tool_calls = [mock_tool_call]
mock_response = unittest.mock.Mock()
mock_response.choices = [mock_choice]

litellm_acompletion.return_value = mock_response

with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=False):
with pytest.raises(ValueError, match="Model does not support response_format"):
stream = model.structured_output(test_output_model_cls, messages)
await stream.__anext__()
stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)
tru_result = events[-1]

litellm_acompletion.assert_not_called()
exp_result = {"output": test_output_model_cls(name="John", age=30)}
assert tru_result == exp_result


def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings):
Expand Down
61 changes: 61 additions & 0 deletions tests_integ/models/test_model_litellm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import unittest.mock

import pydantic
import pytest

Expand Down Expand Up @@ -40,6 +42,37 @@ class Weather(pydantic.BaseModel):
return Weather(time="12:00", weather="sunny")


class Location(pydantic.BaseModel):
"""Location information."""

city: str = pydantic.Field(description="The city name")
country: str = pydantic.Field(description="The country name")


class WeatherCondition(pydantic.BaseModel):
"""Weather condition details."""

condition: str = pydantic.Field(description="The weather condition (e.g., 'sunny', 'rainy', 'cloudy')")
temperature: int = pydantic.Field(description="Temperature in Celsius")


class NestedWeather(pydantic.BaseModel):
"""Weather report with nested location and condition information."""

time: str = pydantic.Field(description="The time in HH:MM format")
location: Location = pydantic.Field(description="Location information")
weather: WeatherCondition = pydantic.Field(description="Weather condition details")


@pytest.fixture
def nested_weather():
return NestedWeather(
time="12:00",
location=Location(city="New York", country="USA"),
weather=WeatherCondition(condition="sunny", temperature=25),
)


@pytest.fixture
def yellow_color():
class Color(pydantic.BaseModel):
Expand Down Expand Up @@ -134,3 +167,31 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color):
tru_color = agent.structured_output(type(yellow_color), content)
exp_color = yellow_color
assert tru_color == exp_color


def test_structured_output_unsupported_model(model, nested_weather):
# Mock supports_response_schema to return False to test fallback mechanism
with (
unittest.mock.patch.multiple(
"strands.models.litellm",
supports_response_schema=unittest.mock.DEFAULT,
) as mocks,
unittest.mock.patch.object(
model, "_structured_output_using_tool", wraps=model._structured_output_using_tool
) as mock_tool,
unittest.mock.patch.object(
model, "_structured_output_using_response_schema", wraps=model._structured_output_using_response_schema
) as mock_schema,
):
mocks["supports_response_schema"].return_value = False

# Test that structured output still works via tool calling fallback
agent = Agent(model=model)
prompt = "The time is 12:00 in New York, USA and the weather is sunny with temperature 25 degrees Celsius"
tru_weather = agent.structured_output(NestedWeather, prompt)
exp_weather = nested_weather
assert tru_weather == exp_weather

# Verify that the tool method was called and schema method was not
mock_tool.assert_called_once()
mock_schema.assert_not_called()