Skip to content

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912

Draft
tdophung wants to merge 28 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block
Draft

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
tdophung wants to merge 28 commits into
NVIDIA:mainfrom
tdophung:teddy/moe_block

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 21, 2026

Description

Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation. MoEBlock is a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton sort_chunks_by_index), grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism via shard_map

This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (wi_kernel_axes/ wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across (ep, fsdp) simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.

Fixes #2895

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • New transformer_engine/jax/flax/moe.py -- MoEBlock Linen module:
    gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
  • Extended transformer_engine/jax/permutation.py with A2A param helpers (compute_ragged_all_to_all_params, compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, local_unpermute_before_a2a) and the pure-JAX unfused_token_dispatch / unfused_token_combine paths
    with custom VJPs.
  • tests/jax/test_moe_block.py -- single-device shape, backward,
    cross-backend equivalence, aux-loss, group-topk, JIT determinism.
  • tests/jax/test_distributed_moe_block.py -- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) and data_parallelism_axes=("fsdp",) to exercise true FSDP (batch sharded across both axes).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung tdophung marked this pull request as ready for review May 5, 2026 21:47
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR introduces _MoEBlock, an experimental self-contained Flax-Linen Mixture-of-Experts block for TransformerEngine JAX. It wires together the fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton), cuBLASLt grouped_dense-based expert FFN, and ragged-all-to-all expert parallelism via shard_map.

  • transformer_engine/jax/flax/moe.py: New _MoEBlock module with gate, router, global permute, A2A EP path, expert FFN, and global combine stages; two top-level forward variants (no-EP and shard_map-wrapped A2A-EP).
  • transformer_engine/jax/permutation.py: Adds pure-JAX argsort-based dispatch/combine with custom VJPs, PureJaxPermState, routing_map_to_selected_experts, and ragged-all-to-all EP helpers.
  • transformer_engine/jax/sharding.py and gemm.py: Minor additions — ep_resource field on MeshResource, get_active_resource_axis helper, and removal of @cache on _should_enforce_v2_grouped_gemm to support monkeypatch.setenv in tests.

Confidence Score: 4/5

Safe to merge after addressing the missing ep_axis exclusion guard in the data_parallelism_axes validation loop.

The A2A-EP forward correctly addresses the recv_buffer_rows alignment fix. One gap remains: if a caller passes the EP axis name in data_parallelism_axes, the batch PartitionSpec gets a duplicate axis and dp_size is double-counted, producing an undersized ragged_all_to_all receive buffer with no useful error message.

transformer_engine/jax/flax/moe.py — specifically the data_parallelism_axes validation block in _forward_a2a_ep.

Important Files Changed

Filename Overview
transformer_engine/jax/flax/moe.py New 1174-line _MoEBlock Linen module; contains a missing validation allowing ep_axis to appear in data_parallelism_axes, which produces a duplicate-axis PartitionSpec and an undersized recv buffer.
transformer_engine/jax/permutation.py Adds pure-JAX dispatch/combine with custom VJPs and ragged-A2A helpers; logic is correct, recv_buffer_rows alignment fix is in place.
transformer_engine/jax/sharding.py Adds ep_resource to MeshResource and get_active_resource_axis helper; consistent with existing axis-resolution patterns.
transformer_engine/jax/cpp_extensions/gemm.py Removes @cache from _should_enforce_v2_grouped_gemm so monkeypatch.setenv works in tests; negligible performance impact.
tests/jax/test_moe_block.py Comprehensive single-device tests covering shape, backward, cross-backend equivalence, aux loss, group-topk, align_size, and JIT determinism.
tests/jax/test_distributed_moe_block.py EP=2 x FSDP=2 distributed test using canonical Flax-Linen sharded-init pattern; validates output, loss, aux_loss, and per-parameter gradients.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant _MoEBlock
    participant Router
    participant GlobalPermute
    participant A2A as ragged_all_to_all (EP)
    participant LocalPerm as local_permute_after_a2a
    participant ExpertFFN as _expert_ffn (grouped_dense x3)
    participant GlobalCombine

    Caller->>_MoEBlock: inputs [B, S, H]
    _MoEBlock->>Router: "gate_logits -> fused_topk_with_score_function"
    Router-->>_MoEBlock: sparse_probs, routing_map
    _MoEBlock->>GlobalPermute: _global_permute (pure_jax or triton)
    GlobalPermute-->>_MoEBlock: sorted_inputs, group_sizes [E]

    alt No-EP path
        _MoEBlock->>ExpertFFN: "sorted_inputs, group_sizes, n_groups=E"
        ExpertFFN-->>_MoEBlock: expert_outputs
    else A2A-EP path via shard_map
        _MoEBlock->>A2A: all_gather(group_sizes)
        A2A->>A2A: forward ragged_all_to_all over ep axis
        A2A->>LocalPerm: reorder recv buffer
        LocalPerm-->>A2A: sorted_x, local_group_sizes
        A2A->>ExpertFFN: sorted_x, local_group_sizes
        ExpertFFN-->>A2A: expert_outputs
        A2A->>LocalPerm: local_unpermute_before_a2a
        A2A->>A2A: reverse ragged_all_to_all
        A2A-->>_MoEBlock: y_back
    end

    _MoEBlock->>GlobalCombine: _global_combine
    GlobalCombine-->>_MoEBlock: output [B, S, H]
    _MoEBlock-->>Caller: output [B, S, H], aux_loss
Loading

Reviews (6): Last reviewed commit: "change naming and add message for experi..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/jax/flax/moe.py Outdated
tdophung added 6 commits May 5, 2026 16:35
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/moe_block branch from 8a838f3 to 6aeb491 Compare May 5, 2026 23:44
pre-commit-ci Bot and others added 2 commits May 5, 2026 23:45
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread tests/jax/test_distributed_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/jax/permutation.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
nvjax and others added 2 commits May 7, 2026 15:18
…int in C++ files, make FP8 works. Tested with current scaling

Signed-off-by: JAX Toolbox <jax@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 7, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

Comment thread transformer_engine/common/util/multi_stream.cpp Outdated
Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/gemm.cpp Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread tests/jax/test_moe_block.py Outdated
Comment thread transformer_engine/common/util/multi_stream.cpp Outdated
… grad tol to 5e-2, move arch/align_size docs into MoEBlock class

Signed-off-by: tdophung <tdophung@nvidia.com>
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment on lines +909 to +914
batch_divisor = num_ep * dp_size
if global_batch_size % batch_divisor != 0:
raise ValueError(
f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}"
)
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Receive buffer undersized when align_size > 0 + EP are combined

recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:

recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)

This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.

Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should go with exposing GroupMLP VJP first before the MoE module to enable future possible fusions.

tdophung added 3 commits May 12, 2026 15:53
…ing None as group_topk, align_size rename,

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment on lines +923 to +928
for ax in self.data_parallelism_axes:
if ax not in mesh.shape:
raise ValueError(
f"data_parallelism_axes contains {ax!r} but mesh has"
f" axes {tuple(mesh.shape.keys())}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The validation loop checks that every axis in data_parallelism_axes exists in the mesh but does not check that the axis differs from ep_axis. If a caller passes data_parallelism_axes=("ep",) when ep_axis="ep", batch_pspec_axis becomes ("ep", "ep") — a duplicate-axis PartitionSpec that JAX rejects with a cryptic error. Independently, dp_size accumulates mesh.shape["ep"] a second time, so recv_buffer_rows is undersized by a factor of num_ep and batch_divisor becomes num_ep², both causing wrong runtime behaviour before JAX ever sees the bad spec.

Suggested change
for ax in self.data_parallelism_axes:
if ax not in mesh.shape:
raise ValueError(
f"data_parallelism_axes contains {ax!r} but mesh has"
f" axes {tuple(mesh.shape.keys())}"
)
for ax in self.data_parallelism_axes:
if ax not in mesh.shape:
raise ValueError(
f"data_parallelism_axes contains {ax!r} but mesh has"
f" axes {tuple(mesh.shape.keys())}"
)
if ax == ep_axis:
raise ValueError(
f"data_parallelism_axes contains {ax!r}, which is the same as the"
f" EP axis {ep_axis!r}. The EP axis is already included in the batch"
" sharding spec; listing it again produces a duplicate-axis"
" PartitionSpec and an undersized ragged_all_to_all receive buffer."
)

…er functions to group permute -> a2a -> local permute to dispatch and combine

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as draft May 15, 2026 16:46
@tdophung
Copy link
Copy Markdown
Collaborator Author

Changing back to draft to not spam people's email while I push commits to this branch for the full unrolling of ops in a big VJP.

pre-commit-ci Bot and others added 13 commits May 15, 2026 21:47
Signed-off-by: tdophung <tdophung@nvidia.com>
The triton backend of test_distributed_moe_vjp hangs in bwd:
MainThread parks in _pjit_call_impl_python, one GPU pinned at
100%, no NCCL ops enqueued. Root cause is async-dispatch race
between our Triton kernels (which use input_output_aliases on a
pre-zeroed output buffer) and the downstream NCCL
ragged_all_to_all -- XLA mis-tracks the dependency edge and the
collective launches before the kernel finishes writing
sorted_inputs; different ranks then read different versions of
the per-expert token counts, deadlocking NCCL.

Workaround: set CUDA_LAUNCH_BLOCKING=1 in the test runner. Smoke
suite now passes in <1 min across 3 consecutive runs. Slowdown
on these correctness shapes is negligible.

Also flip the faulthandler watchdog to all_threads=True so the
next investigator can see worker-thread frames, not just MainThread.

Signed-off-by: tdophung <tdophung@nvidia.com>
The 'CRITICAL: -p no:typeguard' section was based on an early
incorrect bisection. The actual root cause is an async-dispatch
race between our Triton custom_calls (with input_output_aliases
on pre-zeroed output buffers) and the downstream NCCL collective
in the same shard_map body -- XLA mis-handles the cross-stream
sync edge from the aliased custom_call to the NCCL op.

Why the old _MoEBlock path didn't hit this: each primitive
(token_dispatch, permute, ragged_all_to_all, sort_chunks) sat
behind its own custom_vjp boundary, which acted as an implicit
sync barrier. The new unified moe() custom_vjp removes those
boundaries (so ScaledTensor can survive across them), exposing
the bug.

Document CUDA_LAUNCH_BLOCKING=1 as the current workaround and
flag the proper fix (stream sync in triton_call_lowering, or
file an upstream JAX FFI bug) for follow-up.

Signed-off-by: tdophung <tdophung@nvidia.com>
faulthandler.dump_traceback_later() takes only (timeout, repeat,
file, exit) -- there is no all_threads parameter. It already
dumps every Python thread by default. Don't confuse with
faulthandler.register() which does take all_threads.

The bad kwarg caused pytest collection to crash with TypeError
before any test could run.

Signed-off-by: tdophung <tdophung@nvidia.com>
Experiment C of the bwd-hang investigation. When
NVTE_MOE_OPT_BARRIER=1, insert jax.lax.optimization_barrier on
the Triton-kernel output before every immediately-following NCCL
ragged_all_to_all in _dispatch, _combine_bwd, _dispatch_bwd.

If this fixes the hang WITHOUT CUDA_LAUNCH_BLOCKING=1, the bug
is fixable at the lowering layer by forcing materialization
between the aliased Triton custom_call and the NCCL collective,
which is much cheaper than serializing every CUDA launch.

Off by default so the existing CUDA_LAUNCH_BLOCKING workaround
continues to work without this opt-in.

Signed-off-by: tdophung <tdophung@nvidia.com>
Experiment C results from dlcluster job 1045311:
  C.0 (no barrier, no blocking):   PASS in 32s (won the race)
  C.1 (with barrier, single test): TIMEOUT_HANG at 180s
  C.2 (with barrier, full smoke):  TIMEOUT_HANG at 480s

The optimization_barrier doesn't fix the hang and in fact makes
it MORE reliably reproducible. This rules out the HLO-dataflow
fix hypothesis: the bug is not a missing XLA stream-sync edge.

C.0 passing demonstrates the hang is RACY not deterministic --
matches Olli Lupton's October 2025 memo on JAX/XLA multi-GPU
deadlocks from lazy CUDA module loading interleaved with active
NCCL collectives (nvbug/5564750). With multi-GPU-per-process
the lazy load of a Triton kernel on GPU0 can take the global
driver lock and block on cuiStreamSynchronize for an active
NCCL kernel that itself depends on GPU1's progress, which is
blocked on the same lock.

Triton kernels can't be pre-loaded via FFI 'initialize' (JAX core
owns the primitive). The proper fix is multiprocess launch
(one JAX process per GPU) -- see follow-up commit.

Signed-off-by: tdophung <tdophung@nvidia.com>
Companion to test_distributed_moe_vjp.py that avoids the multi-GPU
lazy-load + active-NCCL deadlock entirely by giving each GPU its
own Python process / CUDA driver context. With one device per
process there is no global module-load lock shared across the
threads driving different GPUs, so the failure mode described in
past_JAX_XLA_deadlock.txt (nvbug/5564750) cannot occur and no
CUDA_LAUNCH_BLOCKING=1 workaround is needed.

Pattern mirrors examples/jax/encoder/test_multiprocessing_encoder.py:
- pytest --num-process=N --process-id=i CLI options (added to
  tests/jax/conftest.py, defaults to 0 = single-process so harmless
  for other tests)
- jax.distributed.initialize(... local_device_ids=process_id ...)
  at module top-level
- module-level skip when not launched via the runner so direct
  pytest collection on tests/jax/ is harmless
- run_multiprocess_moe_vjp.sh forks N=nvidia-smi pytest processes
  and waits for all of them

Tests themselves are 1:1 with TestMoeVjpDistributedSmoke from the
single-process file (fwd_and_bwd_smoke, aux_loss_smoke, parity).

Keeping BOTH files in tree: single-process is simpler for dev-loop
iteration (with CUDA_LAUNCH_BLOCKING=1 workaround), multiprocess
is what CI should run for guaranteed correctness without
workarounds.

Signed-off-by: tdophung <tdophung@nvidia.com>
First multiprocess attempt hung at jnp.any(g != 0.0).item() in the
post-grad assertion loop. In a single-process test that's a local
reduction; in multi-host JAX it implicitly triggers a cross-process
collective (all-gather + reduce) under the hood, and any small
divergence in graph build order across processes can deadlock.

Replace the host-side reductions with:
  multihost_utils.process_allgather(x, tiled=True) -> np.asarray
Then run the finite / non-zero / parity asserts entirely in numpy.
Every process gathers in lockstep, no surprise JAX collectives.

Also: launcher now respects MOE_VJP_MP_LOG_DIR so per-process logs
survive on a host-mounted volume after the container exits.

Signed-off-by: tdophung <tdophung@nvidia.com>
process_allgather hung on procs 0+3 while procs 1+2 finished
(divergence detected from per-process logs in dlcluster job
1046001). A multi-host collective inside the post-grad assertion
loop is too easy to deadlock when even one assertion fires on
some procs first.

Use the local addressable shard on each process via
arr.addressable_data(0) -> np.asarray. Same correctness coverage
(if any rank has NaN, that rank's assertion fires) without
needing to emit a cross-process collective for the test
machinery itself.

Signed-off-by: tdophung <tdophung@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JAX] Create initial MoE Block

4 participants