Skip to content

Commit f2628ff

Browse files
authored
Add migration script for Azure Cosmos DB, old container to new container (#2442)
* Add cosmos migration files * Migrate to CosmosDB * Make mypy happy * Mock env the right way
1 parent a2a03cd commit f2628ff

File tree

2 files changed

+354
-0
lines changed

2 files changed

+354
-0
lines changed

scripts/cosmosdb_migration.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
A migration script to migrate data from CosmosDB to a new format.
3+
The old schema:
4+
id: str
5+
entra_oid: str
6+
title: str
7+
timestamp: int
8+
answers: list of 2-item list of str, dict
9+
10+
The new schema has two item types in the same container:
11+
For session items:
12+
id: str
13+
session_id: str
14+
entra_oid: str
15+
title: str
16+
timestamp: int
17+
type: str (always "session")
18+
version: str (always "cosmosdb-v2")
19+
20+
For message_pair items:
21+
id: str
22+
session_id: str
23+
entra_oid: str
24+
type: str (always "message_pair")
25+
version: str (always "cosmosdb-v2")
26+
question: str
27+
response: dict
28+
"""
29+
30+
import os
31+
32+
from azure.cosmos.aio import CosmosClient
33+
from azure.identity.aio import AzureDeveloperCliCredential
34+
35+
from load_azd_env import load_azd_env
36+
37+
38+
class CosmosDBMigrator:
39+
"""
40+
Migrator class for CosmosDB data migration.
41+
"""
42+
43+
def __init__(self, cosmos_account, database_name, credential=None):
44+
"""
45+
Initialize the migrator with CosmosDB account and database.
46+
47+
Args:
48+
cosmos_account: CosmosDB account name
49+
database_name: Database name
50+
credential: Azure credential, defaults to AzureDeveloperCliCredential
51+
"""
52+
self.cosmos_account = cosmos_account
53+
self.database_name = database_name
54+
self.credential = credential or AzureDeveloperCliCredential()
55+
self.client = None
56+
self.database = None
57+
self.old_container = None
58+
self.new_container = None
59+
60+
async def connect(self):
61+
"""
62+
Connect to CosmosDB and initialize containers.
63+
"""
64+
self.client = CosmosClient(
65+
url=f"https://{self.cosmos_account}.documents.azure.com:443/", credential=self.credential
66+
)
67+
self.database = self.client.get_database_client(self.database_name)
68+
self.old_container = self.database.get_container_client("chat-history")
69+
self.new_container = self.database.get_container_client("chat-history-v2")
70+
try:
71+
await self.old_container.read()
72+
except Exception:
73+
raise ValueError(f"Old container {self.old_container.id} does not exist")
74+
try:
75+
await self.new_container.read()
76+
except Exception:
77+
raise ValueError(f"New container {self.new_container.id} does not exist")
78+
79+
async def migrate(self):
80+
"""
81+
Migrate data from old schema to new schema.
82+
"""
83+
if not self.client:
84+
await self.connect()
85+
if not self.old_container or not self.new_container:
86+
raise ValueError("Containers do not exist")
87+
88+
query_results = self.old_container.query_items(query="SELECT * FROM c")
89+
90+
item_migration_count = 0
91+
async for page in query_results.by_page():
92+
async for old_item in page:
93+
batch_operations = []
94+
# Build session item
95+
session_item = {
96+
"id": old_item["id"],
97+
"version": "cosmosdb-v2",
98+
"session_id": old_item["id"],
99+
"entra_oid": old_item["entra_oid"],
100+
"title": old_item.get("title"),
101+
"timestamp": old_item.get("timestamp"),
102+
"type": "session",
103+
}
104+
batch_operations.append(("upsert", (session_item,)))
105+
106+
# Build message_pair
107+
answers = old_item.get("answers", [])
108+
for idx, answer in enumerate(answers):
109+
question = answer[0]
110+
response = answer[1]
111+
message_pair = {
112+
"id": f"{old_item['id']}-{idx}",
113+
"version": "cosmosdb-v2",
114+
"session_id": old_item["id"],
115+
"entra_oid": old_item["entra_oid"],
116+
"type": "message_pair",
117+
"question": question,
118+
"response": response,
119+
"order": idx,
120+
"timestamp": None,
121+
}
122+
batch_operations.append(("upsert", (message_pair,)))
123+
124+
# Execute the batch using partition key [entra_oid, session_id]
125+
await self.new_container.execute_item_batch(
126+
batch_operations=batch_operations, partition_key=[old_item["entra_oid"], old_item["id"]]
127+
)
128+
item_migration_count += 1
129+
print(f"Total items migrated: {item_migration_count}")
130+
131+
async def close(self):
132+
"""
133+
Close the CosmosDB client.
134+
"""
135+
if self.client:
136+
await self.client.close()
137+
138+
139+
async def migrate_cosmosdb_data():
140+
"""
141+
Legacy function for backward compatibility.
142+
Migrate data from CosmosDB to a new format.
143+
"""
144+
USE_CHAT_HISTORY_COSMOS = os.getenv("USE_CHAT_HISTORY_COSMOS", "").lower() == "true"
145+
if not USE_CHAT_HISTORY_COSMOS:
146+
raise ValueError("USE_CHAT_HISTORY_COSMOS must be set to true")
147+
AZURE_COSMOSDB_ACCOUNT = os.environ["AZURE_COSMOSDB_ACCOUNT"]
148+
AZURE_CHAT_HISTORY_DATABASE = os.environ["AZURE_CHAT_HISTORY_DATABASE"]
149+
150+
migrator = CosmosDBMigrator(AZURE_COSMOSDB_ACCOUNT, AZURE_CHAT_HISTORY_DATABASE)
151+
try:
152+
await migrator.migrate()
153+
finally:
154+
await migrator.close()
155+
156+
157+
if __name__ == "__main__":
158+
load_azd_env()
159+
160+
import asyncio
161+
162+
asyncio.run(migrate_cosmosdb_data())

tests/test_cosmosdb_migration.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import os
2+
from unittest.mock import AsyncMock, MagicMock, patch
3+
4+
import pytest
5+
6+
from scripts.cosmosdb_migration import CosmosDBMigrator, migrate_cosmosdb_data
7+
8+
# Sample old format item
9+
TEST_OLD_ITEM = {
10+
"id": "123",
11+
"entra_oid": "OID_X",
12+
"title": "This is a test message",
13+
"timestamp": 123456789,
14+
"answers": [
15+
[
16+
"What does a Product Manager do?",
17+
{
18+
"delta": {"role": "assistant"},
19+
"session_state": "143c0240-b2ee-4090-8e90-2a1c58124894",
20+
"message": {
21+
"content": "A Product Manager is responsible for product strategy and execution.",
22+
"role": "assistant",
23+
},
24+
},
25+
],
26+
[
27+
"What about a Software Engineer?",
28+
{
29+
"delta": {"role": "assistant"},
30+
"session_state": "243c0240-b2ee-4090-8e90-2a1c58124894",
31+
"message": {
32+
"content": "A Software Engineer writes code to create applications.",
33+
"role": "assistant",
34+
},
35+
},
36+
],
37+
],
38+
}
39+
40+
41+
class MockAsyncPageIterator:
42+
"""Helper class to mock an async page from CosmosDB"""
43+
44+
def __init__(self, items):
45+
self.items = items
46+
47+
def __aiter__(self):
48+
return self
49+
50+
async def __anext__(self):
51+
if not self.items:
52+
raise StopAsyncIteration
53+
return self.items.pop(0)
54+
55+
56+
class MockCosmosDBResultsIterator:
57+
"""Helper class to mock a paginated query result from CosmosDB"""
58+
59+
def __init__(self, data=[]):
60+
self.data = data
61+
self.continuation_token = None
62+
63+
def by_page(self, continuation_token=None):
64+
"""Return a paged iterator"""
65+
self.continuation_token = "next_token" if not continuation_token else continuation_token + "_next"
66+
# Return an async iterator that contains pages
67+
return MockPagesAsyncIterator(self.data)
68+
69+
70+
class MockPagesAsyncIterator:
71+
"""Helper class to mock an iterator of pages"""
72+
73+
def __init__(self, data):
74+
self.data = data
75+
self.continuation_token = "next_token"
76+
77+
def __aiter__(self):
78+
return self
79+
80+
async def __anext__(self):
81+
if not self.data:
82+
raise StopAsyncIteration
83+
# Return a page, which is an async iterator of items
84+
return MockAsyncPageIterator([self.data.pop(0)])
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_migrate_method():
89+
"""Test the migrate method of CosmosDBMigrator"""
90+
# Create mock objects
91+
mock_container = MagicMock()
92+
mock_database = MagicMock()
93+
mock_client = MagicMock()
94+
95+
# Set up the query_items mock to return our test item
96+
mock_container.query_items.return_value = MockCosmosDBResultsIterator([TEST_OLD_ITEM])
97+
98+
# Set up execute_item_batch as a spy to capture calls
99+
execute_batch_mock = AsyncMock()
100+
mock_container.execute_item_batch = execute_batch_mock
101+
102+
# Set up the database mock to return our container mocks
103+
mock_database.get_container_client.side_effect = lambda container_name: mock_container
104+
105+
# Set up the client mock
106+
mock_client.get_database_client.return_value = mock_database
107+
108+
# Create the migrator with our mocks
109+
migrator = CosmosDBMigrator("dummy_account", "dummy_db")
110+
migrator.client = mock_client
111+
migrator.database = mock_database
112+
migrator.old_container = mock_container
113+
migrator.new_container = mock_container
114+
115+
# Call the migrate method
116+
await migrator.migrate()
117+
118+
# Verify query_items was called with the right parameters
119+
mock_container.query_items.assert_called_once_with(query="SELECT * FROM c")
120+
121+
# Verify execute_item_batch was called
122+
execute_batch_mock.assert_called_once()
123+
124+
# Extract the arguments from the call
125+
call_args = execute_batch_mock.call_args[1]
126+
batch_operations = call_args["batch_operations"]
127+
partition_key = call_args["partition_key"]
128+
129+
# Verify the partition key
130+
assert partition_key == ["OID_X", "123"]
131+
132+
# We should have 3 operations: 1 for session and 2 for message pairs
133+
assert len(batch_operations) == 3
134+
135+
# Verify session item
136+
session_operation = batch_operations[0]
137+
assert session_operation[0] == "upsert"
138+
session_item = session_operation[1][0]
139+
assert session_item["id"] == "123"
140+
assert session_item["session_id"] == "123"
141+
assert session_item["entra_oid"] == "OID_X"
142+
assert session_item["title"] == "This is a test message"
143+
assert session_item["timestamp"] == 123456789
144+
assert session_item["type"] == "session"
145+
assert session_item["version"] == "cosmosdb-v2"
146+
147+
# Verify first message pair
148+
message1_operation = batch_operations[1]
149+
assert message1_operation[0] == "upsert"
150+
message1_item = message1_operation[1][0]
151+
assert message1_item["id"] == "123-0"
152+
assert message1_item["session_id"] == "123"
153+
assert message1_item["entra_oid"] == "OID_X"
154+
assert message1_item["question"] == "What does a Product Manager do?"
155+
assert message1_item["type"] == "message_pair"
156+
assert message1_item["order"] == 0
157+
158+
# Verify second message pair
159+
message2_operation = batch_operations[2]
160+
assert message2_operation[0] == "upsert"
161+
message2_item = message2_operation[1][0]
162+
assert message2_item["id"] == "123-1"
163+
assert message2_item["session_id"] == "123"
164+
assert message2_item["entra_oid"] == "OID_X"
165+
assert message2_item["question"] == "What about a Software Engineer?"
166+
assert message2_item["type"] == "message_pair"
167+
assert message2_item["order"] == 1
168+
169+
170+
@pytest.mark.asyncio
171+
async def test_migrate_cosmosdb_data(monkeypatch):
172+
"""Test the main migrate_cosmosdb_data function"""
173+
with patch.dict(os.environ, clear=True):
174+
monkeypatch.setenv("USE_CHAT_HISTORY_COSMOS", "true")
175+
monkeypatch.setenv("AZURE_COSMOSDB_ACCOUNT", "dummy_account")
176+
monkeypatch.setenv("AZURE_CHAT_HISTORY_DATABASE", "dummy_db")
177+
178+
# Create a mock for the CosmosDBMigrator
179+
with patch("scripts.cosmosdb_migration.CosmosDBMigrator") as mock_migrator_class:
180+
# Set up the mock for the migrator instance
181+
mock_migrator = AsyncMock()
182+
mock_migrator_class.return_value = mock_migrator
183+
184+
# Call the function
185+
await migrate_cosmosdb_data()
186+
187+
# Verify the migrator was created with the right parameters
188+
mock_migrator_class.assert_called_once_with("dummy_account", "dummy_db")
189+
190+
# Verify migrate and close were called
191+
mock_migrator.migrate.assert_called_once()
192+
mock_migrator.close.assert_called_once()

0 commit comments

Comments
 (0)