diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index ce57424..8bd712e 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -293,7 +293,7 @@ def wait_for_slices( timeout. Returns: - The good slice indices + The active slice indices Raises: TimeoutError: If the timeout is reached before the slices become @@ -388,7 +388,9 @@ def wrapper(*args, **kwargs): "Elastic attempt %d out of %d", retry_index + 1, max_retries ) - self.wait_for_slices(poll_interval=poll_interval, timeout=timeout) + self.active_slice_indices = self.wait_for_slices( + poll_interval=poll_interval, timeout=timeout + ) return func(*args, **kwargs) except jax.errors.JaxRuntimeError as error: