From 15c1aa14705198e0f566646d48d61cb4e705c305 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Tue, 23 Dec 2025 10:11:10 +0000 Subject: [PATCH] Add class for GatedDeltaNetCache to kvcache.py Added support for GDN to maxengine but NNX linen incompatible Merged code from other branch Qwen3-next Modified to accept dynamic model mode and work with maxengine changes Fix GDN init with model_mode Do same cache update during packed prefill as normal prefill Convert batch to int in init for state remove new_cache and resolve comments from pr fix merge conflicts use maxtext instead of MaxText typo in import removed testcases remove circular import Add support for decoding with pdb > 1 Fix slicing bug when using batch_size > 1 Fix linter issues Fix linter issues and flatten conditionals for pylint uncommit pre-commit check --- src/MaxText/maxengine.py | 24 +++++++- src/maxtext/inference/kvcache.py | 77 +++++++++++++++++++++++++- src/maxtext/layers/decoders.py | 19 +++++-- src/maxtext/models/qwen3.py | 95 ++++++++++++++++++++++++-------- 4 files changed, 187 insertions(+), 28 deletions(-) diff --git a/src/MaxText/maxengine.py b/src/MaxText/maxengine.py index d67cddf806..75f040750d 100644 --- a/src/MaxText/maxengine.py +++ b/src/MaxText/maxengine.py @@ -1146,6 +1146,9 @@ def copy(path, partial_cache, full_cache, annotations): "cached_prefill_value_scale", ]: full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) + elif path_key in ["recurrent_state", "conv_state"]: + # Direct update for fixed-size linear attention states + full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) else: raise ValueError(f"We don't have a strategy for inserting {path_key}") @@ -1258,6 +1261,10 @@ def copy(path, partial_cache, full_cache, annotations): "cached_prefill_value_scale", ]: return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) + elif path_key in ["recurrent_state", "conv_state"]: + # For linear attention, the state is fixed size. We simply copy the result + # from the prefill step (partial_cache) into the decode state (full_cache). + return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) else: raise ValueError(f"We don't have a strategy for inserting {path_key}") @@ -1447,6 +1454,15 @@ def copy(path, partial_cache, full_cache, annotations): partial_cache = jax.lax.dynamic_slice(partial_cache, start_indices, slice_size) return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) + elif path_key in ["recurrent_state", "conv_state"]: + # SSM states are the "final state" after prefill, so we just overwrite the slot. + # We don't need to slice by sequence length like we do for KV cache. + if num_prompts > 1: + raise NotImplementedError( + "Packed prefill is currently incompatible with linear attention states (GDN). " + "Prompt memory will bleed into adjacent prompts. Please disable packed prefill." + ) + return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) else: raise ValueError(f"We don't have a strategy for inserting {path_key}") @@ -1660,7 +1676,13 @@ def initialize(): def is_lp(k): return isinstance(k, flax.linen.spmd.LogicallyPartitioned) - self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x: tuple(x.names), cache, is_leaf=is_lp) + self.kv_cache_annotations_named = jax.tree_util.tree_map( + lambda x: tuple(x.logical_axes) + if hasattr(x, "logical_axes") + else (tuple(x.names) if hasattr(x, "names") else ()), + cache, + is_leaf=is_lp, + ) zeroed = max_utils.unbox_logicallypartioned(init_state) return zeroed diff --git a/src/maxtext/inference/kvcache.py b/src/maxtext/inference/kvcache.py index 0ac2fe1098..c4222adf4d 100644 --- a/src/maxtext/inference/kvcache.py +++ b/src/maxtext/inference/kvcache.py @@ -230,7 +230,11 @@ def kv_cache_as_linen( ) -class KVCache(nnx.Module): +class BaseCache(nnx.Module): + """Abstract base class for Caches.""" + + +class KVCache(BaseCache): """Implementation of the KVCache.""" def __init__( @@ -290,6 +294,7 @@ def __init__( use_chunked_prefill: Whether to use chunked prefill. rngs: The random number generators for initialization. """ + super().__init__() self.max_prefill_length = max_prefill_length self.max_target_length = max_target_length self.batch = batch @@ -844,6 +849,76 @@ def __call__( raise ValueError(f"Model Mode isn't supported! {model_mode=}") +class GatedDeltaNetCache(BaseCache): + """Cache for Linear Attention (Gated Delta Net). + + Stores the fixed-size recurrent state and the sliding window state for convolution. + """ + + def __init__( + self, + batch: int, + num_heads: int, + k_head_dim: int, + v_head_dim: int, + conv_kernel_size: int, + conv_dim: int, + dtype: DType, + cache_batch_axis_name: str = CACHE_BATCH, + cache_heads_axis_name: str = CACHE_HEADS, + ): + super().__init__() + self.batch = batch + self.dtype = dtype + + # 1. Recurrent State (S) for the Delta Rule + # Shape: [Batch, Heads, K_Dim, V_Dim] + # We maintain the running state matrix. + self.recurrent_state = nnx.Cache( + jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype), + # Sharding: Batch, Heads, None (K), None (V) + sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None), + ) + + # 2. Convolution State for the 1D Conv + # Shape: [Batch, Kernel_Size - 1, Conv_Dim] + # We store the last (K-1) inputs to perform the sliding window conv during decoding. + self.conv_state = nnx.Cache( + jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype), + # Sharding: Batch, None (Time), None (Dim) + sharding=(cache_batch_axis_name, None, None), + ) + + def __call__(self): + """Returns the cache variables for the layer to use.""" + return self + + +def gated_delta_net_cache_as_linen( + *, + batch: int, + num_heads: int, + head_dim: int, + conv_kernel_size: int, + conv_dim: int, + dtype: DType, + name: str | None = None, +): + """Initializes the GatedDeltaNetCache and returns it as a Linen module.""" + return nnx_wrappers.to_linen( + GatedDeltaNetCache, + batch=batch, + num_heads=num_heads, + head_dim=head_dim, + conv_kernel_size=conv_kernel_size, + conv_dim=conv_dim, + dtype=dtype, + metadata_fn=variable_to_logically_partitioned, + name=name, + abstract_init=False, + ) + + def mla_kv_cache_as_linen( *, max_prefill_length: int, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 8433fa633f..393cb04a67 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -904,6 +904,14 @@ def __call__( } if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: layer_kwargs = {"layer_idx": lyr} + kv_cache = None + if kv_caches is not None and cfg.decoder_block != DecoderBlockType.QWEN3_NEXT: + kv_cache = kv_caches[lyr] + elif kv_caches is not None and cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + # For Qwen3Next, kv_caches is a dictionary of lists of caches. + if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) + if cfg.decoder_block == DecoderBlockType.GPT_OSS: layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} if cfg.decoder_block == DecoderBlockType.OLMO3: @@ -911,8 +919,7 @@ def __call__( layer = RemattedBlockLayer( config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs ) - kv_cache = kv_caches[lyr] if kv_caches is not None else None - y, kv_cache = layer( + y, returned_cache = layer( y, decoder_segment_ids, decoder_positions, @@ -925,8 +932,12 @@ def __call__( attention_metadata=attention_metadata, **layer_call_kwargs, ) - if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache + if kv_caches is not None and returned_cache is not None: + if cfg.decoder_block != DecoderBlockType.QWEN3_NEXT: + kv_caches[lyr] = returned_cache + elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_caches["key_cache"][lyr] = returned_cache[0] + kv_caches["value_cache"][lyr] = returned_cache[1] assert isinstance(y, jax.Array) diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index 8da1162542..5da96d5bf0 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -43,7 +43,10 @@ from maxtext.layers.moe import RoutedMoE from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned from maxtext.inference import page_manager + from maxtext.utils import max_utils +from maxtext.inference import page_manager, kvcache + # ----------------------------------------- # Qwen3-Next Layer Implementations @@ -379,7 +382,7 @@ class Qwen3NextGatedDeltaNet(nnx.Module): 2. output = Linear_out(y) """ - def __init__(self, config: Config, *, rngs: nnx.Rngs): + def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs): """ Args: config: MaxText configuration object. @@ -399,6 +402,17 @@ def __init__(self, config: Config, *, rngs: nnx.Rngs): conv_kernel_size = cfg.gdn_conv_kernel_dim self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads + if model_mode != MODEL_MODE_TRAIN: + self.cache = kvcache.GatedDeltaNetCache( + batch=config.per_device_batch_size, + num_heads=self.num_v_heads, + k_head_dim=self.head_k_dim, + v_head_dim=self.head_v_dim, + conv_kernel_size=self.config.gdn_conv_kernel_dim, + conv_dim=conv_dim, + dtype=dtype, + ) + # Submodule instantiations self.in_proj_qkvz = DenseGeneral( in_features_shape=in_features, @@ -458,7 +472,9 @@ def a_log_init(key, shape, dtype=jnp.float32): rngs=rngs, ) - def __call__(self, hidden_states: Array) -> Array: + def __call__( + self, hidden_states: Array, model_mode: str = MODEL_MODE_TRAIN, kv_cache=None, **kwargs + ) -> tuple[Array, dict[str, Array]]: # hidden_states: (B, S, E) cfg = self.config batch, seq_len, _ = hidden_states.shape @@ -529,15 +545,35 @@ def __call__(self, hidden_states: Array) -> Array: # ========================================================================= # STEP B: 1D Convolution # ========================================================================= - # conv_dim = 2 * K_dim + V_dim - # qkv: (B, S, 2 * K_dim + V_dim) qkv = jnp.concatenate([q, k, v], axis=-1) + batch, seq_len, _ = qkv.shape + conv_kernel_size = self.config.gdn_conv_kernel_dim + + conv_state = None + if model_mode != MODEL_MODE_TRAIN: + # Retrieve state from self.cache + conv_state = self.cache.conv_state.value + + if conv_state.shape[0] != batch: + if conv_state.shape[0] == 1: + conv_state = jnp.broadcast_to(conv_state, (batch,) + conv_state.shape[1:]) + else: + conv_state = conv_state[:batch] + + # Concatenate previous state with new input + conv_input = jnp.concatenate([conv_state, qkv], axis=1) + new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :] + + # Update self.cache in place + self.cache.conv_state.value = new_conv_state + else: + # Train: pad with zeros + conv_input = jnp.pad(qkv, ((0, 0), (conv_kernel_size - 1, 0), (0, 0))) - # TODO(parambole): Implement caching logic for conv_state and recurrent_state - - # Input to conv_layer should be (B, S, C) - # qkv_conv shape: (B, S, conv_dim) - conv_out = self.conv1d(qkv) + # Perform the convolution. + conv_out = self.conv1d(conv_input) + # Slice the output to match the original input sequence length. + conv_out = conv_out[:, -seq_len:, :] qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype) # q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim) q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1) @@ -568,22 +604,35 @@ def __call__(self, hidden_states: Array) -> Array: # key shape after repeat: (B, S, H_v, D_k) key = jnp.repeat(key, repeats, axis=2) elif self.num_k_heads > self.num_v_heads and self.num_k_heads % self.num_v_heads == 0: - # This case might occur if key/query heads are more than value heads. - pass # No repeating needed for query/key in this case + pass - # TODO(parambole): Pass and update cache state for jax_chunk_gated_delta_rule - # core_attn_out shape: (B, S, H_v, D_v) - core_attn_out, _ = jax_chunk_gated_delta_rule( + recurrent_state = None + if model_mode != MODEL_MODE_TRAIN: + # Retrieve state from self.cache + recurrent_state = self.cache.recurrent_state.value + + if recurrent_state.shape[0] != batch: + if recurrent_state.shape[0] == 1: + recurrent_state = jnp.broadcast_to(recurrent_state, (batch,) + recurrent_state.shape[1:]) + else: + recurrent_state = recurrent_state[:batch] + + core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule( query, key, value, g, beta, chunk_size=cfg.gdn_chunk_size, + initial_state=recurrent_state, use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, compute_dtype=cfg.dtype, ) + if model_mode != MODEL_MODE_TRAIN: + # Update self.cache in place for both prefill and decode + self.cache.recurrent_state.value = recurrent_state_out + # ========================================================================= # STEP D: Final Output Stage # ========================================================================= @@ -913,7 +962,7 @@ def __init__( rngs=rngs, ) else: - self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs) + self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs) # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( @@ -937,7 +986,7 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, - kv_cache: None | jnp.ndarray = None, + kv_cache: None | dict[str, Array] = None, attention_metadata: None | dict[str, Any] = None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) @@ -951,7 +1000,7 @@ def __call__( # Conditionally apply either the Linear Attention or Full Attention block. if isinstance(self.attention, Qwen3NextFullAttention): - attention_output, kv_cache = cast(Qwen3NextFullAttention, self.attention)( + attention_output, new_kv_cache = cast(Qwen3NextFullAttention, self.attention)( hidden_states, decoder_segment_ids, decoder_positions, @@ -960,10 +1009,13 @@ def __call__( kv_cache=kv_cache, attention_metadata=attention_metadata, ) - elif isinstance(self.attention, Qwen3NextGatedDeltaNet): - attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)(hidden_states) else: - raise TypeError(f"Unexpected type for self.attention: {type(self.attention)}") + attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)( + hidden_states, + model_mode=model_mode, + kv_cache=None, + ) + new_kv_cache = None # First residual connection after attention hidden_states = residual + attention_output @@ -990,8 +1042,7 @@ def __call__( layer_output, self.activation_axis_names, ) - - return layer_output, kv_cache + return layer_output, new_kv_cache # -----------------------------------------