From 989e88c480355777aab9a735ceb2b6bb033b966f Mon Sep 17 00:00:00 2001 From: Dora Hsieh Date: Wed, 25 Mar 2026 17:58:49 +0800 Subject: [PATCH] change quantization_local_shard_count to number of slices --- src/maxtext/configs/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index ea65ea5fed..ee510123e3 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2190,7 +2190,7 @@ def get_num_target_devices(): # Default quantization sharding count to number of local devices if not set. if self.quantization_local_shard_count == -1: try: - self.quantization_local_shard_count = jax.local_device_count() + self.quantization_local_shard_count = 1 + max(d.slice_index for d in jax.devices()) except RuntimeError: self.quantization_local_shard_count = 1