Skip to content

Commit 80c4e1f

Browse files
committed
feat: add null_safe_eq parameter to upsert
1 parent 8c87df2 commit 80c4e1f

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

pyiceberg/table/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ def upsert(
718718
when_matched_update_all: bool = True,
719719
when_not_matched_insert_all: bool = True,
720720
case_sensitive: bool = True,
721+
null_safe_eq: bool = False,
721722
branch: str | None = MAIN_BRANCH,
722723
snapshot_properties: dict[str, str] = EMPTY_DICT,
723724
) -> UpsertResult:
@@ -732,6 +733,7 @@ def upsert(
732733
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any
733734
existing rows in the table
734735
case_sensitive: Bool indicating if the match should be case-sensitive
736+
null_safe_eq: Bool indicating if the equality operator should be null-safe (<=> instead of =)
735737
branch: Branch Reference to run the upsert operation
736738
snapshot_properties: Custom properties to be added to the snapshot summary
737739
@@ -824,7 +826,7 @@ def upsert(
824826
# values have actually changed. We don't want to do just a blanket overwrite for matched
825827
# rows if the actual non-key column data hasn't changed.
826828
# this extra step avoids unnecessary IO and writes
827-
rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols)
829+
rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols, null_safe_eq=null_safe_eq)
828830

829831
if len(rows_to_update) > 0:
830832
# build the match predicate filter
@@ -1320,6 +1322,7 @@ def upsert(
13201322
when_matched_update_all: bool = True,
13211323
when_not_matched_insert_all: bool = True,
13221324
case_sensitive: bool = True,
1325+
null_safe_eq: bool = False,
13231326
branch: str | None = MAIN_BRANCH,
13241327
snapshot_properties: dict[str, str] = EMPTY_DICT,
13251328
) -> UpsertResult:
@@ -1334,6 +1337,7 @@ def upsert(
13341337
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any
13351338
existing rows in the table
13361339
case_sensitive: Bool indicating if the match should be case-sensitive
1340+
null_safe_eq: Bool indicating if the equality operator should be null-safe (<=> instead of =)
13371341
branch: Branch Reference to run the upsert operation
13381342
snapshot_properties: Custom properties to be added to the snapshot summary
13391343
@@ -1368,6 +1372,7 @@ def upsert(
13681372
when_matched_update_all=when_matched_update_all,
13691373
when_not_matched_insert_all=when_not_matched_insert_all,
13701374
case_sensitive=case_sensitive,
1375+
null_safe_eq=null_safe_eq,
13711376
branch=branch,
13721377
snapshot_properties=snapshot_properties,
13731378
)

pyiceberg/table/upsert_util.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
7676
return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0
7777

7878

79-
def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
79+
def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str], null_safe_eq: bool) -> pa.Table:
8080
"""
8181
Return a table with rows that need to be updated in the target table based on the join columns.
8282
@@ -121,16 +121,20 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
121121
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
122122

123123
# Step 3: Perform an inner join to find which rows from source exist in target
124-
# PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python.
125-
# This is equivalent to:
126-
# matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
127-
source_indices = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()}
128-
target_indices = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()}
129-
matching_indices = [(s, t) for key, s in source_indices.items() if (t := target_indices.get(key)) is not None]
124+
if null_safe_eq:
125+
# PyArrow joins ignore null values, and we want null==null to hold, so we compute the join in Python.
126+
source_indices = {tuple(row[col] for col in join_cols): row[SOURCE_INDEX_COLUMN_NAME] for row in source_index.to_pylist()}
127+
target_indices = {tuple(row[col] for col in join_cols): row[TARGET_INDEX_COLUMN_NAME] for row in target_index.to_pylist()}
128+
paired_indices = [(s, t) for key, s in source_indices.items() if (t := target_indices.get(key)) is not None]
129+
else:
130+
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
131+
source_indices = matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist()
132+
target_indices = matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist()
133+
paired_indices = list(zip(source_indices, target_indices, strict=True))
130134

131135
# Step 4: Compare all rows using Python
132136
to_update_indices = []
133-
for source_idx, target_idx in matching_indices:
137+
for source_idx, target_idx in paired_indices:
134138
source_row = source_table.slice(source_idx, 1)
135139
target_row = target_table.slice(target_idx, 1)
136140

tests/table/test_upsert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def test_upsert_with_nulls_in_join_columns(catalog: Catalog) -> None:
828828
],
829829
schema=schema,
830830
)
831-
upd = table.upsert(data_with_null, join_cols=["foo", "bar"])
831+
upd = table.upsert(data_with_null, join_cols=["foo", "bar"], null_safe_eq=True)
832832
assert upd.rows_updated == 1
833833
assert upd.rows_inserted == 1
834834
assert table.scan().to_arrow() == pa.Table.from_pylist(

0 commit comments

Comments
 (0)