Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions Ironwood/src/benchmark_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:
"""Creates a mesh with the given ICI size."""
devices_needed = ici_size
devices = jax.devices()
local_devices = jax.local_devices()
if devices_needed <= len(local_devices):
devices = local_devices
else:
devices = jax.devices()

if len(devices) < devices_needed:
raise ValueError(
Expand All @@ -52,7 +56,7 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:
first_device = devices[0]
device_kind = first_device.device_kind
print("Device kind: ", device_kind)
mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices())
mesh_devices = mesh_utils.create_device_mesh(shape, devices=devices)
mesh = Mesh(mesh_devices, axis_names)
return mesh

Expand Down Expand Up @@ -118,18 +122,20 @@ def unified_ici_collectives_metrics(
else:
replica_group_type = "non-parallel"

# Safe to access [0] without safeguard because JAX guarantees at least one device is
# always initialized (CPU fallback if no accelerator), and mesh creation has already
# validated that the requested number of devices exist.
# Safe to access [0] without safeguard because JAX guarantees at least
# one device is always initialized (CPU fallback if no accelerator), and
# mesh creation has already validated that the requested number of
# devices exist.
device_kind = jax.devices()[0].device_kind
if device_kind in V6E_DEVICE_KINDS:
# For TPU v6e (Trillium), 1 physical chip = 1 JAX device.
# Ring collective communication volume per chip across N ranks is exactly (N - 1) shards.
# Ring collective communication volume per chip across N ranks is
# exactly (N - 1) shards.
# There is no dual-core traffic multiplier needed.
participating_ranks = rank - 1
tf_multiplier = 1
else:
# Dual-core logic for TPU v7x
# Dual-core logic for TPU v7x
if replica_group_type == "parallel":
participating_ranks = rank - 1
tf_multiplier = 2
Expand Down
6 changes: 3 additions & 3 deletions Ironwood/src/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,18 +1191,18 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh:
strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M
or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N
):
num_devices = jax.device_count()
num_devices = jax.local_device_count()
assert (
num_devices % 2 == 0
), "Total devices must be divisible by 2 (chip size)"
num_chips = num_devices // 2
mesh_shape = (num_chips, 2)
mesh_axes = ("chip", "device")
mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(mesh_shape), mesh_axes
np.array(jax.local_devices()).reshape(mesh_shape), mesh_axes
)
else:
mesh = Mesh(np.array(jax.devices()), axis_names="device")
mesh = Mesh(np.array(jax.local_devices()), axis_names="device")
return mesh


Expand Down