From 5de800659c14fb1c53f303b3b203363e4dc92983 Mon Sep 17 00:00:00 2001 From: biefan Date: Tue, 17 Mar 2026 15:40:22 +0800 Subject: [PATCH] Preserve empty responses in prompt normalizer batches --- pyrit/prompt_normalizer/prompt_normalizer.py | 15 ++-- .../test_prompt_normalizer.py | 81 +++++++++++++------ 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index b730a58669..740ce8f35e 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -137,7 +137,15 @@ async def send_prompt_async( # handling empty responses message list and None responses if not responses or not any(responses): - return None + empty_response = construct_response_from_request( + request=request.message_pieces[0], + response_text_pieces=[""], + response_type="text", + error="empty", + ) + await self._calc_hash(request=empty_response) + self._memory.add_message_to_memory(request=empty_response) + return empty_response # Process all response messages (targets return list[Message]) # Only apply response converters to the last message (final response) @@ -191,7 +199,7 @@ async def send_prompt_batch_to_target_async( "conversation_id", ] - responses = await batch_task_async( + return await batch_task_async( prompt_target=target, batch_size=batch_size, items_to_batch=batch_items, @@ -202,9 +210,6 @@ async def send_prompt_batch_to_target_async( attack_identifier=attack_identifier, ) - # Filter out None responses (e.g., from empty responses) - return [response for response in responses if response is not None] - async def convert_values( self, converter_configurations: list[PromptConverterConfiguration], diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index 6386a1024a..56058557f6 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -118,17 +118,22 @@ async def test_send_prompt_async_multiple_converters(mock_memory_instance, seed_ @pytest.mark.asyncio async def test_send_prompt_async_no_response_adds_memory(mock_memory_instance, seed_group): - prompt_target = AsyncMock() + prompt_target = MagicMock() prompt_target.send_prompt_async = AsyncMock(return_value=None) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_group.prompts[0].value, role="user") - await normalizer.send_prompt_async(message=message, target=prompt_target) - assert mock_memory_instance.add_message_to_memory.call_count == 1 + response = await normalizer.send_prompt_async(message=message, target=prompt_target) + assert mock_memory_instance.add_message_to_memory.call_count == 2 request = mock_memory_instance.add_message_to_memory.call_args[1]["request"] assert_message_piece_hashes_set(request) + assert response.message_pieces[0].response_error == "empty" + assert response.message_pieces[0].original_value == "" + assert response.message_pieces[0].original_value_data_type == "text" + assert_message_piece_hashes_set(response) @pytest.mark.asyncio @@ -184,34 +189,29 @@ async def test_send_prompt_async_request_response_added_to_memory(mock_memory_in @pytest.mark.asyncio async def test_send_prompt_async_exception(mock_memory_instance, seed_group): - prompt_target = AsyncMock() + prompt_target = MagicMock() + prompt_target.send_prompt_async = AsyncMock(side_effect=ValueError("test_exception")) + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") seed_prompt_value = seed_group.prompts[0].value normalizer = PromptNormalizer() message = Message.from_prompt(prompt=seed_prompt_value, role="user") - with patch("pyrit.models.construct_response_from_request") as mock_construct: - mock_construct.return_value = "test" + with pytest.raises(Exception, match="Error sending prompt with conversation ID"): + await normalizer.send_prompt_async(message=message, target=prompt_target) - try: - await normalizer.send_prompt_async(message=message, target=prompt_target) - except ValueError: - assert mock_memory_instance.add_message_to_memory.call_count == 2 + assert mock_memory_instance.add_message_to_memory.call_count == 2 - # Validate that first request is added to memory, then exception is added to memory - assert ( - seed_prompt_value - == mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"] - .message_pieces[0] - .original_value - ) - assert ( - mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"] - .message_pieces[0] - .original_value - == "test_exception" - ) + # Validate that first request is added to memory, then exception is added to memory + assert ( + seed_prompt_value + == mock_memory_instance.add_message_to_memory.call_args_list[0][1]["request"].message_pieces[0].original_value + ) + assert ( + "test_exception" + in mock_memory_instance.add_message_to_memory.call_args_list[1][1]["request"].message_pieces[0].original_value + ) @pytest.mark.asyncio @@ -383,6 +383,41 @@ async def test_prompt_normalizer_send_prompt_batch_async_throws( assert len(results) == 1 +@pytest.mark.asyncio +async def test_prompt_normalizer_send_prompt_batch_async_preserves_empty_response_alignment( + mock_memory_instance, +): + prompt_target = MagicMock() + prompt_target._max_requests_per_minute = None + prompt_target.get_identifier.return_value = get_mock_target_identifier("MockTarget") + prompt_target.send_prompt_async = AsyncMock( + side_effect=[ + [MessagePiece(role="assistant", original_value="response 1", conversation_id="conv-1").to_message()], + None, + ] + ) + + normalizer = PromptNormalizer() + requests = [ + NormalizerRequest( + message=Message.from_prompt(prompt="prompt 1", role="user"), + conversation_id="conv-1", + ), + NormalizerRequest( + message=Message.from_prompt(prompt="prompt 2", role="user"), + conversation_id="conv-2", + ), + ] + + results = await normalizer.send_prompt_batch_to_target_async(requests=requests, target=prompt_target, batch_size=2) + + assert len(results) == 2 + assert results[0].message_pieces[0].original_value == "response 1" + assert results[1].message_pieces[0].response_error == "empty" + assert results[1].message_pieces[0].original_value == "" + assert results[1].message_pieces[0].conversation_id == "conv-2" + + @pytest.mark.asyncio async def test_build_message(mock_memory_instance, seed_group): # This test is obsolete since _build_message was removed and message preparation