diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 41d6357994..15e27231ac 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -13,7 +13,7 @@ from __future__ import annotations as _annotations -from collections.abc import Hashable +from collections.abc import Generator, Hashable from dataclasses import dataclass, field, replace from typing import Any @@ -58,6 +58,8 @@ class ModelResponsePartsManager: """A list of parts (text or tool calls) that make up the current state of the model's response.""" _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) """Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides.""" + _thinking_tag_buffer: dict[VendorId, str] = field(default_factory=dict, init=False) + """Buffers partial content when thinking tags might be split across chunks.""" def get_parts(self) -> list[ModelResponsePart]: """Return only model response parts that are complete (i.e., not ToolCallPartDelta's). @@ -67,6 +69,28 @@ def get_parts(self) -> list[ModelResponsePart]: """ return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] + def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]: + """Flush any buffered content as text parts. + + This should be called when streaming is complete to ensure no content is lost. + Any content buffered in _thinking_tag_buffer that hasn't been processed will be + treated as regular text and emitted. + + Yields: + ModelResponseStreamEvent for any buffered content that gets flushed. + """ + for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()): + if buffered_content: + yield from self._handle_text_delta_simple( + vendor_part_id=vendor_part_id, + content=buffered_content, + id=None, + thinking_tags=None, + ignore_leading_whitespace=False, + ) + + self._thinking_tag_buffer.clear() + def handle_text_delta( self, *, @@ -75,13 +99,17 @@ def handle_text_delta( id: str | None = None, thinking_tags: tuple[str, str] | None = None, ignore_leading_whitespace: bool = False, - ) -> ModelResponseStreamEvent | None: + ) -> Generator[ModelResponseStreamEvent, None, None]: """Handle incoming text content, creating or updating a TextPart in the manager as appropriate. When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart; otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding to that vendor ID is either created or updated. + Thinking tags may be split across multiple chunks. When `thinking_tags` is provided and + `vendor_part_id` is not None, this method buffers content that could be the start of a + thinking tag appearing at the beginning of the current chunk. + Args: vendor_part_id: The ID the vendor uses to identify this piece of text. If None, a new part will be created unless the latest part is already @@ -89,68 +117,231 @@ def handle_text_delta( content: The text content to append to the appropriate TextPart. id: An optional id for the text part. thinking_tags: If provided, will handle content between the thinking tags as thinking parts. + Buffering for split tags requires a non-None vendor_part_id. ignore_leading_whitespace: If True, will ignore leading whitespace in the content. - Returns: - - A `PartStartEvent` if a new part was created. - - A `PartDeltaEvent` if an existing part was updated. - - `None` if no new event is emitted (e.g., the first text part was all whitespace). + Yields: + - `PartStartEvent` if a new part was created. + - `PartDeltaEvent` if an existing part was updated. + May yield multiple events from a single call if buffered content is flushed. Raises: UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart. """ + if thinking_tags and vendor_part_id is not None: + yield from self._handle_text_delta_with_thinking_tags( + vendor_part_id=vendor_part_id, + content=content, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + else: + yield from self._handle_text_delta_simple( + vendor_part_id=vendor_part_id, + content=content, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + + def _handle_text_delta_simple( # noqa: C901 + self, + *, + vendor_part_id: VendorId | None, + content: str, + id: str | None, + thinking_tags: tuple[str, str] | None, + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Handle text delta without split tag buffering (original logic).""" existing_text_part_and_index: tuple[TextPart, int] | None = None if vendor_part_id is None: - # If the vendor_part_id is None, check if the latest part is a TextPart to update if self._parts: part_index = len(self._parts) - 1 latest_part = self._parts[part_index] - if isinstance(latest_part, TextPart): + if isinstance(latest_part, ThinkingPart): + # If there's an existing ThinkingPart and no thinking tags, add content to it + # This handles the case where vendor_part_id=None with trailing content after start tag + yield self.handle_thinking_delta(vendor_part_id=None, content=content) + return + elif isinstance(latest_part, TextPart): existing_text_part_and_index = latest_part, part_index else: - # Otherwise, attempt to look up an existing TextPart by vendor_part_id part_index = self._vendor_id_to_part_index.get(vendor_part_id) if part_index is not None: existing_part = self._parts[part_index] - if thinking_tags and isinstance(existing_part, ThinkingPart): - # We may be building a thinking part instead of a text part if we had previously seen a thinking tag - if content == thinking_tags[1]: - # When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part - self._vendor_id_to_part_index.pop(vendor_part_id) - return None - else: - return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content) + if thinking_tags and isinstance(existing_part, ThinkingPart): # pragma: no cover + end_tag = thinking_tags[1] # pragma: no cover + if end_tag in content: # pragma: no cover + before_end, after_end = content.split(end_tag, 1) # pragma: no cover + + if before_end: # pragma: no cover + yield self.handle_thinking_delta( # pragma: no cover + vendor_part_id=vendor_part_id, content=before_end + ) + + self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover + + if after_end: # pragma: no cover + yield from self._handle_text_delta_simple( # pragma: no cover + vendor_part_id=vendor_part_id, + content=after_end, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + return # pragma: no cover + + if content == end_tag: # pragma: no cover + self._vendor_id_to_part_index.pop(vendor_part_id) # pragma: no cover + return # pragma: no cover + + yield self.handle_thinking_delta( # pragma: no cover + vendor_part_id=vendor_part_id, content=content + ) + return # pragma: no cover elif isinstance(existing_part, TextPart): existing_text_part_and_index = existing_part, part_index else: raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}') - if thinking_tags and content == thinking_tags[0]: - # When we see a thinking start tag (which is a single token), we'll build a new thinking part instead + if thinking_tags and thinking_tags[0] in content: + start_tag = thinking_tags[0] + before_start, after_start = content.split(start_tag, 1) + + if before_start: + yield from self._handle_text_delta_simple( + vendor_part_id=vendor_part_id, + content=before_start, + id=id, + thinking_tags=None, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + self._vendor_id_to_part_index.pop(vendor_part_id, None) - return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') + yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') + + if after_start: + yield from self._handle_text_delta_simple( + vendor_part_id=vendor_part_id, + content=after_start, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + return if existing_text_part_and_index is None: - # This is a workaround for models that emit `\n\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3), - # which we don't want to end up treating as a final result when using `run_stream` with `str` a valid `output_type`. if ignore_leading_whitespace and (len(content) == 0 or content.isspace()): - return None + return - # There is no existing text part that should be updated, so create a new one new_part_index = len(self._parts) part = TextPart(content=content, id=id) if vendor_part_id is not None: self._vendor_id_to_part_index[vendor_part_id] = new_part_index self._parts.append(part) - return PartStartEvent(index=new_part_index, part=part) + yield PartStartEvent(index=new_part_index, part=part) else: - # Update the existing TextPart with the new content delta existing_text_part, part_index = existing_text_part_and_index part_delta = TextPartDelta(content_delta=content) self._parts[part_index] = part_delta.apply(existing_text_part) - return PartDeltaEvent(index=part_index, delta=part_delta) + yield PartDeltaEvent(index=part_index, delta=part_delta) + + def _handle_text_delta_with_thinking_tags( + self, + *, + vendor_part_id: VendorId, + content: str, + id: str | None, + thinking_tags: tuple[str, str], + ignore_leading_whitespace: bool, + ) -> Generator[ModelResponseStreamEvent, None, None]: + """Handle text delta with thinking tag detection and buffering for split tags.""" + start_tag, end_tag = thinking_tags + buffered = self._thinking_tag_buffer.get(vendor_part_id, '') + combined_content = buffered + content + + part_index = self._vendor_id_to_part_index.get(vendor_part_id) + existing_part = self._parts[part_index] if part_index is not None else None + + if existing_part is not None and isinstance(existing_part, ThinkingPart): + if end_tag in combined_content: + before_end, after_end = combined_content.split(end_tag, 1) + + if before_end: + yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=before_end) + + self._vendor_id_to_part_index.pop(vendor_part_id) + self._thinking_tag_buffer.pop(vendor_part_id, None) + + if after_end: + yield from self._handle_text_delta_with_thinking_tags( + vendor_part_id=vendor_part_id, + content=after_end, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + return + + if self._could_be_tag_start(combined_content, end_tag): + self._thinking_tag_buffer[vendor_part_id] = combined_content + return + + self._thinking_tag_buffer.pop(vendor_part_id, None) + yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=combined_content) + return + + if start_tag in combined_content: + before_start, after_start = combined_content.split(start_tag, 1) + + if before_start: + yield from self._handle_text_delta_simple( + vendor_part_id=vendor_part_id, + content=before_start, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + + self._thinking_tag_buffer.pop(vendor_part_id, None) + self._vendor_id_to_part_index.pop(vendor_part_id, None) + yield self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') + + if after_start: + yield from self._handle_text_delta_with_thinking_tags( + vendor_part_id=vendor_part_id, + content=after_start, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + return + + if content.startswith(start_tag[0]) and self._could_be_tag_start(combined_content, start_tag): + self._thinking_tag_buffer[vendor_part_id] = combined_content + return + + self._thinking_tag_buffer.pop(vendor_part_id, None) + yield from self._handle_text_delta_simple( + vendor_part_id=vendor_part_id, + content=combined_content, + id=id, + thinking_tags=thinking_tags, + ignore_leading_whitespace=ignore_leading_whitespace, + ) + + def _could_be_tag_start(self, content: str, tag: str) -> bool: + """Check if content could be the start of a tag.""" + # Defensive check for content that's already complete or longer than tag + # This occurs when buffered content + new chunk exceeds tag length + # Example: buffer='= '' (7 chars) + if len(content) >= len(tag): + return False + return tag.startswith(content) def handle_thinking_delta( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index c9c7f6bc40..6bcfe821d1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -553,6 +553,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: def get(self) -> ModelResponse: """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far.""" + # Flush any buffered content before building response + for _ in self._parts_manager.finalize(): + pass + return ModelResponse( parts=self._parts_manager.get_parts(), model_name=self.model_name, diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 1511f724bc..1c014d3833 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -669,11 +669,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block if isinstance(current_block, BetaTextBlock) and current_block.text: - maybe_event = self._parts_manager.handle_text_delta( + for event_item in self._parts_manager.handle_text_delta( vendor_part_id=event.index, content=current_block.text - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event_item elif isinstance(current_block, BetaThinkingBlock): yield self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, @@ -715,11 +714,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(event.delta, BetaTextDelta): - maybe_event = self._parts_manager.handle_text_delta( + for event_item in self._parts_manager.handle_text_delta( vendor_part_id=event.index, content=event.delta.text - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event_item elif isinstance(event.delta, BetaThinkingDelta): yield self._parts_manager.handle_thinking_delta( vendor_part_id=event.index, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 0584158f1d..ecbe94c12f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -702,9 +702,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: provider_name=self.provider_name if signature else None, ) if text := delta.get('text'): - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text): + yield event if 'toolUse' in delta: tool_use = delta['toolUse'] maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 405c088f7d..5db948db31 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -289,9 +289,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(item, str): response_tokens = _estimate_string_tokens(item) self._usage += usage.RequestUsage(output_tokens=response_tokens) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item): + yield event elif isinstance(item, dict) and item: for dtc_index, delta in item.items(): if isinstance(delta, DeltaThinkingPart): diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 981ef29ef6..500e6c76e3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -461,11 +461,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if 'text' in gemini_part: # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled # amongst the tool call deltas - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id=None, content=gemini_part['text'] - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index e0a9197449..42b4c9d5be 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -666,9 +666,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if part.thought: yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text) else: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text): + yield event elif part.function_call: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=uuid4(), diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index ad43b69110..b6d6f343ba 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -564,14 +564,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event # Handle the tool calls for dtc in choice.delta.tool_calls or []: diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index a71edf7026..48c3785ddc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -483,14 +483,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 90265bbe53..daacac985a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -653,9 +653,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: tool_call_id=maybe_tool_call_part.tool_call_id, ) else: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=text): + yield event # Handle the explicit tool calls for index, dtc in enumerate(choice.delta.tool_calls or []): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 84c566c23f..2251df065c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1657,17 +1657,16 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content: - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, thinking_tags=self._model_profile.thinking_tags, ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace, - ) - if maybe_event is not None: # pragma: no branch - if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart): - maybe_event.part.id = 'content' - maybe_event.part.provider_name = self.provider_name - yield maybe_event + ): + if isinstance(event, PartStartEvent) and isinstance(event.part, ThinkingPart): + event.part.id = 'content' + event.part.provider_name = self.provider_name + yield event for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( @@ -1852,11 +1851,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass # there's nothing we need to do here elif isinstance(chunk, responses.ResponseTextDeltaEvent): - maybe_event = self._parts_manager.handle_text_delta( + for event in self._parts_manager.handle_text_delta( vendor_part_id=chunk.item_id, content=chunk.delta, id=chunk.item_id - ) - if maybe_event is not None: # pragma: no branch - yield maybe_event + ): + yield event elif isinstance(chunk, responses.ResponseTextDoneEvent): pass # there's nothing we need to do here diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 6b772365ba..5b9dbaa26a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -310,14 +310,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: mid = len(text) // 2 words = [text[:mid], text[mid:]] self._usage += _get_string_usage('') - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='') - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=''): + yield event for word in words: self._usage += _get_string_usage(word) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=i, content=word): + yield event elif isinstance(part, ToolCallPart): yield self._parts_manager.handle_tool_call_part( vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id diff --git a/pyproject.toml b/pyproject.toml index b09a172045..f264c80a76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -311,4 +311,4 @@ skip = '.git*,*.svg,*.lock,*.css,*.yaml' check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' -ignore-words-list = 'asend,aci' +ignore-words-list = 'asend,aci,thi' diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 0183634e59..f933a4aa12 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -116,12 +116,10 @@ async def request_stream( class MyResponseStream(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: self._usage = RequestUsage(input_tokens=300, output_tokens=400) - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') - if maybe_event is not None: # pragma: no branch - yield maybe_event - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2') - if maybe_event is not None: # pragma: no branch - yield maybe_event + for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1'): + yield event + for event in self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2'): + yield event @property def model_name(self) -> str: diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index d73e8579c3..bf756ddd9f 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -342,3 +342,24 @@ def test_different_content_input(content: AudioUrl | VideoUrl | ImageUrl | Binar result = agent.run_sync(['x', content], model=TestModel(custom_output_text='custom')) assert result.output == snapshot('custom') assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=1)) + + +@pytest.mark.anyio +async def test_finalize_integration_buffered_content(): + """Integration test: StreamedResponse.get() calls finalize() without breaking. + + Note: TestModel doesn't pass thinking_tags during streaming, so this doesn't actually + test buffering behavior - it just verifies that calling get() works correctly. + The actual buffering logic is thoroughly tested in test_parts_manager_split_tags.py, + and normal streaming is tested extensively in test_streaming.py. + """ + test_model = TestModel(custom_output_text='Hello ', '') - event = manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=0, part=TextPart(content='pre-', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot([TextPart(content='pre-', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' ) ) assert manager.get_parts() == snapshot([TextPart(content='pre-thinking', part_kind='text')]) - event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=1, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start') ) assert manager.get_parts() == snapshot( [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='', part_kind='thinking')] ) - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=1, delta=ThinkingPartDelta(content_delta='thinking', part_delta_kind='thinking'), @@ -119,8 +129,9 @@ def test_handle_text_deltas_with_think_tags(): [TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='thinking', part_kind='thinking')] ) - event = manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=1, delta=ThinkingPartDelta(content_delta=' more', part_delta_kind='thinking'), event_kind='part_delta' ) @@ -132,11 +143,12 @@ def test_handle_text_deltas_with_think_tags(): ] ) - event = manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags) - assert event is None + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 0 - event = manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=2, part=TextPart(content='post-', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot( @@ -147,8 +159,9 @@ def test_handle_text_deltas_with_think_tags(): ] ) - event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( PartDeltaEvent( index=2, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta' ) @@ -376,8 +389,9 @@ def test_handle_tool_call_part(): def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | None, tool_vendor_part_id: str | None): manager = ModelResponsePartsManager() - event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ') - assert event == snapshot( + events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='hello ')) + assert len(events) == 1 + assert events[0] == snapshot( PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') ) assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) @@ -393,9 +407,10 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non ) ) - event = manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world') + events = list(manager.handle_text_delta(vendor_part_id=text_vendor_part_id, content='world')) + assert len(events) == 1 if text_vendor_part_id is None: - assert event == snapshot( + assert events[0] == snapshot( PartStartEvent( index=2, part=TextPart(content='world', part_kind='text'), @@ -410,7 +425,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non ] ) else: - assert event == snapshot( + assert events[0] == snapshot( PartDeltaEvent( index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' ) @@ -425,7 +440,7 @@ def test_handle_mixed_deltas_without_text_part_id(text_vendor_part_id: str | Non def test_cannot_convert_from_text_to_tool_call(): manager = ModelResponsePartsManager() - manager.handle_text_delta(vendor_part_id=1, content='hello') + list(manager.handle_text_delta(vendor_part_id=1, content='hello')) with pytest.raises( UnexpectedModelBehavior, match=re.escape('Cannot apply a tool call delta to existing_part=TextPart(') ): @@ -438,7 +453,7 @@ def test_cannot_convert_from_tool_call_to_text(): with pytest.raises( UnexpectedModelBehavior, match=re.escape('Cannot apply a text delta to existing_part=ToolCallPart(') ): - manager.handle_text_delta(vendor_part_id=1, content='hello') + list(manager.handle_text_delta(vendor_part_id=1, content='hello')) def test_tool_call_id_delta(): @@ -546,7 +561,7 @@ def test_handle_thinking_delta_wrong_part_type(): manager = ModelResponsePartsManager() # Add a text part first - manager.handle_text_delta(vendor_part_id='text', content='hello') + list(manager.handle_text_delta(vendor_part_id='text', content='hello')) # Try to apply thinking delta to the text part - should raise error with pytest.raises(UnexpectedModelBehavior, match=r'Cannot apply a thinking delta to existing_part='): diff --git a/tests/test_parts_manager_split_tags.py b/tests/test_parts_manager_split_tags.py new file mode 100644 index 0000000000..01a425f104 --- /dev/null +++ b/tests/test_parts_manager_split_tags.py @@ -0,0 +1,573 @@ +"""Tests for split thinking tag handling in ModelResponsePartsManager.""" + +from inline_snapshot import snapshot + +from pydantic_ai._parts_manager import ModelResponsePartsManager +from pydantic_ai.messages import ( + PartDeltaEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, +) + + +def test_handle_text_deltas_with_split_think_tags_at_chunk_start(): + """Test split thinking tags when tag starts at position 0 of chunk.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Chunk 1: "" - completes the tag + events = list(manager.handle_text_delta(vendor_part_id='content', content='nk>', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartStartEvent(index=0, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([ThinkingPart(content='', part_kind='thinking')]) + + # Chunk 3: "reasoning content" + events = list( + manager.handle_text_delta(vendor_part_id='content', content='reasoning content', thinking_tags=thinking_tags) + ) + assert len(events) == 1 + assert events[0] == snapshot( + PartDeltaEvent( + index=0, + delta=ThinkingPartDelta(content_delta='reasoning content', part_delta_kind='thinking'), + event_kind='part_delta', + ) + ) + + # Chunk 4: "" - end tag + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 0 + + # Chunk 5: "after" - text after thinking + events = list(manager.handle_text_delta(vendor_part_id='content', content='after', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartStartEvent(index=1, part=TextPart(content='after', part_kind='text'), event_kind='part_start') + ) + + +def test_handle_text_deltas_split_tags_after_text(): + """Test split thinking tags at chunk position 0 after text in previous chunk.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Chunk 1: "pre-" - creates TextPart + events = list(manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartStartEvent(index=0, part=TextPart(content='pre-', part_kind='text'), event_kind='part_start') + ) + + # Chunk 2: "" - completes the tag + events = list(manager.handle_text_delta(vendor_part_id='content', content='nk>', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartStartEvent(index=1, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot( + [TextPart(content='pre-', part_kind='text'), ThinkingPart(content='', part_kind='thinking')] + ) + + +def test_handle_text_deltas_split_tags_mid_chunk_treated_as_text(): + """Test that split tags mid-chunk (after other content in same chunk) are treated as text.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Chunk 1: "pre-" - appends to text (not recognized as completing a tag) + events = list(manager.handle_text_delta(vendor_part_id='content', content='nk>', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='nk>', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot([TextPart(content='pre-', part_kind='text')]) + + +def test_handle_text_deltas_split_tags_no_vendor_id(): + """Test that split tags don't work with vendor_part_id=None (no buffering).""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Chunk 1: "" - appends to text + events = list(manager.handle_text_delta(vendor_part_id=None, content='nk>', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='nk>', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot([TextPart(content='', part_kind='text')]) + + +def test_handle_text_deltas_false_start_then_real_tag(): + """Test buffering a false start, then processing real content.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Chunk 1: "', '') + + # To hit line 231, we need: + # 1. Buffer some content + # 2. Next chunk starts with '<' (to pass first check) + # 3. Combined length >= tag length + + # First chunk: exactly 6 chars + events = list(manager.handle_text_delta(vendor_part_id='content', content='' (7 chars) + events = list(manager.handle_text_delta(vendor_part_id='content', content='<', thinking_tags=thinking_tags)) + # 7 >= 7 is True, so line 231 returns False + assert len(events) == 1 + assert events[0] == snapshot( + PartStartEvent(index=0, part=TextPart(content='', '') + + # Complete start tag with vendor_part_id=None goes through simple path + # This covers lines 161-164 in _handle_text_delta_simple + events = list(manager.handle_text_delta(vendor_part_id=None, content='', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartStartEvent(index=0, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([ThinkingPart(content='', part_kind='thinking')]) + + +def test_exact_tag_length_boundary(): + """Test when buffered content exactly equals tag length.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Send content in one chunk that's exactly tag length + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + # Exact match creates ThinkingPart + assert len(events) == 1 + assert events[0] == snapshot( + PartStartEvent(index=0, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start') + ) + + +def test_buffered_content_flushed_on_finalize(): + """Test that buffered content is flushed when finalize is called.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Buffer partial tag + events = list(manager.handle_text_delta(vendor_part_id='content', content='', '') + + # Buffer for vendor_id_1 + list(manager.handle_text_delta(vendor_part_id='id1', content='82 branch).""" + manager = ModelResponsePartsManager() + # Add both empty and non-empty content to test the branch where buffered_content is falsy + # This ensures the loop continues after skipping the empty content + manager._thinking_tag_buffer['id1'] = '' # Will be skipped # pyright: ignore[reportPrivateUsage] + manager._thinking_tag_buffer['id2'] = 'content' # Will be flushed # pyright: ignore[reportPrivateUsage] + events = list(manager.finalize()) + assert len(events) == 1 # Only non-empty content produces events + assert isinstance(events[0], PartStartEvent) + assert events[0].part == TextPart(content='content') + assert manager._thinking_tag_buffer == {} # Buffer should be cleared # pyright: ignore[reportPrivateUsage] + + +def test_get_parts_after_finalize(): + """Test that get_parts returns flushed content after finalize (unit test).""" + # NOTE: This is a unit test of the manager. Real integration testing with + # StreamedResponse is done in test_finalize_integration(). + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + list(manager.handle_text_delta(vendor_part_id='content', content='', '') + + # Start thinking + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + + # Add thinking content + events = list(manager.handle_text_delta(vendor_part_id='content', content='reasoning', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert events[0] == snapshot( + PartDeltaEvent( + index=0, + delta=ThinkingPartDelta(content_delta='reasoning', part_delta_kind='thinking'), + event_kind='part_delta', + ) + ) + + # End tag with trailing text in same chunk + events = list( + manager.handle_text_delta(vendor_part_id='content', content='post-text', thinking_tags=thinking_tags) + ) + + # Should emit event for new TextPart + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert events[0].part == TextPart(content='post-text') + + # Final state + assert manager.get_parts() == snapshot( + [ThinkingPart(content='reasoning', part_kind='thinking'), TextPart(content='post-text', part_kind='text')] + ) + + +def test_split_end_tag_with_trailing_text(): + """Test split end tag with text after it.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Start thinking (tag at position 0) + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + + # Add thinking content + events = list(manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert isinstance(events[0], PartDeltaEvent) + + # Split end tag: "post" + events = list(manager.handle_text_delta(vendor_part_id='content', content='nk>post', thinking_tags=thinking_tags)) + + # Should close thinking and start text part + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert events[0].part == TextPart(content='post') + + assert manager.get_parts() == snapshot( + [ThinkingPart(content='thinking', part_kind='thinking'), TextPart(content='post', part_kind='text')] + ) + + +def test_thinking_content_before_end_tag_with_trailing(): + """Test thinking content before end tag, with trailing text in same chunk.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Start thinking + events = list(manager.handle_text_delta(vendor_part_id='content', content='', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + + # Send content + end tag + trailing all in one chunk + events = list( + manager.handle_text_delta( + vendor_part_id='content', content='reasoningafter', thinking_tags=thinking_tags + ) + ) + + # Should emit thinking delta event, then text start event + assert len(events) == 2 + assert isinstance(events[0], PartDeltaEvent) + assert events[0].delta == ThinkingPartDelta(content_delta='reasoning') + assert isinstance(events[1], PartStartEvent) + assert events[1].part == TextPart(content='after') + + assert manager.get_parts() == snapshot( + [ThinkingPart(content='reasoning', part_kind='thinking'), TextPart(content='after', part_kind='text')] + ) + + +# Issue 3b: START tags with trailing content +# These tests document the broken behavior where start tags with trailing content +# in the same chunk are not handled correctly. + + +def test_start_tag_with_trailing_content_same_chunk(): + """Test that content after start tag in same chunk is handled correctly.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Start tag with trailing content in same chunk + events = list( + manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags) + ) + + # Should emit event for new ThinkingPart, then delta for content + assert len(events) == 2 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + assert isinstance(events[1], PartDeltaEvent) + assert events[1].delta == ThinkingPartDelta(content_delta='thinking') + + # Final state + assert manager.get_parts() == snapshot([ThinkingPart(content='thinking', part_kind='thinking')]) + + +def test_split_start_tag_with_trailing_content(): + """Test split start tag with content after it.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Split start tag: "content" + events = list( + manager.handle_text_delta(vendor_part_id='content', content='nk>content', thinking_tags=thinking_tags) + ) + + # Should create ThinkingPart and add content + assert len(events) == 2 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + assert isinstance(events[1], PartDeltaEvent) + assert events[1].delta == ThinkingPartDelta(content_delta='content') + + assert manager.get_parts() == snapshot([ThinkingPart(content='content', part_kind='thinking')]) + + +def test_complete_sequence_start_tag_with_inline_content(): + """Test complete sequence: start tag with inline content and end tag.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # All in one chunk: "contentafter" + events = list( + manager.handle_text_delta( + vendor_part_id='content', content='contentafter', thinking_tags=thinking_tags + ) + ) + + # Should create ThinkingPart with content, then TextPart + # Exact event count may vary based on implementation + assert len(events) >= 2 + + # Final state should have both parts + assert manager.get_parts() == snapshot( + [ThinkingPart(content='content', part_kind='thinking'), TextPart(content='after', part_kind='text')] + ) + + +def test_text_then_start_tag_with_content(): + """Test text part followed by start tag with content.""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Chunk 1: "Hello " + events = list(manager.handle_text_delta(vendor_part_id='content', content='Hello ', thinking_tags=thinking_tags)) + assert len(events) == 1 + assert isinstance(events[0], PartStartEvent) + assert events[0].part == TextPart(content='Hello ') + + # Chunk 2: "reasoning" + events = list( + manager.handle_text_delta(vendor_part_id='content', content='reasoning', thinking_tags=thinking_tags) + ) + + # Should create ThinkingPart and add reasoning content + assert len(events) == 2 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + assert isinstance(events[1], PartDeltaEvent) + assert events[1].delta == ThinkingPartDelta(content_delta='reasoning') + + # Final state + assert manager.get_parts() == snapshot( + [TextPart(content='Hello ', part_kind='text'), ThinkingPart(content='reasoning', part_kind='thinking')] + ) + + +def test_text_and_start_tag_same_chunk(): + """Test text followed by start tag in the same chunk (covers line 297).""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Single chunk with text then start tag: "prefix" + events = list( + manager.handle_text_delta(vendor_part_id='content', content='prefix', thinking_tags=thinking_tags) + ) + + # Should create TextPart for "prefix", then ThinkingPart + assert len(events) == 2 + assert isinstance(events[0], PartStartEvent) + assert events[0].part == TextPart(content='prefix') + assert isinstance(events[1], PartStartEvent) + assert isinstance(events[1].part, ThinkingPart) + + # Final state + assert manager.get_parts() == snapshot( + [TextPart(content='prefix', part_kind='text'), ThinkingPart(content='', part_kind='thinking')] + ) + + +def test_text_and_start_tag_with_content_same_chunk(): + """Test text + start tag + content in the same chunk (covers lines 211, 223, 297).""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # Single chunk: "prefixthinking" + events = list( + manager.handle_text_delta( + vendor_part_id='content', content='prefixthinking', thinking_tags=thinking_tags + ) + ) + + # Should create TextPart, ThinkingPart, and add thinking content + assert len(events) >= 2 + + # Final state + assert manager.get_parts() == snapshot( + [TextPart(content='prefix', part_kind='text'), ThinkingPart(content='thinking', part_kind='thinking')] + ) + + +def test_start_tag_with_content_no_vendor_id(): + """Test start tag with trailing content when vendor_part_id=None. + + The content after the start tag should be added to the ThinkingPart, not create a separate TextPart. + """ + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # With vendor_part_id=None and start tag with content + events = list( + manager.handle_text_delta(vendor_part_id=None, content='thinking', thinking_tags=thinking_tags) + ) + + # Should create ThinkingPart and add content + assert len(events) >= 1 + assert isinstance(events[0], PartStartEvent) + assert isinstance(events[0].part, ThinkingPart) + + # Content should be in the ThinkingPart, not a separate TextPart + assert manager.get_parts() == snapshot([ThinkingPart(content='thinking')]) + + +def test_text_then_start_tag_no_vendor_id(): + """Test text before start tag when vendor_part_id=None (covers line 211 in _handle_text_delta_simple).""" + manager = ModelResponsePartsManager() + thinking_tags = ('', '') + + # With vendor_part_id=None and text before start tag + events = list(manager.handle_text_delta(vendor_part_id=None, content='text', thinking_tags=thinking_tags)) + + # Should create TextPart for "text", then ThinkingPart + assert len(events) == 2 + assert isinstance(events[0], PartStartEvent) + assert events[0].part == TextPart(content='text') + assert isinstance(events[1], PartStartEvent) + assert isinstance(events[1].part, ThinkingPart) + + # Final state + assert manager.get_parts() == snapshot([TextPart(content='text'), ThinkingPart(content='')])