Skip to content

Commit f25cc42

Browse files
committed
fix integration test for UpdateSpec
1 parent a95c870 commit f25cc42

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

tests/integration/test_catalog.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pathlib import Path, PosixPath
2121

2222
import pytest
23+
from unittest.mock import patch
2324

2425
from pyiceberg.catalog import Catalog, MetastoreCatalog, load_catalog
2526
from pyiceberg.catalog.hive import HiveCatalog
@@ -37,8 +38,10 @@
3738
from pyiceberg.io import WAREHOUSE
3839
from pyiceberg.partitioning import PartitionField, PartitionSpec
3940
from pyiceberg.schema import INITIAL_SCHEMA_ID, Schema
41+
from pyiceberg.table import CommitTableResponse, Table, TableProperties
4042
from pyiceberg.table.metadata import INITIAL_SPEC_ID
4143
from pyiceberg.table.sorting import INITIAL_SORT_ORDER_ID, SortField, SortOrder
44+
from pyiceberg.table.update import TableRequirement, TableUpdate
4245
from pyiceberg.transforms import BucketTransform, DayTransform, IdentityTransform
4346
from pyiceberg.types import IntegerType, LongType, NestedField, TimestampType, UUIDType
4447
from tests.conftest import clean_up
@@ -527,7 +530,12 @@ def test_update_table_spec_conflict(test_catalog: Catalog, test_schema: Schema,
527530
identifier = (database_name, table_name)
528531
test_catalog.create_namespace(database_name)
529532
spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket"))
530-
table = test_catalog.create_table(identifier, test_schema, partition_spec=spec)
533+
table = test_catalog.create_table(
534+
identifier,
535+
test_schema,
536+
partition_spec=spec,
537+
properties={TableProperties.COMMIT_NUM_RETRIES: "1"}
538+
)
531539

532540
update = table.update_spec()
533541
update.add_field(source_column_name="tpep_pickup_datetime", transform=BucketTransform(16), partition_field_name="shard")
@@ -546,6 +554,44 @@ def test_update_table_spec_conflict(test_catalog: Catalog, test_schema: Schema,
546554
assert loaded.spec() == PartitionSpec(spec_id=1)
547555

548556

557+
@pytest.mark.integration
558+
@pytest.mark.parametrize("test_catalog", CATALOGS)
559+
def test_update_table_spec_conflict_with_retry(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None:
560+
identifier = (database_name, table_name)
561+
test_catalog.create_namespace(database_name)
562+
spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket"))
563+
table = test_catalog.create_table(
564+
identifier,
565+
test_schema,
566+
partition_spec=spec,
567+
properties={TableProperties.COMMIT_NUM_RETRIES: "2"}
568+
)
569+
update = table.update_spec()
570+
update.add_field(source_column_name="tpep_pickup_datetime", transform=BucketTransform(16), partition_field_name="shard")
571+
572+
# update with conflict
573+
conflict_table = test_catalog.load_table(identifier)
574+
with conflict_table.update_spec() as conflict_update:
575+
conflict_update.remove_field("id_bucket")
576+
577+
original_commit = test_catalog.commit_table
578+
commit_count = 0
579+
580+
def mock_commit(
581+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
582+
) -> CommitTableResponse:
583+
nonlocal commit_count
584+
commit_count += 1
585+
return original_commit(tbl, requirements, updates)
586+
587+
with patch.object(test_catalog, "commit_table", side_effect=mock_commit):
588+
update.commit()
589+
590+
loaded = test_catalog.load_table(identifier)
591+
assert loaded.spec().spec_id == 2
592+
assert commit_count == 2
593+
594+
549595
@pytest.mark.integration
550596
@pytest.mark.parametrize("test_catalog", CATALOGS)
551597
def test_update_table_spec_then_revert(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None:

tests/table/test_commit_retry.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@
2727
from pyiceberg.exceptions import CommitFailedException, CommitStateUnknownException
2828
from pyiceberg.schema import Schema
2929
from pyiceberg.table import CommitTableResponse, Table, TableProperties
30-
from pyiceberg.table.update import (
31-
TableRequirement,
32-
TableUpdate,
33-
)
30+
from pyiceberg.table.update import TableRequirement, TableUpdate
3431
from pyiceberg.types import LongType, NestedField
3532

3633

@@ -80,15 +77,14 @@ def mock_commit(
8077
) -> CommitTableResponse:
8178
nonlocal commit_count
8279
commit_count += 1
83-
if commit_count < 2:
80+
if commit_count == 2:
8481
raise CommitFailedException("Simulated conflict")
8582
return original_commit(tbl, requirements, updates)
8683

8784
with patch.object(catalog, "commit_table", side_effect=mock_commit):
8885
table.append(arrow_table)
8986

9087
assert commit_count == 2
91-
# Verify data was written
9288
assert len(table.scan().to_arrow()) == 3
9389

9490
def test_max_retries_exceeded(self, catalog: SqlCatalog, schema: Schema, arrow_table: pa.Table) -> None:

0 commit comments

Comments
 (0)