diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml
index 66ca3a1..c7ba80f 100644
--- a/.github/workflows/python-app.yml
+++ b/.github/workflows/python-app.yml
@@ -1,20 +1,17 @@
-# This workflow will install Python dependencies, run tests and lint with a single version of Python
-# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
-
name: Python application
on:
push:
- branches: [ "main" ]
+ branches: ["main"]
pull_request:
- branches: [ "main" ]
+ branches: ["main"]
permissions:
contents: read
jobs:
build:
- runs-on: self-hosted # self-hosted # ubuntu-latest
+ runs-on: ${{ matrix.os }}
strategy:
matrix:
@@ -22,30 +19,92 @@ jobs:
python-version: ["3.10", "3.11", "3.12"]
steps:
- - uses: actions/checkout@v4
- - name: Set up Python ${{ matrix.python-version }}
- id: setup-python
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- # - name: Cache pip packages
- # id: cache-pip
- # uses: actions/cache@v4
- # with:
- # path: ${{ steps.setup-python.outputs.python-packages-cache-dir }}
- # key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-dev.txt') }}
- # restore-keys: |
- # ${{ runner.os }}-pip-${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install -r requirements-dev.txt
- - name: Lint with flake8
- run: |
- # stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- - name: Test with pytest
- run: |
- pytest -v test_optimize_sql_dump.py
+ - uses: actions/checkout@v4
+
+ - name: Set up Python ${{ matrix.python-version }}
+ id: setup-python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: "pip"
+
+ # ===============================
+ # INSTALL DEPENDENCIES: LINUX/MAC
+ # ===============================
+ - name: Install dependencies (Linux/macOS)
+ if: runner.os != 'Windows'
+ shell: bash
+ run: |
+ python -m pip install --upgrade pip
+ if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
+ python -m pip install ruff black isort pytest coverage
+
+ # ===============================
+ # INSTALL DEPENDENCIES: WINDOWS
+ # ===============================
+ - name: Install dependencies (Windows)
+ if: runner.os == 'Windows'
+ shell: pwsh
+ run: |
+ python -m pip install --upgrade pip
+ if (Test-Path "requirements.txt") { pip install -r requirements.txt }
+ python -m pip install ruff black isort pytest coverage
+
+ # ===== LINT =====
+ - name: Lint with ruff
+ run: ruff check .
+
+ # ===== BLACK FORMAT CHECK (z diff) =====
+ - name: Black format verification
+ run: |
+ black --check --diff .
+
+ # ===== ISORT CHECK =====
+ - name: Isort check
+ run: isort . --check-only --diff
+
+ # ===== TESTS + COVERAGE =====
+ - name: Run tests with coverage
+ run: |
+ coverage run -m pytest -q
+ coverage xml
+ coverage report -m
+
+ # ===== UPLOAD COVERAGE ARTIFACT =====
+ - name: Upload coverage report
+ uses: actions/upload-artifact@v4
+ with:
+ name: coverage-${{ matrix.os }}-${{ matrix.python-version }}
+ path: coverage.xml
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ id: setup-python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ # - name: Cache pip packages
+ # id: cache-pip
+ # uses: actions/cache@v4
+ # with:
+ # path: ${{ steps.setup-python.outputs.python-packages-cache-dir }}
+ # key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-dev.txt') }}
+ # restore-keys: |
+ # ${{ runner.os }}-pip-${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
+ # Ensure dev tools used in CI are available
+ python -m pip install ruff black isort pytest
+
+ - name: Lint with ruff
+ run: |
+ ruff check .
+
+ - name: Format check (black)
+ run: |
+ black --check .
+
+ - name: Test with pytest
+ run: |
+ pytest -q
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..9b4aafc
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,46 @@
+repos:
+ # ------------------------
+ # BLACK — formatting
+ # ------------------------
+ - repo: https://github.com/psf/black
+ rev: 23.12.1
+ hooks:
+ - id: black
+ language_version: python3
+
+ # ------------------------
+ # ISORT — importy
+ # ------------------------
+ - repo: https://github.com/pycqa/isort
+ rev: 5.13.2
+ hooks:
+ - id: isort
+
+ # ------------------------
+ # RUFF — lint + autofix
+ # ------------------------
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.4.3
+ hooks:
+ - id: ruff
+ args: ["--fix"]
+
+ # ------------------------
+ # MYPY — typing (opcjonalnie)
+ # ------------------------
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v1.9.0
+ hooks:
+ - id: mypy
+ additional_dependencies: []
+
+ # ------------------------
+ # PYTEST (pre-commit test)
+ # ------------------------
+ # - repo: local
+ # hooks:
+ # - id: pytest
+ # name: Run pytest before commit
+ # entry: pytest
+ # language: system
+ # pass_filenames: false
diff --git a/optimize_sql_dump.py b/optimize_sql_dump.py
index d0b6f2d..528e0ac 100755
--- a/optimize_sql_dump.py
+++ b/optimize_sql_dump.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
@@ -10,12 +9,14 @@
import locale
import lzma
import os
+
# concurrent execution was attempted earlier but left incomplete; removed.
import re
import sys
import warnings
import zipfile
from abc import ABC, abstractmethod
+from typing import Optional
try:
from tqdm import tqdm
@@ -34,7 +35,6 @@
except (locale.Error, IndexError):
pass # Keep default locale if setting fails
import gettext
-
import logging
"""
@@ -124,20 +124,13 @@ def open_maybe_compressed(path, mode="rt"):
raise ValueError(tl("Empty zip file"))
if len(names) > 1:
warnings.warn(
- tl(
- "ZIP archive contains multiple files, using only the first one: {name}"
- ).format(name=names[0])
+ tl("ZIP archive contains multiple files, using only the first one: {name}").format(name=names[0]),
+ stacklevel=2,
)
b = z.open(names[0], "r")
- return (
- io.TextIOWrapper(b, encoding="utf-8", errors="replace") if text_mode else b
- )
+ return io.TextIOWrapper(b, encoding="utf-8", errors="replace") if text_mode else b
else:
- f = (
- open(path, mode, encoding="utf-8", errors="replace")
- if text_mode
- else open(path, mode)
- )
+ f = open(path, mode, encoding="utf-8", errors="replace") if text_mode else open(path, mode)
return f
@@ -189,9 +182,7 @@ class DatabaseHandler(ABC):
def __init__(self):
self.create_re = re.compile(r"^(CREATE\s+TABLE\b).*", re.IGNORECASE)
- self.insert_re = re.compile(
- r"^(INSERT\s+INTO\s+)(?P
[^\s(]+)", re.IGNORECASE
- )
+ self.insert_re = re.compile(r"^(INSERT\s+INTO\s+)(?P[^\s(]+)", re.IGNORECASE)
self.copy_re = re.compile(r"^(COPY\s+)(?P[^\s(]+)", re.IGNORECASE)
self.insert_template = "INSERT INTO {table} {cols} VALUES\n{values};\n"
self.validator: TypeValidator | None = None
@@ -205,7 +196,7 @@ def get_truncate_statement(self, tname: str) -> str:
pass
@abstractmethod
- def detect_db_type(path):
+ def detect_db_type(self, path):
"""Delegate to module-level detection (kept for backward compatibility)."""
try:
return detect_db_type(path)
@@ -292,12 +283,8 @@ def get_load_statement(self, tname: str, tsv_path: str, cols_str: str) -> str:
f"{cols_str};\n"
)
- def extract_columns_with_types_from_create(
- self, create_stmt: str
- ) -> list[tuple[str, str]]:
- m = re.search(
- r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I
- )
+ def extract_columns_with_types_from_create(self, create_stmt: str) -> list[tuple[str, str]]:
+ m = re.search(r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I)
if not m:
m = re.search(r"CREATE\s+TABLE[^\(]*\((.*)\)\s*;", create_stmt, re.S | re.I)
if not m or not self.validator:
@@ -306,9 +293,7 @@ def extract_columns_with_types_from_create(
cols_with_types = []
for line in cols_blob.splitlines():
line = line.strip().rstrip(",")
- if not line or re.match(
- r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I
- ):
+ if not line or re.match(r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I):
continue
parsed = self.validator.parse_column_definition(line)
if parsed:
@@ -316,9 +301,7 @@ def extract_columns_with_types_from_create(
return cols_with_types
def extract_full_column_definitions(self, create_stmt: str) -> dict[str, str]:
- m = re.search(
- r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I
- )
+ m = re.search(r"\((.*)\)\s*(ENGINE|TYPE|AS|COMMENT|;)", create_stmt, re.S | re.I)
if not m:
m = re.search(r"CREATE\s+TABLE[^\(]*\((.*)\)\s*;", create_stmt, re.S | re.I)
if not m:
@@ -327,9 +310,7 @@ def extract_full_column_definitions(self, create_stmt: str) -> dict[str, str]:
definitions = {}
for line in cols_blob.splitlines():
line = line.strip().rstrip(",")
- if not line or re.match(
- r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I
- ):
+ if not line or re.match(r"PRIMARY\s+KEY|KEY\s+|UNIQUE\s+|CONSTRAINT\s+", line, re.I):
continue
col_name = line.split()[0].strip('`"')
definitions[col_name] = line
@@ -340,9 +321,7 @@ def extract_primary_key(self, create_stmt: str) -> list[str]:
if not m:
return []
pk_blob = m.group(1)
- pk_cols = [
- re.sub(r"\s*\(\d+\)", "", c.strip()).strip('`"') for c in pk_blob.split(",")
- ]
+ pk_cols = [re.sub(r"\s*\(\d+\)", "", c.strip()).strip('`"') for c in pk_blob.split(",")]
return pk_cols
def extract_columns_from_create(self, create_stmt: str) -> str:
@@ -624,7 +603,13 @@ def _emit_statement(self):
self.in_backtick,
self.is_escaped,
self.paren_level,
- ) = False, False, False, False, 0
+ ) = (
+ False,
+ False,
+ False,
+ False,
+ 0,
+ )
t = text.lstrip()[:30].upper()
if t.startswith("CREATE TABLE") or t.startswith("CREATE TEMPORARY TABLE"):
return "create", text
@@ -690,26 +675,20 @@ def _setup_handler(self):
if db_type == "auto":
db_type = detect_db_type(self.args["inpath"])
if self.args.get("verbose"):
- logger.info(
- tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type)
- )
+ logger.info(tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type))
return MySQLHandler() if db_type == "mysql" else PostgresHandler()
def _setup_progress(self):
global progress
filesize = os.path.getsize(self.args["inpath"])
if self.args.get("verbose") and tqdm:
- progress = tqdm(
- total=filesize, unit="B", unit_scale=True, desc=tl("Processing")
- )
+ progress = tqdm(total=filesize, unit="B", unit_scale=True, desc=tl("Processing"))
else:
progress = None
return progress
def _handle_create(self, stmt):
- m = re.search(
- r"CREATE\s+TABLE\s+(IF\s+NOT\s+EXISTS\s+)?(?P[^\s\(;]+)", stmt, re.I
- )
+ m = re.search(r"CREATE\s+TABLE\s+(IF\s+NOT\s+EXISTS\s+)?(?P[^\s\(;]+)", stmt, re.I)
if not m:
return
tname = self.handler.normalize_table_name(m.group("name").strip())
@@ -738,11 +717,7 @@ def _values_to_tsv_row(self, values: list[str | None]) -> str:
processed_values.append(
"\\n"
if v is None
- else str(v)
- .replace("\\", "\\\\")
- .replace("\t", "\\t")
- .replace("\n", "\\n")
- .replace("\r", "\\r")
+ else str(v).replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n").replace("\r", "\\r")
)
return "\t".join(processed_values)
@@ -752,9 +727,7 @@ def _handle_insert(self, stmt):
if not tname or (target_table and tname != target_table):
return
if self.progress:
- self.progress.set_description(
- tl("Processing table: {tname}").format(tname=tname)
- )
+ self.progress.set_description(tl("Processing table: {tname}").format(tname=tname))
prefix, values_body = extract_values_from_insert(stmt)
if self.load_data_mode:
if values_body:
@@ -766,33 +739,29 @@ def _handle_insert(self, stmt):
except Exception as e:
if self.args.get("verbose"):
logger.warning(
- tl(
- "[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}"
- ).format(tname=tname, error=e)
+ tl("[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}").format(
+ tname=tname, error=e
+ )
)
return
if not prefix or not values_body: # Fallback for unparsable statements
# Use writer API to ensure statement is written to the correct output
self.writer.write_statement(tname, stmt)
return
- cols_match = re.search(
- r"INSERT\s+INTO\s+[^\(]+(\([^\)]*\))\s*VALUES", prefix, re.I | re.S
- )
+ cols_match = re.search(r"INSERT\s+INTO\s+[^\(]+(\([^\)]*\))\s*VALUES", prefix, re.I | re.S)
cols_text = (
cols_match.group(1).strip()
if cols_match
- else self.handler.extract_columns_from_create(
- self.create_map.get(tname, "")
- )
+ else self.handler.extract_columns_from_create(self.create_map.get(tname, ""))
)
try:
tuples = list(SqlMultiTupleParser(values_body)) if values_body else []
except Exception as e:
if self.args.get("verbose"):
logger.warning(
- tl(
- "[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}"
- ).format(tname=tname, error=e)
+ tl("[WARN] Failed to parse VALUES in INSERT for table {tname}: {error}").format(
+ tname=tname, error=e
+ )
)
tuples = []
if not tuples:
@@ -849,29 +818,17 @@ def finalize(self):
rows=self.writer.total_rows, batches=self.writer.total_batches
)
)
- if (
- not self.args.get("dry_run")
- and not self.split_mode
- and not self.load_data_mode
- ):
+ if not self.args.get("dry_run") and not self.split_mode and not self.load_data_mode:
logger.info(tl("Done. Saved to: {path}").format(path=self.args["outpath"]))
elif self.split_mode:
- logger.info(
- tl("Done. Split dump into files in directory: {path}").format(
- path=self.args["split_dir"]
- )
- )
+ logger.info(tl("Done. Split dump into files in directory: {path}").format(path=self.args["split_dir"]))
elif self.load_data_mode:
logger.info(
- tl("Done. Generated files for import in directory: {path}").format(
- path=self.args["load_data_dir"]
- )
+ tl("Done. Generated files for import in directory: {path}").format(path=self.args["load_data_dir"])
)
elif self.writer.insert_only_mode:
logger.info(
- tl("Done. Generated insert-only files in directory: {path}").format(
- path=self.args["insert_only"]
- )
+ tl("Done. Generated insert-only files in directory: {path}").format(path=self.args["insert_only"])
)
if self.progress:
self.progress.close()
@@ -886,15 +843,11 @@ def __init__(self, handler: DatabaseHandler, **kwargs):
self.split_mode = bool(self.args.get("split_dir"))
self.load_data_mode = bool(self.args.get("load_data_dir"))
self.insert_only_mode = bool(self.args.get("insert_only"))
- self.output_dir = (
- self.args.get("split_dir")
- or self.args.get("load_data_dir")
- or self.args.get("insert_only")
- )
+ self.output_dir = self.args.get("split_dir") or self.args.get("load_data_dir") or self.args.get("insert_only")
self.fout = None
- self.file_map = {}
- self.insert_buffers = {}
+ self.file_map: dict[str, str] = {}
+ self.insert_buffers: dict[str, list[str]] = {}
self.total_rows = 0
self.total_batches = 0
@@ -911,11 +864,7 @@ def setup(self):
elif not self.args.get("dry_run"):
self.fout = open(self.args["outpath"], "w", encoding="utf-8")
self.fout.write(tl("-- Optimized by SqlDumpOptimizer\n"))
- self.fout.write(
- tl("-- Source: {source}\n").format(
- source=os.path.basename(self.args["inpath"])
- )
- )
+ self.fout.write(tl("-- Source: {source}\n").format(source=os.path.basename(self.args["inpath"])))
self.fout.write("--\n")
else:
self.fout = open(os.devnull, "w")
@@ -951,9 +900,7 @@ def write_statement(self, tname, stmt):
writer.write(stmt)
def add_insert_tuples(self, tname, cols_text, tuples):
- buf = self.insert_buffers.setdefault(
- tname, {"cols_text": cols_text, "tuples": []}
- )
+ buf = self.insert_buffers.setdefault(tname, {"cols_text": cols_text, "tuples": []})
buf["tuples"].extend(tuples)
self.total_rows += len(tuples)
if len(buf["tuples"]) >= self.args.get("batch_size", 1000):
@@ -986,9 +933,7 @@ def flush_tsv_buffer(self, tname, force=False):
return
tsv_info = self.file_map[tname]
tsv_buffer_size = self.args.get("tsv_buffer_size", 200)
- if tsv_info["tsv_buffer"] and (
- force or len(tsv_info["tsv_buffer"]) >= tsv_buffer_size
- ):
+ if tsv_info["tsv_buffer"] and (force or len(tsv_info["tsv_buffer"]) >= tsv_buffer_size):
tsv_info["tsv"].write("\n".join(tsv_info["tsv_buffer"]) + "\n")
tsv_info["tsv_buffer"].clear()
@@ -1001,12 +946,8 @@ def finalize(self, create_map):
self.flush_tsv_buffer(tname, force=True)
for tname, writers in self.file_map.items():
if "sql" in writers and not writers["sql"].closed:
- cols_str = self.handler.extract_columns_from_create(
- create_map.get(tname, "")
- )
- load_stmt = self.handler.get_load_statement(
- tname, writers["tsv_path"], cols_str
- )
+ cols_str = self.handler.extract_columns_from_create(create_map.get(tname, ""))
+ load_stmt = self.handler.get_load_statement(tname, writers["tsv_path"], cols_str)
writers["sql"].write(load_stmt)
def close_all(self):
@@ -1037,18 +978,14 @@ def _setup_handler(self):
if db_type == "auto":
db_type = detect_db_type(self.args["inpath"])
if self.args.get("verbose"):
- logger.info(
- tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type)
- )
+ logger.info(tl("[INFO] Detected DB type: {db_type}").format(db_type=db_type))
return MySQLHandler() if db_type == "mysql" else PostgresHandler()
def _setup_progress(self):
global progress
filesize = os.path.getsize(self.args["inpath"])
if self.args.get("verbose") and tqdm:
- progress = tqdm(
- total=filesize, unit="B", unit_scale=True, desc=tl("Analyzing dump")
- )
+ progress = tqdm(total=filesize, unit="B", unit_scale=True, desc=tl("Analyzing dump"))
else:
progress = None
return progress
@@ -1072,9 +1009,7 @@ def run(self):
self.stats[tname]["inserts"] += 1
_, values_body = extract_values_from_insert(stmt)
if values_body:
- self.stats[tname]["rows"] += len(
- list(SqlMultiTupleParser(values_body))
- )
+ self.stats[tname]["rows"] += len(list(SqlMultiTupleParser(values_body)))
if self.progress:
self.progress.close()
self.print_summary()
@@ -1138,31 +1073,13 @@ def display(self, outpath):
logger.info(tl("Done. Diff saved to: {path}").format(path=outpath))
logger.info("\n" + tl("--- Diff Summary ---"))
if not self.insert_only:
- logger.info(
- tl("Tables to create: {count}").format(
- count=self.counts["tables_created"]
- )
- )
- logger.info(
- tl("Tables to alter: {count}").format(
- count=self.counts["tables_altered"]
- )
- )
+ logger.info(tl("Tables to create: {count}").format(count=self.counts["tables_created"]))
+ logger.info(tl("Tables to alter: {count}").format(count=self.counts["tables_altered"]))
if self.diff_data:
- logger.info(
- tl("Rows to insert: {count}").format(count=self.counts["rows_inserted"])
- )
+ logger.info(tl("Rows to insert: {count}").format(count=self.counts["rows_inserted"]))
if not self.insert_only:
- logger.info(
- tl("Rows to update: {count}").format(
- count=self.counts["rows_updated"]
- )
- )
- logger.info(
- tl("Rows to delete: {count}").format(
- count=self.counts["rows_deleted"]
- )
- )
+ logger.info(tl("Rows to update: {count}").format(count=self.counts["rows_updated"]))
+ logger.info(tl("Rows to delete: {count}").format(count=self.counts["rows_deleted"]))
logger.info("--------------------\n")
@@ -1225,9 +1142,7 @@ def get_db_primary_keys(self, table_name: str, pk_cols: list[str]) -> set:
pass
@abstractmethod
- def get_db_row_by_pk(
- self, table_name: str, pk_cols: list[str], pk_values: tuple
- ) -> dict | None:
+ def get_db_row_by_pk(self, table_name: str, pk_cols: list[str], pk_values: tuple) -> dict | None:
"""Fetches a single row from the database by its primary key."""
pass
@@ -1252,9 +1167,7 @@ def _store_pks_in_temp_table(self, table_name: str, pk_set: set) -> None:
VALUES %s
"""
# Convert tuples to SQL value lists
- values = [
- f"({', '.join(self._format_sql_value(v) for v in pk)})" for pk in pk_set
- ]
+ values = [f"({', '.join(self._format_sql_value(v) for v in pk)})" for pk in pk_set]
# Insert in batches to avoid too-long queries
batch_size = 1000
for i in range(0, len(values), batch_size):
@@ -1276,8 +1189,7 @@ def _ensure_temp_table_for_pks(self, table_name: str, pk_cols: list[str]) -> Non
def _pk_exists_in_temp_table(self, table_name: str, pk_values: tuple) -> bool:
"""Checks if a primary key exists in the temporary table."""
where_clause = " AND ".join(
- f"`{col}` = {self._format_sql_value(val)}"
- for col, val in zip(self.create_map[table_name]["pk"], pk_values)
+ f"`{col}` = {self._format_sql_value(val)}" for col, val in zip(self.create_map[table_name]["pk"], pk_values)
)
check_stmt = f"""
SELECT 1 FROM `_tmp_pks_{table_name}`
@@ -1288,76 +1200,80 @@ def _pk_exists_in_temp_table(self, table_name: str, pk_values: tuple) -> bool:
# --- Schema and Data Comparison ---
- def compare_schemas(
- self, dump_cols: dict[str, str], db_cols: dict[str, dict], table_name: str
- ) -> list[str]:
+ def compare_schemas(self, dump_cols: dict[str, str], db_cols: dict[str, dict], table_name: str) -> list[str]:
alter_statements = []
last_col = None
for col_name, dump_def in dump_cols.items():
if col_name not in db_cols:
position = f" AFTER `{last_col}`" if last_col else " FIRST"
- alter_statements.append(
- f"ALTER TABLE `{table_name}` ADD COLUMN {dump_def}{position};"
- )
+ alter_statements.append(f"ALTER TABLE `{table_name}` ADD COLUMN {dump_def}{position};")
else:
db_col_info = db_cols.get(col_name, {})
# If DB metadata is incomplete (e.g., missing COLUMN_TYPE or IS_NULLABLE), skip deep comparison
- if (
- not db_col_info
- or not db_col_info.get("COLUMN_TYPE")
- or ("IS_NULLABLE" not in db_col_info)
- ):
+ if not db_col_info or not db_col_info.get("COLUMN_TYPE") or ("IS_NULLABLE" not in db_col_info):
last_col = col_name
continue
# Build a normalized DB-side column definition for comparison
normalized_dump_def = " ".join(dump_def.lower().split())
try:
normalized_db_def = " ".join(
- str(self._build_db_column_definition(col_name, db_col_info))
- .lower()
- .split()
+ str(self._build_db_column_definition(col_name, db_col_info)).lower().split()
)
# Ignore auto_increment when comparing; dump definitions may omit it.
normalized_db_def = normalized_db_def.replace(" auto_increment", "")
except Exception:
normalized_db_def = ""
if normalized_dump_def != normalized_db_def:
- alter_statements.append(
- f"ALTER TABLE `{table_name}` MODIFY COLUMN {dump_def};"
- )
+ alter_statements.append(f"ALTER TABLE `{table_name}` MODIFY COLUMN {dump_def};")
last_col = col_name
return alter_statements
- def compare_data_row(
- self, dump_row: dict, db_row: dict, table_name: str, pk_cols: list[str]
- ) -> str | None:
- updates = []
- params = []
- pk_values = []
- for col_name, dump_val in dump_row.items():
- db_val = db_row.get(col_name)
- dump_val_str = str(dump_val) if dump_val is not None else None
- # Normalize DB bytes to string for fair comparison
- if isinstance(db_val, (bytes, bytearray)):
+ def compare_dataqm_row(
+ self,
+ dump_row: dict[str, object],
+ db_row: dict[str, object],
+ table_name: str,
+ pk_cols: list[str],
+ ) -> Optional[str]:
+ updates: list[str] = []
+
+ def sql_literal(value: object) -> str:
+ if value is None:
+ return "NULL"
+ if isinstance(value, (bytes, bytearray)):
try:
- db_val_str = db_val.decode("utf-8")
+ value = value.decode("utf-8")
except Exception:
- db_val_str = str(db_val)
- else:
- db_val_str = str(db_val) if db_val is not None else None
+ value = str(value)
+ if isinstance(value, str):
+ escaped = value.replace("'", "''")
+ return f"'{escaped}'"
+ return str(value)
+
+ # porównanie kolumn
+ for col_name, dump_val in dump_row.items():
+ db_val = db_row.get(col_name)
+
+ dump_val_str = sql_literal(dump_val)
+ db_val_str = sql_literal(db_val)
+
if dump_val_str != db_val_str:
if col_name not in pk_cols:
- updates.append(f"`{col_name}` = %s")
- params.append(dump_val)
+ updates.append(f"`{col_name}` = {dump_val_str}")
+
if not updates:
return None
+
+ # WHERE PK
+ where_clause_parts = []
for col in pk_cols:
- pk_values.append(dump_row[col])
- params.extend(pk_values)
- where_clause = " AND ".join([f"`{col}` = %s" for col in pk_cols])
- update_stmt = (
- f"UPDATE `{table_name}` SET {', '.join(updates)} WHERE {where_clause};"
- )
+ where_clause_parts.append(f"`{col}` = {sql_literal(dump_row[col])}")
+
+ where_clause = " AND ".join(where_clause_parts)
+
+ update_stmt = f"UPDATE `{table_name}` SET {', '.join(updates)} WHERE {where_clause};"
+
+ return update_stmt
def format_value_for_update(v):
if v is None:
@@ -1368,7 +1284,7 @@ def format_value_for_update(v):
safe_val = str(v).replace("'", "''")
return f"'{safe_val}'"
- return update_stmt % tuple(format_value_for_update(p) for p in params)
+ # return update_stmt % tuple(format_value_for_update(p) for p in params)
# --- Main Processing Logic ---
@@ -1392,17 +1308,13 @@ def _handle_create_statement(self, stmt, fout):
}
if progress:
- progress.set_description(
- tl("Diffing schema for {tname}").format(tname=tname)
- )
+ progress.set_description(tl("Diffing schema for {tname}").format(tname=tname))
db_cols = self.get_db_schema(tname)
if not db_cols:
if not self.args.get("insert_only"):
self.summary.increment("tables_created")
- fout.write(
- f"-- Table `{tname}` does not exist in the database.\n{stmt};\n"
- )
+ fout.write(f"-- Table `{tname}` does not exist in the database.\n{stmt};\n")
self.create_map[tname]["exists_in_db"] = False
elif not self.args.get("insert_only"):
alter_statements = self.compare_schemas(dump_cols, db_cols, tname)
@@ -1435,23 +1347,16 @@ def _handle_insert_statement(self, stmt, fout):
if not table_info.get("pk"):
if self.args.get("verbose") and not table_info.get("pk_checked"):
logger.warning(
- tl(
- "[WARN] Skipping data diff for table `{tname}`: no primary key found."
- ).format(tname=tname)
+ tl("[WARN] Skipping data diff for table `{tname}`: no primary key found.").format(tname=tname)
)
table_info["pk_checked"] = True
return
# Check if we need to switch to temp table
mem_info = self.memory_usage[tname]
- if (
- mem_info["pk_count"] >= self.memory_limit
- and not mem_info["using_temp_table"]
- ):
+ if mem_info["pk_count"] >= self.memory_limit and not mem_info["using_temp_table"]:
if progress:
- progress.set_description(
- tl("Creating temp table for {tname}").format(tname=tname)
- )
+ progress.set_description(tl("Creating temp table for {tname}").format(tname=tname))
self._create_temp_table_for_pks(tname, table_info["pk"])
mem_info["using_temp_table"] = True
if "db_pks" in table_info:
@@ -1461,9 +1366,7 @@ def _handle_insert_statement(self, stmt, fout):
# Initialize PKs if needed
if "db_pks" not in table_info:
if progress:
- progress.set_description(
- tl("Diffing data for {tname}").format(tname=tname)
- )
+ progress.set_description(tl("Diffing data for {tname}").format(tname=tname))
if mem_info["using_temp_table"]:
self._ensure_temp_table_for_pks(tname, table_info["pk"])
table_info["db_pks"] = None
@@ -1476,9 +1379,7 @@ def _handle_insert_statement(self, stmt, fout):
if not table_info.get("pk"):
if self.args.get("verbose") and not table_info.get("pk_checked"):
logger.warning(
- tl(
- "[WARN] Skipping data diff for table `{tname}`: no primary key found."
- ).format(tname=tname)
+ tl("[WARN] Skipping data diff for table `{tname}`: no primary key found.").format(tname=tname)
)
table_info["pk_checked"] = True
return
@@ -1489,9 +1390,7 @@ def _handle_insert_statement(self, stmt, fout):
and not self.memory_usage[tname]["using_temp_table"]
):
if progress:
- progress.set_description(
- tl("Creating temp table for {tname}").format(tname=tname)
- )
+ progress.set_description(tl("Creating temp table for {tname}").format(tname=tname))
self._create_temp_table_for_pks(tname, table_info["pk"])
self.memory_usage[tname]["using_temp_table"] = True
if "db_pks" in table_info:
@@ -1501,9 +1400,7 @@ def _handle_insert_statement(self, stmt, fout):
if "db_pks" not in table_info:
if progress:
- progress.set_description(
- tl("Diffing data for {tname}").format(tname=tname)
- )
+ progress.set_description(tl("Diffing data for {tname}").format(tname=tname))
if self.memory_usage[tname]["using_temp_table"]:
# Use temp table for PKs
self._ensure_temp_table_for_pks(tname, table_info["pk"])
@@ -1549,9 +1446,7 @@ def _handle_insert_statement(self, stmt, fout):
elif not self.args.get("insert_only"):
db_row = self.get_db_row_by_pk(tname, table_info["pk"], pk_values)
if db_row:
- update_stmt = self.compare_data_row(
- dump_row_dict, db_row, tname, table_info["pk"]
- )
+ update_stmt = self.compare_data_row(dump_row_dict, db_row, tname, table_info["pk"])
if update_stmt:
self.summary.increment("rows_updated")
fout.write(f"{update_stmt}\n")
@@ -1563,9 +1458,7 @@ def _generate_delete_statements(self, fout):
if progress:
progress.set_description(tl("Generating DELETE statements"))
- fout.write(
- "\n-- Deleting rows that exist in the database but not in the dump\n"
- )
+ fout.write("\n-- Deleting rows that exist in the database but not in the dump\n")
for tname, table_info in self.create_map.items():
# If either side uses a temp table (db_pks or dump_pks is None), skip
# the in-memory set difference. Production code could implement a
@@ -1583,8 +1476,7 @@ def _generate_delete_statements(self, fout):
self.summary.increment("rows_deleted", len(pks_to_delete))
for pk_tuple in pks_to_delete:
where_clause = " AND ".join(
- f"`{col}` = {self._format_sql_value(val)}"
- for col, val in zip(pk_cols, pk_tuple)
+ f"`{col}` = {self._format_sql_value(val)}" for col, val in zip(pk_cols, pk_tuple)
)
fout.write(f"DELETE FROM `{tname}` WHERE {where_clause};\n")
@@ -1612,10 +1504,7 @@ def _process_statements(self, fin, fout):
statements_count += 1
# Add COMMIT/START TRANSACTION every batch_size statements
- if (
- self.use_transactions
- and statements_count >= self.txn_batch_size
- ):
+ if self.use_transactions and statements_count >= self.txn_batch_size:
fout.write("\nCOMMIT;\nSTART TRANSACTION;\n\n")
statements_count = 0
@@ -1693,7 +1582,8 @@ def _build_db_column_definition(self, col_name: str, db_col_info: dict) -> str:
if isinstance(default, str) and default.upper() == "CURRENT_TIMESTAMP":
parts.append(f"default {default.lower()}")
elif isinstance(default, str):
- parts.append(f"default '{default.replace("'", "''")}'")
+ # Escape single quotes in string defaults
+ parts.append("default '" + default.replace("'", "''") + "'")
else:
parts.append(f"default {default}")
else:
@@ -1715,9 +1605,7 @@ def _get_handler(self) -> DatabaseHandler:
def connect_db(self):
if not mysql:
- raise ImportError(
- "The 'mysql-connector-python' library is required for diffing with MySQL."
- )
+ raise ImportError("The 'mysql-connector-python' library is required for diffing with MySQL.")
try:
self.connection = mysql.connector.connect(
host=self.args["db_host"],
@@ -1728,18 +1616,16 @@ def connect_db(self):
self.cursor = self.connection.cursor(dictionary=True)
if self.args.get("verbose"):
logger.info(
- tl(
- "[INFO] Successfully connected to database '{db}' on {host}"
- ).format(db=self.args["db_name"], host=self.args["db_host"])
+ tl("[INFO] Successfully connected to database '{db}' on {host}").format(
+ db=self.args["db_name"], host=self.args["db_host"]
+ )
)
except AttributeError as err:
# Some environments may have an incomplete ssl module which causes
# mysql.connector to fail when attempting to enable SSL. Retry with
# SSL disabled as a fallback.
logger.warning(
- tl(
- "[WARN] SSL not available, retrying DB connection with SSL disabled: {err}"
- ).format(err=err)
+ tl("[WARN] SSL not available, retrying DB connection with SSL disabled: {err}").format(err=err)
)
try:
self.connection = mysql.connector.connect(
@@ -1751,14 +1637,10 @@ def connect_db(self):
)
self.cursor = self.connection.cursor(dictionary=True)
except mysql.connector.Error as err2:
- logger.error(
- tl("[ERROR] Database connection failed: {error}").format(error=err2)
- )
+ logger.error(tl("[ERROR] Database connection failed: {error}").format(error=err2))
sys.exit(1)
except mysql.connector.Error as err:
- logger.error(
- tl("[ERROR] Database connection failed: {error}").format(error=err)
- )
+ logger.error(tl("[ERROR] Database connection failed: {error}").format(error=err))
sys.exit(1)
def get_db_schema(self, table_name: str) -> dict[str, dict]:
@@ -1786,15 +1668,11 @@ def get_db_primary_keys(self, table_name: str, pk_cols: list[str]) -> set:
keys.add(pk_tuple)
if self.args.get("verbose"):
logger.info(
- tl("[INFO] Fetched {count} primary keys for table `{tname}`.").format(
- count=len(keys), tname=table_name
- )
+ tl("[INFO] Fetched {count} primary keys for table `{tname}`.").format(count=len(keys), tname=table_name)
)
return keys
- def get_db_row_by_pk(
- self, table_name: str, pk_cols: list[str], pk_values: tuple
- ) -> dict | None:
+ def get_db_row_by_pk(self, table_name: str, pk_cols: list[str], pk_values: tuple) -> dict | None:
if not self.connection or not pk_cols or len(pk_cols) != len(pk_values):
return None
where_clause = " AND ".join([f"`{col}` = %s" for col in pk_cols])
@@ -1814,17 +1692,13 @@ def connect_db(self):
def _load_config(config_file="optimize_sql_dump.ini"):
- config = configparser.ConfigParser(
- allow_no_value=True, inline_comment_prefixes=("#", ";")
- )
+ config = configparser.ConfigParser(allow_no_value=True, inline_comment_prefixes=("#", ";"))
config_defaults = {}
boolean_flags = {"verbose", "dry_run", "diff_from_db", "diff_data", "info"}
boolean_like_flags = {"split", "load_data_dir", "insert_only"}
if os.path.exists(config_file) and os.path.getsize(config_file) > 0:
config.read(config_file)
- _parse_config_sections(
- config, config_defaults, boolean_flags, boolean_like_flags
- )
+ _parse_config_sections(config, config_defaults, boolean_flags, boolean_like_flags)
return config_defaults
@@ -1877,9 +1751,7 @@ def load_section(section_name, mapping):
for key, dest in mapping.items():
if key in config[section_name]:
if dest in boolean_flags or dest in boolean_like_flags:
- if config[section_name][key] is None or config.getboolean(
- section_name, key
- ):
+ if config[section_name][key] is None or config.getboolean(section_name, key):
config_defaults[dest] = True
else:
config_defaults[dest] = config.get(section_name, key)
@@ -1916,12 +1788,8 @@ def _create_arg_parser(p: argparse.ArgumentParser) -> argparse.ArgumentParser:
type=int,
help=tl("Number of tuples in a single merged INSERT (default: 1000)"),
)
- p.add_argument(
- "--verbose", "-v", action="store_true", help=tl("Print diagnostic information")
- )
- p.add_argument(
- "--dry-run", action="store_true", help=tl("Dry run: does not write output")
- )
+ p.add_argument("--verbose", "-v", action="store_true", help=tl("Print diagnostic information"))
+ p.add_argument("--dry-run", action="store_true", help=tl("Dry run: does not write output"))
# --- Mutually Exclusive Output Modes ---
output_mode_group = p.add_mutually_exclusive_group()
@@ -1945,17 +1813,13 @@ def _create_arg_parser(p: argparse.ArgumentParser) -> argparse.ArgumentParser:
nargs="?",
const=".",
dest="load_data_dir",
- help=tl(
- "[MySQL] Generate files for LOAD DATA. Optional dir, defaults to current."
- ),
+ help=tl("[MySQL] Generate files for LOAD DATA. Optional dir, defaults to current."),
)
output_mode_group.add_argument(
"--insert-only",
nargs="?",
const=".",
- help=tl(
- "Generate insert-only files (TRUNCATE + INSERTs). Optional dir, defaults to current."
- ),
+ help=tl("Generate insert-only files (TRUNCATE + INSERTs). Optional dir, defaults to current."),
)
output_mode_group.add_argument(
"--info",
@@ -1995,9 +1859,7 @@ def _create_arg_parser(p: argparse.ArgumentParser) -> argparse.ArgumentParser:
diff_group.add_argument(
"--diff-data",
action="store_true",
- help=tl(
- "Also compare table data and generate INSERT/UPDATE/DELETE statements (requires --diff-from-db)."
- ),
+ help=tl("Also compare table data and generate INSERT/UPDATE/DELETE statements (requires --diff-from-db)."),
)
diff_group.add_argument("--db-host", help=tl("Database host for diffing."))
diff_group.add_argument("--db-user", help=tl("Database user for diffing."))
@@ -2024,34 +1886,26 @@ def _validate_args(p, args):
sys.exit(2)
# Check for at least one output mode if not using --diff-from-db
- is_output_mode_set = any(
- [args.output, args.split_dir, args.load_data_dir, args.insert_only, args.info]
- )
+ is_output_mode_set = any([args.output, args.split_dir, args.load_data_dir, args.insert_only, args.info])
if not is_output_mode_set and not args.diff_from_db:
p.error(
- tl(
- "You must specify an output mode (e.g., `script.py in.sql out.sql` or use a flag like --split, --info)."
- )
+ tl("You must specify an output mode (e.g., `script.py in.sql out.sql` or use a flag like --split, --info).")
)
# Append .sql to output filename if needed
if args.output and not args.output.lower().endswith(".sql"):
if args.verbose:
logger.info(
- tl(
- "[INFO] Output filename does not end with .sql, appending it. New name: {name}"
- ).format(name=args.output + ".sql")
+ tl("[INFO] Output filename does not end with .sql, appending it. New name: {name}").format(
+ name=args.output + ".sql"
+ )
)
args.output += ".sql"
# Validate --diff-from-db dependencies
if args.diff_from_db:
if not mysql:
- p.error(
- tl(
- "The 'mysql-connector-python' library is required for --diff-from-db. Please install it."
- )
- )
+ p.error(tl("The 'mysql-connector-python' library is required for --diff-from-db. Please install it."))
if not args.output:
p.error(tl("--diff-from-db requires --output to be specified."))
if not args.db_user or not args.db_name:
@@ -2065,9 +1919,7 @@ def _validate_args(p, args):
def set_parse_arguments_and_config():
parser = argparse.ArgumentParser(
- description=tl(
- "SQL Dump Optimizer: merges INSERTs, detects compression, supports MySQL/Postgres."
- )
+ description=tl("SQL Dump Optimizer: merges INSERTs, detects compression, supports MySQL/Postgres.")
)
config_defaults = _load_config()
parser.set_defaults(**config_defaults)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..e20ccd9
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,35 @@
+[tool.black]
+line-length = 120
+target-version = ["py310", "py311", "py312"]
+
+[tool.isort]
+profile = "black"
+line_length = 120
+multi_line_output = 3
+include_trailing_comma = true
+use_parentheses = true
+
+[tool.pytest.ini_options]
+addopts = "-q"
+testpaths = ["tests"]
+
+[tool.coverage.run]
+branch = true
+source = ["."]
+omit = ["tests/*"]
+
+[tool.coverage.report]
+skip_empty = true
+
+[tool.ruff]
+line-length = 120
+
+[tool.ruff.lint]
+select = ["E", "F", "W", "Q"]
+
+[tool.ruff.lint.per-file-ignores]
+"test_optimize_sql_dump.py" = ["E501"]
+
+[tool.mypy]
+check_untyped_defs = false
+disable_error_code = ["annotation-unchecked"]
diff --git a/pytest.ini b/pytest.ini
index 15c7015..38a4e89 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,3 +1,3 @@
[pytest]
-testpaths = tests
-norecursedirs = .git venv .venv __pycache__ build dist .eggs *.egg-info actions-runner
\ No newline at end of file
+# testpaths = tests
+norecursedirs = .git venv .venv __pycache__ build dist .eggs *.egg-info actions-runner
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 791956f..e10d5ba 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,22 +1,39 @@
-# Include production dependencies
--r requirements.txt
-
-# Testing
allure-pytest==2.15.0
+allure-python-commons==2.15.0
+attrs==25.3.0
coverage==7.10.6
+DateTime==5.5
+execnet==2.1.1
+gherkin-official==29.0.0
+iniconfig==2.1.0
+isort==7.0.0
+Mako==1.3.10
+MarkupSafe==3.0.2
+mysql-connector-python==8.0.33
+mysqlclient==2.2.7
+packaging==25.0
+parse==1.20.2
+parse_type==0.6.6
+pep8==1.7.1
+pluggy==1.6.0
+protobuf==3.20.3
+py==1.11.0
+pyflakes==3.4.0
+Pygments==2.19.2
pytest==8.4.2
pytest-bdd==8.1.0
+pytest-cache==1.0
pytest-cov==7.0.0
+pytest-flakes==4.0.5
pytest-instafail==0.5.0
+pytest-pep8==1.0.6
pytest-timeout==2.4.0
pytest-xdist==3.8.0
-
-# Linting
-flake8==7.1.0
-
-# Dependencies of testing/linting tools
-allure-python-commons==2.15.0
-execnet==2.1.1
-gherkin-official==29.0.0
-iniconfig==2.1.0
-pluggy==1.6.0
+pytz==2025.2
+salt==3007.8
+setuptools==80.9.0
+six==1.17.0
+tabulate==0.9.0
+tqdm==4.67.1
+typing_extensions==4.15.0
+zope.interface==8.0
diff --git a/requirements.txt b/requirements.txt
index 65de3a6..e10d5ba 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,37 @@
+allure-pytest==2.15.0
+allure-python-commons==2.15.0
attrs==25.3.0
+coverage==7.10.6
DateTime==5.5
+execnet==2.1.1
+gherkin-official==29.0.0
+iniconfig==2.1.0
+isort==7.0.0
Mako==1.3.10
MarkupSafe==3.0.2
mysql-connector-python==8.0.33
+mysqlclient==2.2.7
packaging==25.0
parse==1.20.2
parse_type==0.6.6
+pep8==1.7.1
+pluggy==1.6.0
protobuf==3.20.3
-python-dotenv==1.0.0
+py==1.11.0
+pyflakes==3.4.0
+Pygments==2.19.2
+pytest==8.4.2
+pytest-bdd==8.1.0
+pytest-cache==1.0
+pytest-cov==7.0.0
+pytest-flakes==4.0.5
+pytest-instafail==0.5.0
+pytest-pep8==1.0.6
+pytest-timeout==2.4.0
+pytest-xdist==3.8.0
pytz==2025.2
+salt==3007.8
+setuptools==80.9.0
six==1.17.0
tabulate==0.9.0
tqdm==4.67.1
diff --git a/test_optimize_sql_dump.py b/test_optimize_sql_dump.py
index 0896f0a..142d929 100644
--- a/test_optimize_sql_dump.py
+++ b/test_optimize_sql_dump.py
@@ -6,15 +6,13 @@
from unittest.mock import MagicMock, patch
import pytest
+
import optimize_sql_dump as opt
# Helper to get the main function for CLI tests
+from optimize_sql_dump import escape_sql_value # Assuming the function is in this module
from optimize_sql_dump import main as cli_main
-from optimize_sql_dump import (
- escape_sql_value,
-) # Assuming the function is in this module
-
@pytest.fixture
def mysql_validator():
@@ -93,9 +91,7 @@ def test_validate_postgres(self, postgres_validator):
class TestPostgresHandler:
def test_normalize_table_name_postgres(self, postgres_handler):
- assert (
- postgres_handler.normalize_table_name('"public"."my_table"') == "my_table"
- )
+ assert postgres_handler.normalize_table_name('"public"."my_table"') == "my_table"
assert postgres_handler.normalize_table_name("my_table") == "my_table"
assert postgres_handler.normalize_table_name('"my_table"') == "my_table"
@@ -108,9 +104,7 @@ def test_get_load_statement_postgres(self, postgres_handler):
tsv_path = "/path/to/my_table.tsv"
stmt = postgres_handler.get_load_statement("my_table", tsv_path, cols_str)
expected_stmt = "COPY \"my_table\" (\"id\", \"name\") FROM '/path/to/my_table.tsv' WITH (FORMAT csv, DELIMITER E'\\t', NULL '\\n');\n"
- assert (
- stmt == expected_stmt
- )
+ assert stmt == expected_stmt
def test_extract_columns_from_create_postgres(self, postgres_handler):
create_stmt = """
@@ -123,10 +117,7 @@ def test_extract_columns_from_create_postgres(self, postgres_handler):
"""
columns_str = postgres_handler.extract_columns_from_create(create_stmt)
# Normalize returned string to a list of column names and compare as a set for robustness
- cols = [ # noqa: E501
- c.strip().strip('"')
- for c in columns_str.strip().lstrip("(").rstrip(")").split(",")
- ]
+ cols = [c.strip().strip('"') for c in columns_str.strip().lstrip("(").rstrip(")").split(",")] # noqa: E501
assert set(cols) == {"id", "name", "created_at"}
def test_detect_db_type_postgres_specific(self, tmp_path):
@@ -225,9 +216,7 @@ def test_dump_analyzer(tmp_path, capsys):
# Simulate running from command line with --info
test_args = ["optimize_sql_dump.py", "--input", str(dump_file), "--info"]
with patch.object(sys, "argv", test_args):
- with patch(
- "optimize_sql_dump._load_config", return_value={}
- ): # Mock _load_config
+ with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config
opt.main()
captured = capsys.readouterr()
@@ -265,9 +254,7 @@ def test_cli_split_mode(tmp_path):
str(split_dir),
]
with patch.object(sys, "argv", test_args):
- with patch(
- "optimize_sql_dump._load_config", return_value={}
- ): # Mock _load_config
+ with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config
cli_main()
assert (split_dir / "t1.sql").exists()
@@ -308,9 +295,7 @@ def test_cli_load_data_mode(tmp_path):
str(load_data_dir),
]
with patch.object(sys, "argv", test_args):
- with patch(
- "optimize_sql_dump._load_config", return_value={}
- ): # Mock _load_config
+ with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config
cli_main()
sql_file = load_data_dir / "users.sql"
@@ -350,16 +335,15 @@ def test_cli_invalid_arguments(tmp_path, invalid_args):
with pytest.raises(SystemExit) as e:
with patch.object(sys, "argv", base_args + invalid_args):
- with patch(
- "optimize_sql_dump._load_config", return_value={}
- ): # Mock _load_config
+ with patch("optimize_sql_dump._load_config", return_value={}): # Mock _load_config
cli_main()
assert e.type is SystemExit
assert e.value.code != 0 # Ensure it's an error exit code
@pytest.mark.parametrize(
- "input_val, expected", [
+ "input_val, expected",
+ [
# Tests for strings
("tekst", "DEFAULT 'tekst'"),
("O'Reilly", "DEFAULT 'O''Reilly'"),
@@ -376,7 +360,8 @@ def test_cli_invalid_arguments(tmp_path, invalid_args):
(False, "DEFAULT False"),
# Test for None
(None, "DEFAULT NULL"),
- ])
+ ],
+)
def test_escape_sql_value(input_val, expected):
assert escape_sql_value(input_val, prefix_str="DEFAULT").strip() == expected
@@ -397,7 +382,8 @@ def differ(self, tmp_path):
return opt.MySQLDatabaseDiffer(inpath=str(dummy_inpath), verbose=False)
@pytest.mark.parametrize(
- "input_val, expected_sql", [
+ "input_val, expected_sql",
+ [
(None, "NULL"),
(123, "123"),
(-45, "-45"),
@@ -409,7 +395,8 @@ def differ(self, tmp_path):
('string with "quotes"', "'string with \"quotes\"'"),
("O'Reilly", "'O''Reilly'"),
("multiple 'quotes' here", "'multiple ''quotes'' here'"),
- ])
+ ],
+ )
def test_format_sql_value(self, differ, input_val, expected_sql):
assert differ._format_sql_value(input_val) == expected_sql
@@ -599,9 +586,7 @@ def test_build_db_column_definition(self, differ, col_name, db_col_info, expecte
),
],
)
- def test_compare_data_row_edge_cases(
- self, differ, dump_row, db_row, pk_cols, expected_fragment
- ):
+ def test_compare_data_row_edge_cases(self, differ, dump_row, db_row, pk_cols, expected_fragment):
"""Tests various edge cases for data row comparison."""
update_stmt = differ.compare_data_row(dump_row, db_row, "users", pk_cols)
if expected_fragment:
@@ -616,10 +601,7 @@ def test_run_generates_delete(self, tmp_path):
"""
in_file = tmp_path / "dump.sql"
out_file = tmp_path / "diff.sql"
- in_file.write_text(
- "CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n"
- "INSERT INTO `users` VALUES (1);"
- )
+ in_file.write_text("CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n" "INSERT INTO `users` VALUES (1);")
args = {
"inpath": str(in_file),
@@ -641,7 +623,7 @@ def test_run_generates_delete(self, tmp_path):
differ.connect_db = MagicMock()
differ.get_db_schema = MagicMock(return_value={"id": {}}) # Table exists
# DB has PKs (1,) and (2,). Dump only has (1,). So (2,) should be deleted.
- differ.get_db_primary_keys = MagicMock(return_value={('1',), ('2',)})
+ differ.get_db_primary_keys = MagicMock(return_value={("1",), ("2",)})
differ.get_db_row_by_pk = MagicMock(return_value={"id": 1})
differ.run()
@@ -657,10 +639,7 @@ def test_handle_insert_uses_temp_table(tmp_path):
"""Ensure that when memory limit is exceeded, PKs are stored via temp-table helper."""
in_file = tmp_path / "dump.sql"
out_file = tmp_path / "diff.sql"
- in_file.write_text(
- "CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n"
- "INSERT INTO `users` VALUES (3);"
- )
+ in_file.write_text("CREATE TABLE `users` (`id` int, PRIMARY KEY (`id`));\n" "INSERT INTO `users` VALUES (3);")
args = {
"inpath": str(in_file),
@@ -682,7 +661,7 @@ def test_handle_insert_uses_temp_table(tmp_path):
differ.connect_db = MagicMock()
differ.get_db_schema = MagicMock(return_value={"id": {}})
# Simulate that fetching DB PKs would return something small
- differ.get_db_primary_keys = MagicMock(return_value={('1',)})
+ differ.get_db_primary_keys = MagicMock(return_value={("1",)})
# Force memory limit to zero so the code chooses temp-table path
differ.memory_limit = 0
@@ -707,7 +686,7 @@ class TestDumpWriter:
@pytest.fixture
def mock_handler(self):
handler = MagicMock(spec=opt.MySQLHandler)
- handler.normalize_table_name.side_effect = lambda x: x.strip('`')
+ handler.normalize_table_name.side_effect = lambda x: x.strip("`")
handler.insert_template = "INSERT INTO {table} {cols} VALUES\n{values};\n"
handler.get_truncate_statement.side_effect = lambda t: f"TRUNCATE TABLE `{t}`;\n"
handler.extract_columns_from_create.return_value = "(`id`, `name`)"
@@ -727,6 +706,7 @@ def test_setup_normal_mode(self, mock_handler, tmp_path):
def test_setup_dry_run(self, mock_handler):
import os
+
args = {"outpath": "out.sql", "inpath": "dummy.sql", "dry_run": True}
with opt.DumpWriter(mock_handler, **args) as writer:
assert writer.fout.name == os.devnull
@@ -762,7 +742,7 @@ def test_insert_only_mode(self, mock_handler, tmp_path):
"inpath": str(in_file),
"verbose": False,
"db_type": "mysql",
- "outpath": None
+ "outpath": None,
}
optimizer = opt.DumpOptimizer(**args)
@@ -797,18 +777,20 @@ def test_tsv_buffering_and_flushing(self, tmp_path):
"""Tests TSV buffering and flushing logic within DumpOptimizer."""
load_dir = tmp_path / "load_data"
in_file = tmp_path / "in.sql"
- in_file.write_text("""
+ in_file.write_text(
+ """
CREATE TABLE `t1` (`id` int, `name` varchar(10));
INSERT INTO `t1` VALUES (1, 'a'), (2, 'b');
INSERT INTO `t1` VALUES (3, 'c');
- """)
+ """
+ )
args = {
"load_data_dir": str(load_dir),
"inpath": str(in_file),
"tsv_buffer_size": 2,
"verbose": False,
- "db_type": "mysql"
+ "db_type": "mysql",
}
optimizer = opt.DumpOptimizer(**args)