Skip to content

Commit 7ec8a31

Browse files
Add class for GatedDeltaNetCache to kvcache.py
Added support for GDN to maxengine but NNX linen incompatible Merged code from other branch Qwen3-next Modified to accept dynamic model mode and work with maxengine changes Fix GDN init with model_mode Do same cache update during packed prefill as normal prefill Convert batch to int in init for state
1 parent d4a259d commit 7ec8a31

6 files changed

Lines changed: 474 additions & 30 deletions

File tree

src/MaxText/inference/kvcache.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def kv_cache_as_linen(
230230
)
231231

232232

233-
class KVCache(nnx.Module):
233+
class BaseCache(nnx.Module):
234+
"""Abstract base class for Caches."""
235+
pass
236+
237+
class KVCache(BaseCache):
234238
"""Implementation of the KVCache."""
235239

236240
def __init__(
@@ -842,6 +846,75 @@ def __call__(
842846
return self.kv_cache_autoregressive(key, value, use_ragged_attention)
843847
else:
844848
raise ValueError(f"Model Mode isn't supported! {model_mode=}")
849+
850+
851+
class GatedDeltaNetCache(BaseCache):
852+
"""Cache for Linear Attention (Gated Delta Net).
853+
854+
Stores the fixed-size recurrent state and the sliding window state for convolution.
855+
"""
856+
857+
def __init__(
858+
self,
859+
batch: int,
860+
num_heads: int,
861+
k_head_dim: int,
862+
v_head_dim: int,
863+
conv_kernel_size: int,
864+
conv_dim: int,
865+
dtype: DType,
866+
cache_batch_axis_name: str = CACHE_BATCH,
867+
cache_heads_axis_name: str = CACHE_HEADS,
868+
):
869+
self.batch = batch
870+
self.dtype = dtype
871+
872+
# 1. Recurrent State (S) for the Delta Rule
873+
# Shape: [Batch, Heads, K_Dim, V_Dim]
874+
# We maintain the running state matrix.
875+
self.recurrent_state = nnx.Cache(
876+
jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype),
877+
# Sharding: Batch, Heads, None (K), None (V)
878+
sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None)
879+
)
880+
881+
# 2. Convolution State for the 1D Conv
882+
# Shape: [Batch, Kernel_Size - 1, Conv_Dim]
883+
# We store the last (K-1) inputs to perform the sliding window conv during decoding.
884+
self.conv_state = nnx.Cache(
885+
jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype),
886+
# Sharding: Batch, None (Time), None (Dim)
887+
sharding=(cache_batch_axis_name, None, None)
888+
)
889+
890+
def __call__(self):
891+
"""Returns the cache variables for the layer to use."""
892+
return self
893+
894+
895+
def gated_delta_net_cache_as_linen(
896+
*,
897+
batch: int,
898+
num_heads: int,
899+
head_dim: int,
900+
conv_kernel_size: int,
901+
conv_dim: int,
902+
dtype: DType,
903+
name: str | None = None,
904+
):
905+
"""Initializes the GatedDeltaNetCache and returns it as a Linen module."""
906+
return nnx_wrappers.to_linen(
907+
GatedDeltaNetCache,
908+
batch=batch,
909+
num_heads=num_heads,
910+
head_dim=head_dim,
911+
conv_kernel_size=conv_kernel_size,
912+
conv_dim=conv_dim,
913+
dtype=dtype,
914+
metadata_fn=variable_to_logically_partitioned,
915+
name=name,
916+
abstract_init=False,
917+
)
845918

846919

847920
def mla_kv_cache_as_linen(

src/MaxText/layers/decoders.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -874,15 +874,35 @@ def __call__(
874874
"is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval),
875875
"is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step),
876876
}
877+
kv_cache = None
878+
if kv_caches is not None:
879+
# For models other than Qwen3-Next, kv_caches is a list of caches.
880+
if cfg.decoder_block != DecoderBlockType.QWEN3_NEXT:
881+
kv_cache = kv_caches[lyr]
882+
877883
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
878884
layer_kwargs = {"layer_idx": lyr}
885+
# For Qwen3Next, kv_caches is a dictionary of lists of caches.
886+
is_full_attention_layer = (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0
887+
if kv_caches is not None:
888+
if is_full_attention_layer:
889+
k_cache = kv_caches["key_cache"][lyr]
890+
v_cache = kv_caches["value_cache"][lyr]
891+
kv_cache = (k_cache, v_cache)
892+
else:
893+
# For GDN layers, the cache is a dictionary.
894+
# conv_state = kv_caches["conv_states"][lyr]
895+
# recurrent_state = kv_caches["recurrent_states"][lyr]
896+
# gdn_cache = {"conv_state": conv_state, "recurrent_state": recurrent_state}
897+
# kv_cache = {"gdn_cache": gdn_cache}
898+
kv_cache = None
899+
879900
if cfg.decoder_block == DecoderBlockType.GPT_OSS:
880901
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
881902
layer = RemattedBlockLayer(
882903
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
883904
)
884-
kv_cache = kv_caches[lyr] if kv_caches is not None else None
885-
y, kv_cache = layer(
905+
y, returned_cache = layer(
886906
y,
887907
decoder_segment_ids,
888908
decoder_positions,
@@ -895,8 +915,20 @@ def __call__(
895915
attention_metadata=attention_metadata,
896916
**layer_call_kwargs,
897917
)
898-
if kv_caches is not None and kv_cache is not None:
899-
kv_caches[lyr] = kv_cache
918+
if kv_caches is not None and returned_cache is not None:
919+
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
920+
is_full_attention_layer = (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0
921+
if is_full_attention_layer:
922+
kv_caches["key_cache"][lyr] = returned_cache[0]
923+
kv_caches["value_cache"][lyr] = returned_cache[1]
924+
else:
925+
# gdn_cache = returned_cache.get("gdn_cache")
926+
# if gdn_cache:
927+
# kv_caches["conv_states"][lyr] = gdn_cache["conv_state"]
928+
# kv_caches["recurrent_states"][lyr] = gdn_cache["recurrent_state"]
929+
pass
930+
else:
931+
kv_caches[lyr] = returned_cache
900932

901933
assert isinstance(y, jax.Array)
902934

@@ -990,4 +1022,4 @@ def _apply_gemma3_scanned_blocks(
9901022
slot=slot,
9911023
**layer_call_kwargs,
9921024
)
993-
return y
1025+
return y

src/MaxText/layers/qwen3.py

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from flax import nnx
2929

3030
from MaxText import max_utils
31-
from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED
31+
from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL
3232
from MaxText.layers import attentions
3333
from MaxText.layers import initializers as max_initializers
3434
from MaxText.layers import linears
@@ -42,6 +42,7 @@
4242
from MaxText.layers.attentions import Attention
4343
from MaxText.layers.linears import DenseGeneral, MlpBlock
4444
from MaxText.layers.moe import RoutedMoE
45+
from MaxText.inference import kvcache
4546

4647

4748
# -----------------------------------------
@@ -279,7 +280,7 @@ def scan_body(prev_state, x):
279280
# Transpose back to (B, S, H, D_v)
280281
core_attn_out = jnp.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype)
281282

282-
return core_attn_out, final_state if output_final_state else None
283+
return core_attn_out, final_state
283284

284285

285286
class Qwen3NextGatedDeltaNet(nnx.Module):
@@ -310,7 +311,7 @@ class Qwen3NextGatedDeltaNet(nnx.Module):
310311
dtype: The datatype of the computation.
311312
"""
312313

313-
def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs):
314+
def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs):
314315
self.config = config
315316
self.dtype = dtype
316317
cfg = self.config
@@ -326,6 +327,17 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs
326327
conv_kernel_size = cfg.gdn_conv_kernel_dim
327328
self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads
328329

330+
if model_mode != MODEL_MODE_TRAIN:
331+
self.cache = kvcache.GatedDeltaNetCache(
332+
batch=config.per_device_batch_size, # Or appropriate batch dim
333+
num_heads=self.num_v_heads,
334+
k_head_dim=self.head_k_dim,
335+
v_head_dim=self.head_v_dim,
336+
conv_kernel_size=self.config.gdn_conv_kernel_dim,
337+
conv_dim=conv_dim, # Make sure conv_dim is calculated before this
338+
dtype=dtype,
339+
)
340+
329341
# Submodule instantiations
330342
self.in_proj_qkvz = linears.DenseGeneral(
331343
in_features_shape=in_features,
@@ -381,7 +393,7 @@ def a_log_init(key, shape, dtype=jnp.float32):
381393
rngs=rngs,
382394
)
383395

384-
def __call__(self, hidden_states: Array) -> Array:
396+
def __call__(self, hidden_states: Array, model_mode: str = MODEL_MODE_TRAIN, kv_cache = None, **kwargs) -> tuple[Array, dict[str, Array]]:
385397
# hidden_states: (B, S, E)
386398
cfg = self.config
387399
batch, seq_len, _ = hidden_states.shape
@@ -455,12 +467,36 @@ def __call__(self, hidden_states: Array) -> Array:
455467
# conv_dim = 2 * K_dim + V_dim
456468
# qkv: (B, S, 2 * K_dim + V_dim)
457469
qkv = jnp.concatenate([q, k, v], axis=-1)
458-
459-
# TODO(parambole): Implement caching logic for conv_state and recurrent_state
460-
461-
# Input to conv_layer should be (B, S, C)
462-
# qkv_conv shape: (B, S, conv_dim)
463-
conv_out = self.conv1d(qkv)
470+
conv_kernel_size = self.config.gdn_conv_kernel_dim
471+
472+
conv_state = None
473+
if model_mode == MODEL_MODE_AUTOREGRESSIVE:
474+
# Retrieve state from self.cache instead of input arg
475+
conv_state = self.cache.conv_state.value
476+
477+
# Concatenate previous state with new input
478+
conv_input = jnp.concatenate([conv_state, qkv], axis=1)
479+
new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :]
480+
481+
# Update self.cache in place
482+
self.cache.conv_state.value = new_conv_state
483+
else:
484+
# Prefill/Train: pad with zeros
485+
conv_input = jnp.pad(qkv, ((0, 0), (conv_kernel_size - 1, 0), (0, 0)))
486+
487+
# For prefill, we must initialize the cache for the subsequent decode steps
488+
if model_mode == MODEL_MODE_PREFILL:
489+
# Store the last K-1 tokens as the initial state for decoding
490+
new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :]
491+
self.cache.conv_state.value = new_conv_state
492+
else:
493+
# Just a placeholder for return
494+
new_conv_state = None
495+
496+
# Perform the convolution.
497+
conv_out = self.conv1d(conv_input)
498+
# Slice the output to match the original input sequence length.
499+
conv_out = conv_out[:, -seq_len:, :]
464500
qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype)
465501
# q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim)
466502
q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1)
@@ -496,10 +532,31 @@ def __call__(self, hidden_states: Array) -> Array:
496532
pass # No repeating needed for query/key in this case
497533

498534
# TODO(parambole): Pass and update cache state for jax_chunk_gated_delta_rule
499-
# core_attn_out shape: (B, S, H_v, D_v)
500-
core_attn_out, _ = jax_chunk_gated_delta_rule(
501-
query, key, value, g, beta, chunk_size=cfg.gdn_chunk_size, use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn
535+
recurrent_state = None
536+
if model_mode == MODEL_MODE_AUTOREGRESSIVE:
537+
# Retrieve state from self.cache
538+
recurrent_state = self.cache.recurrent_state.value
539+
540+
core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule(
541+
query,
542+
key,
543+
value,
544+
g,
545+
beta,
546+
chunk_size=cfg.gdn_chunk_size,
547+
initial_state=recurrent_state,
548+
use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn,
502549
)
550+
551+
if model_mode != MODEL_MODE_TRAIN:
552+
# Update self.cache in place for both prefill and decode
553+
self.cache.recurrent_state.value = recurrent_state_out
554+
555+
# Construct return dictionary for compatibility with decoders.py (optional but safer)
556+
new_cache = {
557+
"conv_state": self.cache.conv_state.value if model_mode != MODEL_MODE_TRAIN else None,
558+
"recurrent_state": self.cache.recurrent_state.value if model_mode != MODEL_MODE_TRAIN else None
559+
}
503560

504561
# =========================================================================
505562
# STEP D: Final Output Stage
@@ -517,7 +574,7 @@ def __call__(self, hidden_states: Array) -> Array:
517574
# Final output shape: (B, S, E)
518575
output = self.out_proj(gated_output)
519576

520-
return output
577+
return output, new_cache
521578

522579

523580
class Qwen3NextFullAttention(nnx.Module):
@@ -829,7 +886,7 @@ def __init__(
829886
rngs=rngs,
830887
)
831888
else:
832-
self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, rngs=rngs)
889+
self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs)
833890

834891
# Second LayerNorm, applied before the MoE block.
835892
self.post_attention_layernorm = Qwen3NextRMSNorm(
@@ -853,7 +910,7 @@ def __call__(
853910
previous_chunk=None,
854911
page_state: None | page_manager.PageState = None,
855912
slot: None | int = None,
856-
kv_cache: None | jnp.ndarray = None,
913+
kv_cache: None | dict[str, Array] = None,
857914
attention_metadata: None | dict[str, Any] = None,
858915
):
859916
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
@@ -867,7 +924,7 @@ def __call__(
867924

868925
# Conditionally apply either the Linear Attention or Full Attention block.
869926
if isinstance(self.attention, Qwen3NextFullAttention):
870-
attention_output, kv_cache = cast(Qwen3NextFullAttention, self.attention)(
927+
attention_output, new_kv_cache = cast(Qwen3NextFullAttention, self.attention)(
871928
hidden_states,
872929
decoder_segment_ids,
873930
decoder_positions,
@@ -877,9 +934,13 @@ def __call__(
877934
attention_metadata=attention_metadata,
878935
)
879936
elif isinstance(self.attention, Qwen3NextGatedDeltaNet):
880-
attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)(hidden_states)
881-
else:
882-
raise TypeError(f"Unexpected type for self.attention: {type(self.attention)}")
937+
gdn_cache = kv_cache.get("gdn_cache") if kv_cache is not None else None
938+
attention_output, _ = cast(Qwen3NextGatedDeltaNet, self.attention)(
939+
hidden_states,
940+
model_mode=model_mode,
941+
kv_cache=None,
942+
)
943+
new_kv_cache = None
883944

884945
# First residual connection after attention
885946
hidden_states = residual + attention_output
@@ -906,8 +967,7 @@ def __call__(
906967
layer_output,
907968
self.activation_axis_names,
908969
)
909-
910-
return layer_output, kv_cache
970+
return layer_output, new_kv_cache
911971

912972

913973
# -----------------------------------------
@@ -1758,4 +1818,4 @@ def qwen3omni_visionprojector_as_linen(config: Config, mesh: Mesh) -> nn.Module:
17581818
Qwen3NextScannableBlockToLinen = nnx_wrappers.to_linen_class(
17591819
Qwen3NextScannableBlock,
17601820
base_metadata_fn=max_initializers.variable_to_logically_partitioned,
1761-
)
1821+
)

0 commit comments

Comments
 (0)