Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,9 @@ def copy(path, partial_cache, full_cache, annotations):
"cached_prefill_value_scale",
]:
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
elif path_key in ["recurrent_state", "conv_state"]:
# Direct update for fixed-size linear attention states
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
else:
raise ValueError(f"We don't have a strategy for inserting {path_key}")

Expand Down Expand Up @@ -1258,6 +1261,10 @@ def copy(path, partial_cache, full_cache, annotations):
"cached_prefill_value_scale",
]:
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
elif path_key in ["recurrent_state", "conv_state"]:
# For linear attention, the state is fixed size. We simply copy the result
# from the prefill step (partial_cache) into the decode state (full_cache).
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
else:
raise ValueError(f"We don't have a strategy for inserting {path_key}")

Expand Down Expand Up @@ -1447,6 +1454,15 @@ def copy(path, partial_cache, full_cache, annotations):
partial_cache = jax.lax.dynamic_slice(partial_cache, start_indices, slice_size)

return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
elif path_key in ["recurrent_state", "conv_state"]:
# SSM states are the "final state" after prefill, so we just overwrite the slot.
# We don't need to slice by sequence length like we do for KV cache.
if num_prompts > 1:
raise NotImplementedError(
"Packed prefill is currently incompatible with linear attention states (GDN). "
"Prompt memory will bleed into adjacent prompts. Please disable packed prefill."
)
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
else:
raise ValueError(f"We don't have a strategy for inserting {path_key}")

Expand Down Expand Up @@ -1660,7 +1676,13 @@ def initialize():
def is_lp(k):
return isinstance(k, flax.linen.spmd.LogicallyPartitioned)

self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x: tuple(x.names), cache, is_leaf=is_lp)
self.kv_cache_annotations_named = jax.tree_util.tree_map(
lambda x: tuple(x.logical_axes)
if hasattr(x, "logical_axes")
else (tuple(x.names) if hasattr(x, "names") else ()),
cache,
is_leaf=is_lp,
)
zeroed = max_utils.unbox_logicallypartioned(init_state)
return zeroed

Expand Down
77 changes: 76 additions & 1 deletion src/maxtext/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ def kv_cache_as_linen(
)


class KVCache(nnx.Module):
class BaseCache(nnx.Module):
"""Abstract base class for Caches."""


class KVCache(BaseCache):
"""Implementation of the KVCache."""

def __init__(
Expand Down Expand Up @@ -290,6 +294,7 @@ def __init__(
use_chunked_prefill: Whether to use chunked prefill.
rngs: The random number generators for initialization.
"""
super().__init__()
self.max_prefill_length = max_prefill_length
self.max_target_length = max_target_length
self.batch = batch
Expand Down Expand Up @@ -844,6 +849,76 @@ def __call__(
raise ValueError(f"Model Mode isn't supported! {model_mode=}")


class GatedDeltaNetCache(BaseCache):
"""Cache for Linear Attention (Gated Delta Net).

Stores the fixed-size recurrent state and the sliding window state for convolution.
"""

def __init__(
self,
batch: int,
num_heads: int,
k_head_dim: int,
v_head_dim: int,
conv_kernel_size: int,
conv_dim: int,
dtype: DType,
cache_batch_axis_name: str = CACHE_BATCH,
cache_heads_axis_name: str = CACHE_HEADS,
):
super().__init__()
self.batch = batch
self.dtype = dtype

# 1. Recurrent State (S) for the Delta Rule
# Shape: [Batch, Heads, K_Dim, V_Dim]
# We maintain the running state matrix.
self.recurrent_state = nnx.Cache(
jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype),
# Sharding: Batch, Heads, None (K), None (V)
sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None),
)

# 2. Convolution State for the 1D Conv
# Shape: [Batch, Kernel_Size - 1, Conv_Dim]
# We store the last (K-1) inputs to perform the sliding window conv during decoding.
self.conv_state = nnx.Cache(
jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype),
# Sharding: Batch, None (Time), None (Dim)
sharding=(cache_batch_axis_name, None, None),
)

def __call__(self):
"""Returns the cache variables for the layer to use."""
return self


def gated_delta_net_cache_as_linen(
*,
batch: int,
num_heads: int,
head_dim: int,
conv_kernel_size: int,
conv_dim: int,
dtype: DType,
name: str | None = None,
):
"""Initializes the GatedDeltaNetCache and returns it as a Linen module."""
return nnx_wrappers.to_linen(
GatedDeltaNetCache,
batch=batch,
num_heads=num_heads,
head_dim=head_dim,
conv_kernel_size=conv_kernel_size,
conv_dim=conv_dim,
dtype=dtype,
metadata_fn=variable_to_logically_partitioned,
name=name,
abstract_init=False,
)


def mla_kv_cache_as_linen(
*,
max_prefill_length: int,
Expand Down
19 changes: 15 additions & 4 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,15 +904,22 @@ def __call__(
}
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
layer_kwargs = {"layer_idx": lyr}
kv_cache = None
if kv_caches is not None and cfg.decoder_block != DecoderBlockType.QWEN3_NEXT:
kv_cache = kv_caches[lyr]
elif kv_caches is not None and cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
# For Qwen3Next, kv_caches is a dictionary of lists of caches.
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])

if cfg.decoder_block == DecoderBlockType.GPT_OSS:
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
if cfg.decoder_block == DecoderBlockType.OLMO3:
layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)}
layer = RemattedBlockLayer(
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
)
kv_cache = kv_caches[lyr] if kv_caches is not None else None
y, kv_cache = layer(
y, returned_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
Expand All @@ -925,8 +932,12 @@ def __call__(
attention_metadata=attention_metadata,
**layer_call_kwargs,
)
if kv_caches is not None and kv_cache is not None:
kv_caches[lyr] = kv_cache
if kv_caches is not None and returned_cache is not None:
if cfg.decoder_block != DecoderBlockType.QWEN3_NEXT:
kv_caches[lyr] = returned_cache
elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_caches["key_cache"][lyr] = returned_cache[0]
kv_caches["value_cache"][lyr] = returned_cache[1]

assert isinstance(y, jax.Array)

Expand Down
95 changes: 73 additions & 22 deletions src/maxtext/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
from maxtext.layers.moe import RoutedMoE
from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned
from maxtext.inference import page_manager

from maxtext.utils import max_utils
from maxtext.inference import page_manager, kvcache


# -----------------------------------------
# Qwen3-Next Layer Implementations
Expand Down Expand Up @@ -379,7 +382,7 @@ class Qwen3NextGatedDeltaNet(nnx.Module):
2. output = Linear_out(y)
"""

def __init__(self, config: Config, *, rngs: nnx.Rngs):
def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs):
"""
Args:
config: MaxText configuration object.
Expand All @@ -399,6 +402,17 @@ def __init__(self, config: Config, *, rngs: nnx.Rngs):
conv_kernel_size = cfg.gdn_conv_kernel_dim
self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads

if model_mode != MODEL_MODE_TRAIN:
self.cache = kvcache.GatedDeltaNetCache(
batch=config.per_device_batch_size,
num_heads=self.num_v_heads,
k_head_dim=self.head_k_dim,
v_head_dim=self.head_v_dim,
conv_kernel_size=self.config.gdn_conv_kernel_dim,
conv_dim=conv_dim,
dtype=dtype,
)

# Submodule instantiations
self.in_proj_qkvz = DenseGeneral(
in_features_shape=in_features,
Expand Down Expand Up @@ -458,7 +472,9 @@ def a_log_init(key, shape, dtype=jnp.float32):
rngs=rngs,
)

def __call__(self, hidden_states: Array) -> Array:
def __call__(
self, hidden_states: Array, model_mode: str = MODEL_MODE_TRAIN, kv_cache=None, **kwargs
) -> tuple[Array, dict[str, Array]]:
# hidden_states: (B, S, E)
cfg = self.config
batch, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -529,15 +545,35 @@ def __call__(self, hidden_states: Array) -> Array:
# =========================================================================
# STEP B: 1D Convolution
# =========================================================================
# conv_dim = 2 * K_dim + V_dim
# qkv: (B, S, 2 * K_dim + V_dim)
qkv = jnp.concatenate([q, k, v], axis=-1)
batch, seq_len, _ = qkv.shape
conv_kernel_size = self.config.gdn_conv_kernel_dim

conv_state = None
if model_mode != MODEL_MODE_TRAIN:
# Retrieve state from self.cache
conv_state = self.cache.conv_state.value

if conv_state.shape[0] != batch:
if conv_state.shape[0] == 1:
conv_state = jnp.broadcast_to(conv_state, (batch,) + conv_state.shape[1:])
else:
conv_state = conv_state[:batch]

# Concatenate previous state with new input
conv_input = jnp.concatenate([conv_state, qkv], axis=1)
new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :]

# Update self.cache in place
self.cache.conv_state.value = new_conv_state
else:
# Train: pad with zeros
conv_input = jnp.pad(qkv, ((0, 0), (conv_kernel_size - 1, 0), (0, 0)))

# TODO(parambole): Implement caching logic for conv_state and recurrent_state

# Input to conv_layer should be (B, S, C)
# qkv_conv shape: (B, S, conv_dim)
conv_out = self.conv1d(qkv)
# Perform the convolution.
conv_out = self.conv1d(conv_input)
# Slice the output to match the original input sequence length.
conv_out = conv_out[:, -seq_len:, :]
qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype)
# q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim)
q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1)
Expand Down Expand Up @@ -568,22 +604,35 @@ def __call__(self, hidden_states: Array) -> Array:
# key shape after repeat: (B, S, H_v, D_k)
key = jnp.repeat(key, repeats, axis=2)
elif self.num_k_heads > self.num_v_heads and self.num_k_heads % self.num_v_heads == 0:
# This case might occur if key/query heads are more than value heads.
pass # No repeating needed for query/key in this case
pass

# TODO(parambole): Pass and update cache state for jax_chunk_gated_delta_rule
# core_attn_out shape: (B, S, H_v, D_v)
core_attn_out, _ = jax_chunk_gated_delta_rule(
recurrent_state = None
if model_mode != MODEL_MODE_TRAIN:
# Retrieve state from self.cache
recurrent_state = self.cache.recurrent_state.value

if recurrent_state.shape[0] != batch:
if recurrent_state.shape[0] == 1:
recurrent_state = jnp.broadcast_to(recurrent_state, (batch,) + recurrent_state.shape[1:])
else:
recurrent_state = recurrent_state[:batch]

core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=cfg.gdn_chunk_size,
initial_state=recurrent_state,
use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn,
compute_dtype=cfg.dtype,
)

if model_mode != MODEL_MODE_TRAIN:
# Update self.cache in place for both prefill and decode
self.cache.recurrent_state.value = recurrent_state_out

# =========================================================================
# STEP D: Final Output Stage
# =========================================================================
Expand Down Expand Up @@ -913,7 +962,7 @@ def __init__(
rngs=rngs,
)
else:
self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs)
self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs)

# Second LayerNorm, applied before the MoE block.
self.post_attention_layernorm = Qwen3NextRMSNorm(
Expand All @@ -937,7 +986,7 @@ def __call__(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache: None | jnp.ndarray = None,
kv_cache: None | dict[str, Array] = None,
attention_metadata: None | dict[str, Any] = None,
):
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
Expand All @@ -951,7 +1000,7 @@ def __call__(

# Conditionally apply either the Linear Attention or Full Attention block.
if isinstance(self.attention, Qwen3NextFullAttention):
attention_output, kv_cache = cast(Qwen3NextFullAttention, self.attention)(
attention_output, new_kv_cache = cast(Qwen3NextFullAttention, self.attention)(
hidden_states,
decoder_segment_ids,
decoder_positions,
Expand All @@ -960,10 +1009,13 @@ def __call__(
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
elif isinstance(self.attention, Qwen3NextGatedDeltaNet):
attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)(hidden_states)
else:
raise TypeError(f"Unexpected type for self.attention: {type(self.attention)}")
attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)(
hidden_states,
model_mode=model_mode,
kv_cache=None,
)
new_kv_cache = None

# First residual connection after attention
hidden_states = residual + attention_output
Expand All @@ -990,8 +1042,7 @@ def __call__(
layer_output,
self.activation_axis_names,
)

return layer_output, kv_cache
return layer_output, new_kv_cache


# -----------------------------------------
Expand Down
Loading