diff --git a/pyrit/message_normalizer/conversation_context_normalizer.py b/pyrit/message_normalizer/conversation_context_normalizer.py index 46e3c96308..54f4bf0166 100644 --- a/pyrit/message_normalizer/conversation_context_normalizer.py +++ b/pyrit/message_normalizer/conversation_context_normalizer.py @@ -23,6 +23,13 @@ class ConversationContextNormalizer(MessageStringNormalizer): ... """ + _ROLE_LABELS = { + "user": "User", + "assistant": "Assistant", + "tool": "Tool", + "developer": "Developer", + } + async def normalize_string_async(self, messages: list[Message]) -> str: """ Normalize a list of messages into a turn-based context string. @@ -55,7 +62,7 @@ async def normalize_string_async(self, messages: list[Message]) -> str: # Format the piece content content = self._format_piece_content(piece) - role_label = "User" if piece.api_role == "user" else "Assistant" + role_label = self._ROLE_LABELS.get(piece.api_role, piece.api_role.capitalize()) context_parts.append(f"{role_label}: {content}") return "\n".join(context_parts) diff --git a/tests/unit/message_normalizer/test_conversation_context_normalizer.py b/tests/unit/message_normalizer/test_conversation_context_normalizer.py index 855668e72d..4054f790cb 100644 --- a/tests/unit/message_normalizer/test_conversation_context_normalizer.py +++ b/tests/unit/message_normalizer/test_conversation_context_normalizer.py @@ -108,3 +108,31 @@ async def test_shows_original_if_different_from_converted(self): assert "converted text" in result assert "(original: original text)" in result + + @pytest.mark.asyncio + async def test_preserves_tool_role_label(self): + """Test that tool messages keep the Tool label in context output.""" + normalizer = ConversationContextNormalizer() + messages = [ + _make_message("user", "Call the weather tool"), + _make_message("tool", "72F and sunny"), + ] + + result = await normalizer.normalize_string_async(messages) + + assert "Tool: 72F and sunny" in result + assert "Assistant: 72F and sunny" not in result + + @pytest.mark.asyncio + async def test_preserves_developer_role_label(self): + """Test that developer messages keep the Developer label in context output.""" + normalizer = ConversationContextNormalizer() + messages = [ + _make_message("user", "Use concise units"), + _make_message("developer", "Prefer metric units"), + ] + + result = await normalizer.normalize_string_async(messages) + + assert "Developer: Prefer metric units" in result + assert "Assistant: Prefer metric units" not in result