Skip to content

Commit 7d7146c

Browse files
committed
feat: introduce JoinKeys dataclass for improved join key handling in DataFrame
1 parent fa80aa6 commit 7d7146c

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

python/datafusion/dataframe.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import annotations
2323

2424
import warnings
25+
from dataclasses import dataclass
2526
from typing import (
2627
TYPE_CHECKING,
2728
Any,
@@ -57,6 +58,15 @@
5758
from enum import Enum
5859

5960

61+
@dataclass
62+
class JoinKeys:
63+
"""Represents the resolved join keys for a DataFrame join operation."""
64+
65+
on: str | Sequence[str] | None
66+
left_names: list[str]
67+
right_names: list[str]
68+
69+
6070
# excerpt from deltalake
6171
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
6272
class Compression(Enum):
@@ -698,15 +708,20 @@ def join(
698708
Returns:
699709
DataFrame after join.
700710
"""
701-
on, left_on, right_on = self._resolve_join_keys(
711+
join_keys_resolved = self._resolve_join_keys(
702712
on, left_on, right_on, join_keys
703713
)
704714

705715
drop_cols: list[str] | None = None
706-
if deduplicate and on is not None:
707-
right, drop_cols, left_on, right_on = self._prepare_deduplicate(right, on)
716+
if deduplicate and join_keys_resolved.on is not None:
717+
right, drop_cols, left_on_final, right_on_final = self._prepare_deduplicate(
718+
right, join_keys_resolved.on
719+
)
720+
else:
721+
left_on_final = join_keys_resolved.left_names
722+
right_on_final = join_keys_resolved.right_names
708723

709-
result = DataFrame(self.df.join(right.df, how, left_on, right_on))
724+
result = DataFrame(self.df.join(right.df, how, left_on_final, right_on_final))
710725
if drop_cols:
711726
result = result.drop(*drop_cols)
712727
return result
@@ -717,16 +732,20 @@ def _resolve_join_keys(
717732
left_on: str | Sequence[str] | None,
718733
right_on: str | Sequence[str] | None,
719734
join_keys: tuple[list[str], list[str]] | None,
720-
) -> tuple[str | Sequence[str] | None, list[str], list[str]]:
721-
"""Normalize join key arguments and validate them."""
735+
) -> JoinKeys:
736+
"""Normalize join key arguments and validate them."""
737+
# Handle the special case where on is a tuple of lists (legacy format)
738+
resolved_on: str | Sequence[str] | None
722739
if (
723740
isinstance(on, tuple)
724741
and len(on) == 2
725742
and isinstance(on[0], list)
726743
and isinstance(on[1], list)
727744
):
728745
join_keys = on # type: ignore[assignment]
729-
on = None
746+
resolved_on = None
747+
else:
748+
resolved_on = on # type: ignore[assignment]
730749

731750
if join_keys is not None:
732751
warnings.warn(
@@ -737,12 +756,12 @@ def _resolve_join_keys(
737756
left_on = join_keys[0]
738757
right_on = join_keys[1]
739758

740-
if on is not None:
759+
if resolved_on is not None:
741760
if left_on is not None or right_on is not None:
742761
error_msg = "`left_on` or `right_on` should not provided with `on`"
743762
raise ValueError(error_msg)
744-
left_on = on
745-
right_on = on
763+
left_on = resolved_on
764+
right_on = resolved_on
746765
elif left_on is not None or right_on is not None:
747766
if left_on is None or right_on is None:
748767
error_msg = "`left_on` and `right_on` should both be provided."
@@ -751,10 +770,13 @@ def _resolve_join_keys(
751770
error_msg = "either `on` or `left_on` and `right_on` should be provided."
752771
raise ValueError(error_msg)
753772

773+
# At this point, left_on and right_on are guaranteed to be non-None
774+
assert left_on is not None and right_on is not None
775+
754776
left_names = [left_on] if isinstance(left_on, str) else list(left_on)
755777
right_names = [right_on] if isinstance(right_on, str) else list(right_on)
756778

757-
return on, left_names, right_names
779+
return JoinKeys(on=resolved_on, left_names=left_names, right_names=right_names)
758780

759781
def _prepare_deduplicate(
760782
self, right: DataFrame, on: str | Sequence[str]

0 commit comments

Comments
 (0)