diff --git a/setup.py b/setup.py index 5052130..d045ec6 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ python_requires='>=3.10', install_requires=[ 'absl-py', - 'alembic==1.4.3', + 'alembic==1.16.5', 'async_generator', 'attrs', 'cloud-sql-python-connector', @@ -46,7 +46,7 @@ 'immutabledict', 'kubernetes', 'pyyaml', - 'sqlalchemy==1.2.19', + 'sqlalchemy==2.0.43', 'sqlparse', 'termcolor', ], diff --git a/xmanager/xm_local/storage/alembic_migration_test.py b/xmanager/xm_local/storage/alembic_migration_test.py new file mode 100644 index 0000000..4325fd2 --- /dev/null +++ b/xmanager/xm_local/storage/alembic_migration_test.py @@ -0,0 +1,57 @@ +import unittest +import tempfile +import os +import sqlalchemy +from sqlalchemy import inspect +from alembic.config import Config +from alembic import command +from alembic import script as alembic_script + +class AlembicMigrationTest(unittest.TestCase): + + def setUp(self): + self.temp_db_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.temp_db_file.close() + self.db_url = f"sqlite:///{self.temp_db_file.name}" + + alembic_dir = os.path.join( + os.path.dirname(__file__), 'alembic' + ) + self.alembic_cfg = Config(os.path.join(alembic_dir, 'alembic.ini')) + self.alembic_cfg.set_main_option("script_location", alembic_dir) + self.alembic_cfg.set_main_option("sqlalchemy.url", self.db_url) + + self.engine = sqlalchemy.create_engine(self.db_url) + + def tearDown(self): + os.unlink(self.temp_db_file.name) + + def test_migrations_upgrade_and_downgrade(self): + with self.engine.connect() as connection: + inspector = inspect(connection) + self.assertEqual(inspector.get_table_names(), []) + + print(f"Upgrading database to head: {self.db_url}") + command.upgrade(self.alembic_cfg, "head") + + # Verify schema after upgrade + with self.engine.connect() as connection: + inspector = inspect(connection) + expected_tables = ['experiment', 'work_unit', 'job'] # Example tables from your existing migration + actual_tables = inspector.get_table_names() + + self.assertGreaterEqual(len(actual_tables), len(expected_tables), + "Not all expected tables were created during upgrade.") + for table in expected_tables: + self.assertIn(table, actual_tables, f"Table '{table}' not found after upgrade.") + + experiment_columns = [c['name'] for c in inspector.get_columns('experiment')] + self.assertIn('experiment_id', experiment_columns) + self.assertIn('experiment_title', experiment_columns) + + print(f"Attempting to downgrade database to base: {self.db_url}") + with self.assertRaises(RuntimeError): + command.downgrade(self.alembic_cfg, "base") + +if __name__ == '__main__': + unittest.main() diff --git a/xmanager/xm_local/storage/database.py b/xmanager/xm_local/storage/database.py index a103afc..652a3a6 100644 --- a/xmanager/xm_local/storage/database.py +++ b/xmanager/xm_local/storage/database.py @@ -89,13 +89,14 @@ def create_engine(settings: SqlConnectionSettings) -> Engine: f'+{settings.driver}' if settings.driver else '' ) - url = sqlalchemy.engine.url.URL( + url = sqlalchemy.URL.create( drivername=driver_name, username=settings.username, password=settings.password, host=settings.host, port=settings.port, database=settings.db_name, + query={}, ) return sqlalchemy.engine.create_engine(url) @@ -115,7 +116,9 @@ def create_engine(settings: SqlConnectionSettings) -> Engine: "Can't use SqliteConnector with a backendother than `sqlite`" ) - if not os.path.isdir(os.path.dirname(settings.db_name)): + if settings.db_name != ':memory:' and not os.path.isdir( + os.path.dirname(settings.db_name) + ): os.makedirs(os.path.dirname(settings.db_name)) return GenericSqlConnector.create_engine(settings) @@ -154,7 +157,7 @@ def get_connection(): password=settings.password, db=settings.db_name) - url = sqlalchemy.engine.url.URL(drivername=f'{settings.backend}+{driver}', + url = sqlalchemy.URL.create(drivername=f'{settings.backend}+{driver}', host='localhost') return sqlalchemy.create_engine(url, creator=get_connection) @@ -167,9 +170,6 @@ def __init__( ): self.settings = settings self.engine: Engine = connector.create_engine(settings) - # https://github.com/sqlalchemy/sqlalchemy/issues/5645 - # TODO: Remove this line after using sqlalchemy>=1.14. - self.engine.dialect.description_encoding = None storage_dir = os.path.dirname(__file__) alembic_ini_path = os.path.join(storage_dir, 'alembic.ini') @@ -216,9 +216,7 @@ def maybe_migrate_database_version(self): """Enforces the latest version of the database to be used.""" db_version = self.database_version() with self.engine.connect() as connection: - legacy_sqlite_db = self.engine.dialect.has_table( - connection, 'VersionHistory' - ) + legacy_sqlite_db = sqlalchemy.inspect(connection).has_table('VersionHistory') need_to_update = ( db_version != self.latest_version_available() and db_version @@ -239,36 +237,46 @@ def insert_experiment( 'INSERT INTO experiment (experiment_id, experiment_title) ' 'VALUES (:experiment_id, :experiment_title)' ) - self.engine.execute( - query, experiment_id=experiment_id, experiment_title=experiment_title - ) + with self.engine.begin() as connection: + connection.execute( + query, + { + 'experiment_id': experiment_id, + 'experiment_title': experiment_title, + }, + ) def insert_work_unit(self, experiment_id: int, work_unit_id: int) -> None: query = text( 'INSERT INTO work_unit (experiment_id, work_unit_id) ' 'VALUES (:experiment_id, :work_unit_id)' ) - self.engine.execute( - query, experiment_id=experiment_id, work_unit_id=work_unit_id - ) + with self.engine.begin() as connection: + connection.execute( + query, + {'experiment_id': experiment_id, 'work_unit_id': work_unit_id}, + ) def insert_vertex_job( self, experiment_id: int, work_unit_id: int, vertex_job_id: str ) -> None: job = data_pb2.Job(caip=data_pb2.AIPlatformJob(resource_name=vertex_job_id)) - data = text_format.MessageToBytes(job) + data = text_format.MessageToString(job) query = text( 'INSERT INTO ' 'job (experiment_id, work_unit_id, job_name, job_data) ' 'VALUES (:experiment_id, :work_unit_id, :job_name, :job_data)' ) - self.engine.execute( - query, - experiment_id=experiment_id, - work_unit_id=work_unit_id, - job_name=vertex_job_id, - job_data=data, - ) + with self.engine.begin() as connection: + connection.execute( + query, + { + 'experiment_id': experiment_id, + 'work_unit_id': work_unit_id, + 'job_name': vertex_job_id, + 'job_data': data, + }, + ) def insert_kubernetes_job( self, experiment_id: int, work_unit_id: int, namespace: str, job_name: str @@ -285,19 +293,23 @@ def insert_kubernetes_job( 'job (experiment_id, work_unit_id, job_name, job_data) ' 'VALUES (:experiment_id, :work_unit_id, :job_name, :job_data)' ) - self.engine.execute( - query, - experiment_id=experiment_id, - work_unit_id=work_unit_id, - job_name=job_name, - job_data=data, - ) + with self.engine.begin() as connection: + connection.execute( + query, + { + 'experiment_id': experiment_id, + 'work_unit_id': work_unit_id, + 'job_name': job_name, + 'job_data': data, + }, + ) def list_experiment_ids(self) -> List[int]: """Lists all the experiment ids from local database.""" query = text('SELECT experiment_id FROM experiment') - rows = self.engine.execute(query) - return [r['experiment_id'] for r in rows] + with self.engine.connect() as connection: + rows = connection.execute(query).scalars().all() + return rows def get_experiment(self, experiment_id: int) -> ExperimentResult: """Gets an experiment from local database.""" @@ -305,11 +317,10 @@ def get_experiment(self, experiment_id: int) -> ExperimentResult: 'SELECT experiment_title FROM experiment ' 'WHERE experiment_id=:experiment_id' ) - rows = self.engine.execute(query, experiment_id=experiment_id) - title = None - for r in rows: - title = r['experiment_title'] - break + with self.engine.connect() as connection: + title = connection.execute( + query, {'experiment_id': experiment_id} + ).scalar_one_or_none() if title is None: raise ValueError(f"Experiment Id {experiment_id} doesn't exist.") return ExperimentResult( @@ -321,8 +332,16 @@ def list_work_units(self, experiment_id: int) -> List[WorkUnitResult]: query = text( 'SELECT work_unit_id FROM work_unit WHERE experiment_id=:experiment_id' ) - rows = self.engine.execute(query, experiment_id=experiment_id) - return [self.get_work_unit(experiment_id, r['work_unit_id']) for r in rows] + with self.engine.connect() as connection: + work_unit_ids = ( + connection.execute(query, {'experiment_id': experiment_id}) + .scalars() + .all() + ) + return [ + self.get_work_unit(experiment_id, work_unit_id) + for work_unit_id in work_unit_ids + ] def get_work_unit( self, experiment_id: int, work_unit_id: int @@ -333,13 +352,16 @@ def get_work_unit( 'WHERE experiment_id=:experiment_id ' 'AND work_unit_id=:work_unit_id' ) - rows = self.engine.execute( - query, experiment_id=experiment_id, work_unit_id=work_unit_id - ) - jobs = {} - for r in rows: - job = data_pb2.Job() - jobs[r['job_name']] = text_format.Parse(r['job_data'], job) + with self.engine.connect() as connection: + rows = connection.execute( + query, + {'experiment_id': experiment_id, 'work_unit_id': work_unit_id}, + ) + jobs = {} + for r in rows: + job = data_pb2.Job() + text_format.Parse(r.job_data, job) + jobs[r.job_name] = job return WorkUnitResult(work_unit_id, jobs) diff --git a/xmanager/xm_local/storage/database_test.py b/xmanager/xm_local/storage/database_test.py new file mode 100644 index 0000000..1df0c87 --- /dev/null +++ b/xmanager/xm_local/storage/database_test.py @@ -0,0 +1,47 @@ +"""Tests for xmanager.xm_local.storage.database.""" + +import unittest +from unittest import mock +import os +import tempfile +from xmanager.xm_local.storage import database as db_module +from xmanager.xm_local import experiment as local_experiment + +class DatabaseTest(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.db_path = os.path.join(self.temp_dir.name, 'db.sqlite') + + settings = db_module.SqlConnectionSettings(backend='sqlite', db_name=self.db_path) + + self.database = db_module.Database(db_module.SqliteConnector, settings) + + self.patcher = mock.patch('xmanager.xm_local.experiment.database.database', return_value=self.database) + self.mock_db = self.patcher.start() + + def tearDown(self): + self.patcher.stop() + self.temp_dir.cleanup() + + def test_create_experiment(self): + with local_experiment.create_experiment(experiment_title='test_experiment_1') as experiment: + self.assertIsNotNone(experiment.experiment_id) + + with self.database.engine.connect() as connection: + result = connection.execute(db_module.text("SELECT * FROM experiment")) + rows = result.all() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0].experiment_title, 'test_experiment_1') + + with local_experiment.create_experiment(experiment_title='test_experiment_2') as experiment: + self.assertIsNotNone(experiment.experiment_id) + + with self.database.engine.connect() as connection: + result = connection.execute(db_module.text("SELECT * FROM experiment")) + rows = result.all() + self.assertEqual(len(rows), 2) + self.assertEqual(rows[1].experiment_title, 'test_experiment_2') + +if __name__ == '__main__': + unittest.main()