Skip to content

Commit 89d8930

Browse files
timsaucerkosiew
andauthored
feat: reduce duplicate fields on join (#1184)
* Add field to dataframe join to indicate if we should keep duplicate keys * Suppress expected warning * Minor: small tables rendered way too large * Rename from keep_duplicate_keys to drop_duplicate_keys * Add unit tests for dropping duplicate keys or not * Update online docs * Update docs/source/user-guide/common-operations/joins.rst Co-authored-by: kosiew <kosiew@gmail.com> --------- Co-authored-by: kosiew <kosiew@gmail.com>
1 parent c4e7486 commit 89d8930

File tree

6 files changed

+113
-22
lines changed

6 files changed

+113
-22
lines changed

docs/source/user-guide/common-operations/joins.rst

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,36 @@ the right table.
101101

102102
.. ipython:: python
103103
104-
left.join(right, left_on="customer_id", right_on="id", how="anti")
104+
left.join(right, left_on="customer_id", right_on="id", how="anti")
105+
106+
Duplicate Keys
107+
--------------
108+
109+
It is common to join two DataFrames on a common column name. Starting in
110+
version 51.0.0, ``datafusion-python``` will now drop duplicate column names by
111+
default. This reduces problems with ambiguous column selection after joins.
112+
You can disable this feature by setting the parameter ``drop_duplicate_keys``
113+
to ``False``.
114+
115+
.. ipython:: python
116+
117+
left = ctx.from_pydict(
118+
{
119+
"id": [1, 2, 3],
120+
"customer": ["Alice", "Bob", "Charlie"],
121+
}
122+
)
123+
124+
right = ctx.from_pylist([
125+
{"id": 1, "name": "CityCabs"},
126+
{"id": 2, "name": "MetroRide"},
127+
{"id": 5, "name": "UrbanGo"},
128+
])
129+
130+
left.join(right, "id", how="inner")
131+
132+
In contrast to the above example, if we wish to get both columns:
133+
134+
.. ipython:: python
135+
136+
left.join(right, "id", how="inner", drop_duplicate_keys=False)

python/datafusion/dataframe.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def join(
774774
left_on: None = None,
775775
right_on: None = None,
776776
join_keys: None = None,
777+
drop_duplicate_keys: bool = True,
777778
) -> DataFrame: ...
778779

779780
@overload
@@ -786,6 +787,7 @@ def join(
786787
left_on: str | Sequence[str],
787788
right_on: str | Sequence[str],
788789
join_keys: tuple[list[str], list[str]] | None = None,
790+
drop_duplicate_keys: bool = True,
789791
) -> DataFrame: ...
790792

791793
@overload
@@ -798,6 +800,7 @@ def join(
798800
join_keys: tuple[list[str], list[str]],
799801
left_on: None = None,
800802
right_on: None = None,
803+
drop_duplicate_keys: bool = True,
801804
) -> DataFrame: ...
802805

803806
def join(
@@ -809,6 +812,7 @@ def join(
809812
left_on: str | Sequence[str] | None = None,
810813
right_on: str | Sequence[str] | None = None,
811814
join_keys: tuple[list[str], list[str]] | None = None,
815+
drop_duplicate_keys: bool = True,
812816
) -> DataFrame:
813817
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.
814818
@@ -821,11 +825,23 @@ def join(
821825
"right", "full", "semi", "anti".
822826
left_on: Join column of the left dataframe.
823827
right_on: Join column of the right dataframe.
828+
drop_duplicate_keys: When True, the columns from the right DataFrame
829+
that have identical names in the ``on`` fields to the left DataFrame
830+
will be dropped.
824831
join_keys: Tuple of two lists of column names to join on. [Deprecated]
825832
826833
Returns:
827834
DataFrame after join.
828835
"""
836+
if join_keys is not None:
837+
warnings.warn(
838+
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
839+
category=DeprecationWarning,
840+
stacklevel=2,
841+
)
842+
left_on = join_keys[0]
843+
right_on = join_keys[1]
844+
829845
# This check is to prevent breaking API changes where users prior to
830846
# DF 43.0.0 would pass the join_keys as a positional argument instead
831847
# of a keyword argument.
@@ -836,18 +852,10 @@ def join(
836852
and isinstance(on[1], list)
837853
):
838854
# We know this is safe because we've checked the types
839-
join_keys = on # type: ignore[assignment]
855+
left_on = on[0]
856+
right_on = on[1]
840857
on = None
841858

842-
if join_keys is not None:
843-
warnings.warn(
844-
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
845-
category=DeprecationWarning,
846-
stacklevel=2,
847-
)
848-
left_on = join_keys[0]
849-
right_on = join_keys[1]
850-
851859
if on is not None:
852860
if left_on is not None or right_on is not None:
853861
error_msg = "`left_on` or `right_on` should not provided with `on`"
@@ -866,7 +874,9 @@ def join(
866874
if isinstance(right_on, str):
867875
right_on = [right_on]
868876

869-
return DataFrame(self.df.join(right.df, how, left_on, right_on))
877+
return DataFrame(
878+
self.df.join(right.df, how, left_on, right_on, drop_duplicate_keys)
879+
)
870880

871881
def join_on(
872882
self,

python/datafusion/dataframe_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _build_table_container_start(self) -> list[str]:
370370
f"max-height: {self.max_height}px; overflow: auto; border: "
371371
'1px solid #ccc;">'
372372
)
373-
html.append('<table style="border-collapse: collapse; min-width: 100%">')
373+
html.append('<table style="border-collapse: collapse">')
374374
return html
375375

376376
def _build_table_header(self, schema: Any) -> list[str]:

python/tests/test_dataframe.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,6 @@ def test_unnest_without_nulls(nested_df):
647647
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])
648648

649649

650-
@pytest.mark.filterwarnings("ignore:`join_keys`:DeprecationWarning")
651650
def test_join():
652651
ctx = SessionContext()
653652

@@ -664,25 +663,38 @@ def test_join():
664663
df1 = ctx.create_dataframe([[batch]], "r")
665664

666665
df2 = df.join(df1, on="a", how="inner")
667-
df2.show()
668666
df2 = df2.sort(column("l.a"))
669667
table = pa.Table.from_batches(df2.collect())
670668

671669
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
672670
assert table.to_pydict() == expected
673671

674-
df2 = df.join(df1, left_on="a", right_on="a", how="inner")
675-
df2.show()
672+
# Test the default behavior for dropping duplicate keys
673+
# Since we may have a duplicate column name and pa.Table()
674+
# hides the fact, instead we need to explicitly check the
675+
# resultant arrays.
676+
df2 = df.join(df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=True)
676677
df2 = df2.sort(column("l.a"))
677-
table = pa.Table.from_batches(df2.collect())
678+
result = df2.collect()[0]
679+
assert result.num_columns == 3
680+
assert result.column(0) == pa.array([1, 2], pa.int64())
681+
assert result.column(1) == pa.array([4, 5], pa.int64())
682+
assert result.column(2) == pa.array([8, 10], pa.int64())
678683

679-
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
680-
assert table.to_pydict() == expected
684+
df2 = df.join(
685+
df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=False
686+
)
687+
df2 = df2.sort(column("l.a"))
688+
result = df2.collect()[0]
689+
assert result.num_columns == 4
690+
assert result.column(0) == pa.array([1, 2], pa.int64())
691+
assert result.column(1) == pa.array([4, 5], pa.int64())
692+
assert result.column(2) == pa.array([1, 2], pa.int64())
693+
assert result.column(3) == pa.array([8, 10], pa.int64())
681694

682695
# Verify we don't make a breaking change to pre-43.0.0
683696
# where users would pass join_keys as a positional argument
684697
df2 = df.join(df1, (["a"], ["a"]), how="inner")
685-
df2.show()
686698
df2 = df2.sort(column("l.a"))
687699
table = pa.Table.from_batches(df2.collect())
688700

python/tests/test_sql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def test_register_parquet(ctx, tmp_path):
157157
assert result.to_pydict() == {"cnt": [100]}
158158

159159

160+
@pytest.mark.filterwarnings(
161+
"ignore:using literals for table_partition_cols data types:DeprecationWarning"
162+
)
160163
@pytest.mark.parametrize(
161164
("path_to_str", "legacy_data_type"), [(True, False), (False, False), (False, True)]
162165
)

src/dataframe.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ impl PyDataFrame {
629629
how: &str,
630630
left_on: Vec<PyBackedStr>,
631631
right_on: Vec<PyBackedStr>,
632+
drop_duplicate_keys: bool,
632633
) -> PyDataFusionResult<Self> {
633634
let join_type = match how {
634635
"inner" => JoinType::Inner,
@@ -647,13 +648,46 @@ impl PyDataFrame {
647648
let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
648649
let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
649650

650-
let df = self.df.as_ref().clone().join(
651+
let mut df = self.df.as_ref().clone().join(
651652
right.df.as_ref().clone(),
652653
join_type,
653654
&left_keys,
654655
&right_keys,
655656
None,
656657
)?;
658+
659+
if drop_duplicate_keys {
660+
let mutual_keys = left_keys
661+
.iter()
662+
.zip(right_keys.iter())
663+
.filter(|(l, r)| l == r)
664+
.map(|(key, _)| *key)
665+
.collect::<Vec<_>>();
666+
667+
let fields_to_drop = mutual_keys
668+
.iter()
669+
.map(|name| {
670+
df.logical_plan()
671+
.schema()
672+
.qualified_fields_with_unqualified_name(name)
673+
})
674+
.filter(|r| r.len() == 2)
675+
.map(|r| r[1])
676+
.collect::<Vec<_>>();
677+
678+
let expr: Vec<Expr> = df
679+
.logical_plan()
680+
.schema()
681+
.fields()
682+
.into_iter()
683+
.enumerate()
684+
.map(|(idx, _)| df.logical_plan().schema().qualified_field(idx))
685+
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
686+
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
687+
.collect();
688+
df = df.select(expr)?;
689+
}
690+
657691
Ok(Self::new(df))
658692
}
659693

0 commit comments

Comments
 (0)