Skip to content

[DRAFT] Paged Stashing for ring-of-experts MoE Activation Memory#3493

Draft
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/paged-stash-clean
Draft

[DRAFT] Paged Stashing for ring-of-experts MoE Activation Memory#3493
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/paged-stash-clean

Conversation

@abhinavgoel95
Copy link
Contributor

Problem

When use_ring_of_experts=True with expert parallelism EP=4, the ring-all-gather inflates the token buffer on each device to batch × EP × seq × top_k tokens. For DeepSeek V3 671B with batch=2, seq=4096, top_k=8, EP=4, this is 262,144 tokens/device.

The moe_mlpwo GMM output has shape (262144, 7168) = 3.5 GB per layer. With scan_layers=True over 60 MoE layers and host offloading, this totals ~210 GB of host memory — exceeding typical limits (180 GB).

Solution: Paged Stashing

Inspired by Megatron-LM PR #2690, this PR introduces a paged stash mechanism that decouples the compute buffer (static worst-case shape, required by XLA) from the storage buffer (actual tokens only).

Key idea

  • A single shared static buffer is allocated once for all MoE layers: total_capacity = num_moe_layers × expected_tokens_per_layer + max_chunk_slack
  • Each layer writes only its actual routed tokens into this buffer at a tracked offset using lax.dynamic_update_slice (static size, dynamic start index — XLA-compatible)
  • On backward pass, each layer reads back its compact slice and expands to the full compute shape
  • Expected actual tokens ≈ worst_case / EP = 262144 / 4 = 65536, so the shared buffer is ~4× smaller than per-layer worst-case storage

Memory comparison (DeepSeek V3 671B, EP=4, 60 MoE layers)

Approach Storage for moe_mlpwo Notes
Baseline (offload per layer) ~210 GB 262k × 7168 × 60 layers
Paged stash (this PR) ~52 GB 65k × 7168 × 60 layers

Implementation

New file: src/MaxText/layers/paged_stash.py

  • make_stash_fns(max_chunk, hidden) — returns (stash_fn, restore_fn) pair with jax.custom_vjp so gradients flow correctly through the compact/expand operations
  • stash_buffer_size(num_moe_layers, expected_per_layer, max_chunk) — buffer sizing helper
  • expected_tokens_per_layer(batch, ep, seq, top_k) — computes expected load (= worst_case / EP)

Modified: src/MaxText/layers/moe.py

When config.ring_paged_stash=True, replaces the checkpoint_name("moe_mlpwo") call with stash_fn/restore_fn calls that pack actual tokens into the shared buffer.

Modified: src/maxtext/configs/types.py

Adds two new config fields:

  • ring_paged_stash: bool = False
  • ring_paged_stash_safety_margin: float = 1.5

TODOs / Open Questions

  • Decoder scan carry wiring: The shared stash_buf and write_ptr need to be threaded through the lax.scan carry in decoders.py / deepseek.py. Currently the per-layer integration is sketched but the cross-layer buffer threading is not yet wired up. Looking for guidance on the best pattern here.
  • restore_fn backward: The gradient accumulation in restore_fn_bwd needs review — specifically the scatter from compact → full shape.
  • Tests: Unit tests for stash/restore round-trip and gradient correctness.
  • moe_mlpwi_0/moe_mlpwi_1: Same technique could apply to wi outputs if they are offloaded.

Problem
-------
With ring_of_experts=True and EP=N, the GMM activation buffer is inflated to:
  worst_case = batch * EP * seq * top_k
e.g. 262,144 tokens for bs=2, EP=4, seq=4096, top_k=8.

Checkpointing moe_mlpwo at this size costs ~210 GB host memory for a 60-layer
model, exceeding typical 180 GB CPU RAM limits.

Approach (inspired by Megatron-LM PR AI-Hypercomputer#2690 "Paged Stashing")
-------------------------------------------------------------
Decouple compute buffer size from storage buffer size:

  Forward:  run GMM on full worst_case buffer (no token dropping)
            stash only actual tokens (~65k) into a shared static buffer
            at a dynamically-tracked offset

  Backward: restore compact tokens back into worst_case-shaped buffer
            run GMM backward correctly on all tokens

The shared buffer is sized for the *expected cumulative* token count across
all layers (~60 * 65k = ~52 GB) rather than per-layer worst-case (210 GB).
Transient per-layer imbalance is absorbed by the shared budget.

Implementation
--------------
- layers/paged_stash.py: core stash/restore primitives using jax.custom_vjp
  and lax.dynamic_update_slice (static shape, dynamic start index).
- layers/moe.py: ring_of_experts path uses stash_fn/restore_fn when
  ring_paged_stash=True, falling back to checkpoint_name otherwise.
- configs/types.py: ring_paged_stash (bool) and
  ring_paged_stash_safety_margin (float, default 1.5) config fields.

Memory comparison (bs=2, EP=4, seq=4096, top_k=8, hidden=7168, 60 layers):
  Baseline (no cap):   ~210 GB host  (moe_mlpwo alone)
  Static 50% cap:      ~105 GB host
  Paged stash (1.5x):   ~78 GB host  (safety_margin=1.5 => max_chunk=98k)
  Paged stash (1.0x):   ~52 GB host  (safety_margin=1.0 => max_chunk=65k)

Status / TODOs
--------------
This is a draft for upstream feedback.  The following items are incomplete:

1. Decoder scan carry wiring: stash_buf, write_ptr, and layer_sizes must be
   threaded through the decoder __call__ signatures in decoders.py and
   deepseek.py.  Currently the moe.py code references stash_buf/write_ptr
   as free variables to show the intended interface.

2. restore_fn backward: the d_buf accumulation across layers needs careful
   handling in the scan backward -- the current implementation is a sketch.

3. moe_mlpwi_0 / moe_mlpwi_1: the same technique can be applied to the wi
   GMM outputs, saving an additional ~60 GB each if they are offloaded.

4. Tests: unit tests for stash_fn/restore_fn round-trip correctness and
   gradient check via jax.test_util.check_grads.
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the change! Have a few initial comments. Also, it seems unit tests are quite red for this PR. Could you please check?

description=(
"Enable paged stashing for ring-of-experts MoE layers. "
"Instead of checkpointing GMM activations at worst-case buffer size "
"(batch*EP*seq*top_k), compactly stores only the actual routed tokens "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the formula, could you align it with

buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
?

False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
)
ring_paged_stash: bool = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also add those into base.yml & this doc for alignment? Similar comments for other occurrences.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we call it ring_of_experts_paged_stash to distinguish ring attetnion?

description=(
"Safety margin multiplier on the expected-per-layer token count used to "
"size each layer's stash chunk (max_chunk = expected * margin). "
"1.0 = no slack; 1.5 = tolerate 50%% per-layer imbalance without dropping."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this strategy introduce dropping? If I understand correctly, no matter how we set margin value, it will iterate till the last chunk, right? If so, shall we avoid dropping word here?

# See layers/paged_stash.py for full documentation.
# -------------------------------------------------------------------
actual_tokens = jnp.sum(group_sizes)
expected = ps.expected_tokens_per_layer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We found this math will give smaller range:

buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
?

@@ -0,0 +1,222 @@
# Copyright 2025 Google LLC
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2026

Naively checkpointing the GMM outputs (moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo) at
worst-case size balloons host-offload memory to ~210 GB for a 60-layer model.
Idea (inspired by Megatron-LM PR #2690 "Paged Stashing")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we include a link?

Strategy | tokens/layer | Host memory (moe_mlpwo only)
--------------------------|---------------|-----------------------------
No cap (baseline) | 262,144 | ~210 GB ❌
50% static cap | 131,072 | ~105 GB ✅
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't see the trend of dropping strategy due to model quality. Shall we remove this option if Paged stash is obvious better here?

# Core stash / restore primitives
# ---------------------------------------------------------------------------

def make_stash_fns(max_chunk: int, hidden: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we write a unit test in moe_test.py to ensure the correctness with expert sharding? One example:

def test_megablox_expert_parallelism(self):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants