From 0a3c106f66fd730e64a2f4b4e620311f74219fe9 Mon Sep 17 00:00:00 2001 From: Shutong Li Date: Thu, 5 Feb 2026 18:29:29 -0800 Subject: [PATCH] Fix two issues that blocks training loop with continuous checkpoint enabled. PiperOrigin-RevId: 866204905 --- src/maxtext/common/checkpointing.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 61c4d2c5e1..9d3a347268 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -18,6 +18,7 @@ from typing import Any, Optional from absl import flags +import datetime from etils import epath from flax.training import train_state import jax @@ -248,6 +249,11 @@ def create_orbax_checkpoint_manager( else: save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(interval=save_interval_steps) preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep) + async_options = None + if enable_continuous_checkpointing: + async_options = ocp.AsyncOptions( + timeout_secs=int(datetime.timedelta(minutes=60).total_seconds()), + ) manager = CheckpointManager( p, item_names=item_names, @@ -257,6 +263,7 @@ def create_orbax_checkpoint_manager( enable_async_checkpointing=use_async, save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, + async_options=async_options, ), logger=orbax_logger, ) @@ -728,6 +735,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= if config and config.enable_checkpointing: if ( force + or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing) or (step % config.checkpoint_period == 0) or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0) ):