@@ -115,6 +115,153 @@ def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentat
115115 (total_loss , total_z_loss ), _ = _chunked_cross_entropy_loss_fwd (gathered_params , hidden_states , labels , segmentation )
116116 return total_loss , total_z_loss
117117
118+ def _b_v_chunked_cross_entropy_loss_fwd (
119+ gathered_params , hidden_states , labels , segmentation
120+ ):
121+ batch_size , seq_len , emb_dim = hidden_states .shape
122+ v_dim = config .vocab_size
123+
124+ b_dim = batch_size * seq_len
125+ b_block_sz = b_dim // config .num_of_batch_tiling
126+ v_block_sz = v_dim // config .num_vocab_tiling
127+
128+ if b_dim % b_block_sz != 0 or v_dim % v_block_sz != 0 :
129+ raise ValueError (
130+ "Batch/sequence dimension and vocab dimension must be divisible by"
131+ " their block sizes."
132+ )
133+
134+ num_b_blocks = b_dim // b_block_sz
135+ num_v_blocks = v_dim // v_block_sz
136+
137+ flat_hidden = _reshape (
138+ hidden_states ,
139+ (b_dim , emb_dim ),
140+ create_sharding (
141+ model .mesh ,
142+ ("activation_embed_and_logits_batch_sequence" , "activation_embed" ),
143+ ),
144+ )
145+ flat_labels = _reshape (
146+ labels ,
147+ (b_dim ,),
148+ create_sharding (
149+ model .mesh , ("activation_embed_and_logits_batch_sequence" ,)
150+ ),
151+ )
152+ flat_segmentation = _reshape (
153+ segmentation ,
154+ (b_dim ,),
155+ create_sharding (
156+ model .mesh , ("activation_embed_and_logits_batch_sequence" ,)
157+ ),
158+ )
159+
160+ if config .logits_via_embedding :
161+ w = gathered_params ["params" ]["shared_embedding" ]["embedding" ]
162+ else :
163+ w = gathered_params ["params" ]["decoder" ]["logits_dense" ]["kernel" ]
164+
165+ def b_loop_body (i , carry ):
166+ total_loss , total_z_loss = carry
167+ b_start = i * b_block_sz
168+
169+ def v_loop_body (j , v_carry ):
170+ lse_b_ , b_loss_sum_neg_logits_ = v_carry
171+ v_start = j * v_block_sz
172+ labels_b = jax .lax .dynamic_slice (flat_labels , (b_start ,), (b_block_sz ,))
173+ x_b = jax .lax .dynamic_slice (
174+ flat_hidden , (b_start , 0 ), (b_block_sz , emb_dim )
175+ )
176+
177+ # Apply normalization to the batch block
178+ x_b_norm = model .apply (
179+ {"params" : gathered_params ["params" ]},
180+ x_b ,
181+ deterministic = deterministic ,
182+ method = "normalize_hidden_states" ,
183+ )
184+ x_b_norm = _maybe_shard_with_name (x_b_norm , chunked_hidden_spec )
185+
186+ # Extract w_j
187+ if config .logits_via_embedding :
188+ # Attend on embedding table. Table is (vocab_size, emb_dim)
189+ # Transpose to (emb_dim, vocab_size)
190+ w_j = jax .lax .dynamic_slice (w .T , (0 , v_start ), (emb_dim , v_block_sz ))
191+ else :
192+ w_j = jax .lax .dynamic_slice (w , (0 , v_start ), (emb_dim , v_block_sz ))
193+
194+ # Compute logits for the block
195+ logits_bv = jnp .dot (x_b_norm , w_j )
196+
197+ if config .logits_via_embedding and config .normalize_embedding_logits :
198+ logits_bv = logits_bv / jnp .sqrt (emb_dim )
199+ if config .final_logits_soft_cap :
200+ logits_bv = logits_bv / config .final_logits_soft_cap
201+ logits_bv = jnp .tanh (logits_bv ) * config .final_logits_soft_cap
202+
203+ if config .cast_logits_to_fp32 :
204+ logits_bv = logits_bv .astype (jnp .float32 )
205+
206+ lse_b__ = jnp .logaddexp (lse_b_ , jax .nn .logsumexp (logits_bv , axis = - 1 ))
207+
208+ labels_one_hot = jax .nn .one_hot (
209+ labels_b - v_start , v_block_sz , dtype = logits_bv .dtype
210+ )
211+ b_loss_sum_neg_logits__ = b_loss_sum_neg_logits_ - jnp .sum (
212+ logits_bv * labels_one_hot , axis = - 1
213+ )
214+ return lse_b__ , b_loss_sum_neg_logits__
215+
216+ lse_b , b_loss_sum_neg_logits = jax .lax .fori_loop (
217+ 0 ,
218+ num_v_blocks ,
219+ v_loop_body ,
220+ (
221+ jnp .full ((b_block_sz ,), - jnp .inf , dtype = jnp .float32 ),
222+ jnp .zeros ((b_block_sz ,), dtype = jnp .float32 ),
223+ ),
224+ )
225+
226+ segmentation_b = jax .lax .dynamic_slice (
227+ flat_segmentation , (b_start ,), (b_block_sz ,)
228+ )
229+ mask = (segmentation_b != 0 ).astype (jnp .float32 )
230+
231+ # Z-loss
232+ z_loss_b = config .z_loss_multiplier * jnp .square (lse_b ) * mask
233+ total_z_loss += jnp .sum (z_loss_b )
234+
235+ b_loss_sum_neg_logits = b_loss_sum_neg_logits * mask
236+ lse_b_masked = lse_b * mask
237+
238+ total_loss += jnp .sum (b_loss_sum_neg_logits ) + jnp .sum (lse_b_masked )
239+
240+ return total_loss , total_z_loss
241+
242+ initial_acc = (0.0 , 0.0 )
243+ total_loss , total_z_loss = jax .lax .fori_loop (
244+ 0 ,
245+ num_b_blocks ,
246+ b_loop_body ,
247+ initial_acc ,
248+ )
249+
250+ # For drop-in replacement, we return residuals as the current method does.
251+ # We pack necessary values for the backward pass.
252+ # Note that the backward pass would also need to be implemented for this method
253+ # to be fully compatible with jax.custom_vjp.
254+ residuals = (
255+ gathered_params ,
256+ flat_hidden ,
257+ flat_labels ,
258+ flat_segmentation ,
259+ batch_size ,
260+ seq_len ,
261+ emb_dim ,
262+ )
263+ return (total_loss , total_z_loss ), residuals
264+
118265 def _chunked_cross_entropy_loss_fwd (gathered_params , hidden_states , labels , segmentation ):
119266 batch_size , seq_len , emb_dim = hidden_states .shape
120267 vocab_tile_size = (batch_size * seq_len ) // config .num_vocab_tiling
0 commit comments