diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index cac1214..d034a6c 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -35,7 +35,11 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: """Creates a mesh with the given ICI size.""" devices_needed = ici_size - devices = jax.devices() + local_devices = jax.local_devices() + if devices_needed <= len(local_devices): + devices = local_devices + else: + devices = jax.devices() if len(devices) < devices_needed: raise ValueError( @@ -52,7 +56,7 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: first_device = devices[0] device_kind = first_device.device_kind print("Device kind: ", device_kind) - mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices()) + mesh_devices = mesh_utils.create_device_mesh(shape, devices=devices) mesh = Mesh(mesh_devices, axis_names) return mesh @@ -118,18 +122,20 @@ def unified_ici_collectives_metrics( else: replica_group_type = "non-parallel" - # Safe to access [0] without safeguard because JAX guarantees at least one device is - # always initialized (CPU fallback if no accelerator), and mesh creation has already - # validated that the requested number of devices exist. + # Safe to access [0] without safeguard because JAX guarantees at least + # one device is always initialized (CPU fallback if no accelerator), and + # mesh creation has already validated that the requested number of + # devices exist. device_kind = jax.devices()[0].device_kind if device_kind in V6E_DEVICE_KINDS: # For TPU v6e (Trillium), 1 physical chip = 1 JAX device. - # Ring collective communication volume per chip across N ranks is exactly (N - 1) shards. + # Ring collective communication volume per chip across N ranks is + # exactly (N - 1) shards. # There is no dual-core traffic multiplier needed. participating_ranks = rank - 1 tf_multiplier = 1 else: - # Dual-core logic for TPU v7x + # Dual-core logic for TPU v7x if replica_group_type == "parallel": participating_ranks = rank - 1 tf_multiplier = 2 diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index d04e2c5..9d115b4 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1191,7 +1191,7 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N ): - num_devices = jax.device_count() + num_devices = jax.local_device_count() assert ( num_devices % 2 == 0 ), "Total devices must be divisible by 2 (chip size)" @@ -1199,10 +1199,10 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: mesh_shape = (num_chips, 2) mesh_axes = ("chip", "device") mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape(mesh_shape), mesh_axes + np.array(jax.local_devices()).reshape(mesh_shape), mesh_axes ) else: - mesh = Mesh(np.array(jax.devices()), axis_names="device") + mesh = Mesh(np.array(jax.local_devices()), axis_names="device") return mesh