Skip to content

Commit 6a3f111

Browse files
committed
fix(models): Add non-streaming support and resolve type conflicts
This commit introduces a streaming parameter to OpenAIModel to allow for non-streaming responses. The initial implementation revealed a type incompatibility in the inheriting LiteLLMModel. This has been resolved by updating LiteLLMConfig to be consistent with the parent OpenAIConfig, ensuring all pre-commit checks pass. The associated unit tests for OpenAIModel have also been improved to verify the non-streaming behavior.
1 parent 104ecb5 commit 6a3f111

File tree

4 files changed

+185
-37
lines changed

4 files changed

+185
-37
lines changed

src/strands/models/litellm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type:
7171
self._apply_proxy_prefix()
7272

7373
@override
74-
def get_config(self) -> LiteLLMConfig:
74+
def get_config(self) -> LiteLLMConfig: # type: ignore[override]
7575
"""Get the LiteLLM model configuration.
7676
7777
Returns:

src/strands/models/openai.py

Lines changed: 114 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,13 @@ class OpenAIConfig(TypedDict, total=False):
5050
params: Model parameters (e.g., max_tokens).
5151
For a complete list of supported parameters, see
5252
https://platform.openai.com/docs/api-reference/chat/create.
53+
streaming: Optional flag to indicate whether provider streaming should be used.
54+
If omitted, defaults to True (preserves existing behaviour).
5355
"""
5456

5557
model_id: str
5658
params: Optional[dict[str, Any]]
59+
streaming: bool | None
5760

5861
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
5962
"""Initialize provider instance.
@@ -263,7 +266,7 @@ def format_request(
263266
return {
264267
"messages": self.format_request_messages(messages, system_prompt),
265268
"model": self.config["model_id"],
266-
"stream": True,
269+
"stream": self.config.get("streaming", True),
267270
"stream_options": {"include_usage": True},
268271
"tools": [
269272
{
@@ -352,6 +355,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
352355
case _:
353356
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
354357

358+
def _convert_non_streaming_to_streaming(self, response: Any) -> list[StreamEvent]:
359+
"""Convert a provider non-streaming response into streaming-style events.
360+
361+
This helper intentionally *does not* emit the initial message_start/content_start events,
362+
because the caller (stream) already yields them to preserve parity with streaming flow.
363+
"""
364+
events: list[StreamEvent] = []
365+
366+
# Extract main text content from first choice if available
367+
if getattr(response, "choices", None):
368+
choice = response.choices[0]
369+
content = None
370+
if hasattr(choice, "message") and hasattr(choice.message, "content"):
371+
content = choice.message.content
372+
373+
# handle str content
374+
if isinstance(content, str):
375+
events.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content}))
376+
# handle list content (list of blocks/dicts)
377+
elif isinstance(content, list):
378+
for block in content:
379+
if isinstance(block, dict):
380+
# reasoning content
381+
if "reasoningContent" in block and isinstance(block["reasoningContent"], dict):
382+
try:
383+
text = block["reasoningContent"]["reasoningText"]["text"]
384+
events.append(
385+
self.format_chunk(
386+
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": text}
387+
)
388+
)
389+
except Exception:
390+
# fall back to keeping the block as text if malformed
391+
pass
392+
# text block
393+
elif "text" in block:
394+
events.append(
395+
self.format_chunk(
396+
{"chunk_type": "content_delta", "data_type": "text", "data": block["text"]}
397+
)
398+
)
399+
# ignore other block types for now
400+
elif isinstance(block, str):
401+
events.append(
402+
self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": block})
403+
)
404+
405+
# content stop
406+
events.append(self.format_chunk({"chunk_type": "content_stop"}))
407+
408+
# message stop — convert finish reason if available
409+
stop_reason = None
410+
if getattr(response, "choices", None):
411+
stop_reason = getattr(response.choices[0], "finish_reason", None)
412+
events.append(self.format_chunk({"chunk_type": "message_stop", "data": stop_reason or "stop"}))
413+
414+
# metadata (usage) if present
415+
if getattr(response, "usage", None):
416+
events.append(self.format_chunk({"chunk_type": "metadata", "data": response.usage}))
417+
418+
return events
419+
355420
@override
356421
async def stream(
357422
self,
@@ -409,50 +474,63 @@ async def stream(
409474

410475
tool_calls: dict[int, list[Any]] = {}
411476

412-
async for event in response:
413-
# Defensive: skip events with empty or missing choices
414-
if not getattr(event, "choices", None):
415-
continue
416-
choice = event.choices[0]
417-
418-
if choice.delta.content:
419-
yield self.format_chunk(
420-
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
421-
)
422-
423-
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
424-
yield self.format_chunk(
425-
{
426-
"chunk_type": "content_delta",
427-
"data_type": "reasoning_content",
428-
"data": choice.delta.reasoning_content,
429-
}
430-
)
477+
streaming = self.config.get("streaming", True)
478+
479+
if streaming:
480+
# response is an async iterator when streaming=True
481+
async for event in response:
482+
# Defensive: skip events with empty or missing choices
483+
if not getattr(event, "choices", None):
484+
continue
485+
choice = event.choices[0]
486+
487+
if choice.delta.content:
488+
yield self.format_chunk(
489+
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
490+
)
491+
492+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
493+
yield self.format_chunk(
494+
{
495+
"chunk_type": "content_delta",
496+
"data_type": "reasoning_content",
497+
"data": choice.delta.reasoning_content,
498+
}
499+
)
431500

432-
for tool_call in choice.delta.tool_calls or []:
433-
tool_calls.setdefault(tool_call.index, []).append(tool_call)
501+
for tool_call in choice.delta.tool_calls or []:
502+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
434503

435-
if choice.finish_reason:
436-
break
504+
if choice.finish_reason:
505+
break
437506

438-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
507+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
439508

440-
for tool_deltas in tool_calls.values():
441-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
509+
for tool_deltas in tool_calls.values():
510+
yield self.format_chunk(
511+
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
512+
)
442513

443-
for tool_delta in tool_deltas:
444-
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
514+
for tool_delta in tool_deltas:
515+
yield self.format_chunk(
516+
{"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
517+
)
445518

446-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
519+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
447520

448-
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
521+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
449522

450-
# Skip remaining events as we don't have use for anything except the final usage payload
451-
async for event in response:
452-
_ = event
523+
# Skip remaining events as we don't have use for anything except the final usage payload
524+
async for event in response:
525+
_ = event
453526

454-
if event.usage:
455-
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
527+
if event.usage:
528+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
529+
else:
530+
# Non-streaming provider response — convert to streaming-style events (excluding the initial
531+
# message_start/content_start because we already emitted them above).
532+
for ev in self._convert_non_streaming_to_streaming(response):
533+
yield ev
456534

457535
logger.debug("finished streaming response from model")
458536

tests/strands/models/test_openai.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,52 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
612612
openai_client.chat.completions.create.assert_called_once_with(**expected_request)
613613

614614

615+
@pytest.mark.asyncio
616+
async def test_stream_respects_streaming_flag(openai_client, model_id, alist):
617+
# Model configured to NOT stream
618+
model = OpenAIModel(client_args={}, model_id=model_id, params={"max_tokens": 1}, streaming=False)
619+
620+
# Mock a non-streaming response object
621+
mock_choice = unittest.mock.Mock()
622+
mock_choice.finish_reason = "stop"
623+
mock_choice.message = unittest.mock.Mock()
624+
mock_choice.message.content = "non-stream result"
625+
mock_response = unittest.mock.Mock()
626+
mock_response.choices = [mock_choice]
627+
mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
628+
629+
openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response)
630+
631+
# Consume the generator and verify the events
632+
response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
633+
tru_events = await alist(response_gen)
634+
635+
expected_request = {
636+
"max_tokens": 1,
637+
"model": model_id,
638+
"messages": [{"role": "user", "content": [{"text": "hi", "type": "text"}]}],
639+
"stream": False,
640+
"stream_options": {"include_usage": True},
641+
"tools": [],
642+
}
643+
openai_client.chat.completions.create.assert_called_once_with(**expected_request)
644+
645+
exp_events = [
646+
{"messageStart": {"role": "assistant"}},
647+
{"contentBlockStart": {"start": {}}},
648+
{"contentBlockDelta": {"delta": {"text": "non-stream result"}}},
649+
{"contentBlockStop": {}},
650+
{"messageStop": {"stopReason": "end_turn"}},
651+
{
652+
"metadata": {
653+
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
654+
"metrics": {"latencyMs": 0},
655+
}
656+
},
657+
]
658+
assert tru_events == exp_events
659+
660+
615661
@pytest.mark.asyncio
616662
async def test_stream_empty(openai_client, model_id, model, agenerator, alist):
617663
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)

tests_integ/models/test_model_openai.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,27 @@ def test_rate_limit_throttling_integration_no_retries(model):
221221
# Verify it's a rate limit error
222222
error_message = str(exc_info.value).lower()
223223
assert "rate limit" in error_message or "tokens per min" in error_message
224+
225+
226+
@pytest.mark.asyncio
227+
async def test_openai_non_streaming(alist):
228+
"""Integration test for non-streaming OpenAI responses."""
229+
model = OpenAIModel(
230+
model_id="gpt-4o-mini",
231+
streaming=False,
232+
client_args={"api_key": os.getenv("OPENAI_API_KEY")},
233+
)
234+
235+
response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
236+
events = await alist(response_gen)
237+
238+
# In non-streaming mode, we expect a consolidated response converted to stream events.
239+
# The exact number of events can vary slightly, but the core structure should be consistent.
240+
assert len(events) >= 5, "Should receive at least 5 events for a non-streaming response"
241+
242+
assert events[0] == {"messageStart": {"role": "assistant"}}, "First event should be messageStart"
243+
assert events[1] == {"contentBlockStart": {"start": {}}}, "Second event should be contentBlockStart"
244+
assert "contentBlockDelta" in events[2], "Third event should be contentBlockDelta"
245+
assert "text" in events[2]["contentBlockDelta"]["delta"], "Delta should contain text"
246+
assert events[3] == {"contentBlockStop": {}}, "Fourth event should be contentBlockStop"
247+
assert "messageStop" in events[4], "Fifth event should be messageStop"

0 commit comments

Comments
 (0)