Skip to content

Comments

Add caching for GatedDeltaNetCache (linear attention models)#2971

Open
Rohan-Bierneni wants to merge 1 commit intomainfrom
rbierneni-qwen3-next-caching
Open

Add caching for GatedDeltaNetCache (linear attention models)#2971
Rohan-Bierneni wants to merge 1 commit intomainfrom
rbierneni-qwen3-next-caching

Conversation

@Rohan-Bierneni
Copy link
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented Jan 20, 2026

PR: Add Caching Support for Qwen3-Next (Gated Delta Network)

Problem

The Qwen3-Next architecture utilizes a hybrid approach consisting of standard Self-Attention layers and Gated Delta Network (GDN) layers (a form of Linear Attention).
While MaxText’s existing KVCache handles standard attention (growing Key/Value history), it does not support the state management required for GDN. GDN layers require:

  • A Fixed-size Recurrent State: Compresses history into a constant-size matrix, rather than growing linearly with sequence length.

  • A Convolution State: A sliding window buffer for the short 1D convolution preceding the delta rule.

Without these, Qwen3-Next could not perform efficient autoregressive decoding; it would either need to recompute the entire history at every step or fail to capture temporal dependencies correctly.

This PR adds caching support to the Gated Delta Net implementation for Qwen3-Next. This is the first non-standard attention caching support in MaxText and required changes to both the MaxEngine code and the Qwen3-Next implementation.

Note: Decoding is currently ONLY supported in unscanned checkpoint format (scan_layers=False and stack_prefill_result_cache=False).

Solution

This PR implements specific caching infrastructure for Qwen3-Next's linear attention mechanism and integrates it natively into the MaxEngine workflow.

1. New Cache Implementation (src/MaxText/inference/kvcache.py)

  • Introduced GatedDeltaNetCache. Unlike the standard KV cache, this stores fixed-size states:
    • recurrent_state: Shape (B, H, K, V)
    • conv_state: Shape (B, Kernel-1, Dim)
  • Both variables are wrapped in nn.with_logical_constraint to ensure MaxText can automatically infer proper shardings.
  • Created a BaseCache class to abstract commonalities between standard KV and GDN caches.

2. Layer Logic & State Management (src/MaxText/layers/qwen3.py)

  • Conv & Recurrent Updates: Modified Qwen3NextGatedDeltaNet to properly slide the convolution window and update the recurrent state during MODEL_MODE_AUTOREGRESSIVE.

  • Chunked Prefill Support: Re-wrote MODEL_MODE_PREFILL logic to read from self.cache rather than defaulting to zeros. This allows the GDN state to seamlessly carry over between chunks for long prompts.

  • Dynamic Batch Alignment: Added jnp.broadcast_to and [:batch] slicing during cache retrieval. This dynamically bridges the gap between MaxEngine's global cache allocation (e.g., batch=64) and the active compute shape (e.g., batch=1 during prefill), preventing concatenation shape errors during XLA dummy-tracing passes.

3. MaxEngine Integration (src/MaxText/maxengine.py)

  • Non-sequence-growing Insertions: Updated _insert_jit and bulk_insert to handle fixed-size states. For GDN variables (recurrent_state, conv_state), the prefill step produces a "final state" which is directly copied into the decode cache slot (bypassing the sequence-length slicing logic used for standard KV caches).

  • Packed Prefill Guard: Added a NotImplementedError in insert_partial to block packed prefill for GDN states, preventing sequential memory from bleeding across independent packed prompts.

  • Logical Axes Fix: Updated init_decode_state to properly check for both .logical_axes and .names when unpacking LogicallyPartitioned variables, fixing a bug where empty tuples caused batch index crashes.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/448407748

Tests

Ran the MaxText.decode command and got meaningful output (v5p-64 & pdb=1): https://paste.googleplex.com/6411536458448896

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@Rohan-Bierneni Rohan-Bierneni changed the title Add class for GatedDeltaNetCache to kvcache.py Add caching for GatedDeltaNetCache (linear attention models) Jan 20, 2026
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-caching branch from 15f871b to 7ec8a31 Compare January 20, 2026 18:39
@codecov
Copy link

codecov bot commented Jan 20, 2026

Codecov Report

❌ Patch coverage is 32.75862% with 39 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/models/qwen3.py 44.44% 16 Missing and 4 partials ⚠️
src/maxtext/layers/decoders.py 15.38% 8 Missing and 3 partials ⚠️
src/MaxText/maxengine.py 11.11% 8 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@parambole parambole left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these changes. I have left a few comments. Additionally can you add a note on what scan layer settings are supported ?

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-caching branch 4 times, most recently from 6f9ef90 to b49d574 Compare January 30, 2026 18:50
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-caching branch 4 times, most recently from 32514bb to 5d03505 Compare February 19, 2026 21:57
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-caching branch from 5d03505 to b4f5fdf Compare February 20, 2026 21:28
Added support for GDN to maxengine but NNX linen incompatible

Merged code from other branch Qwen3-next

Modified to accept dynamic model mode and work with maxengine changes

Fix GDN init with model_mode

Do same cache update during packed prefill as normal prefill

Convert batch to int in init for state

remove new_cache and resolve comments from pr

fix merge conflicts

use maxtext instead of MaxText

typo in import

removed testcases

remove circular import

Add support for decoding with pdb > 1

Fix slicing bug when using batch_size > 1

Fix linter issues

Fix linter issues and flatten conditionals for pylint

uncommit pre-commit check
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-caching branch from b4f5fdf to 15c1aa1 Compare February 20, 2026 23:50
@Rohan-Bierneni Rohan-Bierneni self-assigned this Feb 20, 2026
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one, let's add one unit test like

def test_autoregression(self, rope_type):
cfg, mla = self.init_mla(self.config_arguments, rope_type)
prefill_length = cfg.max_prefill_predict_length
decode_total_length = cfg.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(cfg, cfg.dtype)
mla_full, _ = mla(
lnx,
lnx,
decoder_segment_ids=decoder_segment_ids,
inputs_positions=decoder_positions,
deterministic=True,
model_mode=MODEL_MODE_TRAIN,
)
lnx_prefill = lnx[:, 0:prefill_length, :]
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length]
decoder_positions_prefill = decoder_positions[:, 0:prefill_length]
mla_prefill, _ = mla(
lnx_prefill,
lnx_prefill,
decoder_segment_ids=decoder_segment_ids_prefill,
inputs_positions=decoder_positions_prefill,
deterministic=True,
model_mode=MODEL_MODE_PREFILL,
)
self.assertTrue(
jax.numpy.allclose(mla_prefill, mla_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False)
)
for idx in range(prefill_length, decode_total_length):
lnx_idx = lnx[:, idx : idx + 1, :]
decoder_positions_idx = decoder_positions[:, idx : idx + 1]
mla_idx, _ = mla(
lnx_idx,
lnx_idx,
inputs_positions=decoder_positions_idx,
deterministic=True,
model_mode=MODEL_MODE_AUTOREGRESSIVE,
)
mla_full_this_idx = mla_full[:, idx : idx + 1, :]
self.assertEqual(mla_full_this_idx.shape, mla_idx.shape)
# TODO (b/394626702) uncomment last check when decode and kv_cache are implemented for MLA
# self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=1e-02, atol=1e-02, equal_nan=False))
to protect any future changes break your feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants