Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 65 additions & 55 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from datafusion.plan import ExecutionPlan, LogicalPlan
from datafusion.record_batch import RecordBatchStream

from .functions import coalesce, col

if TYPE_CHECKING:
import pathlib
from typing import Callable, Sequence
Expand Down Expand Up @@ -77,6 +79,31 @@ class JoinPreparation:
drop_cols: list[str]


def _deduplicate_right(
right: DataFrame, columns: Sequence[str]
) -> tuple[DataFrame, list[str]]:
"""Rename join columns on the right DataFrame for deduplication."""
existing_columns = set(right.schema().names)
modified = right
aliases: list[str] = []

for col_name in columns:
base_alias = f"__right_{col_name}"
alias = base_alias
counter = 0
while alias in existing_columns:
counter += 1
alias = f"{base_alias}_{counter}"
if alias in existing_columns:
alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}"

modified = modified.with_column_renamed(col_name, alias)
aliases.append(alias)
existing_columns.add(alias)

return modified, aliases


# excerpt from deltalake
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
class Compression(Enum):
Expand Down Expand Up @@ -730,10 +757,23 @@ def join(
join_preparation.join_keys.right_names,
)
)


if (
deduplicate
and how in ("right", "full")
and join_preparation.join_keys.on is not None
):
for left_name, right_alias in zip(
join_preparation.join_keys.left_names,
join_preparation.drop_cols,
):
result = result.with_column(
left_name, coalesce(col(left_name), col(right_alias))
)

if join_preparation.drop_cols:
result = result.drop(*join_preparation.drop_cols)

return result

def _prepare_join(
Expand All @@ -746,18 +786,18 @@ def _prepare_join(
deduplicate: bool,
) -> JoinPreparation:
"""Prepare join keys and handle deduplication if requested.

This method combines join key resolution and deduplication preparation
to avoid parameter handling duplication and provide a unified interface.

Args:
right: The right DataFrame to join with.
on: Column names to join on in both dataframes.
left_on: Join column of the left dataframe.
right_on: Join column of the right dataframe.
join_keys: Tuple of two lists of column names to join on. [Deprecated]
deduplicate: If True, prepare right DataFrame for column deduplication.

Returns:
JoinPreparation containing resolved join keys, modified right DataFrame,
and columns to drop after joining.
Expand Down Expand Up @@ -787,71 +827,41 @@ def _prepare_join(

if resolved_on is not None:
if left_on is not None or right_on is not None:
error_msg = (
"`left_on` or `right_on` should not be provided with `on`. "
"Note: `deduplicate` must be specified as a keyword argument."
)
error_msg = "`left_on` or `right_on` should not provided with `on`"
raise ValueError(error_msg)
left_on = resolved_on
right_on = resolved_on
elif left_on is not None or right_on is not None:
if left_on is None or right_on is None:
error_msg = (
"`left_on` and `right_on` should both be provided. "
"Note: `deduplicate` must be specified as a keyword argument."
)
error_msg = "`left_on` and `right_on` should both be provided."
raise ValueError(error_msg)
else:
error_msg = (
"Either `on` or both `left_on` and `right_on` should be provided. "
"Note: `deduplicate` must be specified as a keyword argument."
)
error_msg = "either `on` or `left_on` and `right_on` should be provided."
raise ValueError(error_msg)

# At this point, left_on and right_on are guaranteed to be non-None
assert left_on is not None and right_on is not None

if left_on is None or right_on is None: # pragma: no cover - sanity check
msg = "join keys resolved to None"
raise ValueError(msg)

left_names = [left_on] if isinstance(left_on, str) else list(left_on)
right_names = [right_on] if isinstance(right_on, str) else list(right_on)

join_keys_resolved = JoinKeys(
on=resolved_on, left_names=left_names, right_names=right_names
)

# Step 2: Handle deduplication if requested

drop_cols: list[str] = []
modified_right = right

if deduplicate and resolved_on is not None:
# Prepare deduplication by renaming columns in the right DataFrame
on_cols = [resolved_on] if isinstance(resolved_on, str) else list(resolved_on)

# Get existing column names to avoid collisions
existing_columns = set(right.schema().names)

for col_name in on_cols:
# Generate a collision-safe temporary alias
base_alias = f"__right_{col_name}"
alias = base_alias
counter = 0

# Keep trying until we find a unique name
while alias in existing_columns:
counter += 1
alias = f"{base_alias}_{counter}"

# If even that fails (very unlikely), use UUID
if alias in existing_columns:
alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}"

modified_right = modified_right.with_column_renamed(col_name, alias)
drop_cols.append(alias)
# Add the new alias to existing columns to avoid future collisions
existing_columns.add(alias)

# Update right_names to use the new aliases
right_names = drop_cols.copy()

on_cols = (
[resolved_on] if isinstance(resolved_on, str) else list(resolved_on)
)
modified_right, aliases = _deduplicate_right(right, on_cols)
drop_cols.extend(aliases)
right_names = aliases.copy()

join_keys_resolved = JoinKeys(
on=resolved_on, left_names=left_names, right_names=right_names
)

return JoinPreparation(
join_keys=join_keys_resolved,
modified_right=modified_right,
Expand Down
32 changes: 17 additions & 15 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def test_join_deduplicate_multi():
right = ctx.create_dataframe([[batch]], "r")

joined = left.join(right, on=["a", "b"], deduplicate=True)
joined = joined.sort([column("a"), column("b")])
joined = joined.sort(column("a"), column("b"))
table = pa.Table.from_batches(joined.collect())

expected = {"a": [1, 2], "b": [3, 4], "r": ["u", "v"], "l": ["x", "y"]}
Expand Down Expand Up @@ -2678,7 +2678,9 @@ def test_join_deduplicate_select():

# Ensure no internal alias names like "__right_id" appear in the schema
for col_name in column_names:
assert not col_name.startswith("__"), f"Internal alias '{col_name}' leaked into schema"
assert not col_name.startswith("__"), (
f"Internal alias '{col_name}' leaked into schema"
)

# Test selecting each column individually to ensure they all work
for col_name in expected_columns:
Expand All @@ -2693,13 +2695,13 @@ def test_join_deduplicate_select():
assert all_result.schema.names == expected_columns

# Verify that attempting to select a potential internal alias fails appropriately
with pytest.raises(Exception): # Should raise an error for non-existent column
with pytest.raises(Exception): # noqa: B017 - generic exception from FFI
joined_df.select(column("__right_id")).collect()


def test_join_deduplicate_all_types():
"""Test deduplication behavior across different join types (left, right, outer).

Note: This test may show linting errors due to method signature overloads,
but the functionality should work correctly at runtime.
"""
Expand All @@ -2721,8 +2723,8 @@ def test_join_deduplicate_all_types():

# Test inner join with deduplication (default behavior)
inner_joined = left_df.join(right_df, on="id", how="inner", deduplicate=True)
inner_result = inner_joined.sort([column("id")]).collect()[0]
inner_result = inner_joined.sort(column("id")).collect()[0]

# Should only have matching rows (2, 3)
expected_inner = {
"id": [2, 3],
Expand All @@ -2733,8 +2735,8 @@ def test_join_deduplicate_all_types():

# Test left join with deduplication
left_joined = left_df.join(right_df, on="id", how="left", deduplicate=True)
left_result = left_joined.sort([column("id")]).collect()[0]
left_result = left_joined.sort(column("id")).collect()[0]

# Should have all left rows, with nulls for unmatched right rows
expected_left = {
"id": [1, 2, 3, 4],
Expand All @@ -2745,8 +2747,8 @@ def test_join_deduplicate_all_types():

# Test right join with deduplication
right_joined = left_df.join(right_df, on="id", how="right", deduplicate=True)
right_result = right_joined.sort([column("id")]).collect()[0]
right_result = right_joined.sort(column("id")).collect()[0]

# Should have all right rows, with nulls for unmatched left rows
expected_right = {
"id": [2, 3, 5, 6],
Expand All @@ -2756,9 +2758,9 @@ def test_join_deduplicate_all_types():
assert right_result.to_pydict() == expected_right

# Test full outer join with deduplication
outer_joined = left_df.join(right_df, on="id", how="outer", deduplicate=True)
outer_result = outer_joined.sort([column("id")]).collect()[0]
outer_joined = left_df.join(right_df, on="id", how="full", deduplicate=True)
outer_result = outer_joined.sort(column("id")).collect()[0]

# Should have all rows from both sides, with nulls for unmatched rows
expected_outer = {
"id": [1, 2, 3, 4, 5, 6],
Expand All @@ -2768,8 +2770,8 @@ def test_join_deduplicate_all_types():
assert outer_result.to_pydict() == expected_outer

# Verify that we can still select the deduplicated column without issues
for join_type in ["inner", "left", "right", "outer"]:
for join_type in ["inner", "left", "right", "full"]:
joined = left_df.join(right_df, on="id", how=join_type, deduplicate=True)
selected = joined.select(column("id"))
# Should not raise an error and should have the same number of rows
assert len(selected.collect()[0]) == len(joined.collect()[0])
assert len(selected.collect()[0]) == len(joined.collect()[0])
Loading