Skip to content

Commit 0db7641

Browse files
committed
fix(models): Add non-streaming support and resolve type conflicts
This commit introduces a streaming parameter to OpenAIModel to allow for non-streaming responses. The initial implementation revealed a type incompatibility in the inheriting LiteLLMModel. This has been resolved by updating LiteLLMConfig to be consistent with the parent OpenAIConfig, ensuring all pre-commit checks pass. The associated unit tests for OpenAIModel have also been improved to verify the non-streaming behavior.
1 parent ded0934 commit 0db7641

File tree

4 files changed

+195
-46
lines changed

4 files changed

+195
-46
lines changed

src/strands/models/litellm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type:
7272
self._apply_proxy_prefix()
7373

7474
@override
75-
def get_config(self) -> LiteLLMConfig:
75+
def get_config(self) -> LiteLLMConfig: # type: ignore[override]
7676
"""Get the LiteLLM model configuration.
7777
7878
Returns:

src/strands/models/openai.py

Lines changed: 124 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,13 @@ class OpenAIConfig(TypedDict, total=False):
5050
params: Model parameters (e.g., max_tokens).
5151
For a complete list of supported parameters, see
5252
https://platform.openai.com/docs/api-reference/chat/create.
53+
streaming: Optional flag to indicate whether provider streaming should be used.
54+
If omitted, defaults to True (preserves existing behaviour).
5355
"""
5456

5557
model_id: str
5658
params: Optional[dict[str, Any]]
59+
streaming: bool | None
5760

5861
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
5962
"""Initialize provider instance.
@@ -332,7 +335,7 @@ def format_request(
332335
messages, system_prompt, system_prompt_content=system_prompt_content
333336
),
334337
"model": self.config["model_id"],
335-
"stream": True,
338+
"stream": self.config.get("streaming", True),
336339
"stream_options": {"include_usage": True},
337340
"tools": [
338341
{
@@ -422,6 +425,68 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
422425
case _:
423426
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
424427

428+
def _convert_non_streaming_to_streaming(self, response: Any) -> list[StreamEvent]:
429+
"""Convert a provider non-streaming response into streaming-style events.
430+
431+
This helper intentionally *does not* emit the initial message_start/content_start events,
432+
because the caller (stream) already yields them to preserve parity with streaming flow.
433+
"""
434+
events: list[StreamEvent] = []
435+
436+
# Extract main text content from first choice if available
437+
if getattr(response, "choices", None):
438+
choice = response.choices[0]
439+
content = None
440+
if hasattr(choice, "message") and hasattr(choice.message, "content"):
441+
content = choice.message.content
442+
443+
# handle str content
444+
if isinstance(content, str):
445+
events.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content}))
446+
# handle list content (list of blocks/dicts)
447+
elif isinstance(content, list):
448+
for block in content:
449+
if isinstance(block, dict):
450+
# reasoning content
451+
if "reasoningContent" in block and isinstance(block["reasoningContent"], dict):
452+
try:
453+
text = block["reasoningContent"]["reasoningText"]["text"]
454+
events.append(
455+
self.format_chunk(
456+
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": text}
457+
)
458+
)
459+
except Exception:
460+
# fall back to keeping the block as text if malformed
461+
pass
462+
# text block
463+
elif "text" in block:
464+
events.append(
465+
self.format_chunk(
466+
{"chunk_type": "content_delta", "data_type": "text", "data": block["text"]}
467+
)
468+
)
469+
# ignore other block types for now
470+
elif isinstance(block, str):
471+
events.append(
472+
self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": block})
473+
)
474+
475+
# content stop
476+
events.append(self.format_chunk({"chunk_type": "content_stop"}))
477+
478+
# message stop — convert finish reason if available
479+
stop_reason = None
480+
if getattr(response, "choices", None):
481+
stop_reason = getattr(response.choices[0], "finish_reason", None)
482+
events.append(self.format_chunk({"chunk_type": "message_stop", "data": stop_reason or "stop"}))
483+
484+
# metadata (usage) if present
485+
if getattr(response, "usage", None):
486+
events.append(self.format_chunk({"chunk_type": "metadata", "data": response.usage}))
487+
488+
return events
489+
425490
@override
426491
async def stream(
427492
self,
@@ -480,57 +545,71 @@ async def stream(
480545
finish_reason = None # Store finish_reason for later use
481546
event = None # Initialize for scope safety
482547

483-
async for event in response:
484-
# Defensive: skip events with empty or missing choices
485-
if not getattr(event, "choices", None):
486-
continue
487-
choice = event.choices[0]
488-
489-
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
490-
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
491-
for chunk in chunks:
492-
yield chunk
493-
yield self.format_chunk(
494-
{
495-
"chunk_type": "content_delta",
496-
"data_type": data_type,
497-
"data": choice.delta.reasoning_content,
498-
}
499-
)
500-
501-
if choice.delta.content:
502-
chunks, data_type = self._stream_switch_content("text", data_type)
503-
for chunk in chunks:
504-
yield chunk
548+
streaming = self.config.get("streaming", True)
549+
550+
if streaming:
551+
# response is an async iterator when streaming=True
552+
async for event in response:
553+
# skip events with empty or missing choices
554+
if not getattr(event, "choices", None):
555+
continue
556+
choice = event.choices[0]
557+
558+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
559+
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
560+
for chunk in chunks:
561+
yield chunk
562+
yield self.format_chunk(
563+
{
564+
"chunk_type": "content_delta",
565+
"data_type": data_type,
566+
"data": choice.delta.reasoning_content,
567+
}
568+
)
569+
570+
if choice.delta.content:
571+
chunks, data_type = self._stream_switch_content("text", data_type)
572+
for chunk in chunks:
573+
yield chunk
574+
yield self.format_chunk(
575+
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
576+
)
577+
578+
for tool_call in choice.delta.tool_calls or []:
579+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
580+
581+
if choice.finish_reason:
582+
finish_reason = choice.finish_reason # Store for use outside loop
583+
if data_type:
584+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
585+
break
586+
587+
for tool_deltas in tool_calls.values():
505588
yield self.format_chunk(
506-
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
589+
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
507590
)
508591

509-
for tool_call in choice.delta.tool_calls or []:
510-
tool_calls.setdefault(tool_call.index, []).append(tool_call)
511-
512-
if choice.finish_reason:
513-
finish_reason = choice.finish_reason # Store for use outside loop
514-
if data_type:
515-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
516-
break
517-
518-
for tool_deltas in tool_calls.values():
519-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
520-
521-
for tool_delta in tool_deltas:
522-
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
592+
for tool_delta in tool_deltas:
593+
yield self.format_chunk(
594+
{"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
595+
)
523596

524-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
597+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
525598

526-
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"})
599+
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"})
527600

528-
# Skip remaining events as we don't have use for anything except the final usage payload
529-
async for event in response:
530-
_ = event
601+
# Skip remaining events
602+
async for event in response:
603+
_ = event
531604

532-
if event and hasattr(event, "usage") and event.usage:
533-
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
605+
if event and hasattr(event, "usage") and event.usage:
606+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
607+
else:
608+
# Non-streaming provider response — convert to streaming-style events.
609+
# We manually emit the content_start event here to align with the streaming path
610+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
611+
for ev in self._convert_non_streaming_to_streaming(response):
612+
yield ev
534613

535614
logger.debug("finished streaming response from model")
536615

tests/strands/models/test_openai.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,52 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
614614
openai_client.chat.completions.create.assert_called_once_with(**expected_request)
615615

616616

617+
@pytest.mark.asyncio
618+
async def test_stream_respects_streaming_flag(openai_client, model_id, alist):
619+
# Model configured to NOT stream
620+
model = OpenAIModel(client_args={}, model_id=model_id, params={"max_tokens": 1}, streaming=False)
621+
622+
# Mock a non-streaming response object
623+
mock_choice = unittest.mock.Mock()
624+
mock_choice.finish_reason = "stop"
625+
mock_choice.message = unittest.mock.Mock()
626+
mock_choice.message.content = "non-stream result"
627+
mock_response = unittest.mock.Mock()
628+
mock_response.choices = [mock_choice]
629+
mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
630+
631+
openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response)
632+
633+
# Consume the generator and verify the events
634+
response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
635+
tru_events = await alist(response_gen)
636+
637+
expected_request = {
638+
"max_tokens": 1,
639+
"model": model_id,
640+
"messages": [{"role": "user", "content": [{"text": "hi", "type": "text"}]}],
641+
"stream": False,
642+
"stream_options": {"include_usage": True},
643+
"tools": [],
644+
}
645+
openai_client.chat.completions.create.assert_called_once_with(**expected_request)
646+
647+
exp_events = [
648+
{"messageStart": {"role": "assistant"}},
649+
{"contentBlockStart": {"start": {}}},
650+
{"contentBlockDelta": {"delta": {"text": "non-stream result"}}},
651+
{"contentBlockStop": {}},
652+
{"messageStop": {"stopReason": "end_turn"}},
653+
{
654+
"metadata": {
655+
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
656+
"metrics": {"latencyMs": 0},
657+
}
658+
},
659+
]
660+
assert tru_events == exp_events
661+
662+
617663
@pytest.mark.asyncio
618664
async def test_stream_empty(openai_client, model_id, model, agenerator, alist):
619665
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)

tests_integ/models/test_model_openai.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,27 @@ def test_system_prompt_backward_compatibility_integration(model):
257257

258258
# The response should contain our specific system prompt instruction
259259
assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"]
260+
261+
262+
@pytest.mark.asyncio
263+
async def test_openai_non_streaming(alist):
264+
"""Integration test for non-streaming OpenAI responses."""
265+
model = OpenAIModel(
266+
model_id="gpt-4o-mini",
267+
streaming=False,
268+
client_args={"api_key": os.getenv("OPENAI_API_KEY")},
269+
)
270+
271+
response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
272+
events = await alist(response_gen)
273+
274+
# In non-streaming mode, we expect a consolidated response converted to stream events.
275+
# The exact number of events can vary slightly, but the core structure should be consistent.
276+
assert len(events) >= 5, "Should receive at least 5 events for a non-streaming response"
277+
278+
assert events[0] == {"messageStart": {"role": "assistant"}}, "First event should be messageStart"
279+
assert events[1] == {"contentBlockStart": {"start": {}}}, "Second event should be contentBlockStart"
280+
assert "contentBlockDelta" in events[2], "Third event should be contentBlockDelta"
281+
assert "text" in events[2]["contentBlockDelta"]["delta"], "Delta should contain text"
282+
assert events[3] == {"contentBlockStop": {}}, "Fourth event should be contentBlockStop"
283+
assert "messageStop" in events[4], "Fifth event should be messageStop"

0 commit comments

Comments
 (0)