Skip to content

Commit af2dd4a

Browse files
committed
fix ManageSnapshots retry bug
1 parent 84cc527 commit af2dd4a

File tree

2 files changed

+550
-56
lines changed

2 files changed

+550
-56
lines changed

pyiceberg/table/update/snapshot.py

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from concurrent.futures import Future
2626
from datetime import datetime
2727
from functools import cached_property
28-
from typing import TYPE_CHECKING, Generic
28+
from typing import TYPE_CHECKING, Any, Generic
2929

3030
from sortedcontainers import SortedList
3131

@@ -440,31 +440,19 @@ def commit(self) -> None:
440440

441441
properties = self._transaction.table_metadata.properties
442442

443+
# Use explicit None checks to honor zero-valued properties
444+
max_attempts = property_as_int(properties, TableProperties.COMMIT_NUM_RETRIES)
445+
min_wait_ms = property_as_int(properties, TableProperties.COMMIT_MIN_RETRY_WAIT_MS)
446+
max_wait_ms = property_as_int(properties, TableProperties.COMMIT_MAX_RETRY_WAIT_MS)
447+
total_timeout_ms = property_as_int(properties, TableProperties.COMMIT_TOTAL_RETRY_TIME_MS)
448+
443449
retry_config = RetryConfig(
444-
max_attempts=property_as_int(
445-
properties,
446-
TableProperties.COMMIT_NUM_RETRIES,
447-
TableProperties.COMMIT_NUM_RETRIES_DEFAULT,
448-
)
449-
or TableProperties.COMMIT_NUM_RETRIES_DEFAULT,
450-
min_wait_ms=property_as_int(
451-
properties,
452-
TableProperties.COMMIT_MIN_RETRY_WAIT_MS,
453-
TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT,
454-
)
455-
or TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT,
456-
max_wait_ms=property_as_int(
457-
properties,
458-
TableProperties.COMMIT_MAX_RETRY_WAIT_MS,
459-
TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT,
460-
)
461-
or TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT,
462-
total_timeout_ms=property_as_int(
463-
properties,
464-
TableProperties.COMMIT_TOTAL_RETRY_TIME_MS,
465-
TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT,
466-
)
467-
or TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT,
450+
max_attempts=max_attempts if max_attempts is not None else TableProperties.COMMIT_NUM_RETRIES_DEFAULT,
451+
min_wait_ms=min_wait_ms if min_wait_ms is not None else TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT,
452+
max_wait_ms=max_wait_ms if max_wait_ms is not None else TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT,
453+
total_timeout_ms=total_timeout_ms
454+
if total_timeout_ms is not None
455+
else TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT,
468456
)
469457

470458
first_attempt = True
@@ -966,28 +954,38 @@ class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
966954

967955
_updates: tuple[TableUpdate, ...]
968956
_requirements: tuple[TableRequirement, ...]
957+
# Store operations for retry support
958+
_operations: list[tuple[Any, ...]]
969959

970960
def __init__(self, transaction: Transaction) -> None:
971961
super().__init__(transaction)
972962
self._updates = ()
973963
self._requirements = ()
964+
self._operations = []
974965

975966
def _reset_state(self) -> None:
976-
"""No-op: updates contain user-provided snapshot IDs that don't need refresh."""
967+
"""Reset state for retry, rebuilding updates and requirements from refreshed metadata."""
968+
self._updates = ()
969+
self._requirements = ()
970+
971+
for operation in self._operations:
972+
op_type = operation[0]
973+
if op_type == "remove_ref":
974+
_, ref_name = operation
975+
self._do_remove_ref_snapshot(ref_name)
976+
elif op_type == "create_tag":
977+
_, snapshot_id, tag_name, max_ref_age_ms = operation
978+
self._do_create_tag(snapshot_id, tag_name, max_ref_age_ms)
979+
elif op_type == "create_branch":
980+
_, snapshot_id, branch_name, max_ref_age_ms, max_snapshot_age_ms, min_snapshots_to_keep = operation
981+
self._do_create_branch(snapshot_id, branch_name, max_ref_age_ms, max_snapshot_age_ms, min_snapshots_to_keep)
977982

978983
def _commit(self) -> UpdatesAndRequirements:
979984
"""Apply the pending changes and commit."""
980985
return self._updates, self._requirements
981986

982-
def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
983-
"""Remove a snapshot ref.
984-
985-
Args:
986-
ref_name: branch / tag name to remove
987-
Stages the updates and requirements for the remove-snapshot-ref.
988-
Returns
989-
This method for chaining
990-
"""
987+
def _do_remove_ref_snapshot(self, ref_name: str) -> None:
988+
"""Remove a snapshot ref (internal implementation for retry support)."""
991989
updates = (RemoveSnapshotRefUpdate(ref_name=ref_name),)
992990
requirements = (
993991
AssertRefSnapshotId(
@@ -999,20 +997,14 @@ def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
999997
)
1000998
self._updates += updates
1001999
self._requirements += requirements
1002-
return self
10031000

1004-
def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: int | None = None) -> ManageSnapshots:
1005-
"""
1006-
Create a new tag pointing to the given snapshot id.
1007-
1008-
Args:
1009-
snapshot_id (int): snapshot id of the existing snapshot to tag
1010-
tag_name (str): name of the tag
1011-
max_ref_age_ms (Optional[int]): max ref age in milliseconds
1001+
def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
1002+
self._operations.append(("remove_ref", ref_name))
1003+
self._do_remove_ref_snapshot(ref_name)
1004+
return self
10121005

1013-
Returns:
1014-
This for method chaining
1015-
"""
1006+
def _do_create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: int | None) -> None:
1007+
"""Create a tag (internal implementation for retry support)."""
10161008
update, requirement = self._transaction._set_ref_snapshot(
10171009
snapshot_id=snapshot_id,
10181010
ref_name=tag_name,
@@ -1021,6 +1013,10 @@ def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: int | None
10211013
)
10221014
self._updates += update
10231015
self._requirements += requirement
1016+
1017+
def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: int | None = None) -> ManageSnapshots:
1018+
self._operations.append(("create_tag", snapshot_id, tag_name, max_ref_age_ms))
1019+
self._do_create_tag(snapshot_id, tag_name, max_ref_age_ms)
10241020
return self
10251021

10261022
def remove_tag(self, tag_name: str) -> ManageSnapshots:
@@ -1034,6 +1030,25 @@ def remove_tag(self, tag_name: str) -> ManageSnapshots:
10341030
"""
10351031
return self._remove_ref_snapshot(ref_name=tag_name)
10361032

1033+
def _do_create_branch(
1034+
self,
1035+
snapshot_id: int,
1036+
branch_name: str,
1037+
max_ref_age_ms: int | None,
1038+
max_snapshot_age_ms: int | None,
1039+
min_snapshots_to_keep: int | None,
1040+
) -> None:
1041+
update, requirement = self._transaction._set_ref_snapshot(
1042+
snapshot_id=snapshot_id,
1043+
ref_name=branch_name,
1044+
type="branch",
1045+
max_ref_age_ms=max_ref_age_ms,
1046+
max_snapshot_age_ms=max_snapshot_age_ms,
1047+
min_snapshots_to_keep=min_snapshots_to_keep,
1048+
)
1049+
self._updates += update
1050+
self._requirements += requirement
1051+
10371052
def create_branch(
10381053
self,
10391054
snapshot_id: int,
@@ -1054,16 +1069,10 @@ def create_branch(
10541069
Returns:
10551070
This for method chaining
10561071
"""
1057-
update, requirement = self._transaction._set_ref_snapshot(
1058-
snapshot_id=snapshot_id,
1059-
ref_name=branch_name,
1060-
type="branch",
1061-
max_ref_age_ms=max_ref_age_ms,
1062-
max_snapshot_age_ms=max_snapshot_age_ms,
1063-
min_snapshots_to_keep=min_snapshots_to_keep,
1072+
self._operations.append(
1073+
("create_branch", snapshot_id, branch_name, max_ref_age_ms, max_snapshot_age_ms, min_snapshots_to_keep)
10641074
)
1065-
self._updates += update
1066-
self._requirements += requirement
1075+
self._do_create_branch(snapshot_id, branch_name, max_ref_age_ms, max_snapshot_age_ms, min_snapshots_to_keep)
10671076
return self
10681077

10691078
def remove_branch(self, branch_name: str) -> ManageSnapshots:

0 commit comments

Comments
 (0)