Skip to content
Open
245 changes: 218 additions & 27 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from __future__ import annotations as _annotations

from collections.abc import Hashable
from collections.abc import Generator, Hashable
from dataclasses import dataclass, field, replace
from typing import Any

Expand Down Expand Up @@ -58,6 +58,8 @@ class ModelResponsePartsManager:
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""
_thinking_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False)
"""Buffers partial content when thinking tags might be split across chunks."""

def get_parts(self) -> list[ModelResponsePart]:
"""Return only model response parts that are complete (i.e., not ToolCallPartDelta's).
Expand All @@ -67,6 +69,28 @@ def get_parts(self) -> list[ModelResponsePart]:
"""
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]

def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
"""Flush any buffered content as text parts.

This should be called when streaming is complete to ensure no content is lost.
Any content buffered in _thinking_tag_buffer that hasn't been processed will be
treated as regular text and emitted.

Yields:
ModelResponseStreamEvent for any buffered content that gets flushed.
"""
for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()):
if buffered_content:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=buffered_content,
id=None,
thinking_tags=None,
ignore_leading_whitespace=False,
)

self._thinking_tag_buffer.clear()

def handle_text_delta(
self,
*,
Expand All @@ -75,82 +99,249 @@ def handle_text_delta(
id: str | None = None,
thinking_tags: tuple[str, str] | None = None,
ignore_leading_whitespace: bool = False,
) -> ModelResponseStreamEvent | None:
) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.

When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding
to that vendor ID is either created or updated.

Thinking tags may be split across multiple chunks. When `thinking_tags` is provided and
`vendor_part_id` is not None, this method buffers content that could be the start of a
thinking tag appearing at the beginning of the current chunk.

Args:
vendor_part_id: The ID the vendor uses to identify this piece
of text. If None, a new part will be created unless the latest part is already
a TextPart.
content: The text content to append to the appropriate TextPart.
id: An optional id for the text part.
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.
Buffering for split tags requires a non-None vendor_part_id.
ignore_leading_whitespace: If True, will ignore leading whitespace in the content.

Returns:
- A `PartStartEvent` if a new part was created.
- A `PartDeltaEvent` if an existing part was updated.
- `None` if no new event is emitted (e.g., the first text part was all whitespace).
Yields:
- `PartStartEvent` if a new part was created.
- `PartDeltaEvent` if an existing part was updated.
May yield multiple events from a single call if buffered content is flushed.

Raises:
UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
"""
if thinking_tags and vendor_part_id is not None:
yield from self._handle_text_delta_with_thinking_tags(
vendor_part_id=vendor_part_id,
content=content,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
else:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=content,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)

def _handle_text_delta_simple( # noqa: C901
self,
*,
vendor_part_id: VendorId | None,
content: str,
id: str | None,
thinking_tags: tuple[str, str] | None,
ignore_leading_whitespace: bool,
) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle text delta without split tag buffering (original logic)."""
existing_text_part_and_index: tuple[TextPart, int] | None = None

if vendor_part_id is None:
# If the vendor_part_id is None, check if the latest part is a TextPart to update
if self._parts:
part_index = len(self._parts) - 1
latest_part = self._parts[part_index]
if isinstance(latest_part, TextPart):
if isinstance(latest_part, ThinkingPart):
# If there's an existing ThinkingPart and no thinking tags, add content to it
# This handles the case where vendor_part_id=None with trailing content after start tag
yield self.handle_thinking_delta(vendor_part_id=None, content=content)
return
elif isinstance(latest_part, TextPart):
existing_text_part_and_index = latest_part, part_index
else:
# Otherwise, attempt to look up an existing TextPart by vendor_part_id
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
if part_index is not None:
existing_part = self._parts[part_index]

if thinking_tags and isinstance(existing_part, ThinkingPart):
# We may be building a thinking part instead of a text part if we had previously seen a thinking tag
if content == thinking_tags[1]:
# When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part
self._vendor_id_to_part_index.pop(vendor_part_id)
return None
else:
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
if thinking_tags and isinstance(existing_part, ThinkingPart): # pragma: no cover
end_tag = thinking_tags[1] # pragma: no cover
if end_tag in content: # pragma: no cover
before_end, after_end = content.split(end_tag, 1) # pragma: no cover

if before_end: # pragma: no cover
yield self.handle_thinking_delta( # pragma: no cover
vendor_part_id=vendor_part_id, content=before_end
)

self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover

if after_end: # pragma: no cover
yield from self._handle_text_delta_simple( # pragma: no cover
vendor_part_id=vendor_part_id,
content=after_end,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return # pragma: no cover

if content == end_tag: # pragma: no cover
self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover
return # pragma: no cover

yield self.handle_thinking_delta( # pragma: no cover
vendor_part_id=vendor_part_id, content=content
)
return # pragma: no cover
elif isinstance(existing_part, TextPart):
existing_text_part_and_index = existing_part, part_index
else:
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')

if thinking_tags and content == thinking_tags[0]:
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
if thinking_tags and thinking_tags[0] in content:
start_tag = thinking_tags[0]
before_start, after_start = content.split(start_tag, 1)

if before_start:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=before_start,
id=id,
thinking_tags=None,
ignore_leading_whitespace=ignore_leading_whitespace,
)

self._vendor_id_to_part_index.pop(vendor_part_id, None)
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')

if after_start:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=after_start,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return

if existing_text_part_and_index is None:
# This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
# which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`.
if ignore_leading_whitespace and (len(content) == 0 or content.isspace()):
return None
return

# There is no existing text part that should be updated, so create a new one
new_part_index = len(self._parts)
part = TextPart(content=content, id=id)
if vendor_part_id is not None:
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
self._parts.append(part)
return PartStartEvent(index=new_part_index, part=part)
yield PartStartEvent(index=new_part_index, part=part)
else:
# Update the existing TextPart with the new content delta
existing_text_part, part_index = existing_text_part_and_index
part_delta = TextPartDelta(content_delta=content)
self._parts[part_index] = part_delta.apply(existing_text_part)
return PartDeltaEvent(index=part_index, delta=part_delta)
yield PartDeltaEvent(index=part_index, delta=part_delta)

def _handle_text_delta_with_thinking_tags(
self,
*,
vendor_part_id: VendorId,
content: str,
id: str | None,
thinking_tags: tuple[str, str],
ignore_leading_whitespace: bool,
) -> Generator[ModelResponseStreamEvent, None, None]:
"""Handle text delta with thinking tag detection and buffering for split tags."""
start_tag, end_tag = thinking_tags
buffered = self._thinking_tag_buffer.get(vendor_part_id, '')
combined_content = buffered + content

part_index = self._vendor_id_to_part_index.get(vendor_part_id)
existing_part = self._parts[part_index] if part_index is not None else None

if existing_part is not None and isinstance(existing_part, ThinkingPart):
if end_tag in combined_content:
before_end, after_end = combined_content.split(end_tag, 1)

if before_end:
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=before_end)

self._vendor_id_to_part_index.pop(vendor_part_id)
self._thinking_tag_buffer.pop(vendor_part_id, None)

if after_end:
yield from self._handle_text_delta_with_thinking_tags(
vendor_part_id=vendor_part_id,
content=after_end,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return

if self._could_be_tag_start(combined_content, end_tag):
self._thinking_tag_buffer[vendor_part_id] = combined_content
return

self._thinking_tag_buffer.pop(vendor_part_id, None)
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content)
return

if start_tag in combined_content:
before_start, after_start = combined_content.split(start_tag, 1)

if before_start:
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=before_start,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)

self._thinking_tag_buffer.pop(vendor_part_id, None)
self._vendor_id_to_part_index.pop(vendor_part_id, None)
yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')

if after_start:
yield from self._handle_text_delta_with_thinking_tags(
vendor_part_id=vendor_part_id,
content=after_start,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)
return

if content.startswith(start_tag[0]) and self._could_be_tag_start(combined_content, start_tag):
self._thinking_tag_buffer[vendor_part_id] = combined_content
return

self._thinking_tag_buffer.pop(vendor_part_id, None)
yield from self._handle_text_delta_simple(
vendor_part_id=vendor_part_id,
content=combined_content,
id=id,
thinking_tags=thinking_tags,
ignore_leading_whitespace=ignore_leading_whitespace,
)

def _could_be_tag_start(self, content: str, tag: str) -> bool:
"""Check if content could be the start of a tag."""
# Defensive check for content that's already complete or longer than tag
# This occurs when buffered content + new chunk exceeds tag length
# Example: buffer='<think' + new='<' = '<think<' (7 chars) >= '<think>' (7 chars)
if len(content) >= len(tag):
return False
return tag.startswith(content)

def handle_thinking_delta(
self,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

def get(self) -> ModelResponse:
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
# Flush any buffered content before building response
for _ in self._parts_manager.finalize():
pass

return ModelResponse(
parts=self._parts_manager.get_parts(),
model_name=self.model_name,
Expand Down
14 changes: 6 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(event, BetaRawContentBlockStartEvent):
current_block = event.content_block
if isinstance(current_block, BetaTextBlock) and current_block.text:
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=current_block.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(current_block, BetaThinkingBlock):
yield self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
Expand Down Expand Up @@ -715,11 +714,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

elif isinstance(event, BetaRawContentBlockDeltaEvent):
if isinstance(event.delta, BetaTextDelta):
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=event.delta.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(event.delta, BetaThinkingDelta):
yield self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
Expand Down
5 changes: 2 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
provider_name=self.provider_name if signature else None,
)
if text := delta.get('text'):
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text):
yield event
if 'toolUse' in delta:
tool_use = delta['toolUse']
maybe_event = self._parts_manager.handle_tool_call_delta(
Expand Down
5 changes: 2 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if isinstance(item, str):
response_tokens = _estimate_string_tokens(item)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
yield event
elif isinstance(item, dict) and item:
for dtc_index, delta in item.items():
if isinstance(delta, DeltaThinkingPart):
Expand Down
Loading