@@ -62,7 +62,8 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
6262 """
6363 all_columns = set (source_table .column_names )
6464 join_cols_set = set (join_cols )
65- non_key_cols = all_columns - join_cols_set
65+
66+ non_key_cols = list (all_columns - join_cols_set )
6667
6768 if has_duplicate_rows (target_table , join_cols ):
6869 raise ValueError ("Target table has duplicate rows, aborting upsert" )
@@ -71,25 +72,51 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
7172 # When the target table is empty, there is nothing to update :)
7273 return source_table .schema .empty_table ()
7374
74- diff_expr = functools .reduce (
75- operator .or_ ,
76- [
77- pc .or_kleene (
78- pc .not_equal (pc .field (f"{ col } -lhs" ), pc .field (f"{ col } -rhs" )),
79- pc .is_null (pc .not_equal (pc .field (f"{ col } -lhs" ), pc .field (f"{ col } -rhs" ))),
80- )
81- for col in non_key_cols
82- ],
75+ # We need to compare non_key_cols in Python as PyArrow
76+ # 1. Cannot do a join when non-join columns have complex types
77+ # 2. Cannot compare columns with complex types
78+ # See: https://github.com/apache/arrow/issues/35785
79+ SOURCE_INDEX_COLUMN_NAME = "__source_index"
80+ TARGET_INDEX_COLUMN_NAME = "__target_index"
81+
82+ if SOURCE_INDEX_COLUMN_NAME in join_cols or TARGET_INDEX_COLUMN_NAME in join_cols :
83+ raise ValueError (
84+ f"{ SOURCE_INDEX_COLUMN_NAME } and { TARGET_INDEX_COLUMN_NAME } are reserved for joining "
85+ f"DataFrames, and cannot be used as column names"
86+ ) from None
87+
88+ # Step 1: Prepare source index with join keys and a marker index
89+ # Cast to target table schema, so we can do the join
90+ # See: https://github.com/apache/arrow/issues/37542
91+ source_index = (
92+ source_table .cast (target_table .schema )
93+ .select (join_cols_set )
94+ .append_column (SOURCE_INDEX_COLUMN_NAME , pa .array (range (len (source_table ))))
8395 )
8496
85- return (
86- source_table
87- # We already know that the schema is compatible, this is to fix large_ types
88- .cast (target_table .schema )
89- .join (target_table , keys = list (join_cols_set ), join_type = "inner" , left_suffix = "-lhs" , right_suffix = "-rhs" )
90- .filter (diff_expr )
91- .drop_columns ([f"{ col } -rhs" for col in non_key_cols ])
92- .rename_columns ({f"{ col } -lhs" if col not in join_cols else col : col for col in source_table .column_names })
93- # Finally cast to the original schema since it doesn't carry nullability:
94- # https://github.com/apache/arrow/issues/45557
95- ).cast (target_table .schema )
97+ # Step 2: Prepare target index with join keys and a marker
98+ target_index = target_table .select (join_cols_set ).append_column (TARGET_INDEX_COLUMN_NAME , pa .array (range (len (target_table ))))
99+
100+ # Step 3: Perform an inner join to find which rows from source exist in target
101+ matching_indices = source_index .join (target_index , keys = list (join_cols_set ), join_type = "inner" )
102+
103+ # Step 4: Compare all rows using Python
104+ to_update_indices = []
105+ for source_idx , target_idx in zip (
106+ matching_indices [SOURCE_INDEX_COLUMN_NAME ].to_pylist (), matching_indices [TARGET_INDEX_COLUMN_NAME ].to_pylist ()
107+ ):
108+ source_row = source_table .slice (source_idx , 1 )
109+ target_row = target_table .slice (target_idx , 1 )
110+
111+ for key in non_key_cols :
112+ source_val = source_row .column (key )[0 ].as_py ()
113+ target_val = target_row .column (key )[0 ].as_py ()
114+ if source_val != target_val :
115+ to_update_indices .append (source_idx )
116+ break
117+
118+ # Step 5: Take rows from source table using the indices and cast to target schema
119+ if to_update_indices :
120+ return source_table .take (to_update_indices )
121+ else :
122+ return source_table .schema .empty_table ()
0 commit comments