diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 90322ebec..2faad11d3 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -610,6 +610,9 @@ def get_message_pieces( Exception: If there is an error retrieving the prompts, an exception is logged and an empty list is returned. """ + if prompt_ids is not None and len(prompt_ids) == 0: + return [] + conditions = [] if attack_id: conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 67a4292f8..53e2d0f18 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -108,6 +108,18 @@ def test_get_message_pieces_uuid_and_string_ids(sqlite_instance: MemoryInterface assert str(single_str_result[0].id) == str(uuid3) +def test_get_message_pieces_empty_prompt_ids_returns_empty(sqlite_instance: MemoryInterface): + piece = MessagePiece( + id=uuid.uuid4(), + role="user", + original_value="Test prompt", + converted_value="Test prompt", + ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) + + assert sqlite_instance.get_message_pieces(prompt_ids=[]) == [] + + def test_duplicate_memory(sqlite_instance: MemoryInterface): attack1 = PromptSendingAttack(objective_target=get_mock_target()) attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 6087af141..34358d5e9 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -131,6 +131,31 @@ def test_add_score_get_score( assert db_score[0].message_piece_id == prompt_id +def test_get_prompt_scores_empty_prompt_ids_returns_empty(sqlite_instance: MemoryInterface): + prompt_id = uuid4() + piece = MessagePiece( + id=prompt_id, + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[piece]) + + score = Score( + score_value=str(0.8), + score_value_description="High score", + score_type="float_scale", + score_category=["test"], + score_rationale="Test score", + score_metadata={"test": "metadata"}, + scorer_class_identifier=_test_scorer_id("TestScorer"), + message_piece_id=prompt_id, + ) + sqlite_instance.add_scores_to_memory(scores=[score]) + + assert sqlite_instance.get_prompt_scores(prompt_ids=[]) == [] + + def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): # Ensure that scores of duplicate prompts are linked back to the original original_id = uuid4()