2020from pathlib import Path , PosixPath
2121
2222import pytest
23+ from unittest .mock import patch
2324
2425from pyiceberg .catalog import Catalog , MetastoreCatalog , load_catalog
2526from pyiceberg .catalog .hive import HiveCatalog
3738from pyiceberg .io import WAREHOUSE
3839from pyiceberg .partitioning import PartitionField , PartitionSpec
3940from pyiceberg .schema import INITIAL_SCHEMA_ID , Schema
41+ from pyiceberg .table import CommitTableResponse , Table , TableProperties
4042from pyiceberg .table .metadata import INITIAL_SPEC_ID
4143from pyiceberg .table .sorting import INITIAL_SORT_ORDER_ID , SortField , SortOrder
44+ from pyiceberg .table .update import TableRequirement , TableUpdate
4245from pyiceberg .transforms import BucketTransform , DayTransform , IdentityTransform
4346from pyiceberg .types import IntegerType , LongType , NestedField , TimestampType , UUIDType
4447from tests .conftest import clean_up
@@ -527,7 +530,12 @@ def test_update_table_spec_conflict(test_catalog: Catalog, test_schema: Schema,
527530 identifier = (database_name , table_name )
528531 test_catalog .create_namespace (database_name )
529532 spec = PartitionSpec (PartitionField (source_id = 1 , field_id = 1000 , transform = BucketTransform (16 ), name = "id_bucket" ))
530- table = test_catalog .create_table (identifier , test_schema , partition_spec = spec )
533+ table = test_catalog .create_table (
534+ identifier ,
535+ test_schema ,
536+ partition_spec = spec ,
537+ properties = {TableProperties .COMMIT_NUM_RETRIES : "1" }
538+ )
531539
532540 update = table .update_spec ()
533541 update .add_field (source_column_name = "tpep_pickup_datetime" , transform = BucketTransform (16 ), partition_field_name = "shard" )
@@ -546,6 +554,44 @@ def test_update_table_spec_conflict(test_catalog: Catalog, test_schema: Schema,
546554 assert loaded .spec () == PartitionSpec (spec_id = 1 )
547555
548556
557+ @pytest .mark .integration
558+ @pytest .mark .parametrize ("test_catalog" , CATALOGS )
559+ def test_update_table_spec_conflict_with_retry (test_catalog : Catalog , test_schema : Schema , table_name : str , database_name : str ) -> None :
560+ identifier = (database_name , table_name )
561+ test_catalog .create_namespace (database_name )
562+ spec = PartitionSpec (PartitionField (source_id = 1 , field_id = 1000 , transform = BucketTransform (16 ), name = "id_bucket" ))
563+ table = test_catalog .create_table (
564+ identifier ,
565+ test_schema ,
566+ partition_spec = spec ,
567+ properties = {TableProperties .COMMIT_NUM_RETRIES : "2" }
568+ )
569+ update = table .update_spec ()
570+ update .add_field (source_column_name = "tpep_pickup_datetime" , transform = BucketTransform (16 ), partition_field_name = "shard" )
571+
572+ # update with conflict
573+ conflict_table = test_catalog .load_table (identifier )
574+ with conflict_table .update_spec () as conflict_update :
575+ conflict_update .remove_field ("id_bucket" )
576+
577+ original_commit = test_catalog .commit_table
578+ commit_count = 0
579+
580+ def mock_commit (
581+ tbl : Table , requirements : tuple [TableRequirement , ...], updates : tuple [TableUpdate , ...]
582+ ) -> CommitTableResponse :
583+ nonlocal commit_count
584+ commit_count += 1
585+ return original_commit (tbl , requirements , updates )
586+
587+ with patch .object (test_catalog , "commit_table" , side_effect = mock_commit ):
588+ update .commit ()
589+
590+ loaded = test_catalog .load_table (identifier )
591+ assert loaded .spec ().spec_id == 2
592+ assert commit_count == 2
593+
594+
549595@pytest .mark .integration
550596@pytest .mark .parametrize ("test_catalog" , CATALOGS )
551597def test_update_table_spec_then_revert (test_catalog : Catalog , test_schema : Schema , table_name : str , database_name : str ) -> None :
0 commit comments