Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,21 @@ defmodule EXLA do
The metadata is:

* `:key` - the compilation key for debugging

## Sharding

EXLA supports sharding, which is a way to partition a computation across multiple devices.
There are a number of collective operations that are supported by sharding.

### [`all_gather`](https://openxla.org/stablehlo/spec#all_gather)

#### Options

* `:all_gather_dim` - the dimension along which to gather
* `:replica_groups` - 2D list defining how replicas are grouped
* `:use_global_device_ids` - Whether to use global device IDs (default: `false`)
* `:channel_id` - Channel ID for communication (optional)

"""

@behaviour Nx.Defn.Compiler
Expand Down
21 changes: 21 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,27 @@ defmodule EXLA.Defn do
EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type)
end

## to_operator collective ops

defp to_operator(:all_gather, [%Value{} = tensor, opts], ans, _state) do
all_gather_dim = Keyword.fetch!(opts, :all_gather_dim)
replica_groups = Keyword.fetch!(opts, :replica_groups)
use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false)

# We might want to surface all_gather as an operation that takes a container of operands instead of a single one.
[result] =
Value.all_gather(
[tensor],
expr_to_typespec(ans),
all_gather_dim,
replica_groups,
use_global_device_ids,
opts[:channel_id]
)

result
end

defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do
n = opts[:length]
axis = opts[:axis]
Expand Down
23 changes: 23 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ defmodule EXLA.MLIR.Value do
end
end

def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, channel_id \\ nil) do
result_types = typespecs_to_mlir_types([typespec])

num_groups = length(replica_groups)
group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0
flat_groups = List.flatten(replica_groups)

attributes = [
all_gather_dim: attr_i64(all_gather_dim),
replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}),
use_global_device_ids: attr_boolean(use_global_device_ids)
]

attributes =
if channel_id do
Keyword.put(attributes, :channel_id, attr_i64(channel_id))
else
attributes
end

op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes)
end

defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
%{type: lhs_type} = get_typespec(lhs)
%{type: rhs_type} = get_typespec(rhs)
Expand Down
107 changes: 107 additions & 0 deletions exla/test/exla/defn/sharding_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -890,5 +890,112 @@ defmodule EXLA.Defn.ShardingTest do
end)
end
end

@moduletag :multi_device
test "generates correct MLIR with all_gather" do
fun = fn x, y -> Nx.add(x, y)
|> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]])
|> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]])
end

mesh = %Mesh{name: "mesh", shape: {2, 2}}
# First arg: 0..15 (8x2), shard dim 0 on mesh axis 0, dim 1 on mesh axis 1
# Second arg: 100..115 (8x2), same sharding — makes sharded results easy to read
input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0], 1 => [1]}]

# For mesh {2, 2}, 4 partitions. Each gets {4, 1}. Full 8x2 row-major: [[0,1],[2,3],...,[14,15]].
# Partition (axis_0, axis_1): (0,0)=rows 0-3 col 0, (0,1)=rows 0-3 col 1, (1,0)=rows 4-7 col 0, (1,1)=rows 4-7 col 1.
# So partition 0 gets (0,0),(1,0),(2,0),(3,0) = 0,2,4,6; partition 1 gets (0,1),(1,1),... = 1,3,5,7; etc.
args = [
# partition 0: rows 0–3 col 0 -> 0,2,4,6 and 100,102,104,106
[Nx.tensor([[0], [2], [4], [6]]), Nx.tensor([[100], [102], [104], [106]])],
# partition 1: rows 0–3 col 1 -> 1,3,5,7 and 101,103,105,107
[Nx.tensor([[1], [3], [5], [7]]), Nx.tensor([[101], [103], [105], [107]])],
# partition 2: rows 4–7 col 0 -> 8,10,12,14 and 108,110,112,114
[Nx.tensor([[8], [10], [12], [14]]), Nx.tensor([[108], [110], [112], [114]])],
# partition 3: rows 4–7 col 1 -> 9,11,13,15 and 109,111,113,115
[Nx.tensor([[9], [11], [13], [15]]), Nx.tensor([[109], [111], [113], [115]])]
]

result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings)

expected_mlir = """
module {
sdy.mesh @mesh = <["axis_0"=2, "axis_1"=2]>
func.func public @main(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}, %arg1: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}) -> tensor<8x2xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<8x2xi32>
%1 = "stablehlo.all_gather"(%0) <{all_gather_dim = 0 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32>
%2 = "stablehlo.all_gather"(%1) <{all_gather_dim = 1 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32>
return %2 : tensor<8x2xi32>
}
}
"""

assert expected_mlir == result.mlir_module

results = EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args)

assert length(results) == 4

# After all_gather: full first arg 0..15 + full second 100..115 -> 100,102,...,130
expected_result =
Nx.tensor([
[100, 102],
[104, 106],
[108, 110],
[112, 114],
[116, 118],
[120, 122],
[124, 126],
[128, 130]
])

Enum.zip_with([results, 0..3], fn [result, i] ->
assert_equal(result, expected_result)
assert result.data.buffer.device_id == i
end)
end

@moduletag :multi_device
test "can return partially sharded results" do
fun = fn x, y -> Nx.add(x, y) end

mesh = %Mesh{name: "mesh", shape: {2, 2}}
# Inputs sharded on both axes
input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0], 1 => [1]}]
# Output: sharded only on axis 0 (dim 1 replicated) -> each partition gets {4, 2}
output_shardings = [%{0 => [0]}]

# Logical x: 8x2, y: 8x2. Each partition gets {4, 1} of each
args = [
[Nx.tensor([[0], [1], [2], [3]]), Nx.tensor([[100], [101], [102], [103]])],
[Nx.tensor([[10], [11], [12], [13]]), Nx.tensor([[110], [111], [112], [113]])],
[Nx.tensor([[4], [5], [6], [7]]), Nx.tensor([[104], [105], [106], [107]])],
[Nx.tensor([[14], [15], [16], [17]]), Nx.tensor([[114], [115], [116], [117]])]
]

results =
EXLA.shard_jit(fun, mesh,
input_shardings: input_shardings,
output_shardings: output_shardings
).(args)

assert length(results) == 4

# Partially sharded output: dim 0 sharded on axis 0, dim 1 not in output spec
# Each device returns its local shard {4, 1} (x+y computed locally)
# Dev0: col0 rows 0-3, Dev1: col1 rows 0-3, Dev2: col0 rows 4-7, Dev3: col1 rows 4-7
expected_results = [
Nx.tensor([[100], [102], [104], [106]]),
Nx.tensor([[120], [122], [124], [126]]),
Nx.tensor([[108], [110], [112], [114]]),
Nx.tensor([[128], [130], [132], [134]])
]

Enum.zip_with([results, expected_results, 0..3], fn [result, expected, i] ->
assert_equal(result, expected)
assert result.data.buffer.device_id == i
end)
end
end
end
27 changes: 27 additions & 0 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,33 @@ defmodule Nx.Defn.Expr do
expr(out, context, :gather, [tensor, indices, opts])
end

def all_gather(tensor, opts) do
{[tensor], context} = to_exprs([tensor])

_all_gather_dim = opts[:all_gather_dim]
replica_groups = opts[:replica_groups]

# Calculate group size (number of replicas per group)
_group_size =
case replica_groups do
[first_group | _] -> length(first_group)
[] -> 1
end

# Calculate output shape by multiplying the gather dimension by group_size
input_shape = tensor.shape
output_shape =
input_shape
# |> Tuple.to_list()
# |> List.update_at(all_gather_dim, &(&1 * group_size))
# |> List.to_tuple()

# Create output tensor with the new shape
Comment on lines +1172 to +1190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few unused values here due to the stray comments that should all be removed. Also, just pass tensor as out directly

out = %{tensor | shape: output_shape}

expr(out, context, :all_gather, [tensor, opts])
end

@impl true
def reverse(out, tensor, axes) do
tensor = to_expr(tensor)
Expand Down
18 changes: 18 additions & 0 deletions nx/lib/nx/defn/kernel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,24 @@ defmodule Nx.Defn.Kernel do
end
end

@doc """
Gathers tensors from all replicas along a specified dimension.

This operation concatenates tensors from multiple replicas/devices along
the specified dimension. Requires a backend that supports multi-device operations.

## Parameters

* `tensor` - The input tensor to gather

* `opts` - Optional keyword list. These are backend- and compiler-specific;
see your backend or compiler docs for supported options.

"""
def all_gather(tensor, opts \\ []) do
Nx.Defn.Expr.all_gather(tensor, opts)
end

@definitions (Module.definitions_in(__MODULE__, :def) ++
Module.definitions_in(__MODULE__, :defmacro)) --
[
Expand Down
26 changes: 26 additions & 0 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2952,4 +2952,30 @@ defmodule Nx.DefnTest do
assert vectorized_metadata_tuple(x, z) == vec_nonvec_result
end
end

describe "sharding" do
defn all_gather_test(tensor) do
Nx.Defn.Kernel.all_gather(tensor, all_gather_dim: 0, replica_groups: [[0]])
end

test "all_gather produces correct expr format for compiler" do
# Uses debug_expr to inspect the expression without compiling.
# Guarantees the format passed to compilers (e.g. EXLA) stays stable.
assert %T{data: %Expr{op: :all_gather, args: [tensor, opts]}} =
Nx.Defn.debug_expr(&all_gather_test/1).(Nx.tensor([1, 2, 3, 4]))

assert %T{data: %Expr{op: :parameter, args: [0]}} = tensor

# Compilers expect opts with :all_gather_dim and :replica_groups
assert opts[:all_gather_dim] == 0
assert opts[:replica_groups] == [[0]]
end

@tag compiler: Evaluator
test "all_gather works" do
assert_raise UndefinedFunctionError, fn ->
all_gather_test(Nx.tensor([1, 2, 3, 4]))
end
end
end
end
Loading