From 2fed9037b5293080fb5129ffc34331c6c9337cee Mon Sep 17 00:00:00 2001 From: Rissy Ran Date: Thu, 19 Feb 2026 00:28:49 +0000 Subject: [PATCH] Enable Indexer cache for DS v3.2 decoding --- src/maxtext/layers/attention_mla.py | 114 +++++- src/maxtext/layers/attention_op.py | 559 +++++++++------------------- tests/unit/attention_test.py | 69 +++- tests/utils/attention_test_util.py | 3 + 4 files changed, 339 insertions(+), 406 deletions(-) diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index 58fe48de92..84b873c251 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -38,6 +38,11 @@ AxisNames, BATCH, BATCH_NO_EXP, + CACHE_BATCH, + CACHE_BATCH_PREFILL, + CACHE_SEQUENCE, + CACHE_HEADS_NONE, + CACHE_KV, Config, DECODE_BATCH, DECODE_LENGTH, @@ -75,6 +80,9 @@ from maxtext.utils.sharding import create_sharding +PLACEHOLDER_SEQ_LEN = 1 + + class Indexer(nnx.Module): """Indexer for DeepSeek Sparse Attention (DSA). @@ -108,6 +116,7 @@ def __init__( self.rngs = rngs self.dtype = config.dtype self.weight_dtype = config.weight_dtype + self.max_target_length = config.max_target_length self.n_heads = config.index_n_heads self.head_dim = config.index_head_dim @@ -167,6 +176,31 @@ def __init__( rngs=self.rngs, ) + def update_indexer_cache(self, kv_cache, k, decoder_segment_ids, model_mode, previous_chunk): + """Updates Indexer buffers by processing KV cache results.""" + k_expanded = k[:, :, jnp.newaxis, :] + p_res, a_res = kv_cache( + key=k_expanded, + value=k_expanded, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + use_ragged_attention=self.config.use_ragged_attention, + previous_chunk=previous_chunk, + ) + + # Filter out None values to handle PREFILL vs AR modes uniformly + active_results = [res for res in [p_res, a_res] if res is not None] + + if not active_results: + return None, None + + # Extract keys (index 0) and segment IDs (index 2) + keys = jnp.concatenate([res[0] for res in active_results], axis=1) + segs = jnp.concatenate([res[2] for res in active_results], axis=1) + + # squeeze(2) removes the jnp.newaxis added above + return keys.squeeze(2), segs + def apply_partial_rope( self, inputs: Array, @@ -220,6 +254,10 @@ def __call__( inputs_kv: Array, inputs_positions: Optional[Array | None] = None, attention_mask: Optional[Array | None] = None, + decoder_segment_ids: Optional[Array | None] = None, + previous_chunk: Any = None, + kv_cache: Any = None, + model_mode: str = MODEL_MODE_TRAIN, ): """Computes the index score to determine the top-k relevant tokens. @@ -244,6 +282,10 @@ def __call__( `DEFAULT_MASK_VALUE` (a large negative number) prevent it. Returns `None` if no masking is determined to be necessary based on the inputs and configuration. + decoder_segment_ids: Segment IDs for decoder masking. + previous_chunk: Previous chunk info for prefill. + kv_cache: Key-value cache used when serving models. + model_mode: "train", "prefill", or "autoregressive". Returns: index_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens @@ -258,10 +300,6 @@ def __call__( h: Number of Indexer Heads (index_n_heads) d: Indexer Head Dimension (index_head_dim) """ - # NOTE: If sequence length <= topk, indexer always selects all tokens. - if self.config.max_target_length <= self.index_topk: - return None, None, None - bsz, seqlen, _ = inputs_q.shape # s = t = seqlen # Query Processing: Project from Latent low_rank_q @@ -276,6 +314,16 @@ def __call__( k = self.apply_partial_rope(k, inputs_positions=inputs_positions) k = k.squeeze(2) # [b, s, 1, d] -> [b, s, d] + # Update and retrieve from cache if not training + cached_s = None + if model_mode != MODEL_MODE_TRAIN: + k_cached, cached_s = self.update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk) + k = k_cached if k_cached is not None else k + + # NOTE: If the total available sequence length <= topk, indexer always selects all tokens. + if k.shape[1] <= self.index_topk: + return None, None, None + # Compute Index Scores # QK product: relu(q @ k.T), [b, t, s, h] # Similar to MQA, each key is shared by h query head @@ -289,6 +337,12 @@ def __call__( # Aggregate head-wise logits: logits @ weights index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s] + internal_padding_mask = None + if cached_s is not None: + # cached_s marks valid tokens from the original prefill step and all subsequent AR steps + internal_padding_mask = jnp.where(cached_s > 0, 0.0, DEFAULT_MASK_VALUE) + index_score += internal_padding_mask[:, None, :] + # Apply attention mask before TopK if attention_mask is not None: index_score += attention_mask @@ -297,12 +351,15 @@ def __call__( _, topk_indices = jax.lax.top_k(index_score, k=self.index_topk) # topk_indices [b, t, k] # Create Sparse Index Mask: 0 and large negatives - index_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s] + index_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s] # Re-apply attention mask after TopK: in case number of unmasked tokens < TopK if attention_mask is not None: index_mask += attention_mask + if internal_padding_mask is not None: + index_mask += internal_padding_mask[:, None, :] + return index_mask, topk_indices, index_score @@ -615,16 +672,47 @@ def __init__( indexer_rope.interleave = False self.indexer = Indexer( config, - rngs=rngs, rotary_embedding=indexer_rope, kernel_init=kernel_init, quant=quant, model_mode=model_mode, + rngs=rngs, ) + self.IndexerKVCache_0 = self.init_indexer_cache(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None + else: + self.indexer = None + self.IndexerKVCache_0 = None # Module attribute names must match names previously passed to Linen for checkpointing self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None + def init_indexer_cache(self, inputs_kv_shape: Tuple): + """Initializes Indexer Cache.""" + batch_size, _, _ = inputs_kv_shape + # Use standard KVCache to store keys. Values are unused but required by KVCache API. + # KVCache expects key_heads and value_heads. Since k is shared (MQA-like for Indexer), + # we use key_heads=1, value_heads=1. + return kvcache.KVCache( + max_prefill_length=self.max_prefill_predict_length, + max_target_length=self.max_target_length, + batch=batch_size, + key_seq_len=PLACEHOLDER_SEQ_LEN, + value_seq_len=PLACEHOLDER_SEQ_LEN, + key_heads=1, + value_heads=1, + key_head_size=self.config.index_head_dim, + value_head_size=self.config.index_head_dim, + dtype=self.dtype, + kv_quant=None, # Quantization is not yet supported by the indexer. + prefill_cache_logical_axis_names=(CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV), + cache_logical_axis_names=(CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV), + prefill_cache_axis_order=(1, 2, 0, 3), + ar_cache_axis_order=(1, 2, 0, 3), + use_chunked_prefill=self.config.use_chunked_prefill, + model_mode=self.model_mode, + rngs=self.rngs, + ) + def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: """Initializes the MLA-specific projections.""" # Assert required configuration parameters for MLA attention. @@ -856,14 +944,13 @@ def init_mla_kv_caches(self, inputs_kv_shape: Tuple): # and max_target_length, not the passed seq_len. # We can use a placeholder value. The correct fix might involve refactoring # MlaKVCache. - placeholder_seq_len = 1 return kvcache.MlaKVCache( max_prefill_length=self.max_prefill_predict_length, max_target_length=self.max_target_length, batch=batch_size, - key_seq_len=placeholder_seq_len, - value_seq_len=placeholder_seq_len, + key_seq_len=PLACEHOLDER_SEQ_LEN, + value_seq_len=PLACEHOLDER_SEQ_LEN, key_head_size=self.kv_lora_rank, value_head_size=self.qk_rope_head_dim, dtype=self.dtype, @@ -1002,6 +1089,9 @@ def __call__( inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) + if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None: + decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32) + query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) if self.config.force_q_layout: query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) @@ -1015,8 +1105,6 @@ def __call__( # Indexer Logic index_mask = None if self.use_sparse_indexer: - if model_mode != MODEL_MODE_TRAIN: - raise NotImplementedError("Sparse indexer has not implemented for inference yet.") # generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len] attention_mask = self.attention_op.generate_attention_mask( query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask @@ -1028,6 +1116,10 @@ def __call__( inputs_kv=inputs_kv, inputs_positions=inputs_positions, attention_mask=attention_mask, + decoder_segment_ids=decoder_segment_ids, + previous_chunk=previous_chunk, + kv_cache=self.IndexerKVCache_0, + model_mode=model_mode, ) # Check if we need QK Clip stats diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index ced98b5de2..3cdb6330b5 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -99,16 +99,12 @@ global_k_layout = "" global_v_layout = "" -dynamic_vector_slice_in_dim = jax.vmap( - lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None) -) +dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) def validate_compute_axis_order(s: AxisIdxes) -> None: valid_compute_axis_order = ((0, 1, 2, 3), (0, 2, 1, 3)) - if ( - s not in valid_compute_axis_order - ): # currently supported compute_axis_order + if s not in valid_compute_axis_order: # currently supported compute_axis_order raise ValueError( "Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order, @@ -141,21 +137,15 @@ def apply_mask_to_logits(logits: Array, mask: Array): Returns: Masked logits. """ - return jnp.where( - (mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE - ) + return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) def validate_gpu_flash_attention(sinks: Array | None, record_max_logits: bool) -> None: """Helper function to check for unsupported features with flash attention on GPU.""" if sinks is not None: - raise ValueError( - "The flash attention with sinks is not supported on GPU yet." - ) + raise ValueError("The flash attention with sinks is not supported on GPU yet.") if record_max_logits: - raise NotImplementedError( - "record_max_logits (QK-Clip) is not supported for GPU flash attention kernels yet." - ) + raise NotImplementedError("record_max_logits (QK-Clip) is not supported for GPU flash attention kernels yet.") # TODO(agagik): change splash_attention_mask._ComputableMask to be non protected @@ -222,17 +212,17 @@ def __eq__(self, other: object): ) def __hash__(self): - return hash(( - type(self), - self.shape, - self.chunk_size, - self.q_sequence.tobytes() if self.q_sequence is not None else None, - )) + return hash( + ( + type(self), + self.shape, + self.chunk_size, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + ) + ) -def _generate_chunk_attention_mask( - mask_shape: tuple[int, int], chunk_size: int, q_offset: int = 0 -) -> jax.Array: +def _generate_chunk_attention_mask(mask_shape: tuple[int, int], chunk_size: int, q_offset: int = 0) -> jax.Array: """Generates an explicit boolean mask for chunked causal attention. This function computes the full boolean mask array where True indicates @@ -295,9 +285,9 @@ def _make_bidirectional_block_mask(bidirectional_mask): """ q_block_indices = _make_block_mask_indices(bidirectional_mask) kv_block_indices = q_block_indices - bidirectional_block_mask = ( - kv_block_indices[:, None, :] == q_block_indices[..., None] - ) & (q_block_indices[..., None] > 0) + bidirectional_block_mask = (kv_block_indices[:, None, :] == q_block_indices[..., None]) & ( + q_block_indices[..., None] > 0 + ) return bidirectional_block_mask @@ -546,20 +536,14 @@ def maybe_create_nnx(einsum, *args): s_ar = self.max_target_length # Dummy query/key/value shapes as before... - dummy_query_prefill = jnp.zeros( - (b, t_prefill, n_kv, g, d), dtype=self.dtype - ) + dummy_query_prefill = jnp.zeros((b, t_prefill, n_kv, g, d), dtype=self.dtype) dummy_key_prefill = jnp.zeros((b, s_prefill, n_kv, d), dtype=self.dtype) dummy_query_ar = jnp.zeros((b, t_ar, n_kv, g, d), dtype=self.dtype) dummy_key_ar = jnp.zeros((b, s_ar, n_kv, d), dtype=self.dtype) - dummy_attn_weights_prefill = jnp.zeros( - (b, n_kv, g, t_prefill, s_prefill), dtype=jnp.float32 - ) + dummy_attn_weights_prefill = jnp.zeros((b, n_kv, g, t_prefill, s_prefill), dtype=jnp.float32) dummy_value_prefill = jnp.zeros((b, s_prefill, n_kv, d), dtype=self.dtype) - dummy_attn_weights_ar = jnp.zeros( - (b, n_kv, g, t_ar, s_ar), dtype=jnp.float32 - ) + dummy_attn_weights_ar = jnp.zeros((b, n_kv, g, t_ar, s_ar), dtype=jnp.float32) dummy_value_ar = jnp.zeros((b, s_ar, n_kv, d), dtype=self.dtype) # Prefill AqtEinsum instances @@ -595,21 +579,14 @@ def maybe_create_nnx(einsum, *args): self.AqtEinsum_3 = jnp.einsum def _logical_to_mesh_axes(self, logical_name): - return logical_to_mesh_axes( - logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules - ) + return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules) - def check_attention_inputs( - self, query: Array, key: Array | KVTensor, value: Array | KVTensor - ) -> None: + def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None: """Check attention inputs.""" - assert ( - key.ndim == value.ndim - ), f"k (dim {key.ndim}), v (dim {value.ndim}) must have same rank." + assert key.ndim == value.ndim, f"k (dim {key.ndim}), v (dim {value.ndim}) must have same rank." assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - f"{query.shape[:-3]=}, {key.shape[:-3]=}, {value.shape[:-3]=} batch" - " dims must match." + f"{query.shape[:-3]=}, {key.shape[:-3]=}, {value.shape[:-3]=} batch" " dims must match." ) assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." assert key.shape[-3] == value.shape[-3], "k, v lengths must match." @@ -699,11 +676,8 @@ def generate_attention_mask( Chunked Prefills - ArXiv:2308.16369 (https://arxiv.org/abs/2308.16369) """ mask = None - if model_mode == MODEL_MODE_AUTOREGRESSIVE: - mask = ( - decoder_segment_ids[:, None, None, None, :] - == DECODING_ACTIVE_SEQUENCE_INDICATOR - ) + if model_mode == MODEL_MODE_AUTOREGRESSIVE and decoder_segment_ids is not None: + mask = decoder_segment_ids[:, None, None, None, :] == DECODING_ACTIVE_SEQUENCE_INDICATOR elif decoder_segment_ids is not None: mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :] mask = mask[:, None, None, :, :] @@ -721,10 +695,7 @@ def generate_attention_mask( causal_mask = None # We enforce causality except for AUTOREGRESSION - if ( - model_mode != MODEL_MODE_AUTOREGRESSIVE - and self.attention_type != AttentionType.FULL - ): + if model_mode != MODEL_MODE_AUTOREGRESSIVE and self.attention_type != AttentionType.FULL: mask_shape = (q_seq_len, kv_seq_len) # row_ids indicates the position of query # col_ids indicates the position of kv @@ -742,22 +713,15 @@ def generate_attention_mask( elif causal_mask is not None: output_mask = causal_mask - if ( - self.attention_type == AttentionType.LOCAL_SLIDING - and output_mask is not None - ): + if self.attention_type == AttentionType.LOCAL_SLIDING and output_mask is not None: if self.sliding_window_size is None: - raise ValueError( - "Sliding_window_size must be set if Local Sliding attention type" - ) + raise ValueError("Sliding_window_size must be set if Local Sliding attention type") - row_ids_sliding = ( - jax.lax.broadcasted_iota(jnp.int32, (q_seq_len, 1), 0) + next_pos - ) + row_ids_sliding = jax.lax.broadcasted_iota(jnp.int32, (q_seq_len, 1), 0) + next_pos col_ids_sliding = jax.lax.broadcasted_iota(jnp.int32, (1, kv_seq_len), 1) - sliding_mask = ( - col_ids_sliding > (row_ids_sliding - self.sliding_window_size) - ) & (col_ids_sliding <= row_ids_sliding) + sliding_mask = (col_ids_sliding > (row_ids_sliding - self.sliding_window_size)) & ( + col_ids_sliding <= row_ids_sliding + ) output_mask = sliding_mask * output_mask elif self.attention_type == AttentionType.CHUNK and output_mask is not None: mask_shape = (q_seq_len, kv_seq_len) @@ -772,11 +736,7 @@ def generate_attention_mask( image_mask = _make_bidirectional_block_mask(bidirectional_mask) output_mask = output_mask | image_mask[:, None, None, ...] - return ( - jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) - if output_mask is not None - else None - ) + return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None def calculate_moba_gate_logic(self, q_item, k_item, q_pos_item): """Computes the block-level MoBA gating intermediates for one batch item. @@ -800,18 +760,14 @@ def calculate_moba_gate_logic(self, q_item, k_item, q_pos_item): kv_len, n_kv_heads, _ = k_item.shape g = n_q_heads // n_kv_heads - q_item_f32 = q_item.astype(jnp.float32).reshape( - q_len, n_kv_heads, g, head_dim - ) # grouped-query attention (GQA) + q_item_f32 = q_item.astype(jnp.float32).reshape(q_len, n_kv_heads, g, head_dim) # grouped-query attention (GQA) moba_chunk_size = self.config.moba_chunk_size moba_topk = self.config.moba_topk num_block = math.ceil(kv_len / moba_chunk_size) - block_ids = ( - jnp.arange(kv_len, dtype=jnp.int32) // moba_chunk_size - ) # chunk index for each key position + block_ids = jnp.arange(kv_len, dtype=jnp.int32) // moba_chunk_size # chunk index for each key position # Sum key vectors per chunk so we can later average within each block. key_gate_weight_sum = jax.ops.segment_sum( k_item.astype(jnp.float32), block_ids, num_segments=num_block @@ -828,9 +784,7 @@ def calculate_moba_gate_logic(self, q_item, k_item, q_pos_item): ) # [num_block, n_kv_heads, head_dim] # Take the dot product between each query and every key chunk to get a score. - gate = jnp.einsum( - "skgd,Nkd->kgsN", q_item_f32, key_gate_weight - ) # [n_kv_heads, g, q_len, num_block] + gate = jnp.einsum("skgd,Nkd->kgsN", q_item_f32, key_gate_weight) # [n_kv_heads, g, q_len, num_block] gate_before_masking = gate q_block_idx = q_pos_item // moba_chunk_size # chunk id for each query @@ -848,24 +802,16 @@ def calculate_moba_gate_logic(self, q_item, k_item, q_pos_item): gate_after_masking = gate k_for_topk = min(moba_topk, num_block) - gate_top_k_val, gate_top_k_idx = jax.lax.top_k( - gate, k=k_for_topk - ) # [n_kv_heads, g, q_len, k_for_topk] - gate_top_k_val_min = jnp.min( - gate_top_k_val, axis=-1, keepdims=True - ) # [n_kv_heads, g, q_len, 1] - need_attend_threshold_mask = ( - gate >= gate_top_k_val_min - ) # [n_kv_heads, g, q_len, num_block] + gate_top_k_val, gate_top_k_idx = jax.lax.top_k(gate, k=k_for_topk) # [n_kv_heads, g, q_len, k_for_topk] + gate_top_k_val_min = jnp.min(gate_top_k_val, axis=-1, keepdims=True) # [n_kv_heads, g, q_len, 1] + need_attend_threshold_mask = gate >= gate_top_k_val_min # [n_kv_heads, g, q_len, num_block] # Tie-breaking: if multiple blocks have the same gate value as the k-th # block, we only select the ones that appear in the top-k indices. gate_idx_mask = jnp.sum( jax.nn.one_hot(gate_top_k_idx, num_block, dtype=jnp.bool_), axis=-2 ) # [n_kv_heads, g, q_len, num_block] - need_attend = jnp.logical_and( - need_attend_threshold_mask, gate_idx_mask - ) # [n_kv_heads, g, q_len, num_block] + need_attend = jnp.logical_and(need_attend_threshold_mask, gate_idx_mask) # [n_kv_heads, g, q_len, num_block] return ( key_gate_weight, @@ -886,9 +832,7 @@ def generate_moba_mask_single_item(self, q_item, k_item, q_positions): moba_chunk_size = self.config.moba_chunk_size # Run the gating logic to find which key blocks this query cares about. - *_, need_attend = self.calculate_moba_gate_logic( - q_item, k_item, q_positions - ) + *_, need_attend = self.calculate_moba_gate_logic(q_item, k_item, q_positions) # Expand the block-level `need_attend` mask to a token-level mask. k_block_indices = jnp.arange(kv_len, dtype=jnp.int32) // moba_chunk_size @@ -906,9 +850,7 @@ def generate_moba_mask_single_item(self, q_item, k_item, q_positions): # Return the additive mask for this batch item. return gate - def _generate_moba_mask( - self, query: Array, key: Array, q_positions: Array - ) -> Array: + def _generate_moba_mask(self, query: Array, key: Array, q_positions: Array) -> Array: """Builds the token-level MoBA additive mask for the whole batch. Args: @@ -926,9 +868,7 @@ def _generate_moba_mask( `0.` for permitted positions and `-inf` for masked ones. """ # vmap over the batch dimension of query and key. q_positions is constant across the batch. - moba_mask = jax.vmap( - self.generate_moba_mask_single_item, in_axes=(0, 0, None) - )(query, key, q_positions) + moba_mask = jax.vmap(self.generate_moba_mask_single_item, in_axes=(0, 0, None))(query, key, q_positions) return moba_mask def apply_attention( @@ -974,10 +914,7 @@ def apply_attention( # ragged paged attention kernel in `Attention.__call__`. elif ( self.attention_kernel == "dot_product" - or ( - self.attention_kernel == "autoselected" - and model_mode == MODEL_MODE_AUTOREGRESSIVE - ) + or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE) or (self.attention_kernel == "autoselected" and length < 128) or (self.attention_kernel == "paged") or (self.attention_kernel == "vllm_rpa") @@ -1004,8 +941,10 @@ def apply_attention( value = value.dequant() if model_mode == MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) out, max_logits = self.tpu_flash_attention( query, @@ -1056,9 +995,7 @@ def apply_attention( value, q_heads_per_kv_head, axis=head_axis ) # value shape [batch_size, kv_seq_len, num_kv_heads, head_dim] - out = gpu_pallas_attention.mha( - query, key, value, decoder_segment_ids, sm_scale=1.0, causal=True - ) + out = gpu_pallas_attention.mha(query, key, value, decoder_segment_ids, sm_scale=1.0, causal=True) return out, None, None elif self.attention_kernel == "cudnn_flash_te": validate_gpu_flash_attention(sinks, record_max_logits) @@ -1067,12 +1004,12 @@ def apply_attention( if isinstance(value, KVTensor): value = value.dequant() if model_mode == MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) return ( - self.cudnn_flash_attention( - query, key, value, decoder_segment_ids, model_mode - ), + self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None, ) @@ -1083,9 +1020,7 @@ def apply_attention( if isinstance(value, KVTensor): value = value.dequant() return ( - *self.cudnn_jax_flash_attention( - query, key, value, decoder_segment_ids, model_mode - ), + *self.cudnn_jax_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, ) else: @@ -1147,9 +1082,7 @@ def wrap_ragged_attention( ) return local_out, local_max, local_sum - local_out, local_max, local_sum = wrap_ragged_attention( - q_for_gqa, k, v, lengths, block_size - ) + local_out, local_max, local_sum = wrap_ragged_attention(q_for_gqa, k, v, lengths, block_size) # Reshape local_out, local_max and local_sum to match Maxtext requirements local_out = local_out.reshape(batch_size, q_length, q_heads, head_dim) @@ -1167,9 +1100,7 @@ def tpu_ragged_attention( ) -> tuple[Array, Array, Array]: """Ragged Attention.""" if isinstance(query, KVTensor): - raise TypeError( - "Ragged attention does not currently support quantized tensors." - ) + raise TypeError("Ragged attention does not currently support quantized tensors.") b = self._logical_to_mesh_axes(self.ragged_lengths_names) bsnd = self._logical_to_mesh_axes(self.cache_logical_axis_names) @@ -1219,36 +1150,22 @@ def tpu_flash_attention( sink_axis_names = self._logical_to_mesh_axes((HEAD,)) if decoder_segment_ids is not None: if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - segment_axis_names_q = self._logical_to_mesh_axes( - (BATCH_NO_EXP, Q_LENGTH) - ) - segment_axis_names_kv = self._logical_to_mesh_axes( - (BATCH_NO_EXP, KV_LENGTH) - ) + segment_axis_names_q = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH)) + segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH)) else: - segment_axis_names_q = self._logical_to_mesh_axes( - (BATCH, Q_LENGTH_NO_EXP) - ) + segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP)) segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH)) if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - axis_names_splash_kernel = self._logical_to_mesh_axes( - self.flash_axis_names_splash_kernel_ep - ) + axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep) - index_mask_axis_names = self._logical_to_mesh_axes( - (BATCH_NO_EXP, Q_LENGTH, KV_LENGTH) - ) + index_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH)) else: - axis_names_splash_kernel = self._logical_to_mesh_axes( - self.flash_axis_names_splash_kernel - ) + axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) - index_mask_axis_names = self._logical_to_mesh_axes( - (BATCH, Q_LENGTH, KV_LENGTH) - ) + index_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel @@ -1282,9 +1199,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): block_kv_compute=min(global_block_kv_compute, key.shape[2]), block_q_dkv=min(global_block_q_dkv, query.shape[2]), block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), - block_kv_dkv_compute=min( - global_block_kv_dkv_compute, query.shape[2] - ), + block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), use_fused_bwd_kernel=True, # tokamax only supports fused bwd kernel q_layout=tokamax_splash_kernel.QKVLayout[global_q_layout], k_layout=tokamax_splash_kernel.QKVLayout[global_k_layout], @@ -1305,9 +1220,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): ) if config.cost_estimate_flops_bwd >= 0 else None, - dq_reduction_steps=config.dq_reduction_steps - if config.dq_reduction_steps > 0 - else None, + dq_reduction_steps=config.dq_reduction_steps if config.dq_reduction_steps > 0 else None, use_experimental_scheduler=config.use_splash_scheduler, ) else: @@ -1317,15 +1230,9 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): block_kv_compute=min(global_block_kv_compute, key.shape[2]), block_q_dkv=min(global_block_q_dkv, query.shape[2]), block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), - block_kv_dkv_compute=min( - global_block_kv_dkv_compute, query.shape[2] - ), - block_q_dq=None - if global_use_fused_bwd_kernel - else min(global_block_q_dq, query.shape[2]), - block_kv_dq=None - if global_use_fused_bwd_kernel - else min(global_block_kv_dq, query.shape[2]), + block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), + block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]), + block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]), use_fused_bwd_kernel=global_use_fused_bwd_kernel, q_layout=splash_attention_kernel.QKVLayout[global_q_layout], k_layout=splash_attention_kernel.QKVLayout[global_k_layout], @@ -1335,11 +1242,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): sa_config = create_sa_config(self.config, query, key, attn_logits_soft_cap) mask_shape = (query.shape[2], key.shape[2]) # (q_seq_len, kv_seq_len) - mask_module = ( - tokamax_splash_mask - if self.config.use_tokamax_splash - else splash_attention_mask - ) + mask_module = tokamax_splash_mask if self.config.use_tokamax_splash else splash_attention_mask if self.attention_type == AttentionType.FULL: mask = mask_module.FullMask(mask_shape) else: @@ -1353,9 +1256,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): # Apply local masking if local sliding attention is enabled. if self.attention_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: - raise ValueError( - "Sliding_window_size must be set if Local Sliding attention type" - ) + raise ValueError("Sliding_window_size must be set if Local Sliding attention type") mask &= mask_module.LocalMask( shape=(query.shape[2], key.shape[2]), window_size=(self.sliding_window_size, self.sliding_window_size), @@ -1363,9 +1264,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): ) elif self.attention_type == AttentionType.CHUNK: if self.chunk_attn_window_size is None: - raise ValueError( - "chunk_attn_window_size must be set for chunk attention type" - ) + raise ValueError("chunk_attn_window_size must be set for chunk attention type") mask &= ChunkedCausalMask( shape=(query.shape[2], key.shape[2]), @@ -1377,9 +1276,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): # Create mask single_head_mask = mask # tokamax now just uses a single mask and assumes broadcast to all heads if self.config.use_max_logit_estimate > 0: - sa_config = dataclasses.replace( - sa_config, max_logit_const=self.config.use_max_logit_estimate - ) + sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) # Create the splash attention kernel object separately, jit it for performance @partial( @@ -1397,36 +1294,22 @@ def wrap_splash_kernel(single_head_mask, shard_head_size=1): ) return splash_kernel - logical_axis_rules_head = np.array([ - self.mesh.shape[physical_axes] - for physical_axes in dict(self.config.logical_axis_rules)[HEAD] - ]) + logical_axis_rules_head = np.array( + [self.mesh.shape[physical_axes] for physical_axes in dict(self.config.logical_axis_rules)[HEAD]] + ) shard_head_size = np.prod(logical_axis_rules_head) splash_kernel = wrap_splash_kernel(single_head_mask, int(shard_head_size)) if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - segment_axis_names_splash_kernel = self._logical_to_mesh_axes(( - Q_LENGTH, - )) + segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) else: - segment_axis_names_splash_kernel = self._logical_to_mesh_axes(( - Q_LENGTH_NO_EXP, - )) - elif ( - self.config.use_jax_splash - and self.config.expert_shard_attention_option == EP_AS_FSDP - ): + segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) + elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP: if self.config.use_max_logit_estimate > 0: - sa_config = dataclasses.replace( - sa_config, max_logit_const=self.config.use_max_logit_estimate - ) - segment_axis_names_splash_kernel = nn.logical_to_mesh_axes(( - Q_LENGTH_NO_EXP, - )) + sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) + segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) else: # Create multi-head mask - multi_head_mask = splash_attention_mask.MultiHeadMask( - masks=(mask,) * query.shape[1] - ) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) # Create the splash attention kernel object separately, jit it for performance @partial( @@ -1447,18 +1330,13 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1): ) return splash_kernel - logical_axis_rules_head = np.array([ - self.mesh.shape[physical_axes] - for physical_axes in dict(self.config.logical_axis_rules)[HEAD] - ]) + logical_axis_rules_head = np.array( + [self.mesh.shape[physical_axes] for physical_axes in dict(self.config.logical_axis_rules)[HEAD]] + ) shard_head_size = np.prod(logical_axis_rules_head) splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) - named_sharding = jax.sharding.NamedSharding( - self.mesh, axis_names_splash_kernel - ) - segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec( - named_sharding - ) + named_sharding = jax.sharding.NamedSharding(self.mesh, axis_names_splash_kernel) + segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) # Now call the function wrap_flash_attention which does the actual computation. # The splash kernel is passed as a parameter to the function. Since we have the shard map @@ -1521,12 +1399,8 @@ def wrap_flash_attention( # the K and V to be contiguous. Note that K and V are not sharded over the sequence aka context axis # This was we get the unsharded unpermuted key and value tensors if cp_size > 1 and load_balanced_context_parallel: - key = max_utils.reorder_sequence( - tensor=key, cp_size=cp_size, seq_dim=2, to_contiguous=True - ) - value = max_utils.reorder_sequence( - tensor=value, cp_size=cp_size, seq_dim=2, to_contiguous=True - ) + key = max_utils.reorder_sequence(tensor=key, cp_size=cp_size, seq_dim=2, to_contiguous=True) + value = max_utils.reorder_sequence(tensor=value, cp_size=cp_size, seq_dim=2, to_contiguous=True) decoder_segment_ids_unpermuted = max_utils.reorder_sequence( tensor=decoder_segment_ids_kv, cp_size=cp_size, @@ -1541,9 +1415,7 @@ def wrap_flash_attention( ) else: # if cp=1, decoder_segment_ids_q is the same as decoder_segment_ids_kv - decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds( - decoder_segment_ids_q, decoder_segment_ids_kv - ) + decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv) else: decoder_segment_ids_tuple = None @@ -1568,14 +1440,10 @@ def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask): index_mask = jnp.isclose(index_mask, 0.0) if record_max_logits: - attention_output, max_logits = attn_fn( - query, key, value, decoder_segment_ids_tuple, sinks, index_mask - ) + attention_output, max_logits = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask) return attention_output, max_logits else: - attention_output, _ = attn_fn( - query, key, value, decoder_segment_ids_tuple, sinks, index_mask - ) + attention_output, _ = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask) return attention_output, None else: kernel = partial(splash_kernel, max_logit_value=max_logit_value) @@ -1621,9 +1489,7 @@ def kernel_fn(q, k, v, d, s): return attention_output, None - def _maybe_shard_with_pspec( - inputs, pspec: jax.sharding.PartitionSpec | None - ): + def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): # decoder_segment_ids can be None if pspec is None: return None @@ -1639,12 +1505,8 @@ def _maybe_shard_with_pspec( query = _maybe_shard_with_pspec(query, axis_names_q) key = _maybe_shard_with_pspec(key, axis_names_kv) value = _maybe_shard_with_pspec(value, axis_names_kv) - decoder_segment_ids_q = _maybe_shard_with_pspec( - decoder_segment_ids, segment_axis_names_q - ) - decoder_segment_ids_kv = _maybe_shard_with_pspec( - decoder_segment_ids, segment_axis_names_kv - ) + decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) + decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) index_mask = _maybe_shard_with_pspec(index_mask, index_mask_axis_names) @@ -1701,12 +1563,8 @@ def cudnn_flash_attention( # Initialize default attention configuration sliding_window_size = None mask_type = "padding_causal" - qkv_layout = ( # Non-packed format: 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - "BSHD_BSHD_BSHD" - ) - max_segments_per_seq = ( - 1 # max number of segments per sequence; for non-packed its 1 - ) + qkv_layout = "BSHD_BSHD_BSHD" # Non-packed format: 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + max_segments_per_seq = 1 # max number of segments per sequence; for non-packed its 1 # Handle local sliding window attention if configured if self.attention_type == AttentionType.LOCAL_SLIDING: @@ -1714,25 +1572,17 @@ def cudnn_flash_attention( # Handle packing configurations if self.config.packing and self.config.dataset_type != "synthetic": - qkv_layout = ( # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' - "THD_THD_THD" - ) + qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - attn_mask = SequenceDescriptor.from_segment_ids_and_pos( - segment_ids=decoder_segment_ids, segment_pos=None - ) + attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None) # Create dummy SequenceDescriptor for lazy_init dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( - segment_ids=dummy_segment_ids, segment_pos=None - ) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) max_segments_per_seq = self.config.max_segments_per_seq elif using_context_parallelism: if self.attention_type == AttentionType.LOCAL_SLIDING: - raise AssertionError( - "Sliding window attention is not supported for context parallelism" - ) + raise AssertionError("Sliding window attention is not supported for context parallelism") # Context parallelism without packing: only supports causal masking attn_mask = None dummy_attn_mask = None @@ -1743,12 +1593,8 @@ def cudnn_flash_attention( (1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8, ) - attn_mask = self.generate_attention_mask( - query, key, decoder_segment_ids, model_mode - ) - attn_mask = jnp.where( - (attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1 - ).astype(jnp.uint8) + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -1873,9 +1719,7 @@ def compute_local_attention( if sinks is not None: # broadcast sinks to match the attn weights dimension and combine sinks_param = sinks.astype(attn_weights.dtype) # (n_q,) - sinks_logits = sinks_param[ - jnp.newaxis, :, jnp.newaxis, jnp.newaxis - ] # (1, n_q, 1, 1) + sinks_logits = sinks_param[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] # (1, n_q, 1, 1) sinks_logits = jnp.broadcast_to(sinks_logits, (b, n_q, t, 1)) logits = jnp.concatenate([logits, sinks_logits], axis=-1) @@ -1890,44 +1734,26 @@ def compute_local_attention( local_max = jnp.transpose(local_max, (0, 2, 1, 3)) # (b, t, n_q, 1) local_sum = jnp.transpose(local_sum, (0, 2, 1, 3)) # (b, t, n_q, 1) - local_out = self.wv_product( - local_exps, value, model_mode, wv_product_einsum - ) - if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode( - q_seq_len - ): - local_out = partitioning.with_sharding_constraint( - local_out, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV) - ) + local_out = self.wv_product(local_exps, value, model_mode, wv_product_einsum) + if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode(q_seq_len): + local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) elif model_mode == MODEL_MODE_PREFILL: - local_out = partitioning.with_sharding_constraint( - local_out, (BATCH, KV_LENGTH, HEAD, D_KV) - ) + local_out = partitioning.with_sharding_constraint(local_out, (BATCH, KV_LENGTH, HEAD, D_KV)) if self.reshape_q and q_seq_len == 1: local_max = local_max[:, 0:1, :, :] local_sum = local_sum[:, 0:1, :, :] local_out = local_out[:, 0:1, :, :] - if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode( - q_seq_len - ): - local_max = partitioning.with_sharding_constraint( - local_max, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV) - ) - local_sum = partitioning.with_sharding_constraint( - local_sum, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV) - ) - local_out = partitioning.with_sharding_constraint( - local_out, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV) - ) + if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode(q_seq_len): + local_max = partitioning.with_sharding_constraint(local_max, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) + local_sum = partitioning.with_sharding_constraint(local_sum, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) + local_out = partitioning.with_sharding_constraint(local_out, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) return local_out, local_max, local_sum def is_partition_in_decode(self, seq_len): - return ( - self.config.ici_context_autoregressive_parallelism > 0 and seq_len == 1 - ) + return self.config.ici_context_autoregressive_parallelism > 0 and seq_len == 1 def apply_attention_dot( self, @@ -1961,53 +1787,27 @@ def apply_attention_dot( if self.is_partition_in_decode(q_seq_len): query = partitioning.with_sharding_constraint(query, decode_qkv_sharding) # avoid sharding scale tensor when using kv cache quantization - if ( - self.kv_quant - and isinstance(key, KVTensor) - and isinstance(value, KVTensor) - ): - key.qvalue = partitioning.with_sharding_constraint( - key.qvalue, decode_qkv_sharding - ) - value.qvalue = partitioning.with_sharding_constraint( - value.qvalue, decode_qkv_sharding - ) + if self.kv_quant and isinstance(key, KVTensor) and isinstance(value, KVTensor): + key.qvalue = partitioning.with_sharding_constraint(key.qvalue, decode_qkv_sharding) + value.qvalue = partitioning.with_sharding_constraint(value.qvalue, decode_qkv_sharding) else: key = partitioning.with_sharding_constraint(key, decode_qkv_sharding) - value = partitioning.with_sharding_constraint( - value, decode_qkv_sharding - ) + value = partitioning.with_sharding_constraint(value, decode_qkv_sharding) elif model_mode == MODEL_MODE_PREFILL: query = partitioning.with_sharding_constraint(query, prefill_qkv_sharding) # avoid sharding scale tensor when using kv cache quantization - if ( - self.kv_quant - and isinstance(key, KVTensor) - and isinstance(value, KVTensor) - ): - key.qvalue = partitioning.with_sharding_constraint( - key.qvalue, prefill_qkv_sharding - ) - value.qvalue = partitioning.with_sharding_constraint( - value.qvalue, prefill_qkv_sharding - ) + if self.kv_quant and isinstance(key, KVTensor) and isinstance(value, KVTensor): + key.qvalue = partitioning.with_sharding_constraint(key.qvalue, prefill_qkv_sharding) + value.qvalue = partitioning.with_sharding_constraint(value.qvalue, prefill_qkv_sharding) else: key = partitioning.with_sharding_constraint(key, prefill_qkv_sharding) - value = partitioning.with_sharding_constraint( - value, prefill_qkv_sharding - ) + value = partitioning.with_sharding_constraint(value, prefill_qkv_sharding) - attn_weights = self.qk_product( - query, key, q_seq_len, model_mode, qk_product_einsum - ) + attn_weights = self.qk_product(query, key, q_seq_len, model_mode, qk_product_einsum) if self.is_partition_in_decode(q_seq_len): - attn_weights = partitioning.with_sharding_constraint( - attn_weights, (KV_LENGTH, HEAD, None, None, None) - ) + attn_weights = partitioning.with_sharding_constraint(attn_weights, (KV_LENGTH, HEAD, None, None, None)) elif model_mode == MODEL_MODE_PREFILL: - attn_weights = partitioning.with_sharding_constraint( - attn_weights, (BATCH, HEAD, None, PREFILL_LENGTH, KV_LENGTH) - ) + attn_weights = partitioning.with_sharding_constraint(attn_weights, (BATCH, HEAD, None, PREFILL_LENGTH, KV_LENGTH)) if self.attn_logits_soft_cap: attn_weights = jnp.tanh(attn_weights / self.attn_logits_soft_cap) @@ -2052,13 +1852,9 @@ def apply_attention_dot( attn_weights = apply_mask_to_logits(attn_weights, index_mask) if self.is_partition_in_decode(q_seq_len): - attn_mask = partitioning.with_sharding_constraint( - attn_mask, (KV_LENGTH, HEAD, None, None, None) - ) + attn_mask = partitioning.with_sharding_constraint(attn_mask, (KV_LENGTH, HEAD, None, None, None)) elif model_mode == MODEL_MODE_PREFILL: - attn_mask = partitioning.with_sharding_constraint( - attn_mask, (BATCH, HEAD, None, PREFILL_LENGTH, KV_LENGTH) - ) + attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, PREFILL_LENGTH, KV_LENGTH)) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) @@ -2073,9 +1869,7 @@ def apply_attention_dot( max_logits = max_logits_per_group.reshape(b, n_kv * g) self.sow("intermediates", "max_logits", max_logits) - return self.compute_local_attention( - attn_weights, value, q_seq_len, model_mode, wv_product_einsum, sinks - ) + return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode, wv_product_einsum, sinks) def qk_product( self, @@ -2106,11 +1900,7 @@ def qk_product( b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads - precision_kwargs = ( - {"precision": self.config.matmul_precision} - if einsum is jnp.einsum - else {} - ) + precision_kwargs = {"precision": self.config.matmul_precision} if einsum is jnp.einsum else {} if model_mode == MODEL_MODE_TRAIN or self.compute_axis_order == ( 0, 1, @@ -2123,9 +1913,7 @@ def qk_product( result = einsum("btkgd,bskd->bkgts", query, key, **precision_kwargs) elif self.compute_axis_order == (0, 2, 1, 3): query = jnp.transpose(query, axes=self.compute_axis_order) - key = jax.tree.map( - lambda x: jnp.transpose(x, axes=self.compute_axis_order), key - ) + key = jax.tree.map(lambda x: jnp.transpose(x, axes=self.compute_axis_order), key) query = jnp.reshape(query, (b, n_kv, n // n_kv, t, d)) if self.reshape_q and q_seq_len == 1: query = jnp.broadcast_to(query, (b, n_kv, n // n_kv, 2, d)) @@ -2161,17 +1949,10 @@ def wv_product( n // n_kv: number of group for query, sometimes annotated with g """ - precision_kwargs = ( - {"precision": self.config.matmul_precision} - if einsum is jnp.einsum - else {} - ) + precision_kwargs = {"precision": self.config.matmul_precision} if einsum is jnp.einsum else {} if self.kv_quant: # manually cast to bf16 to avoid the fp32 XLA ops for speedup - if ( - isinstance(value, KVTensor) - and self.kv_quant.dtype == jnp.float8_e4m3fn - ): + if isinstance(value, KVTensor) and self.kv_quant.dtype == jnp.float8_e4m3fn: value.qvalue = value.qvalue.astype(jnp.bfloat16) if model_mode == MODEL_MODE_TRAIN or self.compute_axis_order == ( 0, @@ -2183,9 +1964,7 @@ def wv_product( b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) elif self.compute_axis_order == (0, 2, 1, 3): - value = jax.tree.map( - lambda x: jnp.transpose(x, axes=self.compute_axis_order), value - ) + value = jax.tree.map(lambda x: jnp.transpose(x, axes=self.compute_axis_order), value) out = einsum("bkgts,bksd->bkgtd", attn_weights, value, **precision_kwargs) b, n_kv, g, t, d = out.shape result = jnp.reshape(out, (b, n_kv * g, t, d)) @@ -2193,9 +1972,7 @@ def wv_product( return result def reverse_transepose(self, transposed_array, transpose_axis_order): - return jax.numpy.moveaxis( - transposed_array, (0, 1, 2, 3), transpose_axis_order - ) + return jax.numpy.moveaxis(transposed_array, (0, 1, 2, 3), transpose_axis_order) def normalize_cudnn_attention(self, local_outs, local_stats): """Normalize across two cuDNN attentions @@ -2215,13 +1992,9 @@ def normalize_cudnn_attention(self, local_outs, local_stats): stat1 = local_stats[1].reshape((*local_stats[1].shape, 1)) global_stat = jnp.log(jnp.exp(stat0) + jnp.exp(stat1)) # # transpose stat to have shape [b, t, n, 1] for elemenwise multiplication - attn_out = local_outs[0].astype(jnp.float32) * jnp.exp( - stat0 - global_stat - ).transpose((0, 2, 1, 3)) + local_outs[1].astype(jnp.float32) * jnp.exp( - stat1 - global_stat - ).transpose( - (0, 2, 1, 3) - ) + attn_out = local_outs[0].astype(jnp.float32) * jnp.exp(stat0 - global_stat).transpose((0, 2, 1, 3)) + local_outs[ + 1 + ].astype(jnp.float32) * jnp.exp(stat1 - global_stat).transpose((0, 2, 1, 3)) return attn_out.astype(local_stats[0].dtype) def normalize_attention(self, local_outs, local_maxes, local_sums): @@ -2240,10 +2013,9 @@ def normalize_attention(self, local_outs, local_maxes, local_sums): """ # Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py global_max = functools.reduce(jnp.maximum, local_maxes) - global_sum = sum(( - jnp.exp(local_max - global_max) * local_sum - for (local_sum, local_max) in zip(local_sums, local_maxes) - )) + global_sum = sum( + (jnp.exp(local_max - global_max) * local_sum for (local_sum, local_max) in zip(local_sums, local_maxes)) + ) attn_out = 0 for local_max, local_out in zip(local_maxes, local_outs): @@ -2275,7 +2047,19 @@ def __call__( assert prefill_kv_cache key, value, decoder_segment_ids = prefill_kv_cache - prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( + index_mask_prefill = None + index_mask_ar = None + if index_mask is not None: + prefill_len = key.shape[1] + index_mask_prefill = index_mask[:, :, :prefill_len] + if ar_kv_cache is not None: + index_mask_ar = index_mask[:, :, prefill_len:] + + ( + prefill_unnormalized_output, + prefill_exponentials_max, + prefill_exponentials_sum, + ) = self.apply_attention( query=query, key=key, value=value, @@ -2286,8 +2070,8 @@ def __call__( previous_chunk=previous_chunk, bidirectional_mask=bidirectional_mask, sinks=sinks, - index_mask=index_mask, record_max_logits=record_max_logits, + index_mask=index_mask_prefill, qk_product_einsum=self.AqtEinsum_0, wv_product_einsum=self.AqtEinsum_1, ) @@ -2299,7 +2083,6 @@ def __call__( return prefill_unnormalized_output key, value, decoder_segment_ids, lengths = ar_kv_cache - ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( query=query, key=key, @@ -2309,6 +2092,7 @@ def __call__( model_mode=model_mode, use_ragged_attention=self.use_ragged_attention, bidirectional_mask=bidirectional_mask, + index_mask=index_mask_ar, qk_product_einsum=self.AqtEinsum_2, wv_product_einsum=self.AqtEinsum_3, ) @@ -2320,18 +2104,13 @@ def __call__( ] exponentials_maxes = [prefill_exponentials_max, ar_exponentials_max] exponentials_sums = [prefill_exponentials_sum, ar_exponentials_sum] - if ( - prefill_exponentials_max is not None - and prefill_exponentials_sum is None - ): + if prefill_exponentials_max is not None and prefill_exponentials_sum is None: prefill_stat = prefill_exponentials_max ar_stat = ar_exponentials_max stats = [prefill_stat, ar_stat] return self.normalize_cudnn_attention(unnormalized_outputs, stats) else: - return self.normalize_attention( - unnormalized_outputs, exponentials_maxes, exponentials_sums - ) + return self.normalize_attention(unnormalized_outputs, exponentials_maxes, exponentials_sums) else: return prefill_unnormalized_output / prefill_exponentials_sum @@ -2369,9 +2148,7 @@ def causal_mask_function(q_ids, kv_ids): arr = np.arange(shape[0]) # we reorder the mask to be load balanced following the same approach as # used to reorder the input tokens - out = max_utils.reorder_mask_load_balancing( - arr[None, :, None, None], cp_size, 1 - ) + out = max_utils.reorder_mask_load_balancing(arr[None, :, None, None], cp_size, 1) q_sequence = out[0, :, 0, 0] mask_function = causal_mask_function @@ -2387,16 +2164,14 @@ def __eq__(self, other: object): if not isinstance(other, type(self)): return NotImplemented - return ( - self.shape == other.shape - and self.offset == other.offset - and np.array_equal(self.q_sequence, other.q_sequence) - ) + return self.shape == other.shape and self.offset == other.offset and np.array_equal(self.q_sequence, other.q_sequence) def __hash__(self): - return hash(( - type(self), - self.shape, - self.offset, - self.q_sequence.tobytes() if self.q_sequence is not None else None, - )) + return hash( + ( + type(self), + self.shape, + self.offset, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + ) + ) diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 608f1d4d14..f9cc66e812 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1192,7 +1192,7 @@ class MLATest(attention_test_util.MLATestBase): {"testcase_name": "Default_Autoregression", "rope_type": "default"}, ) @pytest.mark.tpu_only - def test_autoregression(self, rope_type): + def test_mla_autoregression(self, rope_type): cfg, mla = self.init_mla(self.config_arguments, rope_type) prefill_length = cfg.max_prefill_predict_length decode_total_length = cfg.max_target_length @@ -1237,8 +1237,71 @@ def test_autoregression(self, rope_type): mla_full_this_idx = mla_full[:, idx : idx + 1, :] self.assertEqual(mla_full_this_idx.shape, mla_idx.shape) - # TODO (b/394626702) uncomment last check when decode and kv_cache are implemented for MLA - # self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) + self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=2e-02, atol=2e-02, equal_nan=False)) + + @parameterized.named_parameters( + {"testcase_name": "prefill_less_than_topk", "prefill_len": 4, "target_len": 12}, + {"testcase_name": "prefill_greater_than_topk", "prefill_len": 12, "target_len": 16}, + ) + @pytest.mark.tpu_only + def test_indexer_autoregression(self, prefill_len, target_len): + config_arguments = self.config_arguments.copy() + config_arguments.update( + { + "use_sparse_indexer": True, + "index_n_heads": 4, + "index_head_dim": 64, + "index_topk": 8, + "attention": "dot_product", + "max_target_length": target_len, + "max_prefill_predict_length": prefill_len, + "per_device_batch_size": 1, + } + ) + cfg, mla = self.init_mla(config_arguments, "yarn") + prefill_length = cfg.max_prefill_predict_length + decode_total_length = cfg.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(cfg, cfg.dtype) + mla_full, _ = mla( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + mla_prefill, _ = mla( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + + self.assertTrue( + jax.numpy.allclose(mla_prefill, mla_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) + ) + + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + mla_idx, _ = mla( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + mla_full_this_idx = mla_full[:, idx : idx + 1, :] + self.assertEqual(mla_full_this_idx.shape, mla_idx.shape) + self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=2e-02, atol=2e-02, equal_nan=False)) def test_projection_initialization(self): """Tests that MLA and Attention layers initialize the correct projection weights.""" diff --git a/tests/utils/attention_test_util.py b/tests/utils/attention_test_util.py index d47086863a..5542be9bfe 100644 --- a/tests/utils/attention_test_util.py +++ b/tests/utils/attention_test_util.py @@ -30,6 +30,7 @@ from maxtext.utils.sharding import maybe_shard_with_name from tests.utils.test_helpers import get_test_config_path + class MLATestBase(parameterized.TestCase): """Test base for MLATest.""" @@ -46,6 +47,8 @@ class MLATestBase(parameterized.TestCase): "qk_nope_head_dim": 128, "qk_rope_head_dim": 64, "v_head_dim": 192, + "dtype": "float32", + "mla_naive_kvcache": False, # TODO: Test both naive/non-naive modes once b/485997160 is resolved. } def setUp(self):