Skip to content

Commit ab224a6

Browse files
committed
Add deduplicate option to DataFrame.join to drop duplicate join columns
- Added a `deduplicate` boolean parameter to `DataFrame.join` that, when True, drops duplicate join columns from the right DataFrame after join. - Implemented helper methods `_resolve_join_keys` and `_prepare_deduplicate` to normalize join key arguments and handle column renaming and dropping. - Updated join logic to rename duplicate join columns in right DataFrame, join with renamed columns, and drop renamed duplicates post-join. - Added tests `test_join_deduplicate` and `test_join_deduplicate_multi` covering deduplication of single and multiple join columns. - Extended documentation with example usage of `deduplicate` for disambiguating columns. Also added Copilot and agent instructions files describing Python and Rust style guidelines, pre-commit usage, testing, and code organization conventions for the DataFusion Python project.
1 parent 460bae9 commit ab224a6

File tree

3 files changed

+119
-10
lines changed

3 files changed

+119
-10
lines changed

docs/source/user-guide/common-operations/joins.rst

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,34 @@ the right table.
101101

102102
.. ipython:: python
103103
104-
left.join(right, left_on="customer_id", right_on="id", how="anti")
104+
left.join(right, left_on="customer_id", right_on="id", how="anti")
105+
106+
Disambiguating Columns
107+
----------------------
108+
109+
When the join key exists in both DataFrames under the same name, the result contains two columns with that name. Assign a name to each DataFrame to use as a prefix and avoid ambiguity.
110+
111+
.. ipython:: python
112+
113+
from datafusion import col
114+
left = ctx.from_pydict({"id": [1, 2]}, name="l")
115+
right = ctx.from_pydict({"id": [2, 3]}, name="r")
116+
joined = left.join(right, on="id")
117+
joined.select(col("l.id"), col("r.id"))
118+
119+
You can remove the duplicate column after joining.
120+
121+
.. ipython:: python
122+
123+
joined.drop("r.id")
124+
125+
Automatic Deduplication
126+
----------------------
127+
128+
Use the ``deduplicate`` argument of :py:meth:`DataFrame.join` to automatically
129+
drop the duplicate join column from the right DataFrame.
130+
131+
.. ipython:: python
132+
133+
left.join(right, on="id", deduplicate=True)
134+

python/datafusion/dataframe.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ def join(
678678
left_on: str | Sequence[str] | None = None,
679679
right_on: str | Sequence[str] | None = None,
680680
join_keys: tuple[list[str], list[str]] | None = None,
681+
deduplicate: bool = False,
681682
) -> DataFrame:
682683
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.
683684
@@ -691,20 +692,39 @@ def join(
691692
left_on: Join column of the left dataframe.
692693
right_on: Join column of the right dataframe.
693694
join_keys: Tuple of two lists of column names to join on. [Deprecated]
695+
deduplicate: If ``True``, drop duplicate join columns from the
696+
right DataFrame similar to PySpark's ``on`` behavior.
694697
695698
Returns:
696699
DataFrame after join.
697700
"""
698-
# This check is to prevent breaking API changes where users prior to
699-
# DF 43.0.0 would pass the join_keys as a positional argument instead
700-
# of a keyword argument.
701+
on, left_on, right_on = self._resolve_join_keys(
702+
on, left_on, right_on, join_keys
703+
)
704+
705+
drop_cols: list[str] | None = None
706+
if deduplicate and on is not None:
707+
right, drop_cols, left_on, right_on = self._prepare_deduplicate(right, on)
708+
709+
result = DataFrame(self.df.join(right.df, how, left_on, right_on))
710+
if drop_cols:
711+
result = result.drop(*drop_cols)
712+
return result
713+
714+
def _resolve_join_keys(
715+
self,
716+
on: str | Sequence[str] | tuple[list[str], list[str]] | None,
717+
left_on: str | Sequence[str] | None,
718+
right_on: str | Sequence[str] | None,
719+
join_keys: tuple[list[str], list[str]] | None,
720+
) -> tuple[str | Sequence[str] | None, list[str], list[str]]:
721+
"""Normalize join key arguments and validate them."""
701722
if (
702723
isinstance(on, tuple)
703724
and len(on) == 2
704725
and isinstance(on[0], list)
705726
and isinstance(on[1], list)
706727
):
707-
# We know this is safe because we've checked the types
708728
join_keys = on # type: ignore[assignment]
709729
on = None
710730

@@ -730,12 +750,25 @@ def join(
730750
else:
731751
error_msg = "either `on` or `left_on` and `right_on` should be provided."
732752
raise ValueError(error_msg)
733-
if isinstance(left_on, str):
734-
left_on = [left_on]
735-
if isinstance(right_on, str):
736-
right_on = [right_on]
737753

738-
return DataFrame(self.df.join(right.df, how, left_on, right_on))
754+
left_names = [left_on] if isinstance(left_on, str) else list(left_on)
755+
right_names = [right_on] if isinstance(right_on, str) else list(right_on)
756+
757+
return on, left_names, right_names
758+
759+
def _prepare_deduplicate(
760+
self, right: DataFrame, on: str | Sequence[str]
761+
) -> tuple[DataFrame, list[str], list[str], list[str]]:
762+
"""Rename join columns to drop them after joining."""
763+
drop_cols: list[str] = []
764+
right_aliases: list[str] = []
765+
on_cols = [on] if isinstance(on, str) else list(on)
766+
for col_name in on_cols:
767+
alias = f"__right_{col_name}"
768+
right = right.with_column_renamed(col_name, alias)
769+
right_aliases.append(alias)
770+
drop_cols.append(alias)
771+
return right, drop_cols, on_cols, right_aliases
739772

740773
def join_on(
741774
self,

python/tests/test_dataframe.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,52 @@ def test_join_on():
519519
assert table.to_pydict() == expected
520520

521521

522+
def test_join_deduplicate():
523+
ctx = SessionContext()
524+
525+
batch = pa.RecordBatch.from_arrays(
526+
[pa.array([1, 2]), pa.array(["l1", "l2"])],
527+
names=["id", "left_val"],
528+
)
529+
left = ctx.create_dataframe([[batch]], "l")
530+
531+
batch = pa.RecordBatch.from_arrays(
532+
[pa.array([1, 2]), pa.array(["r1", "r2"])],
533+
names=["id", "right_val"],
534+
)
535+
right = ctx.create_dataframe([[batch]], "r")
536+
537+
joined = left.join(right, on="id", deduplicate=True)
538+
joined = joined.sort(column("id"))
539+
table = pa.Table.from_batches(joined.collect())
540+
541+
expected = {"id": [1, 2], "right_val": ["r1", "r2"], "left_val": ["l1", "l2"]}
542+
assert table.to_pydict() == expected
543+
544+
545+
def test_join_deduplicate_multi():
546+
ctx = SessionContext()
547+
548+
batch = pa.RecordBatch.from_arrays(
549+
[pa.array([1, 2]), pa.array([3, 4]), pa.array(["x", "y"])],
550+
names=["a", "b", "l"],
551+
)
552+
left = ctx.create_dataframe([[batch]], "l")
553+
554+
batch = pa.RecordBatch.from_arrays(
555+
[pa.array([1, 2]), pa.array([3, 4]), pa.array(["u", "v"])],
556+
names=["a", "b", "r"],
557+
)
558+
right = ctx.create_dataframe([[batch]], "r")
559+
560+
joined = left.join(right, on=["a", "b"], deduplicate=True)
561+
joined = joined.sort(column("a"))
562+
table = pa.Table.from_batches(joined.collect())
563+
564+
expected = {"a": [1, 2], "b": [3, 4], "r": ["u", "v"], "l": ["x", "y"]}
565+
assert table.to_pydict() == expected
566+
567+
522568
def test_distinct():
523569
ctx = SessionContext()
524570

0 commit comments

Comments
 (0)