diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index b507cf317..0da96d44b 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -167,7 +167,7 @@ def from_messages( current_sequence = starting_sequence for message in messages: - role: ChatMessageRole = "assistant" if message.api_role == "assistant" else "user" + role: ChatMessageRole = message.api_role for piece in message.message_pieces: seed_prompt = SeedPrompt( diff --git a/tests/unit/models/test_seed.py b/tests/unit/models/test_seed.py index 32414e0f7..a00d1747e 100644 --- a/tests/unit/models/test_seed.py +++ b/tests/unit/models/test_seed.py @@ -1226,6 +1226,25 @@ def test_from_messages_multiple_messages(): assert result[2].sequence == 2 +@pytest.mark.parametrize( + ("role", "expected_role"), + [ + ("system", "system"), + ("developer", "developer"), + ("tool", "tool"), + ("simulated_assistant", "assistant"), + ], +) +def test_from_messages_preserves_supported_roles(role, expected_role): + """Test from_messages preserves supported API roles instead of collapsing to user.""" + message = Message(message_pieces=[MessagePiece(role=role, original_value=f"{role} message")]) + + result = SeedPrompt.from_messages([message]) + + assert len(result) == 1 + assert result[0].role == expected_role + + def test_from_messages_multipart_message(): """Test from_messages with a multipart message (e.g., text + image).""" conv_id = str(uuid.uuid4())