Skip to content

Commit f9a08e4

Browse files
Integrating b/v blocking for linear softmax cross entropy loss in maxtext
PiperOrigin-RevId: 886387786
1 parent e1a2ba7 commit f9a08e4

6 files changed

Lines changed: 209 additions & 8 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,12 @@ num_slices: -1
567567

568568
# Vocab Tiling Configs
569569
# Enables a memory-saving optimization by computing the cross-entropy loss in chunks.
570-
# The logits are tiled into `num_vocab_tiling` parts along the batch-sequence axis,
571-
# reducing peak memory usage. This is highly recommended for models with large
572-
# vocabularies (e.g., Gemma). Set to a value greater than 1 to enable.
570+
# The logits are tiled into `num_vocab_tiling` parts along the vocabulary axis,
571+
# and `num_of_batch_tiling` parts along the batch-sequence axis, reducing peak memory usage.
572+
# This is highly recommended for models with large vocabularies (e.g., Gemma).
573+
# Set to a value greater than 1 to enable.
573574
num_vocab_tiling: 1
575+
num_of_batch_tiling: 1
574576

575577
# Tokenizer
576578
vocab_size: 32_000 # powers of 2 for sharding

src/maxtext/configs/pyconfig_deprecated.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,20 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -
194194
)
195195

196196

197-
def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
197+
def validate_vocab_tiling(
198+
num_vocab_tiling: int,
199+
num_of_batch_tiling: int,
200+
per_device_batch_size: int,
201+
max_target_length: int,
202+
enable_nnx: bool,
203+
):
198204
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
199205
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
206+
if (per_device_batch_size * max_target_length) % num_of_batch_tiling != 0:
207+
raise ValueError(
208+
"Per device batch size times sequence length should be divisible by the"
209+
" number of batch tiles."
210+
)
200211
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
201212
raise ValueError("We currently don't support vocab tiling on NNX module.")
202213

src/maxtext/configs/types.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,17 @@ class Tokenizer(BaseModel):
953953
)
954954
num_vocab_tiling: int = Field(
955955
1,
956-
description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.",
956+
description=(
957+
"Enables memory-saving optimization by tiling cross-entropy loss"
958+
" computation along the vocabulary axis. >1 to enable."
959+
),
960+
)
961+
num_of_batch_tiling: int = Field(
962+
1,
963+
description=(
964+
"Enables memory-saving optimization by tiling cross-entropy loss"
965+
" computation along the batch-sequence axis. >1 to enable."
966+
),
957967
)
958968

959969

@@ -2461,6 +2471,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24612471
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
24622472
):
24632473
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
2474+
if (
2475+
self.per_device_batch_size > 0
2476+
and (self.per_device_batch_size * self.max_target_length)
2477+
% self.num_of_batch_tiling
2478+
!= 0
2479+
):
2480+
raise ValueError(
2481+
"Per device batch size times sequence length should be divisible by"
2482+
" the number of batch tiles."
2483+
)
24642484
if self.num_vocab_tiling > 1 and self.enable_nnx:
24652485
raise ValueError("We currently don't support vocab tiling on NNX module.")
24662486
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":

src/maxtext/layers/decoders.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,8 @@ def _apply_embedding(
678678
return y
679679

680680
@nn.compact
681-
def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode):
682-
"""Applies final normalization and projects hidden states to logits."""
683-
681+
def normalize_hidden_states(self, y, deterministic, model_mode):
682+
"""Applies final normalization and dropout to hidden states."""
684683
cfg = self.config
685684
if cfg.shard_mode == ShardMode.EXPLICIT:
686685
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed"))
@@ -696,6 +695,20 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
696695
parameter_memory_host_offload=cfg.parameter_memory_host_offload,
697696
)(y, out_sharding=norm_out_sharding)
698697
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
698+
return y
699+
700+
@nn.compact
701+
def apply_output_head(
702+
self,
703+
shared_embedding: nn.Module | nnx.Module,
704+
y,
705+
deterministic,
706+
model_mode,
707+
):
708+
"""Applies final normalization and projects hidden states to logits."""
709+
710+
cfg = self.config
711+
y = self.normalize_hidden_states(y, deterministic, model_mode)
699712

700713
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
701714
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))

src/maxtext/models/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ def logits_from_hidden_states(self, hidden_states, deterministic, model_mode):
118118
)
119119
return logits
120120

121+
def normalize_hidden_states(self, hidden_states, deterministic, model_mode):
122+
"""Normalize hidden states (wrapping decoder.normalize_hidden_states)."""
123+
return self.decoder.normalize_hidden_states(
124+
y=hidden_states,
125+
deterministic=deterministic,
126+
model_mode=model_mode,
127+
)
128+
121129
def __call__(
122130
self,
123131
decoder_input_tokens: jnp.ndarray,

src/maxtext/utils/vocabulary_tiling.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,153 @@ def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentat
115115
(total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation)
116116
return total_loss, total_z_loss
117117

118+
def _b_v_chunked_cross_entropy_loss_fwd(
119+
gathered_params, hidden_states, labels, segmentation
120+
):
121+
batch_size, seq_len, emb_dim = hidden_states.shape
122+
v_dim = config.vocab_size
123+
124+
b_dim = batch_size * seq_len
125+
b_block_sz = b_dim // config.num_of_batch_tiling
126+
v_block_sz = v_dim // config.num_vocab_tiling
127+
128+
if b_dim % b_block_sz != 0 or v_dim % v_block_sz != 0:
129+
raise ValueError(
130+
"Batch/sequence dimension and vocab dimension must be divisible by"
131+
" their block sizes."
132+
)
133+
134+
num_b_blocks = b_dim // b_block_sz
135+
num_v_blocks = v_dim // v_block_sz
136+
137+
flat_hidden = _reshape(
138+
hidden_states,
139+
(b_dim, emb_dim),
140+
create_sharding(
141+
model.mesh,
142+
("activation_embed_and_logits_batch_sequence", "activation_embed"),
143+
),
144+
)
145+
flat_labels = _reshape(
146+
labels,
147+
(b_dim,),
148+
create_sharding(
149+
model.mesh, ("activation_embed_and_logits_batch_sequence",)
150+
),
151+
)
152+
flat_segmentation = _reshape(
153+
segmentation,
154+
(b_dim,),
155+
create_sharding(
156+
model.mesh, ("activation_embed_and_logits_batch_sequence",)
157+
),
158+
)
159+
160+
if config.logits_via_embedding:
161+
w = gathered_params["params"]["shared_embedding"]["embedding"]
162+
else:
163+
w = gathered_params["params"]["decoder"]["logits_dense"]["kernel"]
164+
165+
def b_loop_body(i, carry):
166+
total_loss, total_z_loss = carry
167+
b_start = i * b_block_sz
168+
169+
def v_loop_body(j, v_carry):
170+
lse_b_, b_loss_sum_neg_logits_ = v_carry
171+
v_start = j * v_block_sz
172+
labels_b = jax.lax.dynamic_slice(flat_labels, (b_start,), (b_block_sz,))
173+
x_b = jax.lax.dynamic_slice(
174+
flat_hidden, (b_start, 0), (b_block_sz, emb_dim)
175+
)
176+
177+
# Apply normalization to the batch block
178+
x_b_norm = model.apply(
179+
{"params": gathered_params["params"]},
180+
x_b,
181+
deterministic=deterministic,
182+
method="normalize_hidden_states",
183+
)
184+
x_b_norm = _maybe_shard_with_name(x_b_norm, chunked_hidden_spec)
185+
186+
# Extract w_j
187+
if config.logits_via_embedding:
188+
# Attend on embedding table. Table is (vocab_size, emb_dim)
189+
# Transpose to (emb_dim, vocab_size)
190+
w_j = jax.lax.dynamic_slice(w.T, (0, v_start), (emb_dim, v_block_sz))
191+
else:
192+
w_j = jax.lax.dynamic_slice(w, (0, v_start), (emb_dim, v_block_sz))
193+
194+
# Compute logits for the block
195+
logits_bv = jnp.dot(x_b_norm, w_j)
196+
197+
if config.logits_via_embedding and config.normalize_embedding_logits:
198+
logits_bv = logits_bv / jnp.sqrt(emb_dim)
199+
if config.final_logits_soft_cap:
200+
logits_bv = logits_bv / config.final_logits_soft_cap
201+
logits_bv = jnp.tanh(logits_bv) * config.final_logits_soft_cap
202+
203+
if config.cast_logits_to_fp32:
204+
logits_bv = logits_bv.astype(jnp.float32)
205+
206+
lse_b__ = jnp.logaddexp(lse_b_, jax.nn.logsumexp(logits_bv, axis=-1))
207+
208+
labels_one_hot = jax.nn.one_hot(
209+
labels_b - v_start, v_block_sz, dtype=logits_bv.dtype
210+
)
211+
b_loss_sum_neg_logits__ = b_loss_sum_neg_logits_ - jnp.sum(
212+
logits_bv * labels_one_hot, axis=-1
213+
)
214+
return lse_b__, b_loss_sum_neg_logits__
215+
216+
lse_b, b_loss_sum_neg_logits = jax.lax.fori_loop(
217+
0,
218+
num_v_blocks,
219+
v_loop_body,
220+
(
221+
jnp.full((b_block_sz,), -jnp.inf, dtype=jnp.float32),
222+
jnp.zeros((b_block_sz,), dtype=jnp.float32),
223+
),
224+
)
225+
226+
segmentation_b = jax.lax.dynamic_slice(
227+
flat_segmentation, (b_start,), (b_block_sz,)
228+
)
229+
mask = (segmentation_b != 0).astype(jnp.float32)
230+
231+
# Z-loss
232+
z_loss_b = config.z_loss_multiplier * jnp.square(lse_b) * mask
233+
total_z_loss += jnp.sum(z_loss_b)
234+
235+
b_loss_sum_neg_logits = b_loss_sum_neg_logits * mask
236+
lse_b_masked = lse_b * mask
237+
238+
total_loss += jnp.sum(b_loss_sum_neg_logits) + jnp.sum(lse_b_masked)
239+
240+
return total_loss, total_z_loss
241+
242+
initial_acc = (0.0, 0.0)
243+
total_loss, total_z_loss = jax.lax.fori_loop(
244+
0,
245+
num_b_blocks,
246+
b_loop_body,
247+
initial_acc,
248+
)
249+
250+
# For drop-in replacement, we return residuals as the current method does.
251+
# We pack necessary values for the backward pass.
252+
# Note that the backward pass would also need to be implemented for this method
253+
# to be fully compatible with jax.custom_vjp.
254+
residuals = (
255+
gathered_params,
256+
flat_hidden,
257+
flat_labels,
258+
flat_segmentation,
259+
batch_size,
260+
seq_len,
261+
emb_dim,
262+
)
263+
return (total_loss, total_z_loss), residuals
264+
118265
def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation):
119266
batch_size, seq_len, emb_dim = hidden_states.shape
120267
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling

0 commit comments

Comments
 (0)