2828from flax import nnx
2929
3030from MaxText import max_utils
31- from MaxText .common_types import AttentionType , Config , DType , Array , BATCH , LENGTH_NO_EXP , EMBED
31+ from MaxText .common_types import AttentionType , Config , DType , Array , BATCH , LENGTH_NO_EXP , EMBED , MODEL_MODE_AUTOREGRESSIVE , MODEL_MODE_TRAIN , MODEL_MODE_PREFILL
3232from MaxText .layers import attentions
3333from MaxText .layers import initializers as max_initializers
3434from MaxText .layers import linears
4242from MaxText .layers .attentions import Attention
4343from MaxText .layers .linears import DenseGeneral , MlpBlock
4444from MaxText .layers .moe import RoutedMoE
45+ from MaxText .inference import kvcache
4546
4647
4748# -----------------------------------------
@@ -279,7 +280,7 @@ def scan_body(prev_state, x):
279280 # Transpose back to (B, S, H, D_v)
280281 core_attn_out = jnp .transpose (core_attn_out , (0 , 2 , 1 , 3 )).astype (initial_dtype )
281282
282- return core_attn_out , final_state if output_final_state else None
283+ return core_attn_out , final_state
283284
284285
285286class Qwen3NextGatedDeltaNet (nnx .Module ):
@@ -310,7 +311,7 @@ class Qwen3NextGatedDeltaNet(nnx.Module):
310311 dtype: The datatype of the computation.
311312 """
312313
313- def __init__ (self , config : Config , dtype : DType = jnp .float32 , * , rngs : nnx .Rngs ):
314+ def __init__ (self , config : Config , dtype : DType = jnp .float32 , model_mode : str = MODEL_MODE_TRAIN , * , rngs : nnx .Rngs ):
314315 self .config = config
315316 self .dtype = dtype
316317 cfg = self .config
@@ -326,6 +327,17 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs
326327 conv_kernel_size = cfg .gdn_conv_kernel_dim
327328 self .v_heads_per_k_head = self .num_v_heads // self .num_k_heads
328329
330+ if model_mode != MODEL_MODE_TRAIN :
331+ self .cache = kvcache .GatedDeltaNetCache (
332+ batch = config .per_device_batch_size , # Or appropriate batch dim
333+ num_heads = self .num_v_heads ,
334+ k_head_dim = self .head_k_dim ,
335+ v_head_dim = self .head_v_dim ,
336+ conv_kernel_size = self .config .gdn_conv_kernel_dim ,
337+ conv_dim = conv_dim , # Make sure conv_dim is calculated before this
338+ dtype = dtype ,
339+ )
340+
329341 # Submodule instantiations
330342 self .in_proj_qkvz = linears .DenseGeneral (
331343 in_features_shape = in_features ,
@@ -381,7 +393,7 @@ def a_log_init(key, shape, dtype=jnp.float32):
381393 rngs = rngs ,
382394 )
383395
384- def __call__ (self , hidden_states : Array ) -> Array :
396+ def __call__ (self , hidden_states : Array , model_mode : str = MODEL_MODE_TRAIN , kv_cache = None , ** kwargs ) -> tuple [ Array , dict [ str , Array ]] :
385397 # hidden_states: (B, S, E)
386398 cfg = self .config
387399 batch , seq_len , _ = hidden_states .shape
@@ -455,12 +467,36 @@ def __call__(self, hidden_states: Array) -> Array:
455467 # conv_dim = 2 * K_dim + V_dim
456468 # qkv: (B, S, 2 * K_dim + V_dim)
457469 qkv = jnp .concatenate ([q , k , v ], axis = - 1 )
458-
459- # TODO(parambole): Implement caching logic for conv_state and recurrent_state
460-
461- # Input to conv_layer should be (B, S, C)
462- # qkv_conv shape: (B, S, conv_dim)
463- conv_out = self .conv1d (qkv )
470+ conv_kernel_size = self .config .gdn_conv_kernel_dim
471+
472+ conv_state = None
473+ if model_mode == MODEL_MODE_AUTOREGRESSIVE :
474+ # Retrieve state from self.cache instead of input arg
475+ conv_state = self .cache .conv_state .value
476+
477+ # Concatenate previous state with new input
478+ conv_input = jnp .concatenate ([conv_state , qkv ], axis = 1 )
479+ new_conv_state = conv_input [:, - (conv_kernel_size - 1 ) :, :]
480+
481+ # Update self.cache in place
482+ self .cache .conv_state .value = new_conv_state
483+ else :
484+ # Prefill/Train: pad with zeros
485+ conv_input = jnp .pad (qkv , ((0 , 0 ), (conv_kernel_size - 1 , 0 ), (0 , 0 )))
486+
487+ # For prefill, we must initialize the cache for the subsequent decode steps
488+ if model_mode == MODEL_MODE_PREFILL :
489+ # Store the last K-1 tokens as the initial state for decoding
490+ new_conv_state = conv_input [:, - (conv_kernel_size - 1 ) :, :]
491+ self .cache .conv_state .value = new_conv_state
492+ else :
493+ # Just a placeholder for return
494+ new_conv_state = None
495+
496+ # Perform the convolution.
497+ conv_out = self .conv1d (conv_input )
498+ # Slice the output to match the original input sequence length.
499+ conv_out = conv_out [:, - seq_len :, :]
464500 qkv_conv = jax .nn .silu (conv_out .astype (jnp .float32 )).astype (cfg .dtype )
465501 # q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim)
466502 q_conv , k_conv , v_conv = jnp .split (qkv_conv , [self .key_dim , 2 * self .key_dim ], axis = - 1 )
@@ -496,10 +532,31 @@ def __call__(self, hidden_states: Array) -> Array:
496532 pass # No repeating needed for query/key in this case
497533
498534 # TODO(parambole): Pass and update cache state for jax_chunk_gated_delta_rule
499- # core_attn_out shape: (B, S, H_v, D_v)
500- core_attn_out , _ = jax_chunk_gated_delta_rule (
501- query , key , value , g , beta , chunk_size = cfg .gdn_chunk_size , use_qk_norm_in_gdn = cfg .use_qk_norm_in_gdn
535+ recurrent_state = None
536+ if model_mode == MODEL_MODE_AUTOREGRESSIVE :
537+ # Retrieve state from self.cache
538+ recurrent_state = self .cache .recurrent_state .value
539+
540+ core_attn_out , recurrent_state_out = jax_chunk_gated_delta_rule (
541+ query ,
542+ key ,
543+ value ,
544+ g ,
545+ beta ,
546+ chunk_size = cfg .gdn_chunk_size ,
547+ initial_state = recurrent_state ,
548+ use_qk_norm_in_gdn = cfg .use_qk_norm_in_gdn ,
502549 )
550+
551+ if model_mode != MODEL_MODE_TRAIN :
552+ # Update self.cache in place for both prefill and decode
553+ self .cache .recurrent_state .value = recurrent_state_out
554+
555+ # Construct return dictionary for compatibility with decoders.py (optional but safer)
556+ new_cache = {
557+ "conv_state" : self .cache .conv_state .value if model_mode != MODEL_MODE_TRAIN else None ,
558+ "recurrent_state" : self .cache .recurrent_state .value if model_mode != MODEL_MODE_TRAIN else None
559+ }
503560
504561 # =========================================================================
505562 # STEP D: Final Output Stage
@@ -517,7 +574,7 @@ def __call__(self, hidden_states: Array) -> Array:
517574 # Final output shape: (B, S, E)
518575 output = self .out_proj (gated_output )
519576
520- return output
577+ return output , new_cache
521578
522579
523580class Qwen3NextFullAttention (nnx .Module ):
@@ -829,7 +886,7 @@ def __init__(
829886 rngs = rngs ,
830887 )
831888 else :
832- self .attention = Qwen3NextGatedDeltaNet (config = cfg , dtype = cfg .dtype , rngs = rngs )
889+ self .attention = Qwen3NextGatedDeltaNet (config = cfg , dtype = cfg .dtype , model_mode = model_mode , rngs = rngs )
833890
834891 # Second LayerNorm, applied before the MoE block.
835892 self .post_attention_layernorm = Qwen3NextRMSNorm (
@@ -853,7 +910,7 @@ def __call__(
853910 previous_chunk = None ,
854911 page_state : None | page_manager .PageState = None ,
855912 slot : None | int = None ,
856- kv_cache : None | jnp . ndarray = None ,
913+ kv_cache : None | dict [ str , Array ] = None ,
857914 attention_metadata : None | dict [str , Any ] = None ,
858915 ):
859916 # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
@@ -867,7 +924,7 @@ def __call__(
867924
868925 # Conditionally apply either the Linear Attention or Full Attention block.
869926 if isinstance (self .attention , Qwen3NextFullAttention ):
870- attention_output , kv_cache = cast (Qwen3NextFullAttention , self .attention )(
927+ attention_output , new_kv_cache = cast (Qwen3NextFullAttention , self .attention )(
871928 hidden_states ,
872929 decoder_segment_ids ,
873930 decoder_positions ,
@@ -877,9 +934,13 @@ def __call__(
877934 attention_metadata = attention_metadata ,
878935 )
879936 elif isinstance (self .attention , Qwen3NextGatedDeltaNet ):
880- attention_output = cast (Qwen3NextGatedDeltaNet , self .attention )(hidden_states )
881- else :
882- raise TypeError (f"Unexpected type for self.attention: { type (self .attention )} " )
937+ gdn_cache = kv_cache .get ("gdn_cache" ) if kv_cache is not None else None
938+ attention_output , _ = cast (Qwen3NextGatedDeltaNet , self .attention )(
939+ hidden_states ,
940+ model_mode = model_mode ,
941+ kv_cache = None ,
942+ )
943+ new_kv_cache = None
883944
884945 # First residual connection after attention
885946 hidden_states = residual + attention_output
@@ -906,8 +967,7 @@ def __call__(
906967 layer_output ,
907968 self .activation_axis_names ,
908969 )
909-
910- return layer_output , kv_cache
970+ return layer_output , new_kv_cache
911971
912972
913973# -----------------------------------------
@@ -1758,4 +1818,4 @@ def qwen3omni_visionprojector_as_linen(config: Config, mesh: Mesh) -> nn.Module:
17581818Qwen3NextScannableBlockToLinen = nnx_wrappers .to_linen_class (
17591819 Qwen3NextScannableBlock ,
17601820 base_metadata_fn = max_initializers .variable_to_logically_partitioned ,
1761- )
1821+ )
0 commit comments