diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 66ca3a1..c7ba80f 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,20 +1,17 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - name: Python application on: push: - branches: [ "main" ] + branches: ["main"] pull_request: - branches: [ "main" ] + branches: ["main"] permissions: contents: read jobs: build: - runs-on: self-hosted # self-hosted # ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: @@ -22,30 +19,92 @@ jobs: python-version: ["3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - id: setup-python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - # - name: Cache pip packages - # id: cache-pip - # uses: actions/cache@v4 - # with: - # path: ${{ steps.setup-python.outputs.python-packages-cache-dir }} - # key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-dev.txt') }} - # restore-keys: | - # ${{ runner.os }}-pip-${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements-dev.txt - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest -v test_optimize_sql_dump.py + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + id: setup-python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + + # =============================== + # INSTALL DEPENDENCIES: LINUX/MAC + # =============================== + - name: Install dependencies (Linux/macOS) + if: runner.os != 'Windows' + shell: bash + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + python -m pip install ruff black isort pytest coverage + + # =============================== + # INSTALL DEPENDENCIES: WINDOWS + # =============================== + - name: Install dependencies (Windows) + if: runner.os == 'Windows' + shell: pwsh + run: | + python -m pip install --upgrade pip + if (Test-Path "requirements.txt") { pip install -r requirements.txt } + python -m pip install ruff black isort pytest coverage + + # ===== LINT ===== + - name: Lint with ruff + run: ruff check . + + # ===== BLACK FORMAT CHECK (z diff) ===== + - name: Black format verification + run: | + black --check --diff . + + # ===== ISORT CHECK ===== + - name: Isort check + run: isort . --check-only --diff + + # ===== TESTS + COVERAGE ===== + - name: Run tests with coverage + run: | + coverage run -m pytest -q + coverage xml + coverage report -m + + # ===== UPLOAD COVERAGE ARTIFACT ===== + - name: Upload coverage report + uses: actions/upload-artifact@v4 + with: + name: coverage-${{ matrix.os }}-${{ matrix.python-version }} + path: coverage.xml + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + id: setup-python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + # - name: Cache pip packages + # id: cache-pip + # uses: actions/cache@v4 + # with: + # path: ${{ steps.setup-python.outputs.python-packages-cache-dir }} + # key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-dev.txt') }} + # restore-keys: | + # ${{ runner.os }}-pip-${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + # Ensure dev tools used in CI are available + python -m pip install ruff black isort pytest + + - name: Lint with ruff + run: | + ruff check . + + - name: Format check (black) + run: | + black --check . + + - name: Test with pytest + run: | + pytest -q diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9b4aafc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,46 @@ +repos: + # ------------------------ + # BLACK — formatting + # ------------------------ + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + language_version: python3 + + # ------------------------ + # ISORT — importy + # ------------------------ + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + + # ------------------------ + # RUFF — lint + autofix + # ------------------------ + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.3 + hooks: + - id: ruff + args: ["--fix"] + + # ------------------------ + # MYPY — typing (opcjonalnie) + # ------------------------ + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + additional_dependencies: [] + + # ------------------------ + # PYTEST (pre-commit test) + # ------------------------ + # - repo: local + # hooks: + # - id: pytest + # name: Run pytest before commit + # entry: pytest + # language: system + # pass_filenames: false diff --git a/optimize_sql_dump.py b/optimize_sql_dump.py index d0b6f2d..528e0ac 100755 --- a/optimize_sql_dump.py +++ b/optimize_sql_dump.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- from __future__ import annotations import argparse @@ -10,12 +9,14 @@ import locale import lzma import os + # concurrent execution was attempted earlier but left incomplete; removed. import re import sys import warnings import zipfile from abc import ABC, abstractmethod +from typing import Optional try: from tqdm import tqdm @@ -34,7 +35,6 @@ except (locale.Error, IndexError): pass # Keep default locale if setting fails import gettext - import logging """ @@ -124,20 +124,13 @@ def open_maybe_compressed(path, mode="rt"): raise ValueError(tl("Empty zip file")) if len(names) > 1: warnings.warn( - tl( - "ZIP archive contains multiple files, using only the first one: {name}" - ).format(name=names[0]) + tl("ZIP archive contains multiple files, using only the first one: {name}").format(name=names[0]), + stacklevel=2, ) b = z.open(names[0], "r") - return ( - io.TextIOWrapper(b, encoding="utf-8", errors="replace") if text_mode else b - ) + return io.TextIOWrapper(b, encoding="utf-8", errors="replace") if text_mode else b else: - f = ( - open(path, mode, encoding="utf-8", errors="replace") - if text_mode - else open(path, mode) - ) + f = open(path, mode, encoding="utf-8", errors="replace") if text_mode else open(path, mode) return f @@ -189,9 +182,7 @@ class DatabaseHandler(ABC): def __init__(self): self.create_re = re.compile(r"^(CREATE\s+TABLE\b).*", re.IGNORECASE) - self.insert_re = re.compile( - r"^(INSERT\s+INTO\s+)(?P[^\s(]+)", re.IGNORECASE - ) + self.insert_re = re.compile(r"^(INSERT\s+INTO\s+)(?P
[^\s(]+)", re.IGNORECASE) self.copy_re = re.compile(r"^(COPY\s+)(?P
[^\s(]+)", re.IGNORECASE) self.insert_template = "INSERT INTO {table} {cols} VALUES\n{values};\n" self.validator: TypeValidator | None = None @@ -205,7 +196,7 @@ def get_truncate_statement(self, tname: str) -> str: pass @abstractmethod - def detect_db_type(path): + def detect_db_type(self, path): """Delegate to module-level detection (kept for backward compatibility).""" try: return detect_db_type(path) @@ -292,12 +283,8 @@ def get_load_statement(self, tname: str, tsv_path: str, cols_str: str) -> str: f"{cols_str};\n" ) - def extract_columns_with_types_from_create( - self, create_stmt: str - ) -> list[tuple[str, str]]: - m = re.search( - r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I - ) + def extract_columns_with_types_from_create(self, create_stmt: str) -> list[tuple[str, str]]: + m = re.search(r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I) if not m: m = re.search(r"CREATE\s+TABLE[^\(]*\((.*)\)\s*;", create_stmt, re.S | re.I) if not m or not self.validator: @@ -306,9 +293,7 @@ def extract_columns_with_types_from_create( cols_with_types = [] for line in cols_blob.splitlines(): line = line.strip().rstrip(",") - if not line or re.match( - r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I - ): + if not line or re.match(r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I): continue parsed = self.validator.parse_column_definition(line) if parsed: @@ -316,9 +301,7 @@ def extract_columns_with_types_from_create( return cols_with_types def extract_full_column_definitions(self, create_stmt: str) -> dict[str, str]: - m = re.search( - r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I - ) + m = re.search(r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I) if not m: m = re.search(r"CREATE\s+TABLE[^\(]*\((.*)\)\s*;", create_stmt, re.S | re.I) if not m: @@ -327,9 +310,7 @@ def extract_full_column_definitions(self, create_stmt: str) -> dict[str, str]: definitions = {} for line in cols_blob.splitlines(): line = line.strip().rstrip(",") - if not line or re.match( - r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I - ): + if not line or re.match(r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I): continue col_name = line.split()[0].strip('`"') definitions[col_name] = line @@ -340,9 +321,7 @@ def extract_primary_key(self, create_stmt: str) -> list[str]: if not m: return [] pk_blob = m.group(1) - pk_cols = [ - re.sub(r"\s*\(\d+\)", "", c.strip()).strip('`"') for c in pk_blob.split(",") - ] + pk_cols = [re.sub(r"\s*\(\d+\)", "", c.strip()).strip('`"') for c in pk_blob.split(",")] return pk_cols def extract_columns_from_create(self, create_stmt: str) -> str: @@ -624,7 +603,13 @@ def _emit_statement(self): self.in_backtick, self.is_escaped, self.paren_level, - ) = False, False, False, False, 0 + ) = ( + False, + False, + False, + False, + 0, + ) t = text.lstrip()[:30].upper() if t.startswith("CREATE TABLE") or t.startswith("CREATE TEMPORARY TABLE"): return "create", text @@ -690,26 +675,20 @@ def _setup_handler(self): if db_type == "auto": db_type = detect_db_type(self.args["inpath"]) if self.args.get("verbose"): - logger.info( - tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type) - ) + logger.info(tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type)) return MySQLHandler() if db_type == "mysql" else PostgresHandler() def _setup_progress(self): global progress filesize = os.path.getsize(self.args["inpath"]) if self.args.get("verbose") and tqdm: - progress = tqdm( - total=filesize, unit="B", unit_scale=True, desc=tl("Processing") - ) + progress = tqdm(total=filesize, unit="B", unit_scale=True, desc=tl("Processing")) else: progress = None return progress def _handle_create(self, stmt): - m = re.search( - r"CREATE\s+TABLE\s+(IF\s+NOT\s+EXISTS\s+)?(?P[^\s\(;]+)", stmt, re.I - ) + m = re.search(r"CREATE\s+TABLE\s+(IF\s+NOT\s+EXISTS\s+)?(?P[^\s\(;]+)", stmt, re.I) if not m: return tname = self.handler.normalize_table_name(m.group("name").strip()) @@ -738,11 +717,7 @@ def _values_to_tsv_row(self, values: list[str | None]) -> str: processed_values.append( "\\n" if v is None - else str(v) - .replace("\\", "\\\\") - .replace("\t", "\\t") - .replace("\n", "\\n") - .replace("\r", "\\r") + else str(v).replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n").replace("\r", "\\r") ) return "\t".join(processed_values) @@ -752,9 +727,7 @@ def _handle_insert(self, stmt): if not tname or (target_table and tname != target_table): return if self.progress: - self.progress.set_description( - tl("Processing table: {tname}").format(tname=tname) - ) + self.progress.set_description(tl("Processing table: {tname}").format(tname=tname)) prefix, values_body = extract_values_from_insert(stmt) if self.load_data_mode: if values_body: @@ -766,33 +739,29 @@ def _handle_insert(self, stmt): except Exception as e: if self.args.get("verbose"): logger.warning( - tl( - "[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}" - ).format(tname=tname, error=e) + tl("[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}").format( + tname=tname, error=e + ) ) return if not prefix or not values_body: # Fallback for unparsable statements # Use writer API to ensure statement is written to the correct output self.writer.write_statement(tname, stmt) return - cols_match = re.search( - r"INSERT\s+INTO\s+[^\(]+(\([^\)]*\))\s*VALUES", prefix, re.I | re.S - ) + cols_match = re.search(r"INSERT\s+INTO\s+[^\(]+(\([^\)]*\))\s*VALUES", prefix, re.I | re.S) cols_text = ( cols_match.group(1).strip() if cols_match - else self.handler.extract_columns_from_create( - self.create_map.get(tname, "") - ) + else self.handler.extract_columns_from_create(self.create_map.get(tname, "")) ) try: tuples = list(SqlMultiTupleParser(values_body)) if values_body else [] except Exception as e: if self.args.get("verbose"): logger.warning( - tl( - "[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}" - ).format(tname=tname, error=e) + tl("[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}").format( + tname=tname, error=e + ) ) tuples = [] if not tuples: @@ -849,29 +818,17 @@ def finalize(self): rows=self.writer.total_rows, batches=self.writer.total_batches ) ) - if ( - not self.args.get("dry_run") - and not self.split_mode - and not self.load_data_mode - ): + if not self.args.get("dry_run") and not self.split_mode and not self.load_data_mode: logger.info(tl("Done. Saved to: {path}").format(path=self.args["outpath"])) elif self.split_mode: - logger.info( - tl("Done. Split dump into files in directory: {path}").format( - path=self.args["split_dir"] - ) - ) + logger.info(tl("Done. Split dump into files in directory: {path}").format(path=self.args["split_dir"])) elif self.load_data_mode: logger.info( - tl("Done. Generated files for import in directory: {path}").format( - path=self.args["load_data_dir"] - ) + tl("Done. Generated files for import in directory: {path}").format(path=self.args["load_data_dir"]) ) elif self.writer.insert_only_mode: logger.info( - tl("Done. Generated insert-only files in directory: {path}").format( - path=self.args["insert_only"] - ) + tl("Done. Generated insert-only files in directory: {path}").format(path=self.args["insert_only"]) ) if self.progress: self.progress.close() @@ -886,15 +843,11 @@ def __init__(self, handler: DatabaseHandler, **kwargs): self.split_mode = bool(self.args.get("split_dir")) self.load_data_mode = bool(self.args.get("load_data_dir")) self.insert_only_mode = bool(self.args.get("insert_only")) - self.output_dir = ( - self.args.get("split_dir") - or self.args.get("load_data_dir") - or self.args.get("insert_only") - ) + self.output_dir = self.args.get("split_dir") or self.args.get("load_data_dir") or self.args.get("insert_only") self.fout = None - self.file_map = {} - self.insert_buffers = {} + self.file_map: dict[str, str] = {} + self.insert_buffers: dict[str, list[str]] = {} self.total_rows = 0 self.total_batches = 0 @@ -911,11 +864,7 @@ def setup(self): elif not self.args.get("dry_run"): self.fout = open(self.args["outpath"], "w", encoding="utf-8") self.fout.write(tl("-- Optimized by SqlDumpOptimizer\n")) - self.fout.write( - tl("-- Source: {source}\n").format( - source=os.path.basename(self.args["inpath"]) - ) - ) + self.fout.write(tl("-- Source: {source}\n").format(source=os.path.basename(self.args["inpath"]))) self.fout.write("--\n") else: self.fout = open(os.devnull, "w") @@ -951,9 +900,7 @@ def write_statement(self, tname, stmt): writer.write(stmt) def add_insert_tuples(self, tname, cols_text, tuples): - buf = self.insert_buffers.setdefault( - tname, {"cols_text": cols_text, "tuples": []} - ) + buf = self.insert_buffers.setdefault(tname, {"cols_text": cols_text, "tuples": []}) buf["tuples"].extend(tuples) self.total_rows += len(tuples) if len(buf["tuples"]) >= self.args.get("batch_size", 1000): @@ -986,9 +933,7 @@ def flush_tsv_buffer(self, tname, force=False): return tsv_info = self.file_map[tname] tsv_buffer_size = self.args.get("tsv_buffer_size", 200) - if tsv_info["tsv_buffer"] and ( - force or len(tsv_info["tsv_buffer"]) >= tsv_buffer_size - ): + if tsv_info["tsv_buffer"] and (force or len(tsv_info["tsv_buffer"]) >= tsv_buffer_size): tsv_info["tsv"].write("\n".join(tsv_info["tsv_buffer"]) + "\n") tsv_info["tsv_buffer"].clear() @@ -1001,12 +946,8 @@ def finalize(self, create_map): self.flush_tsv_buffer(tname, force=True) for tname, writers in self.file_map.items(): if "sql" in writers and not writers["sql"].closed: - cols_str = self.handler.extract_columns_from_create( - create_map.get(tname, "") - ) - load_stmt = self.handler.get_load_statement( - tname, writers["tsv_path"], cols_str - ) + cols_str = self.handler.extract_columns_from_create(create_map.get(tname, "")) + load_stmt = self.handler.get_load_statement(tname, writers["tsv_path"], cols_str) writers["sql"].write(load_stmt) def close_all(self): @@ -1037,18 +978,14 @@ def _setup_handler(self): if db_type == "auto": db_type = detect_db_type(self.args["inpath"]) if self.args.get("verbose"): - logger.info( - tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type) - ) + logger.info(tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type)) return MySQLHandler() if db_type == "mysql" else PostgresHandler() def _setup_progress(self): global progress filesize = os.path.getsize(self.args["inpath"]) if self.args.get("verbose") and tqdm: - progress = tqdm( - total=filesize, unit="B", unit_scale=True, desc=tl("Analyzing dump") - ) + progress = tqdm(total=filesize, unit="B", unit_scale=True, desc=tl("Analyzing dump")) else: progress = None return progress @@ -1072,9 +1009,7 @@ def run(self): self.stats[tname]["inserts"] += 1 _, values_body = extract_values_from_insert(stmt) if values_body: - self.stats[tname]["rows"] += len( - list(SqlMultiTupleParser(values_body)) - ) + self.stats[tname]["rows"] += len(list(SqlMultiTupleParser(values_body))) if self.progress: self.progress.close() self.print_summary() @@ -1138,31 +1073,13 @@ def display(self, outpath): logger.info(tl("Done. Diff saved to: {path}").format(path=outpath)) logger.info("\n" + tl("--- Diff Summary ---")) if not self.insert_only: - logger.info( - tl("Tables to create: {count}").format( - count=self.counts["tables_created"] - ) - ) - logger.info( - tl("Tables to alter: {count}").format( - count=self.counts["tables_altered"] - ) - ) + logger.info(tl("Tables to create: {count}").format(count=self.counts["tables_created"])) + logger.info(tl("Tables to alter: {count}").format(count=self.counts["tables_altered"])) if self.diff_data: - logger.info( - tl("Rows to insert: {count}").format(count=self.counts["rows_inserted"]) - ) + logger.info(tl("Rows to insert: {count}").format(count=self.counts["rows_inserted"])) if not self.insert_only: - logger.info( - tl("Rows to update: {count}").format( - count=self.counts["rows_updated"] - ) - ) - logger.info( - tl("Rows to delete: {count}").format( - count=self.counts["rows_deleted"] - ) - ) + logger.info(tl("Rows to update: {count}").format(count=self.counts["rows_updated"])) + logger.info(tl("Rows to delete: {count}").format(count=self.counts["rows_deleted"])) logger.info("--------------------\n") @@ -1225,9 +1142,7 @@ def get_db_primary_keys(self, table_name: str, pk_cols: list[str]) -> set: pass @abstractmethod - def get_db_row_by_pk( - self, table_name: str, pk_cols: list[str], pk_values: tuple - ) -> dict | None: + def get_db_row_by_pk(self, table_name: str, pk_cols: list[str], pk_values: tuple) -> dict | None: """Fetches a single row from the database by its primary key.""" pass @@ -1252,9 +1167,7 @@ def _store_pks_in_temp_table(self, table_name: str, pk_set: set) -> None: VALUES %s """ # Convert tuples to SQL value lists - values = [ - f"({', '.join(self._format_sql_value(v) for v in pk)})" for pk in pk_set - ] + values = [f"({', '.join(self._format_sql_value(v) for v in pk)})" for pk in pk_set] # Insert in batches to avoid too-long queries batch_size = 1000 for i in range(0, len(values), batch_size): @@ -1276,8 +1189,7 @@ def _ensure_temp_table_for_pks(self, table_name: str, pk_cols: list[str]) -> Non def _pk_exists_in_temp_table(self, table_name: str, pk_values: tuple) -> bool: """Checks if a primary key exists in the temporary table.""" where_clause = " AND ".join( - f"`{col}` = {self._format_sql_value(val)}" - for col, val in zip(self.create_map[table_name]["pk"], pk_values) + f"`{col}` = {self._format_sql_value(val)}" for col, val in zip(self.create_map[table_name]["pk"], pk_values) ) check_stmt = f""" SELECT 1 FROM `_tmp_pks_{table_name}` @@ -1288,76 +1200,80 @@ def _pk_exists_in_temp_table(self, table_name: str, pk_values: tuple) -> bool: # --- Schema and Data Comparison --- - def compare_schemas( - self, dump_cols: dict[str, str], db_cols: dict[str, dict], table_name: str - ) -> list[str]: + def compare_schemas(self, dump_cols: dict[str, str], db_cols: dict[str, dict], table_name: str) -> list[str]: alter_statements = [] last_col = None for col_name, dump_def in dump_cols.items(): if col_name not in db_cols: position = f" AFTER `{last_col}`" if last_col else " FIRST" - alter_statements.append( - f"ALTER TABLE `{table_name}` ADD COLUMN {dump_def}{position};" - ) + alter_statements.append(f"ALTER TABLE `{table_name}` ADD COLUMN {dump_def}{position};") else: db_col_info = db_cols.get(col_name, {}) # If DB metadata is incomplete (e.g., missing COLUMN_TYPE or IS_NULLABLE), skip deep comparison - if ( - not db_col_info - or not db_col_info.get("COLUMN_TYPE") - or ("IS_NULLABLE" not in db_col_info) - ): + if not db_col_info or not db_col_info.get("COLUMN_TYPE") or ("IS_NULLABLE" not in db_col_info): last_col = col_name continue # Build a normalized DB-side column definition for comparison normalized_dump_def = " ".join(dump_def.lower().split()) try: normalized_db_def = " ".join( - str(self._build_db_column_definition(col_name, db_col_info)) - .lower() - .split() + str(self._build_db_column_definition(col_name, db_col_info)).lower().split() ) # Ignore auto_increment when comparing; dump definitions may omit it. normalized_db_def = normalized_db_def.replace(" auto_increment", "") except Exception: normalized_db_def = "" if normalized_dump_def != normalized_db_def: - alter_statements.append( - f"ALTER TABLE `{table_name}` MODIFY COLUMN {dump_def};" - ) + alter_statements.append(f"ALTER TABLE `{table_name}` MODIFY COLUMN {dump_def};") last_col = col_name return alter_statements - def compare_data_row( - self, dump_row: dict, db_row: dict, table_name: str, pk_cols: list[str] - ) -> str | None: - updates = [] - params = [] - pk_values = [] - for col_name, dump_val in dump_row.items(): - db_val = db_row.get(col_name) - dump_val_str = str(dump_val) if dump_val is not None else None - # Normalize DB bytes to string for fair comparison - if isinstance(db_val, (bytes, bytearray)): + def compare_dataqm_row( + self, + dump_row: dict[str, object], + db_row: dict[str, object], + table_name: str, + pk_cols: list[str], + ) -> Optional[str]: + updates: list[str] = [] + + def sql_literal(value: object) -> str: + if value is None: + return "NULL" + if isinstance(value, (bytes, bytearray)): try: - db_val_str = db_val.decode("utf-8") + value = value.decode("utf-8") except Exception: - db_val_str = str(db_val) - else: - db_val_str = str(db_val) if db_val is not None else None + value = str(value) + if isinstance(value, str): + escaped = value.replace("'", "''") + return f"'{escaped}'" + return str(value) + + # porównanie kolumn + for col_name, dump_val in dump_row.items(): + db_val = db_row.get(col_name) + + dump_val_str = sql_literal(dump_val) + db_val_str = sql_literal(db_val) + if dump_val_str != db_val_str: if col_name not in pk_cols: - updates.append(f"`{col_name}` = %s") - params.append(dump_val) + updates.append(f"`{col_name}` = {dump_val_str}") + if not updates: return None + + # WHERE PK + where_clause_parts = [] for col in pk_cols: - pk_values.append(dump_row[col]) - params.extend(pk_values) - where_clause = " AND ".join([f"`{col}` = %s" for col in pk_cols]) - update_stmt = ( - f"UPDATE `{table_name}` SET {', '.join(updates)} WHERE {where_clause};" - ) + where_clause_parts.append(f"`{col}` = {sql_literal(dump_row[col])}") + + where_clause = " AND ".join(where_clause_parts) + + update_stmt = f"UPDATE `{table_name}` SET {', '.join(updates)} WHERE {where_clause};" + + return update_stmt def format_value_for_update(v): if v is None: @@ -1368,7 +1284,7 @@ def format_value_for_update(v): safe_val = str(v).replace("'", "''") return f"'{safe_val}'" - return update_stmt % tuple(format_value_for_update(p) for p in params) + # return update_stmt % tuple(format_value_for_update(p) for p in params) # --- Main Processing Logic --- @@ -1392,17 +1308,13 @@ def _handle_create_statement(self, stmt, fout): } if progress: - progress.set_description( - tl("Diffing schema for {tname}").format(tname=tname) - ) + progress.set_description(tl("Diffing schema for {tname}").format(tname=tname)) db_cols = self.get_db_schema(tname) if not db_cols: if not self.args.get("insert_only"): self.summary.increment("tables_created") - fout.write( - f"-- Table `{tname}` does not exist in the database.\n{stmt};\n" - ) + fout.write(f"-- Table `{tname}` does not exist in the database.\n{stmt};\n") self.create_map[tname]["exists_in_db"] = False elif not self.args.get("insert_only"): alter_statements = self.compare_schemas(dump_cols, db_cols, tname) @@ -1435,23 +1347,16 @@ def _handle_insert_statement(self, stmt, fout): if not table_info.get("pk"): if self.args.get("verbose") and not table_info.get("pk_checked"): logger.warning( - tl( - "[WARN] Skipping data diff for table `{tname}`: no primary key found." - ).format(tname=tname) + tl("[WARN] Skipping data diff for table `{tname}`: no primary key found.").format(tname=tname) ) table_info["pk_checked"] = True return # Check if we need to switch to temp table mem_info = self.memory_usage[tname] - if ( - mem_info["pk_count"] >= self.memory_limit - and not mem_info["using_temp_table"] - ): + if mem_info["pk_count"] >= self.memory_limit and not mem_info["using_temp_table"]: if progress: - progress.set_description( - tl("Creating temp table for {tname}").format(tname=tname) - ) + progress.set_description(tl("Creating temp table for {tname}").format(tname=tname)) self._create_temp_table_for_pks(tname, table_info["pk"]) mem_info["using_temp_table"] = True if "db_pks" in table_info: @@ -1461,9 +1366,7 @@ def _handle_insert_statement(self, stmt, fout): # Initialize PKs if needed if "db_pks" not in table_info: if progress: - progress.set_description( - tl("Diffing data for {tname}").format(tname=tname) - ) + progress.set_description(tl("Diffing data for {tname}").format(tname=tname)) if mem_info["using_temp_table"]: self._ensure_temp_table_for_pks(tname, table_info["pk"]) table_info["db_pks"] = None @@ -1476,9 +1379,7 @@ def _handle_insert_statement(self, stmt, fout): if not table_info.get("pk"): if self.args.get("verbose") and not table_info.get("pk_checked"): logger.warning( - tl( - "[WARN] Skipping data diff for table `{tname}`: no primary key found." - ).format(tname=tname) + tl("[WARN] Skipping data diff for table `{tname}`: no primary key found.").format(tname=tname) ) table_info["pk_checked"] = True return @@ -1489,9 +1390,7 @@ def _handle_insert_statement(self, stmt, fout): and not self.memory_usage[tname]["using_temp_table"] ): if progress: - progress.set_description( - tl("Creating temp table for {tname}").format(tname=tname) - ) + progress.set_description(tl("Creating temp table for {tname}").format(tname=tname)) self._create_temp_table_for_pks(tname, table_info["pk"]) self.memory_usage[tname]["using_temp_table"] = True if "db_pks" in table_info: @@ -1501,9 +1400,7 @@ def _handle_insert_statement(self, stmt, fout): if "db_pks" not in table_info: if progress: - progress.set_description( - tl("Diffing data for {tname}").format(tname=tname) - ) + progress.set_description(tl("Diffing data for {tname}").format(tname=tname)) if self.memory_usage[tname]["using_temp_table"]: # Use temp table for PKs self._ensure_temp_table_for_pks(tname, table_info["pk"]) @@ -1549,9 +1446,7 @@ def _handle_insert_statement(self, stmt, fout): elif not self.args.get("insert_only"): db_row = self.get_db_row_by_pk(tname, table_info["pk"], pk_values) if db_row: - update_stmt = self.compare_data_row( - dump_row_dict, db_row, tname, table_info["pk"] - ) + update_stmt = self.compare_data_row(dump_row_dict, db_row, tname, table_info["pk"]) if update_stmt: self.summary.increment("rows_updated") fout.write(f"{update_stmt}\n") @@ -1563,9 +1458,7 @@ def _generate_delete_statements(self, fout): if progress: progress.set_description(tl("Generating DELETE statements")) - fout.write( - "\n-- Deleting rows that exist in the database but not in the dump\n" - ) + fout.write("\n-- Deleting rows that exist in the database but not in the dump\n") for tname, table_info in self.create_map.items(): # If either side uses a temp table (db_pks or dump_pks is None), skip # the in-memory set difference. Production code could implement a @@ -1583,8 +1476,7 @@ def _generate_delete_statements(self, fout): self.summary.increment("rows_deleted", len(pks_to_delete)) for pk_tuple in pks_to_delete: where_clause = " AND ".join( - f"`{col}` = {self._format_sql_value(val)}" - for col, val in zip(pk_cols, pk_tuple) + f"`{col}` = {self._format_sql_value(val)}" for col, val in zip(pk_cols, pk_tuple) ) fout.write(f"DELETE FROM `{tname}` WHERE {where_clause};\n") @@ -1612,10 +1504,7 @@ def _process_statements(self, fin, fout): statements_count += 1 # Add COMMIT/START TRANSACTION every batch_size statements - if ( - self.use_transactions - and statements_count >= self.txn_batch_size - ): + if self.use_transactions and statements_count >= self.txn_batch_size: fout.write("\nCOMMIT;\nSTART TRANSACTION;\n\n") statements_count = 0 @@ -1693,7 +1582,8 @@ def _build_db_column_definition(self, col_name: str, db_col_info: dict) -> str: if isinstance(default, str) and default.upper() == "CURRENT_TIMESTAMP": parts.append(f"default {default.lower()}") elif isinstance(default, str): - parts.append(f"default '{default.replace("'", "''")}'") + # Escape single quotes in string defaults + parts.append("default '" + default.replace("'", "''") + "'") else: parts.append(f"default {default}") else: @@ -1715,9 +1605,7 @@ def _get_handler(self) -> DatabaseHandler: def connect_db(self): if not mysql: - raise ImportError( - "The 'mysql-connector-python' library is required for diffing with MySQL." - ) + raise ImportError("The 'mysql-connector-python' library is required for diffing with MySQL.") try: self.connection = mysql.connector.connect( host=self.args["db_host"], @@ -1728,18 +1616,16 @@ def connect_db(self): self.cursor = self.connection.cursor(dictionary=True) if self.args.get("verbose"): logger.info( - tl( - "[INFO] Successfully connected to database '{db}' on {host}" - ).format(db=self.args["db_name"], host=self.args["db_host"]) + tl("[INFO] Successfully connected to database '{db}' on {host}").format( + db=self.args["db_name"], host=self.args["db_host"] + ) ) except AttributeError as err: # Some environments may have an incomplete ssl module which causes # mysql.connector to fail when attempting to enable SSL. Retry with # SSL disabled as a fallback. logger.warning( - tl( - "[WARN] SSL not available, retrying DB connection with SSL disabled: {err}" - ).format(err=err) + tl("[WARN] SSL not available, retrying DB connection with SSL disabled: {err}").format(err=err) ) try: self.connection = mysql.connector.connect( @@ -1751,14 +1637,10 @@ def connect_db(self): ) self.cursor = self.connection.cursor(dictionary=True) except mysql.connector.Error as err2: - logger.error( - tl("[ERROR] Database connection failed: {error}").format(error=err2) - ) + logger.error(tl("[ERROR] Database connection failed: {error}").format(error=err2)) sys.exit(1) except mysql.connector.Error as err: - logger.error( - tl("[ERROR] Database connection failed: {error}").format(error=err) - ) + logger.error(tl("[ERROR] Database connection failed: {error}").format(error=err)) sys.exit(1) def get_db_schema(self, table_name: str) -> dict[str, dict]: @@ -1786,15 +1668,11 @@ def get_db_primary_keys(self, table_name: str, pk_cols: list[str]) -> set: keys.add(pk_tuple) if self.args.get("verbose"): logger.info( - tl("[INFO] Fetched {count} primary keys for table `{tname}`.").format( - count=len(keys), tname=table_name - ) + tl("[INFO] Fetched {count} primary keys for table `{tname}`.").format(count=len(keys), tname=table_name) ) return keys - def get_db_row_by_pk( - self, table_name: str, pk_cols: list[str], pk_values: tuple - ) -> dict | None: + def get_db_row_by_pk(self, table_name: str, pk_cols: list[str], pk_values: tuple) -> dict | None: if not self.connection or not pk_cols or len(pk_cols) != len(pk_values): return None where_clause = " AND ".join([f"`{col}` = %s" for col in pk_cols]) @@ -1814,17 +1692,13 @@ def connect_db(self): def _load_config(config_file="optimize_sql_dump.ini"): - config = configparser.ConfigParser( - allow_no_value=True, inline_comment_prefixes=("#", ";") - ) + config = configparser.ConfigParser(allow_no_value=True, inline_comment_prefixes=("#", ";")) config_defaults = {} boolean_flags = {"verbose", "dry_run", "diff_from_db", "diff_data", "info"} boolean_like_flags = {"split", "load_data_dir", "insert_only"} if os.path.exists(config_file) and os.path.getsize(config_file) > 0: config.read(config_file) - _parse_config_sections( - config, config_defaults, boolean_flags, boolean_like_flags - ) + _parse_config_sections(config, config_defaults, boolean_flags, boolean_like_flags) return config_defaults @@ -1877,9 +1751,7 @@ def load_section(section_name, mapping): for key, dest in mapping.items(): if key in config[section_name]: if dest in boolean_flags or dest in boolean_like_flags: - if config[section_name][key] is None or config.getboolean( - section_name, key - ): + if config[section_name][key] is None or config.getboolean(section_name, key): config_defaults[dest] = True else: config_defaults[dest] = config.get(section_name, key) @@ -1916,12 +1788,8 @@ def _create_arg_parser(p: argparse.ArgumentParser) -> argparse.ArgumentParser: type=int, help=tl("Number of tuples in a single merged INSERT (default: 1000)"), ) - p.add_argument( - "--verbose", "-v", action="store_true", help=tl("Print diagnostic information") - ) - p.add_argument( - "--dry-run", action="store_true", help=tl("Dry run: does not write output") - ) + p.add_argument("--verbose", "-v", action="store_true", help=tl("Print diagnostic information")) + p.add_argument("--dry-run", action="store_true", help=tl("Dry run: does not write output")) # --- Mutually Exclusive Output Modes --- output_mode_group = p.add_mutually_exclusive_group() @@ -1945,17 +1813,13 @@ def _create_arg_parser(p: argparse.ArgumentParser) -> argparse.ArgumentParser: nargs="?", const=".", dest="load_data_dir", - help=tl( - "[MySQL] Generate files for LOAD DATA. Optional dir, defaults to current." - ), + help=tl("[MySQL] Generate files for LOAD DATA. Optional dir, defaults to current."), ) output_mode_group.add_argument( "--insert-only", nargs="?", const=".", - help=tl( - "Generate insert-only files (TRUNCATE + INSERTs). Optional dir, defaults to current." - ), + help=tl("Generate insert-only files (TRUNCATE + INSERTs). Optional dir, defaults to current."), ) output_mode_group.add_argument( "--info", @@ -1995,9 +1859,7 @@ def _create_arg_parser(p: argparse.ArgumentParser) -> argparse.ArgumentParser: diff_group.add_argument( "--diff-data", action="store_true", - help=tl( - "Also compare table data and generate INSERT/UPDATE/DELETE statements (requires --diff-from-db)." - ), + help=tl("Also compare table data and generate INSERT/UPDATE/DELETE statements (requires --diff-from-db)."), ) diff_group.add_argument("--db-host", help=tl("Database host for diffing.")) diff_group.add_argument("--db-user", help=tl("Database user for diffing.")) @@ -2024,34 +1886,26 @@ def _validate_args(p, args): sys.exit(2) # Check for at least one output mode if not using --diff-from-db - is_output_mode_set = any( - [args.output, args.split_dir, args.load_data_dir, args.insert_only, args.info] - ) + is_output_mode_set = any([args.output, args.split_dir, args.load_data_dir, args.insert_only, args.info]) if not is_output_mode_set and not args.diff_from_db: p.error( - tl( - "You must specify an output mode (e.g., `script.py in.sql out.sql` or use a flag like --split, --info)." - ) + tl("You must specify an output mode (e.g., `script.py in.sql out.sql` or use a flag like --split, --info).") ) # Append .sql to output filename if needed if args.output and not args.output.lower().endswith(".sql"): if args.verbose: logger.info( - tl( - "[INFO] Output filename does not end with .sql, appending it. New name: {name}" - ).format(name=args.output + ".sql") + tl("[INFO] Output filename does not end with .sql, appending it. New name: {name}").format( + name=args.output + ".sql" + ) ) args.output += ".sql" # Validate --diff-from-db dependencies if args.diff_from_db: if not mysql: - p.error( - tl( - "The 'mysql-connector-python' library is required for --diff-from-db. Please install it." - ) - ) + p.error(tl("The 'mysql-connector-python' library is required for --diff-from-db. Please install it.")) if not args.output: p.error(tl("--diff-from-db requires --output to be specified.")) if not args.db_user or not args.db_name: @@ -2065,9 +1919,7 @@ def _validate_args(p, args): def set_parse_arguments_and_config(): parser = argparse.ArgumentParser( - description=tl( - "SQL Dump Optimizer: merges INSERTs, detects compression, supports MySQL/Postgres." - ) + description=tl("SQL Dump Optimizer: merges INSERTs, detects compression, supports MySQL/Postgres.") ) config_defaults = _load_config() parser.set_defaults(**config_defaults) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e20ccd9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[tool.black] +line-length = 120 +target-version = ["py310", "py311", "py312"] + +[tool.isort] +profile = "black" +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true +use_parentheses = true + +[tool.pytest.ini_options] +addopts = "-q" +testpaths = ["tests"] + +[tool.coverage.run] +branch = true +source = ["."] +omit = ["tests/*"] + +[tool.coverage.report] +skip_empty = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "W", "Q"] + +[tool.ruff.lint.per-file-ignores] +"test_optimize_sql_dump.py" = ["E501"] + +[tool.mypy] +check_untyped_defs = false +disable_error_code = ["annotation-unchecked"] diff --git a/pytest.ini b/pytest.ini index 15c7015..38a4e89 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] -testpaths = tests -norecursedirs = .git venv .venv __pycache__ build dist .eggs *.egg-info actions-runner \ No newline at end of file +# testpaths = tests +norecursedirs = .git venv .venv __pycache__ build dist .eggs *.egg-info actions-runner diff --git a/requirements-dev.txt b/requirements-dev.txt index 791956f..e10d5ba 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,22 +1,39 @@ -# Include production dependencies --r requirements.txt - -# Testing allure-pytest==2.15.0 +allure-python-commons==2.15.0 +attrs==25.3.0 coverage==7.10.6 +DateTime==5.5 +execnet==2.1.1 +gherkin-official==29.0.0 +iniconfig==2.1.0 +isort==7.0.0 +Mako==1.3.10 +MarkupSafe==3.0.2 +mysql-connector-python==8.0.33 +mysqlclient==2.2.7 +packaging==25.0 +parse==1.20.2 +parse_type==0.6.6 +pep8==1.7.1 +pluggy==1.6.0 +protobuf==3.20.3 +py==1.11.0 +pyflakes==3.4.0 +Pygments==2.19.2 pytest==8.4.2 pytest-bdd==8.1.0 +pytest-cache==1.0 pytest-cov==7.0.0 +pytest-flakes==4.0.5 pytest-instafail==0.5.0 +pytest-pep8==1.0.6 pytest-timeout==2.4.0 pytest-xdist==3.8.0 - -# Linting -flake8==7.1.0 - -# Dependencies of testing/linting tools -allure-python-commons==2.15.0 -execnet==2.1.1 -gherkin-official==29.0.0 -iniconfig==2.1.0 -pluggy==1.6.0 +pytz==2025.2 +salt==3007.8 +setuptools==80.9.0 +six==1.17.0 +tabulate==0.9.0 +tqdm==4.67.1 +typing_extensions==4.15.0 +zope.interface==8.0 diff --git a/requirements.txt b/requirements.txt index 65de3a6..e10d5ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,37 @@ +allure-pytest==2.15.0 +allure-python-commons==2.15.0 attrs==25.3.0 +coverage==7.10.6 DateTime==5.5 +execnet==2.1.1 +gherkin-official==29.0.0 +iniconfig==2.1.0 +isort==7.0.0 Mako==1.3.10 MarkupSafe==3.0.2 mysql-connector-python==8.0.33 +mysqlclient==2.2.7 packaging==25.0 parse==1.20.2 parse_type==0.6.6 +pep8==1.7.1 +pluggy==1.6.0 protobuf==3.20.3 -python-dotenv==1.0.0 +py==1.11.0 +pyflakes==3.4.0 +Pygments==2.19.2 +pytest==8.4.2 +pytest-bdd==8.1.0 +pytest-cache==1.0 +pytest-cov==7.0.0 +pytest-flakes==4.0.5 +pytest-instafail==0.5.0 +pytest-pep8==1.0.6 +pytest-timeout==2.4.0 +pytest-xdist==3.8.0 pytz==2025.2 +salt==3007.8 +setuptools==80.9.0 six==1.17.0 tabulate==0.9.0 tqdm==4.67.1 diff --git a/test_optimize_sql_dump.py b/test_optimize_sql_dump.py index 0896f0a..142d929 100644 --- a/test_optimize_sql_dump.py +++ b/test_optimize_sql_dump.py @@ -6,15 +6,13 @@ from unittest.mock import MagicMock, patch import pytest + import optimize_sql_dump as opt # Helper to get the main function for CLI tests +from optimize_sql_dump import escape_sql_value # Assuming the function is in this module from optimize_sql_dump import main as cli_main -from optimize_sql_dump import ( - escape_sql_value, -) # Assuming the function is in this module - @pytest.fixture def mysql_validator(): @@ -93,9 +91,7 @@ def test_validate_postgres(self, postgres_validator): class TestPostgresHandler: def test_normalize_table_name_postgres(self, postgres_handler): - assert ( - postgres_handler.normalize_table_name('"public"."my_table"') == "my_table" - ) + assert postgres_handler.normalize_table_name('"public"."my_table"') == "my_table" assert postgres_handler.normalize_table_name("my_table") == "my_table" assert postgres_handler.normalize_table_name('"my_table"') == "my_table" @@ -108,9 +104,7 @@ def test_get_load_statement_postgres(self, postgres_handler): tsv_path = "/path/to/my_table.tsv" stmt = postgres_handler.get_load_statement("my_table", tsv_path, cols_str) expected_stmt = "COPY \"my_table\" (\"id\", \"name\") FROM '/path/to/my_table.tsv' WITH (FORMAT csv, DELIMITER E'\\t', NULL '\\n');\n" - assert ( - stmt == expected_stmt - ) + assert stmt == expected_stmt def test_extract_columns_from_create_postgres(self, postgres_handler): create_stmt = """ @@ -123,10 +117,7 @@ def test_extract_columns_from_create_postgres(self, postgres_handler): """ columns_str = postgres_handler.extract_columns_from_create(create_stmt) # Normalize returned string to a list of column names and compare as a set for robustness - cols = [ # noqa: E501 - c.strip().strip('"') - for c in columns_str.strip().lstrip("(").rstrip(")").split(",") - ] + cols = [c.strip().strip('"') for c in columns_str.strip().lstrip("(").rstrip(")").split(",")] # noqa: E501 assert set(cols) == {"id", "name", "created_at"} def test_detect_db_type_postgres_specific(self, tmp_path): @@ -225,9 +216,7 @@ def test_dump_analyzer(tmp_path, capsys): # Simulate running from command line with --info test_args = ["optimize_sql_dump.py", "--input", str(dump_file), "--info"] with patch.object(sys, "argv", test_args): - with patch( - "optimize_sql_dump._load_config", return_value={} - ): # Mock _load_config + with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config opt.main() captured = capsys.readouterr() @@ -265,9 +254,7 @@ def test_cli_split_mode(tmp_path): str(split_dir), ] with patch.object(sys, "argv", test_args): - with patch( - "optimize_sql_dump._load_config", return_value={} - ): # Mock _load_config + with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config cli_main() assert (split_dir / "t1.sql").exists() @@ -308,9 +295,7 @@ def test_cli_load_data_mode(tmp_path): str(load_data_dir), ] with patch.object(sys, "argv", test_args): - with patch( - "optimize_sql_dump._load_config", return_value={} - ): # Mock _load_config + with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config cli_main() sql_file = load_data_dir / "users.sql" @@ -350,16 +335,15 @@ def test_cli_invalid_arguments(tmp_path, invalid_args): with pytest.raises(SystemExit) as e: with patch.object(sys, "argv", base_args + invalid_args): - with patch( - "optimize_sql_dump._load_config", return_value={} - ): # Mock _load_config + with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config cli_main() assert e.type is SystemExit assert e.value.code != 0 # Ensure it's an error exit code @pytest.mark.parametrize( - "input_val, expected", [ + "input_val, expected", + [ # Tests for strings ("tekst", "DEFAULT 'tekst'"), ("O'Reilly", "DEFAULT 'O''Reilly'"), @@ -376,7 +360,8 @@ def test_cli_invalid_arguments(tmp_path, invalid_args): (False, "DEFAULT False"), # Test for None (None, "DEFAULT NULL"), - ]) + ], +) def test_escape_sql_value(input_val, expected): assert escape_sql_value(input_val, prefix_str="DEFAULT").strip() == expected @@ -397,7 +382,8 @@ def differ(self, tmp_path): return opt.MySQLDatabaseDiffer(inpath=str(dummy_inpath), verbose=False) @pytest.mark.parametrize( - "input_val, expected_sql", [ + "input_val, expected_sql", + [ (None, "NULL"), (123, "123"), (-45, "-45"), @@ -409,7 +395,8 @@ def differ(self, tmp_path): ('string with "quotes"', "'string with \"quotes\"'"), ("O'Reilly", "'O''Reilly'"), ("multiple 'quotes' here", "'multiple ''quotes'' here'"), - ]) + ], + ) def test_format_sql_value(self, differ, input_val, expected_sql): assert differ._format_sql_value(input_val) == expected_sql @@ -599,9 +586,7 @@ def test_build_db_column_definition(self, differ, col_name, db_col_info, expecte ), ], ) - def test_compare_data_row_edge_cases( - self, differ, dump_row, db_row, pk_cols, expected_fragment - ): + def test_compare_data_row_edge_cases(self, differ, dump_row, db_row, pk_cols, expected_fragment): """Tests various edge cases for data row comparison.""" update_stmt = differ.compare_data_row(dump_row, db_row, "users", pk_cols) if expected_fragment: @@ -616,10 +601,7 @@ def test_run_generates_delete(self, tmp_path): """ in_file = tmp_path / "dump.sql" out_file = tmp_path / "diff.sql" - in_file.write_text( - "CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n" - "INSERT INTO `users` VALUES (1);" - ) + in_file.write_text("CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n" "INSERT INTO `users` VALUES (1);") args = { "inpath": str(in_file), @@ -641,7 +623,7 @@ def test_run_generates_delete(self, tmp_path): differ.connect_db = MagicMock() differ.get_db_schema = MagicMock(return_value={"id": {}}) # Table exists # DB has PKs (1,) and (2,). Dump only has (1,). So (2,) should be deleted. - differ.get_db_primary_keys = MagicMock(return_value={('1',), ('2',)}) + differ.get_db_primary_keys = MagicMock(return_value={("1",), ("2",)}) differ.get_db_row_by_pk = MagicMock(return_value={"id": 1}) differ.run() @@ -657,10 +639,7 @@ def test_handle_insert_uses_temp_table(tmp_path): """Ensure that when memory limit is exceeded, PKs are stored via temp-table helper.""" in_file = tmp_path / "dump.sql" out_file = tmp_path / "diff.sql" - in_file.write_text( - "CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n" - "INSERT INTO `users` VALUES (3);" - ) + in_file.write_text("CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n" "INSERT INTO `users` VALUES (3);") args = { "inpath": str(in_file), @@ -682,7 +661,7 @@ def test_handle_insert_uses_temp_table(tmp_path): differ.connect_db = MagicMock() differ.get_db_schema = MagicMock(return_value={"id": {}}) # Simulate that fetching DB PKs would return something small - differ.get_db_primary_keys = MagicMock(return_value={('1',)}) + differ.get_db_primary_keys = MagicMock(return_value={("1",)}) # Force memory limit to zero so the code chooses temp-table path differ.memory_limit = 0 @@ -707,7 +686,7 @@ class TestDumpWriter: @pytest.fixture def mock_handler(self): handler = MagicMock(spec=opt.MySQLHandler) - handler.normalize_table_name.side_effect = lambda x: x.strip('`') + handler.normalize_table_name.side_effect = lambda x: x.strip("`") handler.insert_template = "INSERT INTO {table} {cols} VALUES\n{values};\n" handler.get_truncate_statement.side_effect = lambda t: f"TRUNCATE TABLE `{t}`;\n" handler.extract_columns_from_create.return_value = "(`id`, `name`)" @@ -727,6 +706,7 @@ def test_setup_normal_mode(self, mock_handler, tmp_path): def test_setup_dry_run(self, mock_handler): import os + args = {"outpath": "out.sql", "inpath": "dummy.sql", "dry_run": True} with opt.DumpWriter(mock_handler, **args) as writer: assert writer.fout.name == os.devnull @@ -762,7 +742,7 @@ def test_insert_only_mode(self, mock_handler, tmp_path): "inpath": str(in_file), "verbose": False, "db_type": "mysql", - "outpath": None + "outpath": None, } optimizer = opt.DumpOptimizer(**args) @@ -797,18 +777,20 @@ def test_tsv_buffering_and_flushing(self, tmp_path): """Tests TSV buffering and flushing logic within DumpOptimizer.""" load_dir = tmp_path / "load_data" in_file = tmp_path / "in.sql" - in_file.write_text(""" + in_file.write_text( + """ CREATE TABLE `t1` (`id` int, `name` varchar(10)); INSERT INTO `t1` VALUES (1, 'a'), (2, 'b'); INSERT INTO `t1` VALUES (3, 'c'); - """) + """ + ) args = { "load_data_dir": str(load_dir), "inpath": str(in_file), "tsv_buffer_size": 2, "verbose": False, - "db_type": "mysql" + "db_type": "mysql", } optimizer = opt.DumpOptimizer(**args)