From b3bef488c17eb655414548a6f573e129f597c07d Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sun, 7 Sep 2025 09:42:06 -0400 Subject: [PATCH 1/5] Fix `job_data` typing in `insert_vertex_job` --- xmanager/xm_local/storage/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xmanager/xm_local/storage/database.py b/xmanager/xm_local/storage/database.py index a103afc..1168fdf 100644 --- a/xmanager/xm_local/storage/database.py +++ b/xmanager/xm_local/storage/database.py @@ -256,7 +256,7 @@ def insert_vertex_job( self, experiment_id: int, work_unit_id: int, vertex_job_id: str ) -> None: job = data_pb2.Job(caip=data_pb2.AIPlatformJob(resource_name=vertex_job_id)) - data = text_format.MessageToBytes(job) + data = text_format.MessageToString(job) query = text( 'INSERT INTO ' 'job (experiment_id, work_unit_id, job_name, job_data) ' From 5bfdec25fcdc5d307e5b7a2e499d908fb4f63c18 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sun, 7 Sep 2025 09:41:50 -0400 Subject: [PATCH 2/5] Upgrade `alembic==1.16.5` and `sqlalchemy==1.4.54` --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 5052130..a9bc3f0 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ python_requires='>=3.10', install_requires=[ 'absl-py', - 'alembic==1.4.3', + 'alembic==1.16.5', 'async_generator', 'attrs', 'cloud-sql-python-connector', @@ -46,7 +46,7 @@ 'immutabledict', 'kubernetes', 'pyyaml', - 'sqlalchemy==1.2.19', + 'sqlalchemy==1.4.54', 'sqlparse', 'termcolor', ], From e7d9652c6c3846a02e1563c19aa0b3175b172ff7 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sun, 7 Sep 2025 09:44:56 -0400 Subject: [PATCH 3/5] Add simple test for database --- xmanager/xm_local/storage/database.py | 4 +- xmanager/xm_local/storage/database_test.py | 43 ++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 xmanager/xm_local/storage/database_test.py diff --git a/xmanager/xm_local/storage/database.py b/xmanager/xm_local/storage/database.py index 1168fdf..e0aaadf 100644 --- a/xmanager/xm_local/storage/database.py +++ b/xmanager/xm_local/storage/database.py @@ -115,7 +115,9 @@ def create_engine(settings: SqlConnectionSettings) -> Engine: "Can't use SqliteConnector with a backendother than `sqlite`" ) - if not os.path.isdir(os.path.dirname(settings.db_name)): + if settings.db_name != ':memory:' and not os.path.isdir( + os.path.dirname(settings.db_name) + ): os.makedirs(os.path.dirname(settings.db_name)) return GenericSqlConnector.create_engine(settings) diff --git a/xmanager/xm_local/storage/database_test.py b/xmanager/xm_local/storage/database_test.py new file mode 100644 index 0000000..059ffda --- /dev/null +++ b/xmanager/xm_local/storage/database_test.py @@ -0,0 +1,43 @@ +"""Tests for xmanager.xm_local.storage.database.""" + +import unittest + +from google.protobuf import text_format +from xmanager.generated import data_pb2 +from xmanager.xm_local.storage import database + + +class DatabaseTest(unittest.TestCase): + + def setUp(self): + super().setUp() + # Use an in-memory SQLite database for testing. + # The Database class will automatically create the schema. + settings = database.sqlite_settings(db_file=':memory:') + self.db = database.Database(database.SqliteConnector, settings) + + def test_insert_and_get_vertex_job(self): + experiment_id = 1 + work_unit_id = 1 + vertex_job_id = 'projects/p/locations/l/customJobs/123' + + self.db.insert_experiment(experiment_id, 'test_experiment') + self.db.insert_work_unit(experiment_id, work_unit_id) + + self.db.insert_vertex_job(experiment_id, work_unit_id, vertex_job_id) + + work_unit = self.db.get_work_unit(experiment_id, work_unit_id) + self.assertIn(vertex_job_id, work_unit.jobs) + + retrieved_job = work_unit.jobs[vertex_job_id] + self.assertEqual(retrieved_job.caip.resource_name, vertex_job_id) + + expected_job = data_pb2.Job() + text_format.Parse( + f'caip: {{ resource_name: "{vertex_job_id}" }}', expected_job + ) + self.assertEqual(retrieved_job, expected_job) + + +if __name__ == '__main__': + unittest.main() From 45b0e8b4fc5cbb8ad7c3cec8a36a1f0a3a356cec Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sun, 7 Sep 2025 09:49:11 -0400 Subject: [PATCH 4/5] Fix code to work with `sqlalchemy==1.4.54` --- .../storage/alembic_migration_test.py | 57 ++++++++++ xmanager/xm_local/storage/database.py | 104 +++++++++++------- xmanager/xm_local/storage/database_test.py | 60 +++++----- 3 files changed, 151 insertions(+), 70 deletions(-) create mode 100644 xmanager/xm_local/storage/alembic_migration_test.py diff --git a/xmanager/xm_local/storage/alembic_migration_test.py b/xmanager/xm_local/storage/alembic_migration_test.py new file mode 100644 index 0000000..4325fd2 --- /dev/null +++ b/xmanager/xm_local/storage/alembic_migration_test.py @@ -0,0 +1,57 @@ +import unittest +import tempfile +import os +import sqlalchemy +from sqlalchemy import inspect +from alembic.config import Config +from alembic import command +from alembic import script as alembic_script + +class AlembicMigrationTest(unittest.TestCase): + + def setUp(self): + self.temp_db_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.temp_db_file.close() + self.db_url = f"sqlite:///{self.temp_db_file.name}" + + alembic_dir = os.path.join( + os.path.dirname(__file__), 'alembic' + ) + self.alembic_cfg = Config(os.path.join(alembic_dir, 'alembic.ini')) + self.alembic_cfg.set_main_option("script_location", alembic_dir) + self.alembic_cfg.set_main_option("sqlalchemy.url", self.db_url) + + self.engine = sqlalchemy.create_engine(self.db_url) + + def tearDown(self): + os.unlink(self.temp_db_file.name) + + def test_migrations_upgrade_and_downgrade(self): + with self.engine.connect() as connection: + inspector = inspect(connection) + self.assertEqual(inspector.get_table_names(), []) + + print(f"Upgrading database to head: {self.db_url}") + command.upgrade(self.alembic_cfg, "head") + + # Verify schema after upgrade + with self.engine.connect() as connection: + inspector = inspect(connection) + expected_tables = ['experiment', 'work_unit', 'job'] # Example tables from your existing migration + actual_tables = inspector.get_table_names() + + self.assertGreaterEqual(len(actual_tables), len(expected_tables), + "Not all expected tables were created during upgrade.") + for table in expected_tables: + self.assertIn(table, actual_tables, f"Table '{table}' not found after upgrade.") + + experiment_columns = [c['name'] for c in inspector.get_columns('experiment')] + self.assertIn('experiment_id', experiment_columns) + self.assertIn('experiment_title', experiment_columns) + + print(f"Attempting to downgrade database to base: {self.db_url}") + with self.assertRaises(RuntimeError): + command.downgrade(self.alembic_cfg, "base") + +if __name__ == '__main__': + unittest.main() diff --git a/xmanager/xm_local/storage/database.py b/xmanager/xm_local/storage/database.py index e0aaadf..6195a02 100644 --- a/xmanager/xm_local/storage/database.py +++ b/xmanager/xm_local/storage/database.py @@ -96,6 +96,7 @@ def create_engine(settings: SqlConnectionSettings) -> Engine: host=settings.host, port=settings.port, database=settings.db_name, + query={}, ) return sqlalchemy.engine.create_engine(url) @@ -169,9 +170,6 @@ def __init__( ): self.settings = settings self.engine: Engine = connector.create_engine(settings) - # https://github.com/sqlalchemy/sqlalchemy/issues/5645 - # TODO: Remove this line after using sqlalchemy>=1.14. - self.engine.dialect.description_encoding = None storage_dir = os.path.dirname(__file__) alembic_ini_path = os.path.join(storage_dir, 'alembic.ini') @@ -218,9 +216,7 @@ def maybe_migrate_database_version(self): """Enforces the latest version of the database to be used.""" db_version = self.database_version() with self.engine.connect() as connection: - legacy_sqlite_db = self.engine.dialect.has_table( - connection, 'VersionHistory' - ) + legacy_sqlite_db = sqlalchemy.inspect(connection).has_table('VersionHistory') need_to_update = ( db_version != self.latest_version_available() and db_version @@ -241,18 +237,25 @@ def insert_experiment( 'INSERT INTO experiment (experiment_id, experiment_title) ' 'VALUES (:experiment_id, :experiment_title)' ) - self.engine.execute( - query, experiment_id=experiment_id, experiment_title=experiment_title - ) + with self.engine.begin() as connection: + connection.execute( + query, + { + 'experiment_id': experiment_id, + 'experiment_title': experiment_title, + }, + ) def insert_work_unit(self, experiment_id: int, work_unit_id: int) -> None: query = text( 'INSERT INTO work_unit (experiment_id, work_unit_id) ' 'VALUES (:experiment_id, :work_unit_id)' ) - self.engine.execute( - query, experiment_id=experiment_id, work_unit_id=work_unit_id - ) + with self.engine.begin() as connection: + connection.execute( + query, + {'experiment_id': experiment_id, 'work_unit_id': work_unit_id}, + ) def insert_vertex_job( self, experiment_id: int, work_unit_id: int, vertex_job_id: str @@ -264,13 +267,16 @@ def insert_vertex_job( 'job (experiment_id, work_unit_id, job_name, job_data) ' 'VALUES (:experiment_id, :work_unit_id, :job_name, :job_data)' ) - self.engine.execute( - query, - experiment_id=experiment_id, - work_unit_id=work_unit_id, - job_name=vertex_job_id, - job_data=data, - ) + with self.engine.begin() as connection: + connection.execute( + query, + { + 'experiment_id': experiment_id, + 'work_unit_id': work_unit_id, + 'job_name': vertex_job_id, + 'job_data': data, + }, + ) def insert_kubernetes_job( self, experiment_id: int, work_unit_id: int, namespace: str, job_name: str @@ -287,19 +293,23 @@ def insert_kubernetes_job( 'job (experiment_id, work_unit_id, job_name, job_data) ' 'VALUES (:experiment_id, :work_unit_id, :job_name, :job_data)' ) - self.engine.execute( - query, - experiment_id=experiment_id, - work_unit_id=work_unit_id, - job_name=job_name, - job_data=data, - ) + with self.engine.begin() as connection: + connection.execute( + query, + { + 'experiment_id': experiment_id, + 'work_unit_id': work_unit_id, + 'job_name': job_name, + 'job_data': data, + }, + ) def list_experiment_ids(self) -> List[int]: """Lists all the experiment ids from local database.""" query = text('SELECT experiment_id FROM experiment') - rows = self.engine.execute(query) - return [r['experiment_id'] for r in rows] + with self.engine.connect() as connection: + rows = connection.execute(query).scalars().all() + return rows def get_experiment(self, experiment_id: int) -> ExperimentResult: """Gets an experiment from local database.""" @@ -307,11 +317,10 @@ def get_experiment(self, experiment_id: int) -> ExperimentResult: 'SELECT experiment_title FROM experiment ' 'WHERE experiment_id=:experiment_id' ) - rows = self.engine.execute(query, experiment_id=experiment_id) - title = None - for r in rows: - title = r['experiment_title'] - break + with self.engine.connect() as connection: + title = connection.execute( + query, {'experiment_id': experiment_id} + ).scalar_one_or_none() if title is None: raise ValueError(f"Experiment Id {experiment_id} doesn't exist.") return ExperimentResult( @@ -323,8 +332,16 @@ def list_work_units(self, experiment_id: int) -> List[WorkUnitResult]: query = text( 'SELECT work_unit_id FROM work_unit WHERE experiment_id=:experiment_id' ) - rows = self.engine.execute(query, experiment_id=experiment_id) - return [self.get_work_unit(experiment_id, r['work_unit_id']) for r in rows] + with self.engine.connect() as connection: + work_unit_ids = ( + connection.execute(query, {'experiment_id': experiment_id}) + .scalars() + .all() + ) + return [ + self.get_work_unit(experiment_id, work_unit_id) + for work_unit_id in work_unit_ids + ] def get_work_unit( self, experiment_id: int, work_unit_id: int @@ -335,13 +352,16 @@ def get_work_unit( 'WHERE experiment_id=:experiment_id ' 'AND work_unit_id=:work_unit_id' ) - rows = self.engine.execute( - query, experiment_id=experiment_id, work_unit_id=work_unit_id - ) - jobs = {} - for r in rows: - job = data_pb2.Job() - jobs[r['job_name']] = text_format.Parse(r['job_data'], job) + with self.engine.connect() as connection: + rows = connection.execute( + query, + {'experiment_id': experiment_id, 'work_unit_id': work_unit_id}, + ) + jobs = {} + for r in rows: + job = data_pb2.Job() + text_format.Parse(r.job_data, job) + jobs[r.job_name] = job return WorkUnitResult(work_unit_id, jobs) diff --git a/xmanager/xm_local/storage/database_test.py b/xmanager/xm_local/storage/database_test.py index 059ffda..1df0c87 100644 --- a/xmanager/xm_local/storage/database_test.py +++ b/xmanager/xm_local/storage/database_test.py @@ -1,43 +1,47 @@ """Tests for xmanager.xm_local.storage.database.""" import unittest - -from google.protobuf import text_format -from xmanager.generated import data_pb2 -from xmanager.xm_local.storage import database - +from unittest import mock +import os +import tempfile +from xmanager.xm_local.storage import database as db_module +from xmanager.xm_local import experiment as local_experiment class DatabaseTest(unittest.TestCase): - def setUp(self): - super().setUp() - # Use an in-memory SQLite database for testing. - # The Database class will automatically create the schema. - settings = database.sqlite_settings(db_file=':memory:') - self.db = database.Database(database.SqliteConnector, settings) + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = os.path.join(self.temp_dir.name, 'db.sqlite') + + settings = db_module.SqlConnectionSettings(backend='sqlite', db_name=self.db_path) - def test_insert_and_get_vertex_job(self): - experiment_id = 1 - work_unit_id = 1 - vertex_job_id = 'projects/p/locations/l/customJobs/123' + self.database = db_module.Database(db_module.SqliteConnector, settings) - self.db.insert_experiment(experiment_id, 'test_experiment') - self.db.insert_work_unit(experiment_id, work_unit_id) + self.patcher = mock.patch('xmanager.xm_local.experiment.database.database', return_value=self.database) + self.mock_db = self.patcher.start() - self.db.insert_vertex_job(experiment_id, work_unit_id, vertex_job_id) + def tearDown(self): + self.patcher.stop() + self.temp_dir.cleanup() - work_unit = self.db.get_work_unit(experiment_id, work_unit_id) - self.assertIn(vertex_job_id, work_unit.jobs) + def test_create_experiment(self): + with local_experiment.create_experiment(experiment_title='test_experiment_1') as experiment: + self.assertIsNotNone(experiment.experiment_id) - retrieved_job = work_unit.jobs[vertex_job_id] - self.assertEqual(retrieved_job.caip.resource_name, vertex_job_id) + with self.database.engine.connect() as connection: + result = connection.execute(db_module.text("SELECT * FROM experiment")) + rows = result.all() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0].experiment_title, 'test_experiment_1') - expected_job = data_pb2.Job() - text_format.Parse( - f'caip: {{ resource_name: "{vertex_job_id}" }}', expected_job - ) - self.assertEqual(retrieved_job, expected_job) + with local_experiment.create_experiment(experiment_title='test_experiment_2') as experiment: + self.assertIsNotNone(experiment.experiment_id) + with self.database.engine.connect() as connection: + result = connection.execute(db_module.text("SELECT * FROM experiment")) + rows = result.all() + self.assertEqual(len(rows), 2) + self.assertEqual(rows[1].experiment_title, 'test_experiment_2') if __name__ == '__main__': - unittest.main() + unittest.main() From 09291e821a7e9a518542e0a568495e8117abc3c9 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sat, 6 Sep 2025 21:03:23 -0400 Subject: [PATCH 5/5] Upgrade `sqlalchemy==2.0.43` --- setup.py | 2 +- xmanager/xm_local/storage/database.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index a9bc3f0..d045ec6 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ 'immutabledict', 'kubernetes', 'pyyaml', - 'sqlalchemy==1.4.54', + 'sqlalchemy==2.0.43', 'sqlparse', 'termcolor', ], diff --git a/xmanager/xm_local/storage/database.py b/xmanager/xm_local/storage/database.py index 6195a02..652a3a6 100644 --- a/xmanager/xm_local/storage/database.py +++ b/xmanager/xm_local/storage/database.py @@ -89,7 +89,7 @@ def create_engine(settings: SqlConnectionSettings) -> Engine: f'+{settings.driver}' if settings.driver else '' ) - url = sqlalchemy.engine.url.URL( + url = sqlalchemy.URL.create( drivername=driver_name, username=settings.username, password=settings.password, @@ -157,7 +157,7 @@ def get_connection(): password=settings.password, db=settings.db_name) - url = sqlalchemy.engine.url.URL(drivername=f'{settings.backend}+{driver}', + url = sqlalchemy.URL.create(drivername=f'{settings.backend}+{driver}', host='localhost') return sqlalchemy.create_engine(url, creator=get_connection)