Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 103 additions & 11 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
AxisNames,
BATCH,
BATCH_NO_EXP,
CACHE_BATCH,
CACHE_BATCH_PREFILL,
CACHE_SEQUENCE,
CACHE_HEADS_NONE,
CACHE_KV,
Config,
DECODE_BATCH,
DECODE_LENGTH,
Expand Down Expand Up @@ -75,6 +80,9 @@
from maxtext.utils.sharding import create_sharding


PLACEHOLDER_SEQ_LEN = 1


class Indexer(nnx.Module):
"""Indexer for DeepSeek Sparse Attention (DSA).

Expand Down Expand Up @@ -108,6 +116,7 @@ def __init__(
self.rngs = rngs
self.dtype = config.dtype
self.weight_dtype = config.weight_dtype
self.max_target_length = config.max_target_length

self.n_heads = config.index_n_heads
self.head_dim = config.index_head_dim
Expand Down Expand Up @@ -167,6 +176,31 @@ def __init__(
rngs=self.rngs,
)

def update_indexer_cache(self, kv_cache, k, decoder_segment_ids, model_mode, previous_chunk):
"""Updates Indexer buffers by processing KV cache results."""
k_expanded = k[:, :, jnp.newaxis, :]
p_res, a_res = kv_cache(
key=k_expanded,
value=k_expanded,
decoder_segment_ids=decoder_segment_ids,
model_mode=model_mode,
use_ragged_attention=self.config.use_ragged_attention,
previous_chunk=previous_chunk,
)

# Filter out None values to handle PREFILL vs AR modes uniformly
active_results = [res for res in [p_res, a_res] if res is not None]

if not active_results:
return None, None

# Extract keys (index 0) and segment IDs (index 2)
keys = jnp.concatenate([res[0] for res in active_results], axis=1)
segs = jnp.concatenate([res[2] for res in active_results], axis=1)

# squeeze(2) removes the jnp.newaxis added above
return keys.squeeze(2), segs

def apply_partial_rope(
self,
inputs: Array,
Expand Down Expand Up @@ -220,6 +254,10 @@ def __call__(
inputs_kv: Array,
inputs_positions: Optional[Array | None] = None,
attention_mask: Optional[Array | None] = None,
decoder_segment_ids: Optional[Array | None] = None,
previous_chunk: Any = None,
kv_cache: Any = None,
model_mode: str = MODEL_MODE_TRAIN,
):
"""Computes the index score to determine the top-k relevant tokens.

Expand All @@ -244,6 +282,10 @@ def __call__(
`DEFAULT_MASK_VALUE` (a large negative number) prevent it.
Returns `None` if no masking is determined to be necessary based on
the inputs and configuration.
decoder_segment_ids: Segment IDs for decoder masking.
previous_chunk: Previous chunk info for prefill.
kv_cache: Key-value cache used when serving models.
model_mode: "train", "prefill", or "autoregressive".

Returns:
index_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
Expand All @@ -258,10 +300,6 @@ def __call__(
h: Number of Indexer Heads (index_n_heads)
d: Indexer Head Dimension (index_head_dim)
"""
# NOTE: If sequence length <= topk, indexer always selects all tokens.
if self.config.max_target_length <= self.index_topk:
return None, None, None

bsz, seqlen, _ = inputs_q.shape # s = t = seqlen

# Query Processing: Project from Latent low_rank_q
Expand All @@ -276,6 +314,16 @@ def __call__(
k = self.apply_partial_rope(k, inputs_positions=inputs_positions)
k = k.squeeze(2) # [b, s, 1, d] -> [b, s, d]

# Update and retrieve from cache if not training
cached_s = None
if model_mode != MODEL_MODE_TRAIN:
k_cached, cached_s = self.update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk)
k = k_cached if k_cached is not None else k

# NOTE: If the total available sequence length <= topk, indexer always selects all tokens.
if k.shape[1] <= self.index_topk:
return None, None, None

# Compute Index Scores
# QK product: relu(q @ k.T), [b, t, s, h]
# Similar to MQA, each key is shared by h query head
Expand All @@ -289,6 +337,12 @@ def __call__(
# Aggregate head-wise logits: logits @ weights
index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]

internal_padding_mask = None
if cached_s is not None:
# cached_s marks valid tokens from the original prefill step and all subsequent AR steps
internal_padding_mask = jnp.where(cached_s > 0, 0.0, DEFAULT_MASK_VALUE)
index_score += internal_padding_mask[:, None, :]

# Apply attention mask before TopK
if attention_mask is not None:
index_score += attention_mask
Expand All @@ -297,12 +351,15 @@ def __call__(
_, topk_indices = jax.lax.top_k(index_score, k=self.index_topk) # topk_indices [b, t, k]

# Create Sparse Index Mask: 0 and large negatives
index_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
index_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s]

# Re-apply attention mask after TopK: in case number of unmasked tokens < TopK
if attention_mask is not None:
index_mask += attention_mask

if internal_padding_mask is not None:
index_mask += internal_padding_mask[:, None, :]

return index_mask, topk_indices, index_score


Expand Down Expand Up @@ -615,16 +672,47 @@ def __init__(
indexer_rope.interleave = False
self.indexer = Indexer(
config,
rngs=rngs,
rotary_embedding=indexer_rope,
kernel_init=kernel_init,
quant=quant,
model_mode=model_mode,
rngs=rngs,
)
self.IndexerKVCache_0 = self.init_indexer_cache(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
else:
self.indexer = None
self.IndexerKVCache_0 = None

# Module attribute names must match names previously passed to Linen for checkpointing
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None

def init_indexer_cache(self, inputs_kv_shape: Tuple):
"""Initializes Indexer Cache."""
batch_size, _, _ = inputs_kv_shape
# Use standard KVCache to store keys. Values are unused but required by KVCache API.
# KVCache expects key_heads and value_heads. Since k is shared (MQA-like for Indexer),
# we use key_heads=1, value_heads=1.
return kvcache.KVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=PLACEHOLDER_SEQ_LEN,
value_seq_len=PLACEHOLDER_SEQ_LEN,
key_heads=1,
value_heads=1,
key_head_size=self.config.index_head_dim,
value_head_size=self.config.index_head_dim,
dtype=self.dtype,
kv_quant=None, # Quantization is not yet supported by the indexer.
prefill_cache_logical_axis_names=(CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
cache_logical_axis_names=(CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
prefill_cache_axis_order=(1, 2, 0, 3),
ar_cache_axis_order=(1, 2, 0, 3),
use_chunked_prefill=self.config.use_chunked_prefill,
model_mode=self.model_mode,
rngs=self.rngs,
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
"""Initializes the MLA-specific projections."""
# Assert required configuration parameters for MLA attention.
Expand Down Expand Up @@ -856,14 +944,13 @@ def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
# and max_target_length, not the passed seq_len.
# We can use a placeholder value. The correct fix might involve refactoring
# MlaKVCache.
placeholder_seq_len = 1

return kvcache.MlaKVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=placeholder_seq_len,
value_seq_len=placeholder_seq_len,
key_seq_len=PLACEHOLDER_SEQ_LEN,
value_seq_len=PLACEHOLDER_SEQ_LEN,
key_head_size=self.kv_lora_rank,
value_head_size=self.qk_rope_head_dim,
dtype=self.dtype,
Expand Down Expand Up @@ -1002,6 +1089,9 @@ def __call__(
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)

if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)

query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
if self.config.force_q_layout:
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
Expand All @@ -1015,8 +1105,6 @@ def __call__(
# Indexer Logic
index_mask = None
if self.use_sparse_indexer:
if model_mode != MODEL_MODE_TRAIN:
raise NotImplementedError("Sparse indexer has not implemented for inference yet.")
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
attention_mask = self.attention_op.generate_attention_mask(
query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask
Expand All @@ -1028,6 +1116,10 @@ def __call__(
inputs_kv=inputs_kv,
inputs_positions=inputs_positions,
attention_mask=attention_mask,
decoder_segment_ids=decoder_segment_ids,
previous_chunk=previous_chunk,
kv_cache=self.IndexerKVCache_0,
model_mode=model_mode,
)

# Check if we need QK Clip stats
Expand Down
Loading
Loading