diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 4950c8f57b..35e9fbc2db 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -1721,7 +1721,7 @@ model_type="gemma3-12b", tuning_params={ "per_device_batch_size": 1, - "num_vocab_tiling": 16, + "num_batch_seq_tiling": 16, "ici_fsdp_parallelism": -1, "remat_policy": "custom", "decoder_layer_input": "device", @@ -1739,7 +1739,9 @@ "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 2, - "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"), + "tokenizer_path": os.path.join( + "assets", "tokenizers", "tokenizer.gemma3" + ), "sa_block_q": 1024, "sa_block_kv": 1024, "sa_block_kv_compute": 1024, @@ -1760,7 +1762,7 @@ model_type="gemma3-12b", tuning_params={ "per_device_batch_size": 1, - "num_vocab_tiling": 16, + "num_batch_seq_tiling": 16, "ici_fsdp_parallelism": 1, "ici_fsdp_transpose_parallelism": -1, "remat_policy": "custom", @@ -1779,7 +1781,9 @@ "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 2, - "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"), + "tokenizer_path": os.path.join( + "assets", "tokenizers", "tokenizer.gemma3" + ), "sa_block_q": 1024, "sa_block_kv": 1024, "sa_block_kv_compute": 1024, @@ -1800,7 +1804,7 @@ model_type="gemma3-12b", tuning_params={ "per_device_batch_size": 1, - "num_vocab_tiling": 16, + "num_batch_seq_tiling": 16, "ici_fsdp_parallelism": 1, "ici_fsdp_transpose_parallelism": -1, "remat_policy": "custom", @@ -1819,7 +1823,9 @@ "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 2, - "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"), + "tokenizer_path": os.path.join( + "assets", "tokenizers", "tokenizer.gemma3" + ), "sa_block_q": 1024, "sa_block_kv": 1024, "sa_block_kv_compute": 1024, diff --git a/benchmarks/maxtext_v5p_model_configs.py b/benchmarks/maxtext_v5p_model_configs.py index f228b0f7fc..3df783022b 100644 --- a/benchmarks/maxtext_v5p_model_configs.py +++ b/benchmarks/maxtext_v5p_model_configs.py @@ -38,7 +38,7 @@ "remat_policy": "custom", "context": "offload", "mlpwo": "offload", - "num_vocab_tiling": 4, + "num_batch_seq_tiling": 4, "sa_block_q": 2048, "sa_block_kv": 2048, "sa_block_kv_compute": 2048, diff --git a/docs/reference/core_concepts/tiling.md b/docs/reference/core_concepts/tiling.md index 5d203e31aa..ef1efbb298 100644 --- a/docs/reference/core_concepts/tiling.md +++ b/docs/reference/core_concepts/tiling.md @@ -67,9 +67,9 @@ The final output unembedding layer of a language model maps hidden states to log Vocabulary tiling avoids materializing the full logits tensor. Instead, it tiles the input hidden states and computes the logits, loss, and gradients one tile at a time. Unlike GA, which is applied at the start of the model, vocabulary tiling is applied only to the input of the final layer. -In MaxText, the `num_vocab_tiling` configuration controls the number of tiles. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance. +In MaxText, the `num_batch_seq_tiling` configuration controls the number of tiles in batch and sequence axis. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. One may also tile vocabulary dimension using `num_vocab_tiling` configuration. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance. -![Illustration of vocabulary tiling.](../../_static/vocab_tiling.png) +![Illustration of batch_sequence tiling.](../../_static/vocab_tiling.png) *Figure 2: Vocabulary tiling processes hidden states in tiles to avoid generating the full logits tensor.* ### Other Tiling Methods diff --git a/docs/release_notes.md b/docs/release_notes.md index 7192da4d77..3835f67033 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -47,7 +47,8 @@ MaxText is [available in PyPI](https://pypi.org/project/maxtext/) and can be ins - [Optimized models tiering documentation](https://maxtext.readthedocs.io/en/latest/reference/models/tiering.html) has been refreshed. - Added Versioning. Check out our [first set of release notes](https://maxtext.readthedocs.io/en/latest/release_notes.html)! - Post-Training (SFT, RL) via [Tunix](https://github.com/google/tunix) is now available. -- Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage. +- Batch-Sequence tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_batch_seq_tiling` to unlock more efficient memory usage. +- Vocabulary tiling Additionally vocabulary dimension can also be tiled by adjusting `num_vocab_tiling`. - The GPT-OSS family of models (20B, 120B) is now supported. # Deprecations diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 3ff1c33153..83adff99f9 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -586,10 +586,12 @@ num_slices: -1 # Vocab Tiling Configs # Enables a memory-saving optimization by computing the cross-entropy loss in chunks. -# The logits are tiled into `num_vocab_tiling` parts along the batch-sequence axis, -# reducing peak memory usage. This is highly recommended for models with large -# vocabularies (e.g., Gemma). Set to a value greater than 1 to enable. +# The logits are tiled into `num_vocab_tiling` parts along the vocabulary axis, +# and `num_batch_seq_tiling` parts along the batch-sequence axis, reducing peak memory usage. +# This is highly recommended for models with large vocabularies (e.g., Gemma). +# Set to a value greater than 1 to enable. num_vocab_tiling: 1 +num_batch_seq_tiling: 1 # Tokenizer vocab_size: 32_000 # powers of 2 for sharding diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 888a23b199..ffc401c2df 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -194,9 +194,34 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - ) -def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): - if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: - raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") +def validate_batch_seq_tiling( + num_batch_seq_tiling: int, + per_device_batch_size: int, + max_target_length: int, + enable_nnx: bool, +): + if (per_device_batch_size * max_target_length) % num_batch_seq_tiling != 0: + raise ValueError( + "Per device batch size times sequence length should be divisible by the" + " number of batch seq tiles." + ) + if ( + num_batch_seq_tiling > 1 and enable_nnx + ): # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration + raise ValueError( + "We currently don't support batch seq tiling on NNX module." + ) + + +def validate_vocab_tiling( + num_vocab_tiling: int, + vocab_size: int, + enable_nnx: bool, +): + if vocab_size % num_vocab_tiling != 0: + raise ValueError( + "vocab_size should be divisible by the number of vocab tiles." + ) if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration raise ValueError("We currently don't support vocab tiling on NNX module.") @@ -240,8 +265,16 @@ def validate_keys(keys): validate_model_call_mode(keys["model_call_mode"]) validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"]) validate_rope_type(keys["rope_type"]) + validate_batch_seq_tiling( + keys["num_batch_seq_tiling"], + keys["per_device_batch_size"], + keys["max_target_length"], + keys["enable_nnx"], + ) validate_vocab_tiling( - keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"] + keys["num_vocab_tiling"], + keys["vocab_size"], + keys["enable_nnx"], ) if keys["enable_rampup_batch_size"]: validate_rampup_batch_size( diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 5c97ac2c1e..59c3199e02 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -983,7 +983,17 @@ class Tokenizer(BaseModel): ) num_vocab_tiling: int = Field( 1, - description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.", + description=( + "Enables memory-saving optimization by tiling cross-entropy loss" + " computation along the vocabulary axis. >1 to enable." + ), + ) + num_batch_seq_tiling: int = Field( + 1, + description=( + "Enables memory-saving optimization by tiling cross-entropy loss" + " computation along the batch-sequence axis. >1 to enable." + ), ) @@ -2503,12 +2513,23 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.quantization: raise ValueError("Quantization is not supported with 'explicit' sharding.") + if self.vocab_size % self.num_vocab_tiling != 0: + raise ValueError( + "vocab_size should be divisible by the number of vocab tiles." + ) if ( self.per_device_batch_size > 0 - and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 + and (self.per_device_batch_size * self.max_target_length) + % self.num_batch_seq_tiling + != 0 ): - raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: + raise ValueError( + "Per device batch size times sequence length should be divisible by" + " the number of batch tiles." + ) + if ( + self.num_vocab_tiling > 1 or self.num_batch_seq_tiling > 1 + ) and self.enable_nnx: raise ValueError("We currently don't support vocab tiling on NNX module.") if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 0ab392ecc2..0b30d64217 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -678,9 +678,8 @@ def _apply_embedding( return y @nn.compact - def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode): - """Applies final normalization and projects hidden states to logits.""" - + def normalize_hidden_states(self, y, deterministic, model_mode): + """Applies final normalization and dropout to hidden states.""" cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) @@ -696,6 +695,20 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi parameter_memory_host_offload=cfg.parameter_memory_host_offload, )(y, out_sharding=norm_out_sharding) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + return y + + @nn.compact + def apply_output_head( + self, + shared_embedding: nn.Module | nnx.Module, + y, + deterministic, + model_mode, + ): + """Applies final normalization and projects hidden states to logits.""" + + cfg = self.config + y = self.normalize_hidden_states(y, deterministic, model_mode) if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) @@ -1085,9 +1098,9 @@ def __call__( # for efficiency, as the main model is frozen and the LM loss is not needed. elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN: logits = None - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory + # When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow("intermediates", "hidden_states", hidden_state) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c96ec08c8d..98658e83d0 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -1057,9 +1057,9 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): if cfg.attention == "vllm_rpa": logits = None - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory + # When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + if cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 0d1fcab700..1be303e737 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -118,6 +118,14 @@ def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): ) return logits + def normalize_hidden_states(self, hidden_states, deterministic, model_mode): + """Normalize hidden states (wrapping decoder.normalize_hidden_states).""" + return self.decoder.normalize_hidden_states( + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def __call__( self, decoder_input_tokens: jnp.ndarray, @@ -531,8 +539,8 @@ def __call__( mutable=mutable_collections, ) # pytype: disable=wrong-keyword-args - # Materialize hidden state when vocab tiling is enabled - if self.config.num_vocab_tiling > 1: + # Materialize hidden state when batch-sequence tiling is enabled. + if self.config.num_batch_seq_tiling > 1: self.hidden_states = hidden_state # If we are initializing the model AND MTP is enabled, we must create diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index fc973990ec..2abac19d52 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -141,7 +141,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): # The main model parameters are frozen and only the indexer is trained via KL divergence. total_loss = 0.0 total_z_loss = 0.0 - elif config.num_vocab_tiling > 1: + elif config.num_batch_seq_tiling > 1: hidden_state_key = ("intermediates", "decoder", "hidden_states") hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] total_loss, total_z_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index ec68e9bc78..fe03136de9 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -108,22 +108,41 @@ def _reshape(inputs, out_shape, out_sharding): # Customized forward and backward maps for the embedding tiling @jax.custom_vjp - def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentation): - """ - Calculates the total cross-entropy loss using vocab tiling. - """ - (total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation) + def chunked_cross_entropy_loss( + gathered_params, hidden_states, labels, segmentation + ): + """Calculates the total cross-entropy loss using vocab tiling.""" + # if both batch-sequence tiling and vocab tiling are enabled, call + # _b_v_chunked_cross_entropy_loss_fwd + if config.num_vocab_tiling > 1: + (total_loss, total_z_loss), _ = _b_v_chunked_cross_entropy_loss_fwd( + gathered_params, hidden_states, labels, segmentation + ) + else: + (total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd( + gathered_params, hidden_states, labels, segmentation + ) return total_loss, total_z_loss def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation): batch_size, seq_len, emb_dim = hidden_states.shape - vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + batch_seq_tile_size = (batch_size * seq_len) // config.num_batch_seq_tiling reshaped_hidden_states = _reshape( - hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + hidden_states, + (config.num_batch_seq_tiling, batch_seq_tile_size, emb_dim), + reshaped_hidden_spec, + ) + reshaped_labels = _reshape( + labels, + (config.num_batch_seq_tiling, batch_seq_tile_size), + reshaped_data_spec, + ) + reshaped_segmentation = _reshape( + segmentation, + (config.num_batch_seq_tiling, batch_seq_tile_size), + reshaped_data_spec, ) - reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) - reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) # Scan body accumulates loss from each tile given chunked hidden states and labels def _fwd_scan_body(accumulators, chunk_data): @@ -169,6 +188,171 @@ def _fwd_scan_body(accumulators, chunk_data): return (total_loss, total_z_loss), residuals + # Chunked cross entropy loss forward pass, chunk along batch-sequence and + # vocab dimensions. + def _b_v_chunked_cross_entropy_loss_fwd( + gathered_params, hidden_states, labels, segmentation + ): + batch_size, seq_len, emb_dim = hidden_states.shape + v_dim = config.vocab_size + + b_dim = batch_size * seq_len + b_block_sz = b_dim // config.num_batch_seq_tiling + v_block_sz = v_dim // config.num_vocab_tiling + + if b_dim % b_block_sz != 0 or v_dim % v_block_sz != 0: + raise ValueError( + "Batch/sequence dimension and vocab dimension must be divisible by" + " their block sizes." + ) + + num_b_blocks = b_dim // b_block_sz + num_v_blocks = v_dim // v_block_sz + + flat_hidden = _reshape( + hidden_states, + (b_dim, emb_dim), + create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ), + ) + flat_labels = _reshape( + labels, + (b_dim,), + create_sharding( + model.mesh, ("activation_embed_and_logits_batch_sequence",) + ), + ) + flat_segmentation = _reshape( + segmentation, + (b_dim,), + create_sharding( + model.mesh, ("activation_embed_and_logits_batch_sequence",) + ), + ) + + if config.logits_via_embedding: + w = gathered_params["params"]["shared_embedding"]["embedding"] + else: + w = gathered_params["params"]["decoder"]["logits_dense"]["kernel"] + + if hasattr(w, "unbox"): + w = w.unbox() + elif hasattr(w, "value"): + w = w.value + + def b_loop_body(i, carry): + total_loss, total_z_loss = carry + b_start = i * b_block_sz + + def v_loop_body(j, v_carry): + lse_b_, b_loss_sum_neg_logits_ = v_carry + v_start = j * v_block_sz + labels_b = jax.lax.dynamic_slice(flat_labels, (b_start,), (b_block_sz,)) + x_b = jax.lax.dynamic_slice( + flat_hidden, (b_start, 0), (b_block_sz, emb_dim) + ) + + # Apply normalization to the batch block + x_b_norm = model.apply( + {"params": gathered_params["params"]}, + x_b, + deterministic=deterministic, + method="normalize_hidden_states", + ) + x_b_norm = _maybe_shard_with_name(x_b_norm, chunked_hidden_spec) + + # Extract w_j + if config.logits_via_embedding: + # Attend on embedding table. Table is (vocab_size, emb_dim) + # Transpose to (emb_dim, vocab_size) + w_j = jax.lax.dynamic_slice(w.T, (0, v_start), (emb_dim, v_block_sz)) + else: + w_j = jax.lax.dynamic_slice(w, (0, v_start), (emb_dim, v_block_sz)) + + # Compute logits for the block + logits_bv = jnp.dot(x_b_norm, w_j) + + if config.logits_via_embedding and config.normalize_embedding_logits: + logits_bv = logits_bv / jnp.sqrt(emb_dim) + if config.final_logits_soft_cap: + logits_bv = logits_bv / config.final_logits_soft_cap + logits_bv = jnp.tanh(logits_bv) * config.final_logits_soft_cap + + if config.cast_logits_to_fp32: + logits_bv = logits_bv.astype(jnp.float32) + + lse_b__ = jnp.logaddexp(lse_b_, jax.nn.logsumexp(logits_bv, axis=-1)) + + labels_one_hot = jax.nn.one_hot( + labels_b - v_start, v_block_sz, dtype=logits_bv.dtype + ) + b_loss_sum_neg_logits__ = b_loss_sum_neg_logits_ - jnp.sum( + logits_bv * labels_one_hot, axis=-1 + ) + return lse_b__, b_loss_sum_neg_logits__ + + lse_b, b_loss_sum_neg_logits = jax.lax.fori_loop( + 0, + num_v_blocks, + v_loop_body, + ( + jnp.full((b_block_sz,), -jnp.inf, dtype=jnp.float32), + jnp.zeros((b_block_sz,), dtype=jnp.float32), + ), + ) + + segmentation_b = jax.lax.dynamic_slice( + flat_segmentation, (b_start,), (b_block_sz,) + ) + mask = (segmentation_b != 0).astype(jnp.float32) + + # Z-loss + z_loss_b = config.z_loss_multiplier * jnp.square(lse_b) * mask + total_z_loss += jnp.sum(z_loss_b) + + b_loss_sum_neg_logits = b_loss_sum_neg_logits * mask + lse_b_masked = lse_b * mask + + total_loss += jnp.sum(b_loss_sum_neg_logits) + jnp.sum(lse_b_masked) + + return total_loss, total_z_loss + + initial_acc = (0.0, 0.0) + total_loss, total_z_loss = jax.lax.fori_loop( + 0, + num_b_blocks, + b_loop_body, + initial_acc, + ) + + # Reshape the flattened 2D tensors `(b_dim, ...)` into 3D chunked tensors + # `(num_b_blocks, b_block_sz, ...)` so we can process them sequentially + # over the batch dimension using `jax.lax.scan` in the backward pass. + # TODO(b/486111493): When we replace the bwd pass, perhaps we can think + # about what to do with these reshape operations. + reshaped_hidden_states = _reshape( + flat_hidden, (num_b_blocks, b_block_sz, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape( + flat_labels, (num_b_blocks, b_block_sz), reshaped_data_spec + ) + reshaped_segmentation = _reshape( + flat_segmentation, (num_b_blocks, b_block_sz), reshaped_data_spec + ) + + residuals = ( + gathered_params, + reshaped_hidden_states, + reshaped_labels, + reshaped_segmentation, + batch_size, + seq_len, + emb_dim, + ) + return (total_loss, total_z_loss), residuals + def _chunked_cross_entropy_loss_bwd(residuals, cotangents): # Unpack the cotangents tuple. We ignore the z_loss cotangent since the gradients # of the z_loss term are already factored into the loss_cotangent. @@ -237,7 +421,14 @@ def _bwd_scan_body(grad_params_acc, chunk_data): None, # grad for reshaped_segmentation ) - chunked_cross_entropy_loss.defvjp(_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd) + if config.num_vocab_tiling > 1: + chunked_cross_entropy_loss.defvjp( + _b_v_chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd + ) + else: + chunked_cross_entropy_loss.defvjp( + _chunked_cross_entropy_loss_fwd, _chunked_cross_entropy_loss_bwd + ) total_loss, total_z_loss = chunked_cross_entropy_loss( gathered_params, diff --git a/tests/unit/tiling_test.py b/tests/unit/tiling_test.py index 58b688634d..9e36f48138 100644 --- a/tests/unit/tiling_test.py +++ b/tests/unit/tiling_test.py @@ -42,7 +42,7 @@ def compute_loss_linen(intermediate_outputs, logits, data, config, model, params """ A loss function wrapper that deals with both vocab tiling or non-vocab tiling cases """ - if config.num_vocab_tiling > 1: + if config.num_batch_seq_tiling > 1: hidden_state_key = ("intermediates", "decoder", "hidden_states") hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] total_loss, _ = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train) @@ -206,7 +206,7 @@ def test_vocab_tiling_gradient_with_z_loss(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, z_loss_multiplier=1e-4, # Enable z-loss ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) @@ -242,7 +242,7 @@ def test_vocab_tiling_gradient_with_z_loss(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, z_loss_multiplier=1e-4, # Enable z-loss ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) @@ -273,7 +273,7 @@ def test_vocab_tiling_gradient_non_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -308,7 +308,7 @@ def test_vocab_tiling_gradient_non_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) # Loss correctness test @@ -337,7 +337,7 @@ def test_vocab_tiling_gradient_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) @@ -372,7 +372,7 @@ def test_vocab_tiling_gradient_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) @@ -399,7 +399,7 @@ def test_vocab_tiling_gradient_data_parallelism(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -435,7 +435,7 @@ def test_vocab_tiling_gradient_data_parallelism(self): dtype="float32", matmul_precision="high", ici_data_parallelism=4, - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) # Loss correctness test @@ -463,7 +463,7 @@ def test_vocab_tiling_gradient_tensor_parallelism(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -499,7 +499,7 @@ def test_vocab_tiling_gradient_tensor_parallelism(self): dtype="float32", matmul_precision="high", ici_tensor_parallelism=4, - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) # Loss correctness test @@ -529,7 +529,7 @@ def test_vocab_tiling_gradient_context_parallelism(self): packing=False, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -567,7 +567,7 @@ def test_vocab_tiling_gradient_context_parallelism(self): packing=False, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)