Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1067,18 +1067,6 @@ context_parallel_load_balance: True
context_parallel_strategy: "all_gather" # "all_gather" or "ring"
context_parallel_reorder_strategy: "auto" # "auto", "dual_chunk_swap", or "striped"

### Paged Attention ###
# These settings take effect only when `attention=paged`.
# They should be adjusted based on the available HBM and model config.
# Note: one page group corresponds to one request/slot
pagedattn_num_pages: 64 # total number of pages to allocate
pagedattn_tokens_per_page: 32 # number of tokens each page can hold
pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels
pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
# TPUs, the head_dim is padded to the nearest multiple of 128.
pagedattn_head_dim_alignment: 128


# Chunked Prefill Parameters
prefill_chunk_size: 256
Expand Down
8 changes: 1 addition & 7 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def validate_attention_kernel(s: str) -> None:
"flash",
"cudnn_flash_te",
"cudnn_flash_jax",
"paged",
"vllm_rpa",
)
if s not in valid_attention_kernels: # currently supported attention
Expand All @@ -119,7 +118,7 @@ def validate_attention_type(s: str) -> None:


def validate_moba_attention(moba, attention) -> None:
if moba and attention in ("autoselected", "flash", "cudnn_flash_te", "cudnn_flash_jax", "paged"):
if moba and attention in ("autoselected", "flash", "cudnn_flash_te", "cudnn_flash_jax"):
raise ValueError("MoBA is only supported dot_product attention")


Expand Down Expand Up @@ -816,11 +815,6 @@ def user_init(raw_keys):
)
raw_keys["shardy"] = False

if raw_keys["pagedattn_max_pages_per_group"] <= 0:
raw_keys["pagedattn_max_pages_per_group"] = (
raw_keys["max_target_length"] + raw_keys["pagedattn_tokens_per_page"] - 1
) // raw_keys["pagedattn_tokens_per_page"]

raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys)
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
raw_keys["context_parallel_size"] = get_context_parallel_size(raw_keys)
Expand Down
14 changes: 0 additions & 14 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,18 +698,6 @@ class SplashAttention(BaseModel):
use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.")


class PagedAttention(BaseModel):
"""Tunable parameters for Paged Attention kernels."""

pagedattn_num_pages: int = Field(64, description="Total number of pages to allocate for paged attention.")
pagedattn_tokens_per_page: int = Field(32, description="Number of tokens each page can hold.")
pagedattn_pages_per_compute_block: int = Field(4, description="Number of pages processed together in pallas kernels.")
pagedattn_max_pages_per_group: int = Field(-1, description="Max pages per request; -1 defaults to max_target_length.")
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
# TPUs, the head_dim is padded to the nearest multiple of 128.
pagedattn_head_dim_alignment: int = Field(128, description="Alignment of head_dim to the nearest multiple.")


class MoEGeneral(BaseModel):
"""General configuration for Mixture of Experts (MoE) layers."""

Expand Down Expand Up @@ -2260,7 +2248,6 @@ class MaxTextConfig(
AttentionIndexer,
Llama4Attention,
SplashAttention,
PagedAttention,
# Mixture of Experts
MoEGeneral,
MoEKernels,
Expand Down Expand Up @@ -3208,5 +3195,4 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
f"For qwen3_custom_moe, moe_expert_input_dim ({self.moe_expert_input_dim}) "
f"must be equal to attention_output_dim ({self.attention_output_dim})"
)

return self
100 changes: 11 additions & 89 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from maxtext.models import models
from maxtext.layers import quantizations
from maxtext.inference import inference_utils
from maxtext.inference.page_manager import PageManager, PageState
from maxtext.inference.page_manager import PageState
from maxtext.multimodal import processor as mm_processor
from maxtext.utils import lora_utils
from maxtext.utils import max_utils
Expand Down Expand Up @@ -132,9 +132,6 @@ def __init__(self, config: Any, devices: Any | None = None):
# Initialize page manager and page state
self.page_manager = None
self.page_state = None
if self.config.attention == "paged":
self.page_manager = PageManager(self.config)
self.page_state = self.page_manager.get_initial_page_state()

def print_stats(self, label: str):
max_utils.print_mem_stats(label)
Expand Down Expand Up @@ -250,7 +247,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
)

self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations(
self.model, self.config, rng2, self._mesh, self.page_state
self.model, self.config, rng2, self._mesh, None
)
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x),
Expand All @@ -265,9 +262,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
)
self.prefill_kv_cache_shardings = self.prefill_kv_cache_shardings["decoder"]["layers_0"]

self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(
self.model, self.config, rng2, self._mesh, self.page_state
)
self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(self.model, self.config, rng2, self._mesh, None)
self.kv_cache_shardings = jax.tree_util.tree_map(
lambda x: jax.sharding.NamedSharding(self._mesh, x),
self.kv_cache_annotations,
Expand Down Expand Up @@ -612,13 +607,6 @@ def prefill(
temperature: float | None = None,
): # returns (new_prefix, result_tokens)
"""Public API for prefill that updates page state outside JIT."""
# Update page state before JIT call
if self.config.attention == "paged" and self.page_manager is not None and self.page_state is not None:
self.page_state = self.page_manager.update_prefill_pages( # pytype: disable=attribute-error
page_state=self.page_state,
page_group_id=slot,
true_length=true_length,
)

# Sample rng before JIT call
if rng is None:
Expand All @@ -639,8 +627,6 @@ def prefill(
audio_masks=audio_masks,
sampler=sampler,
true_length=true_length,
page_state=self.page_state, # Pass current page state
slot=slot,
rng=rng,
return_prompt_logp=return_prompt_logp,
algorithm=algorithm,
Expand Down Expand Up @@ -955,8 +941,6 @@ def generate(
"""Public API for generate that updates page state outside JIT."""

# Update page state before JIT call
if self.page_manager is not None and self.page_state is not None:
self.page_state = self.page_manager.update_decode_pages(self.page_state)

# Sample rng before JIT call
if rng is None:
Expand All @@ -969,7 +953,6 @@ def generate(
params=params,
decode_state=decode_state,
sampler=sampler,
page_state=self.page_state,
rng=rng,
algorithm=algorithm,
topk=topk,
Expand Down Expand Up @@ -1213,7 +1196,6 @@ def _insert_jit(
decode_state: DecodeState,
slot: int,
request_id: uuid.UUID | None = None, # pylint: disable=unused-argument
page_state_in: PageState | None = None,
) -> DecodeState:
"""Insert a single computed prefill cache into KV cache."""
unboxed_prefix = max_utils.unbox_logicallypartioned(prefix)
Expand Down Expand Up @@ -1269,45 +1251,12 @@ def copy(path, partial_cache, full_cache, annotations):
else:
raise ValueError(f"We don't have a strategy for inserting {path_key}")

if self.config.attention == "paged" and self.page_state is not None:

def _copy_paged(path, prefix_cache, decode_state_cache):
path_key = path[-1].key
if path_key in ["key_pages", "value_pages"]:
page_map_for_slot = page_state_in.page_map[slot] # pytype: disable=attribute-error
num_pages_to_copy = page_state_in.num_pages_used[slot] # pytype: disable=attribute-error

def _update_pages(prefix_page_idx, state):
decode_state_pages, prefix_pages, current_page_map = state
prefix_page = jax.lax.dynamic_index_in_dim(prefix_pages, prefix_page_idx, axis=1)
dest_page_idx = current_page_map[prefix_page_idx]
decode_state_pages = jax.lax.dynamic_update_slice_in_dim(
decode_state_pages, prefix_page, dest_page_idx, axis=1
)
return decode_state_pages, prefix_pages, current_page_map

decode_state_cache, _, _ = jax.lax.fori_loop(
0,
num_pages_to_copy,
_update_pages,
(decode_state_cache, prefix_cache, page_map_for_slot),
)
return decode_state_cache
else:
raise ValueError(f"We don't have a strategy for inserting {path_key} for paged attention.")

inserted_cache = jax.tree_util.tree_map_with_path(
_copy_paged,
unboxed_prefix["cache"],
decode_state["cache"],
)
else:
inserted_cache = jax.tree_util.tree_map_with_path(
copy,
unboxed_prefix["cache"],
decode_state["cache"],
self.kv_cache_annotations_named,
)
inserted_cache = jax.tree_util.tree_map_with_path(
copy,
unboxed_prefix["cache"],
decode_state["cache"],
self.kv_cache_annotations_named,
)

inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0)
inserted_next_pos = jax.lax.dynamic_update_index_in_dim(
Expand Down Expand Up @@ -1349,23 +1298,11 @@ def insert(
) -> DecodeState:
"""Non-JIT wrapper for inserting prefill cache."""

current_page_state = None
if self.config.attention == "paged" and self.page_manager is not None:
if self.page_state is None:
self.page_state = self.page_manager.get_initial_page_state()
current_page_state = self.page_state

updated_decode_state = self._insert_jit(
prefix=prefix,
decode_state=decode_state,
slot=slot,
page_state_in=current_page_state,
)

# Update the PageState after the JIT call
if self.config.attention == "paged" and self.page_manager is not None and self.page_state is not None:
new_has_active_page = self.page_state.has_active_page.at[slot].set(True)
self.page_state = self.page_state.replace(has_active_page=new_has_active_page)
return updated_decode_state

@functools.partial(
Expand Down Expand Up @@ -1513,16 +1450,6 @@ def copy(path, partial_cache, full_cache, annotations):
"token_logp": inserted_token_logp,
}

def release_pages(self, slot: int):
"""Releases pages associated with a specific slot (page group) via the PageManager."""
if self.config.attention != "paged" or self.page_manager is None or self.page_state is None:
print(f"Warning: release_pages called for slot {slot} but paged attention is not configured or state is missing.")
return
new_page_state = self.page_manager.release_pages(
page_state=self.page_state, page_group_id=slot
) # pytype: disable=attribute-error
self.page_state = new_page_state

def get_prefix_destination_sharding(self) -> Any:
return {
"logits": self.replicated_sharding,
Expand Down Expand Up @@ -1592,12 +1519,9 @@ def init_decode_state(
"""Initialises any state which a generation step transforms."""
if rng is None:
rng = jax.random.PRNGKey(0)
page_state = None
if self.config.attention == "paged" and self.page_manager is not None:
page_state = self.page_manager.get_initial_page_state() # pytype: disable=attribute-error

# pylint: disable=unused-argument
def init(abstract_params, page_state):
def init(abstract_params):
x = jnp.ones(
(int(self.config.per_device_batch_size * self.mesh.size), 1),
dtype=jnp.int32,
Expand All @@ -1622,8 +1546,6 @@ def init(abstract_params, page_state):
model_mode=MODEL_MODE_AUTOREGRESSIVE,
rngs={"params": rng},
mutable=["cache"],
page_state=page_state,
slot=0,
)

next_pos = jnp.zeros(
Expand Down Expand Up @@ -1658,7 +1580,7 @@ def init(abstract_params, page_state):
}

with nn_partitioning.axis_rules(self.config.logical_axis_rules):
abstract_outputs = jax.eval_shape(init, self.abstract_params, page_state)
abstract_outputs = jax.eval_shape(init, self.abstract_params)
logical_annotations = nn.get_partition_spec(abstract_outputs)

with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
Expand Down
Loading
Loading