Skip to content

Commit 2f670b4

Browse files
committed
fix(models): Add non-streaming support and resolve type conflicts
This commit introduces a streaming parameter to OpenAIModel to allow for non-streaming responses. The initial implementation revealed a type incompatibility in the inheriting LiteLLMModel. This has been resolved by updating LiteLLMConfig to be consistent with the parent OpenAIConfig, ensuring all pre-commit checks pass. The associated unit tests for OpenAIModel have also been improved to verify the non-streaming behavior.
1 parent eef11cc commit 2f670b4

File tree

3 files changed

+164
-36
lines changed

3 files changed

+164
-36
lines changed

src/strands/models/litellm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ class LiteLLMConfig(TypedDict, total=False):
3535
params: Model parameters (e.g., max_tokens).
3636
For a complete list of supported parameters, see
3737
https://docs.litellm.ai/docs/completion/input#input-params-1.
38+
streaming: Optional flag to indicate whether provider streaming should be used.
39+
If omitted, defaults to True (preserves existing behaviour).
3840
"""
3941

4042
model_id: str
4143
params: Optional[dict[str, Any]]
44+
streaming: Optional[bool]
4245

4346
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None:
4447
"""Initialize provider instance.

src/strands/models/openai.py

Lines changed: 115 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,13 @@ class OpenAIConfig(TypedDict, total=False):
5050
params: Model parameters (e.g., max_tokens).
5151
For a complete list of supported parameters, see
5252
https://platform.openai.com/docs/api-reference/chat/create.
53+
streaming: Optional flag to indicate whether provider streaming should be used.
54+
If omitted, defaults to True (preserves existing behaviour).
5355
"""
5456

5557
model_id: str
5658
params: Optional[dict[str, Any]]
59+
streaming: Optional[bool]
5760

5861
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
5962
"""Initialize provider instance.
@@ -263,7 +266,8 @@ def format_request(
263266
return {
264267
"messages": self.format_request_messages(messages, system_prompt),
265268
"model": self.config["model_id"],
266-
"stream": True,
269+
# Use configured streaming flag; default True to preserve previous behavior.
270+
"stream": bool(self.get_config().get("streaming", True)),
267271
"stream_options": {"include_usage": True},
268272
"tools": [
269273
{
@@ -352,6 +356,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
352356
case _:
353357
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
354358

359+
def _convert_non_streaming_to_streaming(self, response: Any) -> list[StreamEvent]:
360+
"""Convert a provider non-streaming response into streaming-style events.
361+
362+
This helper intentionally *does not* emit the initial message_start/content_start events,
363+
because the caller (stream) already yields them to preserve parity with streaming flow.
364+
"""
365+
events: list[StreamEvent] = []
366+
367+
# Extract main text content from first choice if available
368+
if getattr(response, "choices", None):
369+
choice = response.choices[0]
370+
content = None
371+
if hasattr(choice, "message") and hasattr(choice.message, "content"):
372+
content = choice.message.content
373+
374+
# handle str content
375+
if isinstance(content, str):
376+
events.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content}))
377+
# handle list content (list of blocks/dicts)
378+
elif isinstance(content, list):
379+
for block in content:
380+
if isinstance(block, dict):
381+
# reasoning content
382+
if "reasoningContent" in block and isinstance(block["reasoningContent"], dict):
383+
try:
384+
text = block["reasoningContent"]["reasoningText"]["text"]
385+
events.append(
386+
self.format_chunk(
387+
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": text}
388+
)
389+
)
390+
except Exception:
391+
# fall back to keeping the block as text if malformed
392+
pass
393+
# text block
394+
elif "text" in block:
395+
events.append(
396+
self.format_chunk(
397+
{"chunk_type": "content_delta", "data_type": "text", "data": block["text"]}
398+
)
399+
)
400+
# ignore other block types for now
401+
elif isinstance(block, str):
402+
events.append(
403+
self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": block})
404+
)
405+
406+
# content stop
407+
events.append(self.format_chunk({"chunk_type": "content_stop"}))
408+
409+
# message stop — convert finish reason if available
410+
stop_reason = None
411+
if getattr(response, "choices", None):
412+
stop_reason = getattr(response.choices[0], "finish_reason", None)
413+
events.append(self.format_chunk({"chunk_type": "message_stop", "data": stop_reason or "stop"}))
414+
415+
# metadata (usage) if present
416+
if getattr(response, "usage", None):
417+
events.append(self.format_chunk({"chunk_type": "metadata", "data": response.usage}))
418+
419+
return events
420+
355421
@override
356422
async def stream(
357423
self,
@@ -409,50 +475,63 @@ async def stream(
409475

410476
tool_calls: dict[int, list[Any]] = {}
411477

412-
async for event in response:
413-
# Defensive: skip events with empty or missing choices
414-
if not getattr(event, "choices", None):
415-
continue
416-
choice = event.choices[0]
417-
418-
if choice.delta.content:
419-
yield self.format_chunk(
420-
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
421-
)
422-
423-
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
424-
yield self.format_chunk(
425-
{
426-
"chunk_type": "content_delta",
427-
"data_type": "reasoning_content",
428-
"data": choice.delta.reasoning_content,
429-
}
430-
)
478+
streaming = bool(self.get_config().get("streaming", True))
479+
480+
if streaming:
481+
# response is an async iterator when streaming=True
482+
async for event in response:
483+
# Defensive: skip events with empty or missing choices
484+
if not getattr(event, "choices", None):
485+
continue
486+
choice = event.choices[0]
487+
488+
if choice.delta.content:
489+
yield self.format_chunk(
490+
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
491+
)
492+
493+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
494+
yield self.format_chunk(
495+
{
496+
"chunk_type": "content_delta",
497+
"data_type": "reasoning_content",
498+
"data": choice.delta.reasoning_content,
499+
}
500+
)
431501

432-
for tool_call in choice.delta.tool_calls or []:
433-
tool_calls.setdefault(tool_call.index, []).append(tool_call)
502+
for tool_call in choice.delta.tool_calls or []:
503+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
434504

435-
if choice.finish_reason:
436-
break
505+
if choice.finish_reason:
506+
break
437507

438-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
508+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
439509

440-
for tool_deltas in tool_calls.values():
441-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
510+
for tool_deltas in tool_calls.values():
511+
yield self.format_chunk(
512+
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
513+
)
442514

443-
for tool_delta in tool_deltas:
444-
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
515+
for tool_delta in tool_deltas:
516+
yield self.format_chunk(
517+
{"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
518+
)
445519

446-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
520+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
447521

448-
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
522+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
449523

450-
# Skip remaining events as we don't have use for anything except the final usage payload
451-
async for event in response:
452-
_ = event
524+
# Skip remaining events as we don't have use for anything except the final usage payload
525+
async for event in response:
526+
_ = event
453527

454-
if event.usage:
455-
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
528+
if event.usage:
529+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
530+
else:
531+
# Non-streaming provider response — convert to streaming-style events (excluding the initial
532+
# message_start/content_start because we already emitted them above).
533+
for ev in self._convert_non_streaming_to_streaming(response):
534+
yield ev
456535

457536
logger.debug("finished streaming response from model")
458537

tests/strands/models/test_openai.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,52 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
612612
openai_client.chat.completions.create.assert_called_once_with(**expected_request)
613613

614614

615+
@pytest.mark.asyncio
616+
async def test_stream_respects_streaming_flag(openai_client, model_id, alist):
617+
# Model configured to NOT stream
618+
model = OpenAIModel(client_args={}, model_id=model_id, params={"max_tokens": 1}, streaming=False)
619+
620+
# Mock a non-streaming response object
621+
mock_choice = unittest.mock.Mock()
622+
mock_choice.finish_reason = "stop"
623+
mock_choice.message = unittest.mock.Mock()
624+
mock_choice.message.content = "non-stream result"
625+
mock_response = unittest.mock.Mock()
626+
mock_response.choices = [mock_choice]
627+
mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
628+
629+
openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response)
630+
631+
# Consume the generator and verify the events
632+
response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
633+
tru_events = await alist(response_gen)
634+
635+
expected_request = {
636+
"max_tokens": 1,
637+
"model": model_id,
638+
"messages": [{"role": "user", "content": [{"text": "hi", "type": "text"}]}],
639+
"stream": False,
640+
"stream_options": {"include_usage": True},
641+
"tools": [],
642+
}
643+
openai_client.chat.completions.create.assert_called_once_with(**expected_request)
644+
645+
exp_events = [
646+
{"messageStart": {"role": "assistant"}},
647+
{"contentBlockStart": {"start": {}}},
648+
{"contentBlockDelta": {"delta": {"text": "non-stream result"}}},
649+
{"contentBlockStop": {}},
650+
{"messageStop": {"stopReason": "end_turn"}},
651+
{
652+
"metadata": {
653+
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
654+
"metrics": {"latencyMs": 0},
655+
}
656+
},
657+
]
658+
assert tru_events == exp_events
659+
660+
615661
@pytest.mark.asyncio
616662
async def test_stream_empty(openai_client, model_id, model, agenerator, alist):
617663
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)

0 commit comments

Comments
 (0)