From b74fb381c1f16e79d27c89310c716a1b25754d9c Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Tue, 17 Mar 2026 03:54:10 +0000 Subject: [PATCH] Respect export type in SQLite conversation exports --- pyrit/memory/sqlite_memory.py | 23 +++++- .../memory_interface/test_interface_export.py | 80 +++++++++++++++++++ 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 58ae9098e..72f7818e6 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -34,6 +34,14 @@ Model = TypeVar("Model") +class _ExportableConversationPiece: + def __init__(self, data: dict[str, Any]) -> None: + self._data = data + + def to_dict(self) -> dict[str, Any]: + return self._data + + class SQLiteMemory(MemoryInterface, metaclass=Singleton): """ A memory interface that uses SQLite as the backend database. @@ -418,9 +426,18 @@ def export_conversations( piece_data["scores"] = [score.to_dict() for score in piece_scores] merged_data.append(piece_data) - # Export to JSON manually since the exporter expects objects but we have dicts - with open(file_path, "w") as f: - json.dump(merged_data, f, indent=4) + if not merged_data: + if export_type == "json": + with open(file_path, "w", encoding="utf-8") as f: + json.dump(merged_data, f, indent=4) + elif export_type in self.exporter.export_strategies: + file_path.write_text("", encoding="utf-8") + else: + raise ValueError(f"Unsupported export format: {export_type}") + return file_path + + exportable_pieces = [_ExportableConversationPiece(data=piece_data) for piece_data in merged_data] + self.exporter.export_data(exportable_pieces, file_path=file_path, export_type=export_type) return file_path def print_schema(self) -> None: diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py index 14064b718..aaf4dc57b 100644 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ b/tests/unit/memory/memory_interface/test_interface_export.py @@ -3,6 +3,7 @@ import os import tempfile +import csv from collections.abc import Sequence from pathlib import Path from unittest.mock import MagicMock, patch @@ -151,3 +152,82 @@ def test_export_all_conversations_with_scores_empty_data(sqlite_instance: Memory # Clean up the temp file if file_path.exists(): os.remove(file_path) + + +def test_export_all_conversations_with_scores_csv_format(sqlite_instance: MemoryInterface): + sqlite_instance.exporter = MemoryExporter() + + with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_file: + file_path = Path(temp_file.name) + temp_file.close() + + try: + with ( + patch.object(sqlite_instance, "get_message_pieces") as mock_get_pieces, + patch.object(sqlite_instance, "get_prompt_scores") as mock_get_scores, + ): + mock_piece = MagicMock() + mock_piece.id = "piece_id_1234" + mock_piece.to_dict.return_value = { + "id": "piece_id_1234", + "converted_value": "sample piece", + } + + mock_score = MagicMock() + mock_score.message_piece_id = "piece_id_1234" + mock_score.to_dict.return_value = {"message_piece_id": "piece_id_1234", "score_value": 10} + + mock_get_pieces.return_value = [mock_piece] + mock_get_scores.return_value = [mock_score] + + sqlite_instance.export_conversations(file_path=file_path, export_type="csv") + + with open(file_path, newline="") as exported_file: + reader = csv.DictReader(exported_file) + assert reader.fieldnames == ["id", "converted_value", "scores"] + rows = list(reader) + + assert len(rows) == 1 + assert rows[0]["id"] == "piece_id_1234" + assert rows[0]["converted_value"] == "sample piece" + finally: + if file_path.exists(): + os.remove(file_path) + + +def test_export_all_conversations_with_scores_markdown_format(sqlite_instance: MemoryInterface): + sqlite_instance.exporter = MemoryExporter() + + with tempfile.NamedTemporaryFile(delete=False, suffix=".md") as temp_file: + file_path = Path(temp_file.name) + temp_file.close() + + try: + with ( + patch.object(sqlite_instance, "get_message_pieces") as mock_get_pieces, + patch.object(sqlite_instance, "get_prompt_scores") as mock_get_scores, + ): + mock_piece = MagicMock() + mock_piece.id = "piece_id_1234" + mock_piece.to_dict.return_value = { + "id": "piece_id_1234", + "converted_value": "sample piece", + } + + mock_score = MagicMock() + mock_score.message_piece_id = "piece_id_1234" + mock_score.to_dict.return_value = {"message_piece_id": "piece_id_1234", "score_value": 10} + + mock_get_pieces.return_value = [mock_piece] + mock_get_scores.return_value = [mock_score] + + sqlite_instance.export_conversations(file_path=file_path, export_type="md") + + exported_content = file_path.read_text(encoding="utf-8") + + assert exported_content.startswith("| id | converted_value | scores |") + assert "piece_id_1234" in exported_content + assert "sample piece" in exported_content + finally: + if file_path.exists(): + os.remove(file_path)