From 55d5c6e061a93bdceb34dbb8ba080c1777af12b2 Mon Sep 17 00:00:00 2001 From: Simran Kaur Date: Thu, 4 Jun 2026 18:20:00 +0000 Subject: [PATCH 1/2] Fix mesh creation to use local devices for single-host benchmarks --- Ironwood/src/benchmark_collectives.py | 8 ++++++-- Ironwood/src/benchmark_utils.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index cac1214..a1e35a2 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -35,7 +35,11 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: """Creates a mesh with the given ICI size.""" devices_needed = ici_size - devices = jax.devices() + local_devices = jax.local_devices() + if devices_needed <= len(local_devices): + devices = local_devices + else: + devices = jax.devices() if len(devices) < devices_needed: raise ValueError( @@ -52,7 +56,7 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: first_device = devices[0] device_kind = first_device.device_kind print("Device kind: ", device_kind) - mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices()) + mesh_devices = mesh_utils.create_device_mesh(shape, devices=devices) mesh = Mesh(mesh_devices, axis_names) return mesh diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index d04e2c5..9d115b4 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1191,7 +1191,7 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N ): - num_devices = jax.device_count() + num_devices = jax.local_device_count() assert ( num_devices % 2 == 0 ), "Total devices must be divisible by 2 (chip size)" @@ -1199,10 +1199,10 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: mesh_shape = (num_chips, 2) mesh_axes = ("chip", "device") mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape(mesh_shape), mesh_axes + np.array(jax.local_devices()).reshape(mesh_shape), mesh_axes ) else: - mesh = Mesh(np.array(jax.devices()), axis_names="device") + mesh = Mesh(np.array(jax.local_devices()), axis_names="device") return mesh From c1da258732c20558a983449326e86d1f7fc8c5bc Mon Sep 17 00:00:00 2001 From: Simran Kaur Date: Mon, 8 Jun 2026 14:18:11 +0000 Subject: [PATCH 2/2] Fix pylint errors in benchmark_collectives.py --- Ironwood/src/benchmark_collectives.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index a1e35a2..d034a6c 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -122,18 +122,20 @@ def unified_ici_collectives_metrics( else: replica_group_type = "non-parallel" - # Safe to access [0] without safeguard because JAX guarantees at least one device is - # always initialized (CPU fallback if no accelerator), and mesh creation has already - # validated that the requested number of devices exist. + # Safe to access [0] without safeguard because JAX guarantees at least + # one device is always initialized (CPU fallback if no accelerator), and + # mesh creation has already validated that the requested number of + # devices exist. device_kind = jax.devices()[0].device_kind if device_kind in V6E_DEVICE_KINDS: # For TPU v6e (Trillium), 1 physical chip = 1 JAX device. - # Ring collective communication volume per chip across N ranks is exactly (N - 1) shards. + # Ring collective communication volume per chip across N ranks is + # exactly (N - 1) shards. # There is no dual-core traffic multiplier needed. participating_ranks = rank - 1 tf_multiplier = 1 else: - # Dual-core logic for TPU v7x + # Dual-core logic for TPU v7x if replica_group_type == "parallel": participating_ranks = rank - 1 tf_multiplier = 2