diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index d7a54c5c71..13dc4d6b15 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -1169,32 +1169,39 @@ async def _process_message_history( def _clean_message_history(messages: list[_messages.ModelMessage]) -> list[_messages.ModelMessage]: """Clean the message history by merging consecutive messages of the same type.""" clean_messages: list[_messages.ModelMessage] = [] + # Add parts to a set to ensure no duplication for message in messages: last_message = clean_messages[-1] if len(clean_messages) > 0 else None if isinstance(message, _messages.ModelRequest): - if ( - last_message - and isinstance(last_message, _messages.ModelRequest) - # Requests can only be merged if they have the same instructions - and ( + if last_message and isinstance(last_message, _messages.ModelRequest): + same_instructions = ( not last_message.instructions or not message.instructions or last_message.instructions == message.instructions ) - ): - parts = [*last_message.parts, *message.parts] - parts.sort( - # Tool return parts always need to be at the start - key=lambda x: 0 if isinstance(x, _messages.ToolReturnPart | _messages.RetryPromptPart) else 1 + last_is_stub = all( + isinstance(part, _messages.ToolReturnPart | _messages.RetryPromptPart) + for part in last_message.parts ) - merged_message = _messages.ModelRequest( - parts=parts, - instructions=last_message.instructions or message.instructions, + message_is_stub = all( + isinstance(part, _messages.ToolReturnPart | _messages.RetryPromptPart) for part in message.parts ) - clean_messages[-1] = merged_message - else: - clean_messages.append(message) + + if same_instructions and (not last_is_stub or message_is_stub): + parts = [*last_message.parts, *message.parts] + parts.sort( + # Tool return parts always need to be at the start + key=lambda x: 0 if isinstance(x, _messages.ToolReturnPart | _messages.RetryPromptPart) else 1 + ) + merged_message = _messages.ModelRequest( + parts=parts, + instructions=last_message.instructions or message.instructions, + ) + clean_messages[-1] = merged_message + continue + + clean_messages.append(message) elif isinstance(message, _messages.ModelResponse): # pragma: no branch if ( last_message diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 93e56f12c0..048f8a1e8f 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -623,10 +623,12 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), - ), - UserPromptPart(content='Second message', timestamp=IsDatetime()), + ) ], ), + ModelRequest( + parts=[UserPromptPart(content='Second message', timestamp=IsDatetime())], + ), ] ) diff --git a/tests/test_history_processor.py b/tests/test_history_processor.py index 54d38935d2..78660ef517 100644 --- a/tests/test_history_processor.py +++ b/tests/test_history_processor.py @@ -801,3 +801,29 @@ def __call__(self, _: RunContext, messages: list[ModelMessage]) -> list[ModelMes ] ) assert result.new_messages() == result.all_messages()[-2:] + + +def test_clean_message_history_keeps_tool_stub_separate(): + """Regression guard for b26a6872f that merged tool-return stubs into the next user request.""" + + # TODO: imports should get moved to the top whenever we open P/R + from pydantic_ai._agent_graph import _clean_message_history # pyright: ignore + from pydantic_ai.messages import ToolReturnPart + + tool_stub = ModelRequest( + parts=[ + ToolReturnPart( + tool_name='summarize', + content='summaries galore', + tool_call_id='call-1', + ) + ] + ) + user_request = ModelRequest(parts=[UserPromptPart(content='fresh prompt')]) + + cleaned = _clean_message_history([tool_stub, user_request]) + + assert len(cleaned[0].parts) == 1, 'tool-return part started as unique and should remain unique' + assert len(cleaned) == 2, 'tool-return stubs must remain separate from subsequent user prompts' + assert isinstance(cleaned[0].parts[0], ToolReturnPart) + assert isinstance(cleaned[1].parts[0], UserPromptPart)