Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@
model_type="gemma3-12b",
tuning_params={
"per_device_batch_size": 1,
"num_vocab_tiling": 16,
"num_batch_seq_tiling": 16,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "device",
Expand All @@ -1739,7 +1739,9 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"tokenizer_path": os.path.join(
"assets", "tokenizers", "tokenizer.gemma3"
),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand All @@ -1760,7 +1762,7 @@
model_type="gemma3-12b",
tuning_params={
"per_device_batch_size": 1,
"num_vocab_tiling": 16,
"num_batch_seq_tiling": 16,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"remat_policy": "custom",
Expand All @@ -1779,7 +1781,9 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"tokenizer_path": os.path.join(
"assets", "tokenizers", "tokenizer.gemma3"
),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand All @@ -1800,7 +1804,7 @@
model_type="gemma3-12b",
tuning_params={
"per_device_batch_size": 1,
"num_vocab_tiling": 16,
"num_batch_seq_tiling": 16,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"remat_policy": "custom",
Expand All @@ -1819,7 +1823,9 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"tokenizer_path": os.path.join(
"assets", "tokenizers", "tokenizer.gemma3"
),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/maxtext_v5p_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"remat_policy": "custom",
"context": "offload",
"mlpwo": "offload",
"num_vocab_tiling": 4,
"num_batch_seq_tiling": 4,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/core_concepts/tiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ The final output unembedding layer of a language model maps hidden states to log

Vocabulary tiling avoids materializing the full logits tensor. Instead, it tiles the input hidden states and computes the logits, loss, and gradients one tile at a time. Unlike GA, which is applied at the start of the model, vocabulary tiling is applied only to the input of the final layer.

In MaxText, the `num_vocab_tiling` configuration controls the number of tiles. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance.
In MaxText, the `num_batch_seq_tiling` configuration controls the number of tiles in batch and sequence axis. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. One may also tile vocabulary dimension using `num_vocab_tiling` configuration. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance.

![Illustration of vocabulary tiling.](../../_static/vocab_tiling.png)
![Illustration of batch_sequence tiling.](../../_static/vocab_tiling.png)
*Figure 2: Vocabulary tiling processes hidden states in tiles to avoid generating the full logits tensor.*

### Other Tiling Methods
Expand Down
3 changes: 2 additions & 1 deletion docs/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ MaxText is [available in PyPI](https://pypi.org/project/maxtext/) and can be ins
- [Optimized models tiering documentation](https://maxtext.readthedocs.io/en/latest/reference/models/tiering.html) has been refreshed.
- Added Versioning. Check out our [first set of release notes](https://maxtext.readthedocs.io/en/latest/release_notes.html)!
- Post-Training (SFT, RL) via [Tunix](https://github.com/google/tunix) is now available.
- Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage.
- Batch-Sequence tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_batch_seq_tiling` to unlock more efficient memory usage.
- Vocabulary tiling Additionally vocabulary dimension can also be tiled by adjusting `num_vocab_tiling`.
- The GPT-OSS family of models (20B, 120B) is now supported.

# Deprecations
Expand Down
8 changes: 5 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,12 @@ num_slices: -1

# Vocab Tiling Configs
# Enables a memory-saving optimization by computing the cross-entropy loss in chunks.
# The logits are tiled into `num_vocab_tiling` parts along the batch-sequence axis,
# reducing peak memory usage. This is highly recommended for models with large
# vocabularies (e.g., Gemma). Set to a value greater than 1 to enable.
# The logits are tiled into `num_vocab_tiling` parts along the vocabulary axis,
# and `num_batch_seq_tiling` parts along the batch-sequence axis, reducing peak memory usage.
# This is highly recommended for models with large vocabularies (e.g., Gemma).
# Set to a value greater than 1 to enable.
num_vocab_tiling: 1
num_batch_seq_tiling: 1

# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
Expand Down
41 changes: 37 additions & 4 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,34 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -
)


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
def validate_batch_seq_tiling(
num_batch_seq_tiling: int,
per_device_batch_size: int,
max_target_length: int,
enable_nnx: bool,
):
if (per_device_batch_size * max_target_length) % num_batch_seq_tiling != 0:
raise ValueError(
"Per device batch size times sequence length should be divisible by the"
" number of batch seq tiles."
)
if (
num_batch_seq_tiling > 1 and enable_nnx
): # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError(
"We currently don't support batch seq tiling on NNX module."
)


def validate_vocab_tiling(
num_vocab_tiling: int,
vocab_size: int,
enable_nnx: bool,
):
if vocab_size % num_vocab_tiling != 0:
raise ValueError(
"vocab_size should be divisible by the number of vocab tiles."
)
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")

Expand Down Expand Up @@ -240,8 +265,16 @@ def validate_keys(keys):
validate_model_call_mode(keys["model_call_mode"])
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])
validate_rope_type(keys["rope_type"])
validate_batch_seq_tiling(
keys["num_batch_seq_tiling"],
keys["per_device_batch_size"],
keys["max_target_length"],
keys["enable_nnx"],
)
validate_vocab_tiling(
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
keys["num_vocab_tiling"],
keys["vocab_size"],
keys["enable_nnx"],
)
if keys["enable_rampup_batch_size"]:
validate_rampup_batch_size(
Expand Down
29 changes: 25 additions & 4 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,17 @@ class Tokenizer(BaseModel):
)
num_vocab_tiling: int = Field(
1,
description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.",
description=(
"Enables memory-saving optimization by tiling cross-entropy loss"
" computation along the vocabulary axis. >1 to enable."
),
)
num_batch_seq_tiling: int = Field(
1,
description=(
"Enables memory-saving optimization by tiling cross-entropy loss"
" computation along the batch-sequence axis. >1 to enable."
),
)


Expand Down Expand Up @@ -2503,12 +2513,23 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
)
if self.quantization:
raise ValueError("Quantization is not supported with 'explicit' sharding.")
if self.vocab_size % self.num_vocab_tiling != 0:
raise ValueError(
"vocab_size should be divisible by the number of vocab tiles."
)
if (
self.per_device_batch_size > 0
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
and (self.per_device_batch_size * self.max_target_length)
% self.num_batch_seq_tiling
!= 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError(
"Per device batch size times sequence length should be divisible by"
" the number of batch tiles."
)
if (
self.num_vocab_tiling > 1 or self.num_batch_seq_tiling > 1
) and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
Expand Down
23 changes: 18 additions & 5 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,8 @@ def _apply_embedding(
return y

@nn.compact
def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode):
"""Applies final normalization and projects hidden states to logits."""

def normalize_hidden_states(self, y, deterministic, model_mode):
"""Applies final normalization and dropout to hidden states."""
cfg = self.config
if cfg.shard_mode == ShardMode.EXPLICIT:
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed"))
Expand All @@ -696,6 +695,20 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
parameter_memory_host_offload=cfg.parameter_memory_host_offload,
)(y, out_sharding=norm_out_sharding)
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
return y

@nn.compact
def apply_output_head(
self,
shared_embedding: nn.Module | nnx.Module,
y,
deterministic,
model_mode,
):
"""Applies final normalization and projects hidden states to logits."""

cfg = self.config
y = self.normalize_hidden_states(y, deterministic, model_mode)

if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))
Expand Down Expand Up @@ -1085,9 +1098,9 @@ def __call__(
# for efficiency, as the main model is frozen and the LM loss is not needed.
elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN:
logits = None
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
elif cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow("intermediates", "hidden_states", hidden_state)

Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,9 +1057,9 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
if cfg.attention == "vllm_rpa":
logits = None

# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
if cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow(nnx.Intermediate, "hidden_states", hidden_state)

Expand Down
12 changes: 10 additions & 2 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def logits_from_hidden_states(self, hidden_states, deterministic, model_mode):
)
return logits

def normalize_hidden_states(self, hidden_states, deterministic, model_mode):
"""Normalize hidden states (wrapping decoder.normalize_hidden_states)."""
return self.decoder.normalize_hidden_states(
y=hidden_states,
deterministic=deterministic,
model_mode=model_mode,
)

def __call__(
self,
decoder_input_tokens: jnp.ndarray,
Expand Down Expand Up @@ -531,8 +539,8 @@ def __call__(
mutable=mutable_collections,
) # pytype: disable=wrong-keyword-args

# Materialize hidden state when vocab tiling is enabled
if self.config.num_vocab_tiling > 1:
# Materialize hidden state when batch-sequence tiling is enabled.
if self.config.num_batch_seq_tiling > 1:
self.hidden_states = hidden_state

# If we are initializing the model AND MTP is enabled, we must create
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
# The main model parameters are frozen and only the indexer is trained via KL divergence.
total_loss = 0.0
total_z_loss = 0.0
elif config.num_vocab_tiling > 1:
elif config.num_batch_seq_tiling > 1:
hidden_state_key = ("intermediates", "decoder", "hidden_states")
hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0]
total_loss, total_z_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train)
Expand Down
Loading
Loading