Skip to content
39 changes: 23 additions & 16 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,32 +1169,39 @@ async def _process_message_history(
def _clean_message_history(messages: list[_messages.ModelMessage]) -> list[_messages.ModelMessage]:
"""Clean the message history by merging consecutive messages of the same type."""
clean_messages: list[_messages.ModelMessage] = []
# Add parts to a set to ensure no duplication
for message in messages:
last_message = clean_messages[-1] if len(clean_messages) > 0 else None

if isinstance(message, _messages.ModelRequest):
if (
last_message
and isinstance(last_message, _messages.ModelRequest)
# Requests can only be merged if they have the same instructions
and (
if last_message and isinstance(last_message, _messages.ModelRequest):
same_instructions = (
not last_message.instructions
or not message.instructions
or last_message.instructions == message.instructions
)
):
parts = [*last_message.parts, *message.parts]
parts.sort(
# Tool return parts always need to be at the start
key=lambda x: 0 if isinstance(x, _messages.ToolReturnPart | _messages.RetryPromptPart) else 1
last_is_stub = all(
isinstance(part, _messages.ToolReturnPart | _messages.RetryPromptPart)
for part in last_message.parts
)
merged_message = _messages.ModelRequest(
parts=parts,
instructions=last_message.instructions or message.instructions,
message_is_stub = all(
isinstance(part, _messages.ToolReturnPart | _messages.RetryPromptPart) for part in message.parts
)
clean_messages[-1] = merged_message
else:
clean_messages.append(message)

if same_instructions and (not last_is_stub or message_is_stub):
parts = [*last_message.parts, *message.parts]
parts.sort(
# Tool return parts always need to be at the start
key=lambda x: 0 if isinstance(x, _messages.ToolReturnPart | _messages.RetryPromptPart) else 1
)
merged_message = _messages.ModelRequest(
parts=parts,
instructions=last_message.instructions or message.instructions,
)
clean_messages[-1] = merged_message
continue

clean_messages.append(message)
elif isinstance(message, _messages.ModelResponse): # pragma: no branch
if (
last_message
Expand Down
6 changes: 4 additions & 2 deletions tests/test_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,12 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon
content='Final result processed.',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
),
UserPromptPart(content='Second message', timestamp=IsDatetime()),
)
],
),
ModelRequest(
parts=[UserPromptPart(content='Second message', timestamp=IsDatetime())],
),
]
)

Expand Down
26 changes: 26 additions & 0 deletions tests/test_history_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,29 @@ def __call__(self, _: RunContext, messages: list[ModelMessage]) -> list[ModelMes
]
)
assert result.new_messages() == result.all_messages()[-2:]


def test_clean_message_history_keeps_tool_stub_separate():
"""Regression guard for b26a6872f that merged tool-return stubs into the next user request."""

# TODO: imports should get moved to the top whenever we open P/R
from pydantic_ai._agent_graph import _clean_message_history # pyright: ignore
from pydantic_ai.messages import ToolReturnPart

tool_stub = ModelRequest(
parts=[
ToolReturnPart(
tool_name='summarize',
content='summaries galore',
tool_call_id='call-1',
)
]
)
user_request = ModelRequest(parts=[UserPromptPart(content='fresh prompt')])

cleaned = _clean_message_history([tool_stub, user_request])

assert len(cleaned[0].parts) == 1, 'tool-return part started as unique and should remain unique'
assert len(cleaned) == 2, 'tool-return stubs must remain separate from subsequent user prompts'
assert isinstance(cleaned[0].parts[0], ToolReturnPart)
assert isinstance(cleaned[1].parts[0], UserPromptPart)