diff --git a/litemigration-cli.py b/litemigration-cli.py new file mode 100644 index 0000000..41a247a --- /dev/null +++ b/litemigration-cli.py @@ -0,0 +1,83 @@ +#!/usr/bin/python3 + +import argparse +import importlib + +from litemigration.database import Database + +def check_settings() -> dict: + """ + Looks for database.py file with the list of migrations + List of migrations variable should be migration_changes + Returns the module file + """ + try: + mod = importlib.import_module('database') + return { + 'database': mod.db, + 'changes': mod.MIGRATION_CHANGES + } + + except ModuleNotFoundError as error: + print(f'Unable to find database file: {error}') + exit() + except AttributeError as error: + print(f'Unable to find migration_changes: {error}') + exit() + + +def show_migrations(params): + """ + Show the status of the current migrations + """ + settings = check_settings() + db = settings['database'] + changes = settings['changes'] + table = db.show_migrations(changes) + print(table.table) + + +def migration(params): + """ + * Add new migrations + * Reverse existing migrations + """ + settings = check_settings() + db: Database = settings['database'] + changes = settings['changes'] + if params.direction == 'up': + db.add_migrations(changes) + elif params.direction == 'down' and params.dry: + if params.version == 0: + print("migration version needed") + exit() + else: + table = db.dry_run_reverse(params.version, changes) + print(table.table) + elif params.direction == 'down': + if params.version == 0: + print("migration version needed") + exit() + else: + table = db.reverse_migrations(params.version, changes) + print(table.table) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Manage migrations') + subparsers = parser.add_subparsers(help='sub-command help') + + show_migration = subparsers.add_parser('showmigrations', help='show current migrations') + show_migration.set_defaults(func=show_migrations) + + migrate = subparsers.add_parser('migrate', help='Run forward or reverse migrations') + migrate.add_argument('direction', choices=['up', 'down'], help='forward [up] or reverse migration [down]') + migrate.add_argument('version', help='Version number at which to stop the migration', nargs='?', type=int, default=0) + migrate.add_argument('--dry', help='Show migrations to be reversed or applied', action='store_const', const=True) + migrate.set_defaults(func=migration) + + args = parser.parse_args() + args.func(args) + + + diff --git a/litemigration/database.py b/litemigration/database.py index 39ec139..78a79bf 100755 --- a/litemigration/database.py +++ b/litemigration/database.py @@ -1,135 +1,270 @@ -#!/usr/bin/python3 - -import datetime as dt import logging -import sys -from typing import List, Tuple +import sqlite3 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from typing import List + +from terminaltables import AsciiTable +from colorclass import Color + +try: + import psycopg2 +except ImportError: + pass + log = logging.getLogger(__name__) -class Database: - "Create migration control" - def __init__(self, db_type, host=None, port=None, user=None, - password=None, database=None): - self.db_type = db_type - self.host = host - self.port = port - self.user = user - self.password = password - self.database = database - self.details = "" - self.connect = self._get_connector() - self.cursor = self.connect.cursor() - - def _get_connector(self): - """' - Return database connection from specified database from user - """ - supported_databases = {'postgresql': self._postgresql, - 'sqlite': self._sqlite} - try: - connect = supported_databases[self.db_type]() - return connect - except KeyError: - log.critical("Unknown database or not supported") - exit() +@dataclass +class Migration: + version: int + up: str + down: str + + +class MigrationException(Exception): + pass + + +class Database(ABC): + def __init__(self): + self.connect = self.connection() + + @abstractmethod + def connection(self): + pass - def _get_initail_sql_migration(self) -> Tuple[str, str]: + @abstractmethod + def initialize(self): + pass + + @abstractmethod + def add_migrations(self, change: List[Migration]): + pass + + @abstractmethod + def reverse_migrations(self, version: int, change: List[Migration]): + pass + + def show_migrations(self, change: List[Migration]): + version = 0 + data = [] + data.append(['Applied', 'Version', 'Date']) + + cur = self.connect.cursor() + cur.execute('SELECT version, date FROM migration') + applied = cur.fetchall() + for m in applied: + data.append([Color('{autogreen}Yes{/autogreen}'), m[0], m[1]]) + version = m[0] + + for m in change: + if version < m.version: + data.append([Color('{autored}No{/autored}'), m.version]) + + table = AsciiTable(data) + return table + + def dry_run_reverse(self, version: int, change: List[Migration]): """ - Return 2 sql commands: - 1) Create migration - 2) Insert into migration table + Show which changes will be reversed (--dry) """ - sqlite_create = ("CREATE TABLE migration(" - 'id INTEGER PRIMARY KEY NOT NULL,' - 'version INTEGER UNIQUE NOT NULL,' - 'date TIMESTAMP NOT NULL)', - "INSERT INTO migration(version,date) VALUES(0,?)") - - pg_create = ("CREATE TABLE migration(" - 'id SERIAL PRIMARY KEY,' - 'version INTEGER NOT NULL,' - 'date DATE NOT NULL)', - "INSERT INTO migration(version,date) VALUES(0,%s)") - - all_sql = {'postgresql': pg_create, - 'sqlite': sqlite_create} - - return all_sql[self.db_type] - - def initialise(self): - "Create new database and add initial migration" - create_table, initial_insert = self._get_initail_sql_migration() + data = [] + data.append(['Reversed (Dummy)', 'Version']) + + cur = self.connect.cursor() + cur.execute('SELECT max(version) from migration') + (max_id,) = cur.fetchone() + if version > max_id: + raise MigrationException('version greater than max version unable to reverse') + + for migration in reversed(change): + if migration.version == version: + break + elif migration.version > version: + data.append([Color('{autogreen}Yes{/autogreen}'), migration.version]) + + table = AsciiTable(data) + return table + + +class SqliteDatabase(Database): + def __init__(self, name): + self.name = name + super().__init__() + + def connection(self): + connect = sqlite3.connect(self.name) + return connect + + def initialize(self): + query = ( + 'CREATE TABLE migration(' + 'id INTEGER PRIMARY KEY NOT NULL,' + 'version INTEGER UNIQUE NOT NULL,' + 'date TIMESTAMP NOT NULL)' + ) + + cur = self.connect.cursor() try: - self.cursor.execute(create_table) - self.cursor.execute(initial_insert, - (dt.datetime.now(),)) + cur.execute(query) + cur.execute("INSERT INTO migration(version,date) VALUES(1,?)", (datetime.now(),)) self.connect.commit() - log.info("Database has been created") - except Exception as e: - log.error("Unable to add migration table") - log.exception(e) - sys.exit() + except sqlite3.OperationalError as error: + log.info(f'Error creating migration table: {error}') + raise MigrationException(error) - def add_schema(self, change_list: List[Tuple[int, str]]): + def add_migrations(self, change: List[Migration]): + cur = self.connect.cursor() + cur.execute('SELECT max(version) from migration') + (max_version,) = cur.fetchone() + for migration in change: + if max_version >= migration.version: + log.info(f'migration {migration.version} already applied') + continue + + if migration.version - max_version != 1: + log.error(f'missing migration version before {migration.version}') + raise MigrationException('missing migration version before {}'.format(migration.version)) + try: + cur.execute(migration.up) + cur.execute("INSERT INTO migration(version,date) VALUES(?,?)", (migration.version, datetime.now())) + self.connect.commit() + print(f'Migration {migration.version} applied....' + Color('{autogreen}Ok{/autogreen}')) + max_version = migration.version + except sqlite3.OperationalError as error: + print(f'Migration {migration.version} applied....' + Color('{autored}Error{/autored}')) + log.error(f'unable to apply migration {migration.version}') + raise MigrationException(error) + + self.connect.close() + + def reverse_migrations(self, version: int, change: List[Migration]): """ - The first migration change should be version 1 + Migration version to revert to from max version. + So if max version is 10 and version choosen is 5. + Version 10 - 6 will be reverted + """ + cur = self.connect.cursor() + cur.execute('SELECT max(version) from migration') + (max_id,) = cur.fetchone() + if version > max_id: + raise MigrationException('version greater than max version unbale to reverse') + + for migration in reversed(change): + if migration.version == version: + break + elif migration.version > version: + cur.execute(migration.down) + cur.execute('DELETE FROM migration where version=?', (migration.version,)) + self.connect.commit() + self.connect.close() + + +class Postgresql(Database): + + def __init__(self, name, user, password, host, port=5432): + self.name = name + self.user = user + self.password = password + self.host = host + self.port = port + super().__init__() + + def connection(self): + """ + Check if psycopg2 is installed if using postgres database """ - if self.db_type == 'postgresql': - insert_sql = "INSERT INTO migration(version,date) VALUES(%s,%s)" - elif self.db_type == 'sqlite': - insert_sql = "INSERT INTO migration(version,date) VALUES(?,?)" - - for change_id, sql_statement in change_list: - self.cursor.execute('SELECT max(version) from migration') - (max_id,) = self.cursor.fetchone() - if max_id >= change_id: - log.info("schema change id {} is smaller than the latest" - "change".format(change_id)) - log.info("or schema change id has already been applied ") - else: - try: - self.cursor.execute(sql_statement) - self.cursor.execute(insert_sql, - (change_id, dt.datetime.now(),)) - self.connect.commit() - log.info("new schema added") - except Exception: - log.error("Unable to add schema {}".format(change_id), - exc_info=True) - sys.exit() - - def _postgresql(self): - "create postgresql connection and return the connection object" try: import psycopg2 - connect = psycopg2.connect(database=self.database, - host=self.host, - user=self.user, - password=self.password, - port=self.port) - return connect except ImportError: - log.error("Unable to find python3 postgresql module") - sys.exit() - except psycopg2.Error as e: - log.error("Unable to connect to postgresql") - log.exception(e) - sys.exit() - except psycopg2.OperationalError as e: - log.exception(e) - sys.exit() - - def _sqlite(self): + raise MigrationException('postgresql driver not installed') + + try: + conn = psycopg2.connect( + host=self.host, + database=self.name, + user=self.user, + password=self.password, + port=self.port + ) + except psycopg2.OperationalError as error: + log.error(f'Unable to connect to database: {error}') + raise MigrationException(f'Unable to connect to database: {error}') + else: + return conn + + def initialize(self): + query = ( + 'CREATE TABLE migration(' + 'id SERIAL PRIMARY KEY NOT NULL,' + 'version INTEGER UNIQUE NOT NULL,' + 'date TIMESTAMP NOT NULL)' + ) + + cur = self.connect.cursor() + try: + cur.execute(query) + cur.execute("INSERT INTO migration(version,date) VALUES(1, %s)", (datetime.now(),)) + self.connect.commit() + except (psycopg2.OperationalError, psycopg2.DatabaseError) as error: + log.info(f'Error creating migration table: {error}') + raise MigrationException(error) + + def add_migrations(self, change: List[Migration]): + cur = self.connect.cursor() + cur.execute('SELECT max(version) from migration') + (max_version,) = cur.fetchone() + for migration in change: + if max_version >= migration.version: + log.info(f'migration {migration.version} already applied') + continue + + if migration.version - max_version != 1: + log.error(f'missing migration version before {migration.version}') + raise MigrationException('missing migration version before {}'.format(migration.version)) + try: + cur.execute(migration.up) + cur.execute("INSERT INTO migration(version,date) VALUES(%s,%s)", (migration.version, datetime.now())) + self.connect.commit() + print(f'Migration {migration.version} applied....' + Color('{autogreen}Ok{/autogreen}')) + max_version = migration.version + except (psycopg2.OperationalError, psycopg2.DatabaseError) as error: + print(f'Migration {migration.version} applied....' + Color('{autored}Error{/autored}')) + log.error(f'unable to apply migration {migration.version}') + raise MigrationException(error) + + self.connect.close() + + def reverse_migrations(self, version: int, change: List[Migration]): """ - Create an sqlite connection and return the connection object + Migration version to revert to from max version. + So if max version is 10 and version chosen is 5. + Version 10 - 6 will be reverted """ - import sqlite3 - try: - connect = sqlite3.connect(self.database) - return connect - except sqlite3.OperationalError: - log.error("unable to connect to sqlite database", - exc_info=True) - sys.exit() + data = [] + data.append(['Reversed', 'Version']) + + cur = self.connect.cursor() + cur.execute('SELECT max(version) from migration') + (max_id,) = cur.fetchone() + if version > max_id: + raise MigrationException('version greater than max version unable to reverse') + + for migration in reversed(change): + if migration.version == version: + break + elif migration.version > version: + try: + cur.execute(migration.down) + cur.execute('DELETE FROM migration where version=%s', (migration.version,)) + self.connect.commit() + data.append([Color('{autogreen}Yes{/autogreen}'), migration.version]) + except (psycopg2.DatabaseError, psycopg2.OperationalError): + data.append([Color('{autored}Error{/autored}'), migration.version]) + self.connect.close() + table = AsciiTable(data) + return table diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d1b0be2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +colorclass==2.2.2 +freezegun==1.2.2 +python-dateutil==2.8.2 +six==1.16.0 +terminaltables==3.1.10 diff --git a/setup.py b/setup.py index 44430d5..3c6dc01 100755 --- a/setup.py +++ b/setup.py @@ -4,11 +4,11 @@ if __name__ == '__main__': setup(name='litemigration', - version='1.1.1', - description='Simple simple module to help modify database changes in sqlite', + version='2.0.0', + description='Super simple module to help modify database changes in sqlite', author='Lunga Mthembu', - author_email='stumenz.complex@gmail.com', - url='https://github.com/stumenz/python3-litemigration', + author_email='midnight.complex@protonmail.com', + url='https://github.com/knightebsuku/python3-litemigration', license='GPL', packages=['litemigration'], ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test.py b/tests/test.py new file mode 100644 index 0000000..7bcb714 --- /dev/null +++ b/tests/test.py @@ -0,0 +1,80 @@ +import unittest +import os +from datetime import datetime +from freezegun import freeze_time + +from colorclass import Color +from litemigration.database import SqliteDatabase, Migration + + +@freeze_time('2022-01-01 00:00:00') +class TestMigration(unittest.TestCase): + + def setUp(self) -> None: + self.db = SqliteDatabase('test.db') + self.db.initialize() + self.migration_changes = [ + Migration( + version=2, + up='CREATE TABLE player(name VARCHAR NOT NULL,score INTEGER)', + down='DROP TABLE player' + ), + Migration( + version=3, + up='INSERT INTO player(name,score) VALUES("User", 10)', + down='DELETE FROM PLAYER where name="User"' + ) + ] + + def test_add_migration(self): + self.db.add_migrations(self.migration_changes) + + cur = self.db.connection().cursor() + cur.execute('SELECT max(version) FROM migration') + (max_id,) = cur.fetchone() + self.assertEqual(max_id, 3) + + def test_show_migrations(self): + data = [ + ['Applied', 'Version', 'Date'], + [Color('{autogreen}Yes{/autogreen}'), 1, str(datetime.now())], + [Color('{autored}No{/autored}'), 2], + [Color('{autored}No{/autored}'), 3] + ] + table = self.db.show_migrations(self.migration_changes) + self.assertEqual(data, table.table_data) + + def test_reverse_migrations(self): + self.db.add_migrations(self.migration_changes) + + self.db.connect = self.db.connection() + cur = self.db.connect.cursor() + cur.execute('SELECT max(version) FROM migration') + (max_id,) = cur.fetchone() + self.assertEqual(max_id, 3) + + self.db.connect = self.db.connection() + self.db.reverse_migrations(1, self.migration_changes) + + cur = self.db.connection().cursor() + cur.execute('SELECT max(version) FROM migration') + (max_id,) = cur.fetchone() + self.assertEqual(max_id, 1) + + def test_dry_reverse_migrations(self): + self.db.add_migrations(self.migration_changes) + data = [ + ['Reverse', 'Version'], + [Color('{autogreen}Yes{/autogreen}'), 3] + ] + # Reverse the last migration Version 3 + self.db.connect = self.db.connection() + table = self.db.dry_run_reverse(2, self.migration_changes) + self.assertEqual(data, table.table_data) + + def tearDown(self) -> None: + os.remove('test.db') + + +if __name__ == '__main__': + unittest.main()