Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
python_requires='>=3.10',
install_requires=[
'absl-py',
'alembic==1.4.3',
'alembic==1.16.5',
'async_generator',
'attrs',
'cloud-sql-python-connector',
Expand All @@ -46,7 +46,7 @@
'immutabledict',
'kubernetes',
'pyyaml',
'sqlalchemy==1.2.19',
'sqlalchemy==2.0.43',
'sqlparse',
'termcolor',
],
Expand Down
57 changes: 57 additions & 0 deletions xmanager/xm_local/storage/alembic_migration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest
import tempfile
import os
import sqlalchemy
from sqlalchemy import inspect
from alembic.config import Config
from alembic import command
from alembic import script as alembic_script

class AlembicMigrationTest(unittest.TestCase):

def setUp(self):
self.temp_db_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
self.temp_db_file.close()
self.db_url = f"sqlite:///{self.temp_db_file.name}"

alembic_dir = os.path.join(
os.path.dirname(__file__), 'alembic'
)
self.alembic_cfg = Config(os.path.join(alembic_dir, 'alembic.ini'))
self.alembic_cfg.set_main_option("script_location", alembic_dir)
self.alembic_cfg.set_main_option("sqlalchemy.url", self.db_url)

self.engine = sqlalchemy.create_engine(self.db_url)

def tearDown(self):
os.unlink(self.temp_db_file.name)

def test_migrations_upgrade_and_downgrade(self):
with self.engine.connect() as connection:
inspector = inspect(connection)
self.assertEqual(inspector.get_table_names(), [])

print(f"Upgrading database to head: {self.db_url}")
command.upgrade(self.alembic_cfg, "head")

# Verify schema after upgrade
with self.engine.connect() as connection:
inspector = inspect(connection)
expected_tables = ['experiment', 'work_unit', 'job'] # Example tables from your existing migration
actual_tables = inspector.get_table_names()

self.assertGreaterEqual(len(actual_tables), len(expected_tables),
"Not all expected tables were created during upgrade.")
for table in expected_tables:
self.assertIn(table, actual_tables, f"Table '{table}' not found after upgrade.")

experiment_columns = [c['name'] for c in inspector.get_columns('experiment')]
self.assertIn('experiment_id', experiment_columns)
self.assertIn('experiment_title', experiment_columns)

print(f"Attempting to downgrade database to base: {self.db_url}")
with self.assertRaises(RuntimeError):
command.downgrade(self.alembic_cfg, "base")

if __name__ == '__main__':
unittest.main()
114 changes: 68 additions & 46 deletions xmanager/xm_local/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,14 @@ def create_engine(settings: SqlConnectionSettings) -> Engine:
f'+{settings.driver}' if settings.driver else ''
)

url = sqlalchemy.engine.url.URL(
url = sqlalchemy.URL.create(
drivername=driver_name,
username=settings.username,
password=settings.password,
host=settings.host,
port=settings.port,
database=settings.db_name,
query={},
)

return sqlalchemy.engine.create_engine(url)
Expand All @@ -115,7 +116,9 @@ def create_engine(settings: SqlConnectionSettings) -> Engine:
"Can't use SqliteConnector with a backendother than `sqlite`"
)

if not os.path.isdir(os.path.dirname(settings.db_name)):
if settings.db_name != ':memory:' and not os.path.isdir(
os.path.dirname(settings.db_name)
):
os.makedirs(os.path.dirname(settings.db_name))

return GenericSqlConnector.create_engine(settings)
Expand Down Expand Up @@ -154,7 +157,7 @@ def get_connection():
password=settings.password,
db=settings.db_name)

url = sqlalchemy.engine.url.URL(drivername=f'{settings.backend}+{driver}',
url = sqlalchemy.URL.create(drivername=f'{settings.backend}+{driver}',
host='localhost')
return sqlalchemy.create_engine(url, creator=get_connection)

Expand All @@ -167,9 +170,6 @@ def __init__(
):
self.settings = settings
self.engine: Engine = connector.create_engine(settings)
# https://github.com/sqlalchemy/sqlalchemy/issues/5645
# TODO: Remove this line after using sqlalchemy>=1.14.
self.engine.dialect.description_encoding = None
storage_dir = os.path.dirname(__file__)

alembic_ini_path = os.path.join(storage_dir, 'alembic.ini')
Expand Down Expand Up @@ -216,9 +216,7 @@ def maybe_migrate_database_version(self):
"""Enforces the latest version of the database to be used."""
db_version = self.database_version()
with self.engine.connect() as connection:
legacy_sqlite_db = self.engine.dialect.has_table(
connection, 'VersionHistory'
)
legacy_sqlite_db = sqlalchemy.inspect(connection).has_table('VersionHistory')

need_to_update = (
db_version != self.latest_version_available() and db_version
Expand All @@ -239,36 +237,46 @@ def insert_experiment(
'INSERT INTO experiment (experiment_id, experiment_title) '
'VALUES (:experiment_id, :experiment_title)'
)
self.engine.execute(
query, experiment_id=experiment_id, experiment_title=experiment_title
)
with self.engine.begin() as connection:
connection.execute(
query,
{
'experiment_id': experiment_id,
'experiment_title': experiment_title,
},
)

def insert_work_unit(self, experiment_id: int, work_unit_id: int) -> None:
query = text(
'INSERT INTO work_unit (experiment_id, work_unit_id) '
'VALUES (:experiment_id, :work_unit_id)'
)
self.engine.execute(
query, experiment_id=experiment_id, work_unit_id=work_unit_id
)
with self.engine.begin() as connection:
connection.execute(
query,
{'experiment_id': experiment_id, 'work_unit_id': work_unit_id},
)

def insert_vertex_job(
self, experiment_id: int, work_unit_id: int, vertex_job_id: str
) -> None:
job = data_pb2.Job(caip=data_pb2.AIPlatformJob(resource_name=vertex_job_id))
data = text_format.MessageToBytes(job)
data = text_format.MessageToString(job)
Comment on lines -259 to +264
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed like a bug to me. The database column for job_data is of type sa.String(255), so I think the data should be formatted as string and not bytes. MessageToString is also in line with what insert_kubernetes_job below does already.

query = text(
'INSERT INTO '
'job (experiment_id, work_unit_id, job_name, job_data) '
'VALUES (:experiment_id, :work_unit_id, :job_name, :job_data)'
)
self.engine.execute(
query,
experiment_id=experiment_id,
work_unit_id=work_unit_id,
job_name=vertex_job_id,
job_data=data,
)
with self.engine.begin() as connection:
connection.execute(
query,
{
'experiment_id': experiment_id,
'work_unit_id': work_unit_id,
'job_name': vertex_job_id,
'job_data': data,
},
)

def insert_kubernetes_job(
self, experiment_id: int, work_unit_id: int, namespace: str, job_name: str
Expand All @@ -285,31 +293,34 @@ def insert_kubernetes_job(
'job (experiment_id, work_unit_id, job_name, job_data) '
'VALUES (:experiment_id, :work_unit_id, :job_name, :job_data)'
)
self.engine.execute(
query,
experiment_id=experiment_id,
work_unit_id=work_unit_id,
job_name=job_name,
job_data=data,
)
with self.engine.begin() as connection:
connection.execute(
query,
{
'experiment_id': experiment_id,
'work_unit_id': work_unit_id,
'job_name': job_name,
'job_data': data,
},
)

def list_experiment_ids(self) -> List[int]:
"""Lists all the experiment ids from local database."""
query = text('SELECT experiment_id FROM experiment')
rows = self.engine.execute(query)
return [r['experiment_id'] for r in rows]
with self.engine.connect() as connection:
rows = connection.execute(query).scalars().all()
return rows

def get_experiment(self, experiment_id: int) -> ExperimentResult:
"""Gets an experiment from local database."""
query = text(
'SELECT experiment_title FROM experiment '
'WHERE experiment_id=:experiment_id'
)
rows = self.engine.execute(query, experiment_id=experiment_id)
title = None
for r in rows:
title = r['experiment_title']
break
with self.engine.connect() as connection:
title = connection.execute(
query, {'experiment_id': experiment_id}
).scalar_one_or_none()
if title is None:
raise ValueError(f"Experiment Id {experiment_id} doesn't exist.")
return ExperimentResult(
Expand All @@ -321,8 +332,16 @@ def list_work_units(self, experiment_id: int) -> List[WorkUnitResult]:
query = text(
'SELECT work_unit_id FROM work_unit WHERE experiment_id=:experiment_id'
)
rows = self.engine.execute(query, experiment_id=experiment_id)
return [self.get_work_unit(experiment_id, r['work_unit_id']) for r in rows]
with self.engine.connect() as connection:
work_unit_ids = (
connection.execute(query, {'experiment_id': experiment_id})
.scalars()
.all()
)
return [
self.get_work_unit(experiment_id, work_unit_id)
for work_unit_id in work_unit_ids
]

def get_work_unit(
self, experiment_id: int, work_unit_id: int
Expand All @@ -333,13 +352,16 @@ def get_work_unit(
'WHERE experiment_id=:experiment_id '
'AND work_unit_id=:work_unit_id'
)
rows = self.engine.execute(
query, experiment_id=experiment_id, work_unit_id=work_unit_id
)
jobs = {}
for r in rows:
job = data_pb2.Job()
jobs[r['job_name']] = text_format.Parse(r['job_data'], job)
with self.engine.connect() as connection:
rows = connection.execute(
query,
{'experiment_id': experiment_id, 'work_unit_id': work_unit_id},
)
jobs = {}
for r in rows:
job = data_pb2.Job()
text_format.Parse(r.job_data, job)
jobs[r.job_name] = job
return WorkUnitResult(work_unit_id, jobs)


Expand Down
47 changes: 47 additions & 0 deletions xmanager/xm_local/storage/database_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Tests for xmanager.xm_local.storage.database."""

import unittest
from unittest import mock
import os
import tempfile
from xmanager.xm_local.storage import database as db_module
from xmanager.xm_local import experiment as local_experiment

class DatabaseTest(unittest.TestCase):

def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.db_path = os.path.join(self.temp_dir.name, 'db.sqlite')

settings = db_module.SqlConnectionSettings(backend='sqlite', db_name=self.db_path)

self.database = db_module.Database(db_module.SqliteConnector, settings)

self.patcher = mock.patch('xmanager.xm_local.experiment.database.database', return_value=self.database)
self.mock_db = self.patcher.start()

def tearDown(self):
self.patcher.stop()
self.temp_dir.cleanup()

def test_create_experiment(self):
with local_experiment.create_experiment(experiment_title='test_experiment_1') as experiment:
self.assertIsNotNone(experiment.experiment_id)

with self.database.engine.connect() as connection:
result = connection.execute(db_module.text("SELECT * FROM experiment"))
rows = result.all()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].experiment_title, 'test_experiment_1')

with local_experiment.create_experiment(experiment_title='test_experiment_2') as experiment:
self.assertIsNotNone(experiment.experiment_id)

with self.database.engine.connect() as connection:
result = connection.execute(db_module.text("SELECT * FROM experiment"))
rows = result.all()
self.assertEqual(len(rows), 2)
self.assertEqual(rows[1].experiment_title, 'test_experiment_2')

if __name__ == '__main__':
unittest.main()