From 51cd192d4b3979beee65ff2a2d4b0903457c6961 Mon Sep 17 00:00:00 2001 From: kosiew Date: Tue, 8 Jul 2025 21:25:55 +0800 Subject: [PATCH] fix join deduplicate and tests --- python/datafusion/dataframe.py | 120 ++++++++++++++++++--------------- python/tests/test_dataframe.py | 32 ++++----- 2 files changed, 82 insertions(+), 70 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 24109b247..89f5cd137 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -46,6 +46,8 @@ from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.record_batch import RecordBatchStream +from .functions import coalesce, col + if TYPE_CHECKING: import pathlib from typing import Callable, Sequence @@ -77,6 +79,31 @@ class JoinPreparation: drop_cols: list[str] +def _deduplicate_right( + right: DataFrame, columns: Sequence[str] +) -> tuple[DataFrame, list[str]]: + """Rename join columns on the right DataFrame for deduplication.""" + existing_columns = set(right.schema().names) + modified = right + aliases: list[str] = [] + + for col_name in columns: + base_alias = f"__right_{col_name}" + alias = base_alias + counter = 0 + while alias in existing_columns: + counter += 1 + alias = f"{base_alias}_{counter}" + if alias in existing_columns: + alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}" + + modified = modified.with_column_renamed(col_name, alias) + aliases.append(alias) + existing_columns.add(alias) + + return modified, aliases + + # excerpt from deltalake # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 class Compression(Enum): @@ -730,10 +757,23 @@ def join( join_preparation.join_keys.right_names, ) ) - + + if ( + deduplicate + and how in ("right", "full") + and join_preparation.join_keys.on is not None + ): + for left_name, right_alias in zip( + join_preparation.join_keys.left_names, + join_preparation.drop_cols, + ): + result = result.with_column( + left_name, coalesce(col(left_name), col(right_alias)) + ) + if join_preparation.drop_cols: result = result.drop(*join_preparation.drop_cols) - + return result def _prepare_join( @@ -746,10 +786,10 @@ def _prepare_join( deduplicate: bool, ) -> JoinPreparation: """Prepare join keys and handle deduplication if requested. - + This method combines join key resolution and deduplication preparation to avoid parameter handling duplication and provide a unified interface. - + Args: right: The right DataFrame to join with. on: Column names to join on in both dataframes. @@ -757,7 +797,7 @@ def _prepare_join( right_on: Join column of the right dataframe. join_keys: Tuple of two lists of column names to join on. [Deprecated] deduplicate: If True, prepare right DataFrame for column deduplication. - + Returns: JoinPreparation containing resolved join keys, modified right DataFrame, and columns to drop after joining. @@ -787,71 +827,41 @@ def _prepare_join( if resolved_on is not None: if left_on is not None or right_on is not None: - error_msg = ( - "`left_on` or `right_on` should not be provided with `on`. " - "Note: `deduplicate` must be specified as a keyword argument." - ) + error_msg = "`left_on` or `right_on` should not provided with `on`" raise ValueError(error_msg) left_on = resolved_on right_on = resolved_on elif left_on is not None or right_on is not None: if left_on is None or right_on is None: - error_msg = ( - "`left_on` and `right_on` should both be provided. " - "Note: `deduplicate` must be specified as a keyword argument." - ) + error_msg = "`left_on` and `right_on` should both be provided." raise ValueError(error_msg) else: - error_msg = ( - "Either `on` or both `left_on` and `right_on` should be provided. " - "Note: `deduplicate` must be specified as a keyword argument." - ) + error_msg = "either `on` or `left_on` and `right_on` should be provided." raise ValueError(error_msg) # At this point, left_on and right_on are guaranteed to be non-None - assert left_on is not None and right_on is not None - + if left_on is None or right_on is None: # pragma: no cover - sanity check + msg = "join keys resolved to None" + raise ValueError(msg) + left_names = [left_on] if isinstance(left_on, str) else list(left_on) right_names = [right_on] if isinstance(right_on, str) else list(right_on) - - join_keys_resolved = JoinKeys( - on=resolved_on, left_names=left_names, right_names=right_names - ) - - # Step 2: Handle deduplication if requested + drop_cols: list[str] = [] modified_right = right - + if deduplicate and resolved_on is not None: - # Prepare deduplication by renaming columns in the right DataFrame - on_cols = [resolved_on] if isinstance(resolved_on, str) else list(resolved_on) - - # Get existing column names to avoid collisions - existing_columns = set(right.schema().names) - - for col_name in on_cols: - # Generate a collision-safe temporary alias - base_alias = f"__right_{col_name}" - alias = base_alias - counter = 0 - - # Keep trying until we find a unique name - while alias in existing_columns: - counter += 1 - alias = f"{base_alias}_{counter}" - - # If even that fails (very unlikely), use UUID - if alias in existing_columns: - alias = f"__temp_{uuid.uuid4().hex[:8]}_{col_name}" - - modified_right = modified_right.with_column_renamed(col_name, alias) - drop_cols.append(alias) - # Add the new alias to existing columns to avoid future collisions - existing_columns.add(alias) - - # Update right_names to use the new aliases - right_names = drop_cols.copy() - + on_cols = ( + [resolved_on] if isinstance(resolved_on, str) else list(resolved_on) + ) + modified_right, aliases = _deduplicate_right(right, on_cols) + drop_cols.extend(aliases) + right_names = aliases.copy() + + join_keys_resolved = JoinKeys( + on=resolved_on, left_names=left_names, right_names=right_names + ) + return JoinPreparation( join_keys=join_keys_resolved, modified_right=modified_right, diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index b2b05f105..6c03893d6 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -558,7 +558,7 @@ def test_join_deduplicate_multi(): right = ctx.create_dataframe([[batch]], "r") joined = left.join(right, on=["a", "b"], deduplicate=True) - joined = joined.sort([column("a"), column("b")]) + joined = joined.sort(column("a"), column("b")) table = pa.Table.from_batches(joined.collect()) expected = {"a": [1, 2], "b": [3, 4], "r": ["u", "v"], "l": ["x", "y"]} @@ -2678,7 +2678,9 @@ def test_join_deduplicate_select(): # Ensure no internal alias names like "__right_id" appear in the schema for col_name in column_names: - assert not col_name.startswith("__"), f"Internal alias '{col_name}' leaked into schema" + assert not col_name.startswith("__"), ( + f"Internal alias '{col_name}' leaked into schema" + ) # Test selecting each column individually to ensure they all work for col_name in expected_columns: @@ -2693,13 +2695,13 @@ def test_join_deduplicate_select(): assert all_result.schema.names == expected_columns # Verify that attempting to select a potential internal alias fails appropriately - with pytest.raises(Exception): # Should raise an error for non-existent column + with pytest.raises(Exception): # noqa: B017 - generic exception from FFI joined_df.select(column("__right_id")).collect() def test_join_deduplicate_all_types(): """Test deduplication behavior across different join types (left, right, outer). - + Note: This test may show linting errors due to method signature overloads, but the functionality should work correctly at runtime. """ @@ -2721,8 +2723,8 @@ def test_join_deduplicate_all_types(): # Test inner join with deduplication (default behavior) inner_joined = left_df.join(right_df, on="id", how="inner", deduplicate=True) - inner_result = inner_joined.sort([column("id")]).collect()[0] - + inner_result = inner_joined.sort(column("id")).collect()[0] + # Should only have matching rows (2, 3) expected_inner = { "id": [2, 3], @@ -2733,8 +2735,8 @@ def test_join_deduplicate_all_types(): # Test left join with deduplication left_joined = left_df.join(right_df, on="id", how="left", deduplicate=True) - left_result = left_joined.sort([column("id")]).collect()[0] - + left_result = left_joined.sort(column("id")).collect()[0] + # Should have all left rows, with nulls for unmatched right rows expected_left = { "id": [1, 2, 3, 4], @@ -2745,8 +2747,8 @@ def test_join_deduplicate_all_types(): # Test right join with deduplication right_joined = left_df.join(right_df, on="id", how="right", deduplicate=True) - right_result = right_joined.sort([column("id")]).collect()[0] - + right_result = right_joined.sort(column("id")).collect()[0] + # Should have all right rows, with nulls for unmatched left rows expected_right = { "id": [2, 3, 5, 6], @@ -2756,9 +2758,9 @@ def test_join_deduplicate_all_types(): assert right_result.to_pydict() == expected_right # Test full outer join with deduplication - outer_joined = left_df.join(right_df, on="id", how="outer", deduplicate=True) - outer_result = outer_joined.sort([column("id")]).collect()[0] - + outer_joined = left_df.join(right_df, on="id", how="full", deduplicate=True) + outer_result = outer_joined.sort(column("id")).collect()[0] + # Should have all rows from both sides, with nulls for unmatched rows expected_outer = { "id": [1, 2, 3, 4, 5, 6], @@ -2768,8 +2770,8 @@ def test_join_deduplicate_all_types(): assert outer_result.to_pydict() == expected_outer # Verify that we can still select the deduplicated column without issues - for join_type in ["inner", "left", "right", "outer"]: + for join_type in ["inner", "left", "right", "full"]: joined = left_df.join(right_df, on="id", how=join_type, deduplicate=True) selected = joined.select(column("id")) # Should not raise an error and should have the same number of rows - assert len(selected.collect()[0]) == len(joined.collect()[0]) \ No newline at end of file + assert len(selected.collect()[0]) == len(joined.collect()[0])