Skip to content

Commit 74fde2d

Browse files
Merge branch 'main' into chore/add-poetry-to-pyproject.toml
2 parents 4e67312 + f5e8182 commit 74fde2d

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

awswrangler/sqlserver.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,11 +436,12 @@ def to_sql(
436436
con: "pyodbc.Connection",
437437
table: str,
438438
schema: str,
439-
mode: Literal["append", "overwrite"] = "append",
439+
mode: Literal["append", "overwrite", "upsert"] = "append",
440440
index: bool = False,
441441
dtype: dict[str, str] | None = None,
442442
varchar_lengths: dict[str, int] | None = None,
443443
use_column_names: bool = False,
444+
upsert_conflict_columns: list[str] | None = None,
444445
chunksize: int = 200,
445446
fast_executemany: bool = False,
446447
) -> None:
@@ -457,7 +458,12 @@ def to_sql(
457458
schema : str
458459
Schema name
459460
mode : str
460-
Append or overwrite.
461+
Append, overwrite or upsert.
462+
463+
- append: Inserts new records into table.
464+
- overwrite: Drops table and recreates.
465+
- upsert: Perform an upsert which checks for conflicts on columns given by ``upsert_conflict_columns`` and sets the new values on conflicts. Note that column names of the Dataframe will be used for this operation, as if ``use_column_names`` was set to True.
466+
461467
index : bool
462468
True to store the DataFrame index as a column in the table,
463469
otherwise False to ignore it.
@@ -471,6 +477,8 @@ def to_sql(
471477
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
472478
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
473479
inserted into the database columns `col1` and `col3`.
480+
uspert_conflict_columns: List[str], optional
481+
List of columns to be used as conflict columns in the upsert operation.
474482
chunksize: int
475483
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
476484
fast_executemany: bool
@@ -506,6 +514,8 @@ def to_sql(
506514
if df.empty is True:
507515
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
508516
_validate_connection(con=con)
517+
if mode == "upsert" and not upsert_conflict_columns:
518+
raise exceptions.InvalidArgumentValue("<upsert_conflict_columns> need to be set when using upsert mode.")
509519
try:
510520
with con.cursor() as cursor:
511521
if fast_executemany:
@@ -524,15 +534,28 @@ def to_sql(
524534
df.reset_index(level=df.index.names, inplace=True)
525535
column_placeholders: str = ", ".join(["?"] * len(df.columns))
526536
table_identifier = _get_table_identifier(schema, table)
537+
column_names = [identifier(col, sql_mode="mssql") for col in df.columns]
538+
quoted_columns = ", ".join(column_names)
527539
insertion_columns = ""
528540
if use_column_names:
529-
quoted_columns = ", ".join(f"{identifier(col, sql_mode='mssql')}" for col in df.columns)
530541
insertion_columns = f"({quoted_columns})"
531542
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
532543
df=df, column_placeholders=column_placeholders, chunksize=chunksize
533544
)
534545
for placeholders, parameters in placeholder_parameter_pair_generator:
535546
sql: str = f"INSERT INTO {table_identifier} {insertion_columns} VALUES {placeholders}"
547+
if mode == "upsert" and upsert_conflict_columns:
548+
merge_on_columns = [identifier(col, sql_mode="mssql") for col in upsert_conflict_columns]
549+
sql = f"MERGE INTO {table_identifier}\nUSING (VALUES {placeholders}) AS source ({quoted_columns})\n"
550+
sql += f"ON {' AND '.join(f'{table_identifier}.{col}=source.{col}' for col in merge_on_columns)}\n"
551+
sql += (
552+
f"WHEN MATCHED THEN\n UPDATE "
553+
f"SET {', '.join(f'{col}=source.{col}' for col in column_names)}\n"
554+
)
555+
sql += (
556+
f"WHEN NOT MATCHED THEN\n INSERT "
557+
f"({quoted_columns}) VALUES ({', '.join([f'source.{col}' for col in column_names])});"
558+
)
536559
_logger.debug("sql: %s", sql)
537560
cursor.executemany(sql, (parameters,))
538561
con.commit()

tests/unit/test_sqlserver.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,74 @@ def test_dfs_are_equal_for_different_chunksizes(sqlserver_table, sqlserver_con,
260260
df["c1"] = df["c1"].astype("string")
261261

262262
assert df.equals(df2)
263+
264+
265+
def test_upsert(sqlserver_table, sqlserver_con):
266+
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})
267+
268+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
269+
wr.sqlserver.to_sql(
270+
df=df,
271+
con=sqlserver_con,
272+
schema="dbo",
273+
table=sqlserver_table,
274+
mode="upsert",
275+
upsert_conflict_columns=None,
276+
use_column_names=True,
277+
)
278+
279+
wr.sqlserver.to_sql(
280+
df=df,
281+
con=sqlserver_con,
282+
schema="dbo",
283+
table=sqlserver_table,
284+
mode="upsert",
285+
upsert_conflict_columns=["c0"],
286+
)
287+
wr.sqlserver.to_sql(
288+
df=df,
289+
con=sqlserver_con,
290+
schema="dbo",
291+
table=sqlserver_table,
292+
mode="upsert",
293+
upsert_conflict_columns=["c0"],
294+
)
295+
df2 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
296+
assert bool(len(df2) == 2)
297+
298+
wr.sqlserver.to_sql(
299+
df=df,
300+
con=sqlserver_con,
301+
schema="dbo",
302+
table=sqlserver_table,
303+
mode="upsert",
304+
upsert_conflict_columns=["c0"],
305+
)
306+
df3 = pd.DataFrame({"c0": ["baz", "bar"], "c2": [3, 2]})
307+
wr.sqlserver.to_sql(
308+
df=df3,
309+
con=sqlserver_con,
310+
schema="dbo",
311+
table=sqlserver_table,
312+
mode="upsert",
313+
upsert_conflict_columns=["c0"],
314+
use_column_names=True,
315+
)
316+
df4 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
317+
assert bool(len(df4) == 3)
318+
319+
df5 = pd.DataFrame({"c0": ["foo", "bar"], "c2": [4, 5]})
320+
wr.sqlserver.to_sql(
321+
df=df5,
322+
con=sqlserver_con,
323+
schema="dbo",
324+
table=sqlserver_table,
325+
mode="upsert",
326+
upsert_conflict_columns=["c0"],
327+
use_column_names=True,
328+
)
329+
330+
df6 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
331+
assert bool(len(df6) == 3)
332+
assert bool(len(df6.loc[(df6["c0"] == "foo") & (df6["c2"] == 4)]) == 1)
333+
assert bool(len(df6.loc[(df6["c0"] == "bar") & (df6["c2"] == 5)]) == 1)

0 commit comments

Comments
 (0)