[DRAFT] Paged Stashing for ring-of-experts MoE Activation Memory#3493
[DRAFT] Paged Stashing for ring-of-experts MoE Activation Memory#3493abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Conversation
Problem ------- With ring_of_experts=True and EP=N, the GMM activation buffer is inflated to: worst_case = batch * EP * seq * top_k e.g. 262,144 tokens for bs=2, EP=4, seq=4096, top_k=8. Checkpointing moe_mlpwo at this size costs ~210 GB host memory for a 60-layer model, exceeding typical 180 GB CPU RAM limits. Approach (inspired by Megatron-LM PR AI-Hypercomputer#2690 "Paged Stashing") ------------------------------------------------------------- Decouple compute buffer size from storage buffer size: Forward: run GMM on full worst_case buffer (no token dropping) stash only actual tokens (~65k) into a shared static buffer at a dynamically-tracked offset Backward: restore compact tokens back into worst_case-shaped buffer run GMM backward correctly on all tokens The shared buffer is sized for the *expected cumulative* token count across all layers (~60 * 65k = ~52 GB) rather than per-layer worst-case (210 GB). Transient per-layer imbalance is absorbed by the shared budget. Implementation -------------- - layers/paged_stash.py: core stash/restore primitives using jax.custom_vjp and lax.dynamic_update_slice (static shape, dynamic start index). - layers/moe.py: ring_of_experts path uses stash_fn/restore_fn when ring_paged_stash=True, falling back to checkpoint_name otherwise. - configs/types.py: ring_paged_stash (bool) and ring_paged_stash_safety_margin (float, default 1.5) config fields. Memory comparison (bs=2, EP=4, seq=4096, top_k=8, hidden=7168, 60 layers): Baseline (no cap): ~210 GB host (moe_mlpwo alone) Static 50% cap: ~105 GB host Paged stash (1.5x): ~78 GB host (safety_margin=1.5 => max_chunk=98k) Paged stash (1.0x): ~52 GB host (safety_margin=1.0 => max_chunk=65k) Status / TODOs -------------- This is a draft for upstream feedback. The following items are incomplete: 1. Decoder scan carry wiring: stash_buf, write_ptr, and layer_sizes must be threaded through the decoder __call__ signatures in decoders.py and deepseek.py. Currently the moe.py code references stash_buf/write_ptr as free variables to show the intended interface. 2. restore_fn backward: the d_buf accumulation across layers needs careful handling in the scan backward -- the current implementation is a sketch. 3. moe_mlpwi_0 / moe_mlpwi_1: the same technique can be applied to the wi GMM outputs, saving an additional ~60 GB each if they are offloaded. 4. Tests: unit tests for stash_fn/restore_fn round-trip correctness and gradient check via jax.test_util.check_grads.
RissyRan
left a comment
There was a problem hiding this comment.
Thank you for the change! Have a few initial comments. Also, it seems unit tests are quite red for this PR. Could you please check?
| description=( | ||
| "Enable paged stashing for ring-of-experts MoE layers. " | ||
| "Instead of checkpointing GMM activations at worst-case buffer size " | ||
| "(batch*EP*seq*top_k), compactly stores only the actual routed tokens " |
There was a problem hiding this comment.
for the formula, could you align it with
maxtext/src/maxtext/layers/moe.py
Line 1170 in 16b6848
| False, | ||
| description="Whether to use Ring of Experts for sparse matmul expert parallelism.", | ||
| ) | ||
| ring_paged_stash: bool = Field( |
There was a problem hiding this comment.
Shall we also add those into base.yml & this doc for alignment? Similar comments for other occurrences.
There was a problem hiding this comment.
Shall we call it ring_of_experts_paged_stash to distinguish ring attetnion?
| description=( | ||
| "Safety margin multiplier on the expected-per-layer token count used to " | ||
| "size each layer's stash chunk (max_chunk = expected * margin). " | ||
| "1.0 = no slack; 1.5 = tolerate 50%% per-layer imbalance without dropping." |
There was a problem hiding this comment.
Will this strategy introduce dropping? If I understand correctly, no matter how we set margin value, it will iterate till the last chunk, right? If so, shall we avoid dropping word here?
| # See layers/paged_stash.py for full documentation. | ||
| # ------------------------------------------------------------------- | ||
| actual_tokens = jnp.sum(group_sizes) | ||
| expected = ps.expected_tokens_per_layer( |
There was a problem hiding this comment.
We found this math will give smaller range:
maxtext/src/maxtext/layers/moe.py
Line 1170 in 16b6848
| @@ -0,0 +1,222 @@ | |||
| # Copyright 2025 Google LLC | |||
| Naively checkpointing the GMM outputs (moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo) at | ||
| worst-case size balloons host-offload memory to ~210 GB for a 60-layer model. | ||
| Idea (inspired by Megatron-LM PR #2690 "Paged Stashing") |
There was a problem hiding this comment.
Shall we include a link?
| Strategy | tokens/layer | Host memory (moe_mlpwo only) | ||
| --------------------------|---------------|----------------------------- | ||
| No cap (baseline) | 262,144 | ~210 GB ❌ | ||
| 50% static cap | 131,072 | ~105 GB ✅ |
There was a problem hiding this comment.
We didn't see the trend of dropping strategy due to model quality. Shall we remove this option if Paged stash is obvious better here?
| # Core stash / restore primitives | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| def make_stash_fns(max_chunk: int, hidden: int): |
There was a problem hiding this comment.
Shall we write a unit test in moe_test.py to ensure the correctness with expert sharding? One example:
maxtext/tests/unit/moe_test.py
Line 542 in 16b6848
Problem
When
use_ring_of_experts=Truewith expert parallelism EP=4, the ring-all-gather inflates the token buffer on each device tobatch × EP × seq × top_ktokens. For DeepSeek V3 671B withbatch=2, seq=4096, top_k=8, EP=4, this is 262,144 tokens/device.The
moe_mlpwoGMM output has shape(262144, 7168)= 3.5 GB per layer. Withscan_layers=Trueover 60 MoE layers and host offloading, this totals ~210 GB of host memory — exceeding typical limits (180 GB).Solution: Paged Stashing
Inspired by Megatron-LM PR #2690, this PR introduces a paged stash mechanism that decouples the compute buffer (static worst-case shape, required by XLA) from the storage buffer (actual tokens only).
Key idea
total_capacity = num_moe_layers × expected_tokens_per_layer + max_chunk_slacklax.dynamic_update_slice(static size, dynamic start index — XLA-compatible)worst_case / EP = 262144 / 4 = 65536, so the shared buffer is ~4× smaller than per-layer worst-case storageMemory comparison (DeepSeek V3 671B, EP=4, 60 MoE layers)
moe_mlpwoImplementation
New file:
src/MaxText/layers/paged_stash.pymake_stash_fns(max_chunk, hidden)— returns(stash_fn, restore_fn)pair withjax.custom_vjpso gradients flow correctly through the compact/expand operationsstash_buffer_size(num_moe_layers, expected_per_layer, max_chunk)— buffer sizing helperexpected_tokens_per_layer(batch, ep, seq, top_k)— computes expected load (= worst_case / EP)Modified:
src/MaxText/layers/moe.pyWhen
config.ring_paged_stash=True, replaces thecheckpoint_name("moe_mlpwo")call withstash_fn/restore_fncalls that pack actual tokens into the shared buffer.Modified:
src/maxtext/configs/types.pyAdds two new config fields:
ring_paged_stash: bool = Falsering_paged_stash_safety_margin: float = 1.5TODOs / Open Questions
stash_bufandwrite_ptrneed to be threaded through thelax.scancarry indecoders.py/deepseek.py. Currently the per-layer integration is sketched but the cross-layer buffer threading is not yet wired up. Looking for guidance on the best pattern here.restore_fnbackward: The gradient accumulation inrestore_fn_bwdneeds review — specifically the scatter from compact → full shape.moe_mlpwi_0/moe_mlpwi_1: Same technique could apply to wi outputs if they are offloaded.