2222from __future__ import annotations
2323
2424import warnings
25+ from dataclasses import dataclass
2526from typing import (
2627 TYPE_CHECKING ,
2728 Any ,
5758from enum import Enum
5859
5960
61+ @dataclass
62+ class JoinKeys :
63+ """Represents the resolved join keys for a DataFrame join operation."""
64+
65+ on : str | Sequence [str ] | None
66+ left_names : list [str ]
67+ right_names : list [str ]
68+
69+
6070# excerpt from deltalake
6171# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
6272class Compression (Enum ):
@@ -698,15 +708,20 @@ def join(
698708 Returns:
699709 DataFrame after join.
700710 """
701- on , left_on , right_on = self ._resolve_join_keys (
711+ join_keys_resolved = self ._resolve_join_keys (
702712 on , left_on , right_on , join_keys
703713 )
704714
705715 drop_cols : list [str ] | None = None
706- if deduplicate and on is not None :
707- right , drop_cols , left_on , right_on = self ._prepare_deduplicate (right , on )
716+ if deduplicate and join_keys_resolved .on is not None :
717+ right , drop_cols , left_on_final , right_on_final = self ._prepare_deduplicate (
718+ right , join_keys_resolved .on
719+ )
720+ else :
721+ left_on_final = join_keys_resolved .left_names
722+ right_on_final = join_keys_resolved .right_names
708723
709- result = DataFrame (self .df .join (right .df , how , left_on , right_on ))
724+ result = DataFrame (self .df .join (right .df , how , left_on_final , right_on_final ))
710725 if drop_cols :
711726 result = result .drop (* drop_cols )
712727 return result
@@ -717,16 +732,20 @@ def _resolve_join_keys(
717732 left_on : str | Sequence [str ] | None ,
718733 right_on : str | Sequence [str ] | None ,
719734 join_keys : tuple [list [str ], list [str ]] | None ,
720- ) -> tuple [str | Sequence [str ] | None , list [str ], list [str ]]:
721- """Normalize join key arguments and validate them."""
735+ ) -> JoinKeys :
736+ """Normalize join key arguments and validate them."""
737+ # Handle the special case where on is a tuple of lists (legacy format)
738+ resolved_on : str | Sequence [str ] | None
722739 if (
723740 isinstance (on , tuple )
724741 and len (on ) == 2
725742 and isinstance (on [0 ], list )
726743 and isinstance (on [1 ], list )
727744 ):
728745 join_keys = on # type: ignore[assignment]
729- on = None
746+ resolved_on = None
747+ else :
748+ resolved_on = on # type: ignore[assignment]
730749
731750 if join_keys is not None :
732751 warnings .warn (
@@ -737,12 +756,12 @@ def _resolve_join_keys(
737756 left_on = join_keys [0 ]
738757 right_on = join_keys [1 ]
739758
740- if on is not None :
759+ if resolved_on is not None :
741760 if left_on is not None or right_on is not None :
742761 error_msg = "`left_on` or `right_on` should not provided with `on`"
743762 raise ValueError (error_msg )
744- left_on = on
745- right_on = on
763+ left_on = resolved_on
764+ right_on = resolved_on
746765 elif left_on is not None or right_on is not None :
747766 if left_on is None or right_on is None :
748767 error_msg = "`left_on` and `right_on` should both be provided."
@@ -751,10 +770,13 @@ def _resolve_join_keys(
751770 error_msg = "either `on` or `left_on` and `right_on` should be provided."
752771 raise ValueError (error_msg )
753772
773+ # At this point, left_on and right_on are guaranteed to be non-None
774+ assert left_on is not None and right_on is not None
775+
754776 left_names = [left_on ] if isinstance (left_on , str ) else list (left_on )
755777 right_names = [right_on ] if isinstance (right_on , str ) else list (right_on )
756778
757- return on , left_names , right_names
779+ return JoinKeys ( on = resolved_on , left_names = left_names , right_names = right_names )
758780
759781 def _prepare_deduplicate (
760782 self , right : DataFrame , on : str | Sequence [str ]
0 commit comments