From b7e12bf65679d8799d8fb8cb80c1f428b9129b1b Mon Sep 17 00:00:00 2001 From: biefan Date: Tue, 17 Mar 2026 15:30:35 +0800 Subject: [PATCH] Preserve user message structure in generic system squash --- .../generic_system_squash.py | 34 ++++++++++-- .../test_generic_system_squash_normalizer.py | 53 +++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/pyrit/message_normalizer/generic_system_squash.py b/pyrit/message_normalizer/generic_system_squash.py index 56850289da..86ab06178b 100644 --- a/pyrit/message_normalizer/generic_system_squash.py +++ b/pyrit/message_normalizer/generic_system_squash.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import copy from pyrit.message_normalizer.message_normalizer import MessageListNormalizer -from pyrit.models import Message +from pyrit.models import Message, MessagePiece class GenericSystemSquashNormalizer(MessageListNormalizer[Message]): @@ -43,12 +44,35 @@ async def normalize_async(self, messages: list[Message]) -> list[Message]: # Only system message, convert to user message return [Message.from_prompt(prompt=first_piece.converted_value, role="user")] - # Combine system with first user message + user_message_index = next( + (i for i, message in enumerate(messages[1:], start=1) if message.api_role == "user"), + -1, + ) + if user_message_index == -1: + # Preserve the instruction content without rewriting non-user messages. + return [Message.from_prompt(prompt=first_piece.converted_value, role="user")] + list(messages[1:]) + + # Combine system with the first user message system_content = first_piece.converted_value - user_piece = messages[1].get_piece() + user_piece = messages[user_message_index].get_piece() user_content = user_piece.converted_value combined_content = f"### Instructions ###\n\n{system_content}\n\n######\n\n{user_content}" - squashed_message = Message.from_prompt(prompt=combined_content, role="user") + squashed_message = copy.deepcopy(messages[user_message_index]) + + if squashed_message.message_pieces[0].converted_value_data_type == "text": + squashed_message.message_pieces[0].original_value = combined_content + squashed_message.message_pieces[0].converted_value = combined_content + else: + squashed_message.message_pieces.insert( + 0, + MessagePiece( + role="user", + original_value=combined_content, + conversation_id=user_piece.conversation_id, + sequence=user_piece.sequence, + ), + ) + # Return the squashed message followed by remaining messages (skip first two) - return [squashed_message] + list(messages[2:]) + return list(messages[1:user_message_index]) + [squashed_message] + list(messages[user_message_index + 1 :]) diff --git a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py index 40445f9771..45272c9def 100644 --- a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py +++ b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py @@ -70,3 +70,56 @@ async def test_generic_squash_normalize_to_dicts_async(): assert "### Instructions ###" in result[0]["converted_value"] assert "System message" in result[0]["converted_value"] assert "User message" in result[0]["converted_value"] + + +@pytest.mark.asyncio +async def test_generic_squash_preserves_multipart_user_message(): + """Test that squashing keeps non-text user pieces instead of collapsing to plain text.""" + conversation_id = "conv-1" + messages = [ + _make_message("system", "System message"), + Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="User message", + conversation_id=conversation_id, + sequence=0, + ), + MessagePiece( + role="user", + original_value="/tmp/example.png", + original_value_data_type="image_path", + conversation_id=conversation_id, + sequence=0, + ), + ] + ), + ] + + result = await GenericSystemSquashNormalizer().normalize_async(messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + assert len(result[0].message_pieces) == 2 + assert result[0].get_value() == "### Instructions ###\n\nSystem message\n\n######\n\nUser message" + assert result[0].message_pieces[1].converted_value == "/tmp/example.png" + assert result[0].message_pieces[1].converted_value_data_type == "image_path" + + +@pytest.mark.asyncio +async def test_generic_squash_uses_first_user_message_instead_of_rewriting_assistant(): + """Test that squash targets the first user message even if assistant messages appear first.""" + messages = [ + _make_message("system", "System message"), + _make_message("assistant", "Assistant message"), + _make_message("user", "User message"), + ] + + result = await GenericSystemSquashNormalizer().normalize_async(messages) + + assert len(result) == 2 + assert result[0].api_role == "assistant" + assert result[0].get_value() == "Assistant message" + assert result[1].api_role == "user" + assert result[1].get_value() == "### Instructions ###\n\nSystem message\n\n######\n\nUser message"