2525from concurrent .futures import Future
2626from datetime import datetime
2727from functools import cached_property
28- from typing import TYPE_CHECKING , Generic
28+ from typing import TYPE_CHECKING , Any , Generic
2929
3030from sortedcontainers import SortedList
3131
@@ -440,31 +440,19 @@ def commit(self) -> None:
440440
441441 properties = self ._transaction .table_metadata .properties
442442
443+ # Use explicit None checks to honor zero-valued properties
444+ max_attempts = property_as_int (properties , TableProperties .COMMIT_NUM_RETRIES )
445+ min_wait_ms = property_as_int (properties , TableProperties .COMMIT_MIN_RETRY_WAIT_MS )
446+ max_wait_ms = property_as_int (properties , TableProperties .COMMIT_MAX_RETRY_WAIT_MS )
447+ total_timeout_ms = property_as_int (properties , TableProperties .COMMIT_TOTAL_RETRY_TIME_MS )
448+
443449 retry_config = RetryConfig (
444- max_attempts = property_as_int (
445- properties ,
446- TableProperties .COMMIT_NUM_RETRIES ,
447- TableProperties .COMMIT_NUM_RETRIES_DEFAULT ,
448- )
449- or TableProperties .COMMIT_NUM_RETRIES_DEFAULT ,
450- min_wait_ms = property_as_int (
451- properties ,
452- TableProperties .COMMIT_MIN_RETRY_WAIT_MS ,
453- TableProperties .COMMIT_MIN_RETRY_WAIT_MS_DEFAULT ,
454- )
455- or TableProperties .COMMIT_MIN_RETRY_WAIT_MS_DEFAULT ,
456- max_wait_ms = property_as_int (
457- properties ,
458- TableProperties .COMMIT_MAX_RETRY_WAIT_MS ,
459- TableProperties .COMMIT_MAX_RETRY_WAIT_MS_DEFAULT ,
460- )
461- or TableProperties .COMMIT_MAX_RETRY_WAIT_MS_DEFAULT ,
462- total_timeout_ms = property_as_int (
463- properties ,
464- TableProperties .COMMIT_TOTAL_RETRY_TIME_MS ,
465- TableProperties .COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT ,
466- )
467- or TableProperties .COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT ,
450+ max_attempts = max_attempts if max_attempts is not None else TableProperties .COMMIT_NUM_RETRIES_DEFAULT ,
451+ min_wait_ms = min_wait_ms if min_wait_ms is not None else TableProperties .COMMIT_MIN_RETRY_WAIT_MS_DEFAULT ,
452+ max_wait_ms = max_wait_ms if max_wait_ms is not None else TableProperties .COMMIT_MAX_RETRY_WAIT_MS_DEFAULT ,
453+ total_timeout_ms = total_timeout_ms
454+ if total_timeout_ms is not None
455+ else TableProperties .COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT ,
468456 )
469457
470458 first_attempt = True
@@ -966,28 +954,38 @@ class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
966954
967955 _updates : tuple [TableUpdate , ...]
968956 _requirements : tuple [TableRequirement , ...]
957+ # Store operations for retry support
958+ _operations : list [tuple [Any , ...]]
969959
970960 def __init__ (self , transaction : Transaction ) -> None :
971961 super ().__init__ (transaction )
972962 self ._updates = ()
973963 self ._requirements = ()
964+ self ._operations = []
974965
975966 def _reset_state (self ) -> None :
976- """No-op: updates contain user-provided snapshot IDs that don't need refresh."""
967+ """Reset state for retry, rebuilding updates and requirements from refreshed metadata."""
968+ self ._updates = ()
969+ self ._requirements = ()
970+
971+ for operation in self ._operations :
972+ op_type = operation [0 ]
973+ if op_type == "remove_ref" :
974+ _ , ref_name = operation
975+ self ._do_remove_ref_snapshot (ref_name )
976+ elif op_type == "create_tag" :
977+ _ , snapshot_id , tag_name , max_ref_age_ms = operation
978+ self ._do_create_tag (snapshot_id , tag_name , max_ref_age_ms )
979+ elif op_type == "create_branch" :
980+ _ , snapshot_id , branch_name , max_ref_age_ms , max_snapshot_age_ms , min_snapshots_to_keep = operation
981+ self ._do_create_branch (snapshot_id , branch_name , max_ref_age_ms , max_snapshot_age_ms , min_snapshots_to_keep )
977982
978983 def _commit (self ) -> UpdatesAndRequirements :
979984 """Apply the pending changes and commit."""
980985 return self ._updates , self ._requirements
981986
982- def _remove_ref_snapshot (self , ref_name : str ) -> ManageSnapshots :
983- """Remove a snapshot ref.
984-
985- Args:
986- ref_name: branch / tag name to remove
987- Stages the updates and requirements for the remove-snapshot-ref.
988- Returns
989- This method for chaining
990- """
987+ def _do_remove_ref_snapshot (self , ref_name : str ) -> None :
988+ """Remove a snapshot ref (internal implementation for retry support)."""
991989 updates = (RemoveSnapshotRefUpdate (ref_name = ref_name ),)
992990 requirements = (
993991 AssertRefSnapshotId (
@@ -999,20 +997,14 @@ def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
999997 )
1000998 self ._updates += updates
1001999 self ._requirements += requirements
1002- return self
10031000
1004- def create_tag (self , snapshot_id : int , tag_name : str , max_ref_age_ms : int | None = None ) -> ManageSnapshots :
1005- """
1006- Create a new tag pointing to the given snapshot id.
1007-
1008- Args:
1009- snapshot_id (int): snapshot id of the existing snapshot to tag
1010- tag_name (str): name of the tag
1011- max_ref_age_ms (Optional[int]): max ref age in milliseconds
1001+ def _remove_ref_snapshot (self , ref_name : str ) -> ManageSnapshots :
1002+ self ._operations .append (("remove_ref" , ref_name ))
1003+ self ._do_remove_ref_snapshot (ref_name )
1004+ return self
10121005
1013- Returns:
1014- This for method chaining
1015- """
1006+ def _do_create_tag (self , snapshot_id : int , tag_name : str , max_ref_age_ms : int | None ) -> None :
1007+ """Create a tag (internal implementation for retry support)."""
10161008 update , requirement = self ._transaction ._set_ref_snapshot (
10171009 snapshot_id = snapshot_id ,
10181010 ref_name = tag_name ,
@@ -1021,6 +1013,10 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: int | None
10211013 )
10221014 self ._updates += update
10231015 self ._requirements += requirement
1016+
1017+ def create_tag (self , snapshot_id : int , tag_name : str , max_ref_age_ms : int | None = None ) -> ManageSnapshots :
1018+ self ._operations .append (("create_tag" , snapshot_id , tag_name , max_ref_age_ms ))
1019+ self ._do_create_tag (snapshot_id , tag_name , max_ref_age_ms )
10241020 return self
10251021
10261022 def remove_tag (self , tag_name : str ) -> ManageSnapshots :
@@ -1034,6 +1030,25 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots:
10341030 """
10351031 return self ._remove_ref_snapshot (ref_name = tag_name )
10361032
1033+ def _do_create_branch (
1034+ self ,
1035+ snapshot_id : int ,
1036+ branch_name : str ,
1037+ max_ref_age_ms : int | None ,
1038+ max_snapshot_age_ms : int | None ,
1039+ min_snapshots_to_keep : int | None ,
1040+ ) -> None :
1041+ update , requirement = self ._transaction ._set_ref_snapshot (
1042+ snapshot_id = snapshot_id ,
1043+ ref_name = branch_name ,
1044+ type = "branch" ,
1045+ max_ref_age_ms = max_ref_age_ms ,
1046+ max_snapshot_age_ms = max_snapshot_age_ms ,
1047+ min_snapshots_to_keep = min_snapshots_to_keep ,
1048+ )
1049+ self ._updates += update
1050+ self ._requirements += requirement
1051+
10371052 def create_branch (
10381053 self ,
10391054 snapshot_id : int ,
@@ -1054,16 +1069,10 @@ def create_branch(
10541069 Returns:
10551070 This for method chaining
10561071 """
1057- update , requirement = self ._transaction ._set_ref_snapshot (
1058- snapshot_id = snapshot_id ,
1059- ref_name = branch_name ,
1060- type = "branch" ,
1061- max_ref_age_ms = max_ref_age_ms ,
1062- max_snapshot_age_ms = max_snapshot_age_ms ,
1063- min_snapshots_to_keep = min_snapshots_to_keep ,
1072+ self ._operations .append (
1073+ ("create_branch" , snapshot_id , branch_name , max_ref_age_ms , max_snapshot_age_ms , min_snapshots_to_keep )
10641074 )
1065- self ._updates += update
1066- self ._requirements += requirement
1075+ self ._do_create_branch (snapshot_id , branch_name , max_ref_age_ms , max_snapshot_age_ms , min_snapshots_to_keep )
10671076 return self
10681077
10691078 def remove_branch (self , branch_name : str ) -> ManageSnapshots :
0 commit comments