diff --git a/df_to_azure/db.py b/df_to_azure/db.py index e3f9ad4..e9c7777 100644 --- a/df_to_azure/db.py +++ b/df_to_azure/db.py @@ -11,11 +11,87 @@ class SqlUpsert: - def __init__(self, table_name, schema, id_cols, columns): + def __init__(self, table_name, schema, id_cols, columns, preserve_identity=False): self.table_name = table_name self.schema = schema self.id_cols = id_cols self.columns = [col.strip() for col in columns] + self.preserve_identity = preserve_identity + self.identity_columns = [] + + def get_identity_columns(self): + """ + Query SQL Server to detect IDENTITY (auto-increment) columns in the target table. + + Returns + ------- + list + List of column names that have IDENTITY property in the target table. + """ + query = text(f""" + SELECT c.name + FROM sys.identity_columns ic + INNER JOIN sys.columns c ON ic.object_id = c.object_id AND ic.column_id = c.column_id + INNER JOIN sys.tables t ON ic.object_id = t.object_id + INNER JOIN sys.schemas s ON t.schema_id = s.schema_id + WHERE t.name = '{self.table_name}' AND s.name = '{self.schema}' + """) + + with auth_azure() as con: + result = con.execute(query) + return [row[0] for row in result] + + def validate_identity_usage(self): + """ + Validate that user isn't trying to upsert on IDENTITY columns without explicit permission. + + This method checks if any of the id_cols (columns used for matching in upsert) are IDENTITY + columns. If so, and preserve_identity=False, it raises an informative error with alternatives. + + If preserve_identity=True and IDENTITY columns are in id_cols, logs a warning about the risks. + + Raises + ------ + UpsertError + If id_cols contain IDENTITY columns and preserve_identity=False. + """ + self.identity_columns = self.get_identity_columns() + + # Check if any id_cols are IDENTITY columns + identity_in_id_cols = [col for col in self.id_cols if col in self.identity_columns] + + if identity_in_id_cols and not self.preserve_identity: + # Build helpful error message based on the scenario + if len(self.id_cols) == 1 and len(identity_in_id_cols) == 1: + # Scenario A: IDENTITY is the only id_field + raise UpsertError( + f"Column '{identity_in_id_cols[0]}' is an auto-increment (IDENTITY) column " + f"and cannot be used for upsert matching.\n\n" + f"Suggested alternatives:\n" + f"1. Use method='append' instead if you want to insert new records with auto-generated IDs\n" + f"2. Add a business key column (e.g., 'user_email', 'external_id') and use that for id_field\n" + f"3. If you must preserve existing ID values (e.g., data migration), set preserve_identity=True\n" + f" WARNING: Using preserve_identity=True is not recommended as it can break ID sequence generation" + ) + else: + # Scenario B: IDENTITY is part of composite key + other_cols = [col for col in self.id_cols if col not in identity_in_id_cols] + raise UpsertError( + f"Column(s) {identity_in_id_cols} are auto-increment (IDENTITY) columns " + f"and are part of your id_field {self.id_cols}.\n\n" + f"Suggested alternatives:\n" + f"1. Remove IDENTITY column(s) from id_field and use only: {other_cols}\n" + f"2. If you must preserve existing ID values (e.g., data migration), set preserve_identity=True\n" + f" WARNING: Using preserve_identity=True is not recommended as it can break ID sequence generation" + ) + + # If preserve_identity=True and IDENTITY columns are in id_cols, log warning + if self.preserve_identity and identity_in_id_cols: + logging.warning( + f"preserve_identity=True: IDENTITY_INSERT will be enabled for {self.schema}.{self.table_name}. " + f"This is not recommended and may cause ID sequence issues. " + f"Consider using non-IDENTITY columns for id_field instead." + ) def create_on_statement(self): on = " AND ".join([f"s.[{id_col}] = t.[{id_col}]" for id_col in self.id_cols]) @@ -34,11 +110,20 @@ def create_insert_statement(self): return insert, values def create_merge_query(self): + """ + Generate MERGE statement with optional IDENTITY_INSERT handling. + + If preserve_identity=True, wraps the MERGE statement with + SET IDENTITY_INSERT ON/OFF to allow explicit insertion of IDENTITY values. + + Returns + ------- + text + SQLAlchemy text object containing the CREATE PROCEDURE statement. + """ insert = self.create_insert_statement() - query = f""" - CREATE PROCEDURE [UPSERT_{self.table_name}] - AS - MERGE {self.schema}.{self.table_name} t + + merge_stmt = f"""MERGE {self.schema}.{self.table_name} t USING staging.{self.table_name} s ON {self.create_on_statement()} WHEN MATCHED @@ -46,10 +131,25 @@ def create_merge_query(self): {self.create_update_statement()} WHEN NOT MATCHED BY TARGET THEN INSERT {insert[0]} - VALUES {insert[1]}; + VALUES {insert[1]};""" + + if self.preserve_identity: + # Wrap with IDENTITY_INSERT ON/OFF + query = f""" + CREATE PROCEDURE [UPSERT_{self.table_name}] + AS + SET IDENTITY_INSERT {self.schema}.{self.table_name} ON; + {merge_stmt} + SET IDENTITY_INSERT {self.schema}.{self.table_name} OFF; + """ + else: + query = f""" + CREATE PROCEDURE [UPSERT_{self.table_name}] + AS + {merge_stmt} """ - logging.debug(query) + logging.debug(query) return text(query) def drop_procedure(self): @@ -57,6 +157,20 @@ def drop_procedure(self): return text(query) def create_stored_procedure(self): + """ + Create the stored procedure for upsert operation. + + This method first validates that IDENTITY columns are being used correctly, + then creates the stored procedure with the MERGE statement. + + Raises + ------ + UpsertError + If IDENTITY columns are used incorrectly or if procedure creation fails. + """ + # Validate IDENTITY usage BEFORE creating procedure + self.validate_identity_usage() + with auth_azure() as con: t = con.begin() query_drop_procedure = self.drop_procedure() diff --git a/df_to_azure/export.py b/df_to_azure/export.py index 8de94fe..5f18c19 100644 --- a/df_to_azure/export.py +++ b/df_to_azure/export.py @@ -32,6 +32,7 @@ def df_to_azure( parquet=False, clean_staging=True, container_name="parquet", + preserve_identity=False, ): if parquet: DfToParquet( @@ -57,6 +58,7 @@ def df_to_azure( create=create, dtypes=dtypes, clean_staging=clean_staging, + preserve_identity=preserve_identity, ).run() return adf_client, run_response @@ -77,6 +79,7 @@ def __init__( create: bool = False, dtypes: dict = None, clean_staging: bool = True, + preserve_identity: bool = False, ): super().__init__( df=df, @@ -92,6 +95,7 @@ def __init__( self.decimal_precision = decimal_precision self.dtypes = dtypes self.clean_staging = clean_staging + self.preserve_identity = preserve_identity def run(self): if self.df.empty: @@ -144,6 +148,7 @@ def upload_dataset(self): schema=self.schema, id_cols=self.id_field, columns=self.df.columns, + preserve_identity=self.preserve_identity, ) upsert.create_stored_procedure() self.schema = "staging" diff --git a/df_to_azure/tests/test_identity_insert.py b/df_to_azure/tests/test_identity_insert.py new file mode 100644 index 0000000..ab5d17e --- /dev/null +++ b/df_to_azure/tests/test_identity_insert.py @@ -0,0 +1,72 @@ +import pytest +from pandas import DataFrame, read_sql_table +from pandas._testing import assert_frame_equal + +from df_to_azure import df_to_azure +from df_to_azure.db import auth_azure, execute_stmt +from df_to_azure.exceptions import UpsertError + +SCHEMA = "test" + + +def reset_identity_table(table_name: str) -> None: + execute_stmt( + f""" +IF OBJECT_ID('{SCHEMA}.{table_name}', 'U') IS NOT NULL + DROP TABLE [{SCHEMA}].[{table_name}]; +CREATE TABLE [{SCHEMA}].[{table_name}]( + [id] INT IDENTITY(1,1) NOT NULL PRIMARY KEY, + [value] NVARCHAR(255) NOT NULL +); +""" + ) + + +def insert_values(table_name: str, values: list[str]) -> None: + for value in values: + escaped = value.replace("'", "''") + execute_stmt(f"INSERT INTO [{SCHEMA}].[{table_name}] ([value]) VALUES ('{escaped}')") + + +def test_upsert_identity_column_requires_preserve(): + table_name = "identity_no_preserve" + reset_identity_table(table_name) + insert_values(table_name, ["original value"]) + + df = DataFrame({"id": [1], "value": ["updated value"]}) + + with pytest.raises(UpsertError) as excinfo: + df_to_azure( + df=df, + tablename=table_name, + schema=SCHEMA, + method="upsert", + id_field="id", + wait_till_finished=True, + ) + + assert "Column 'id' is an auto-increment (IDENTITY) column" in str(excinfo.value) + + +def test_upsert_identity_column_with_preserve_identity(): + table_name = "identity_with_preserve" + reset_identity_table(table_name) + insert_values(table_name, ["original value"]) + + df = DataFrame({"id": [1, 10], "value": ["updated value", "migrated value"]}) + + df_to_azure( + df=df, + tablename=table_name, + schema=SCHEMA, + method="upsert", + id_field="id", + wait_till_finished=True, + preserve_identity=True, + ) + + with auth_azure() as con: + result = read_sql_table(table_name=table_name, con=con, schema=SCHEMA).sort_values("id") + + expected = DataFrame({"id": [1, 10], "value": ["updated value", "migrated value"]}) + assert_frame_equal(expected, result.reset_index(drop=True))