diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 2570f5c915..ca7cb828d0 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -186,7 +186,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly. float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax -float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability +float32_weight_sum: false # whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device float32_gate_logits: false # whether to cast inputs to fp32 to compute MoE gate logits for numerical stability # multi-token prediction configs @@ -319,6 +319,7 @@ scan_pipeline_repeats: false scan_layers_per_stage: false set_remat_policy_on_pipeline_iterations: true set_remat_policy_on_layers_per_stage: false +pipeline_save_decoder_layer_input: true # set to false to reduce pipeline tmem at cost of recomputing decoder layer inputs in backward pass # Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 76d1f224b0..931bde4101 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -746,8 +746,8 @@ class MoEGeneral(BaseModel): description="Enable top-k probability normalization for router weights (Qwen3-specific).", ) float32_weight_sum: bool = Field( - True, - description="Whether to use full fp32 precision to sum expert weights for numerical stability.", + False, + description="Whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device.", ) float32_gate_logits: bool = Field( False, @@ -990,6 +990,14 @@ class PipelineParallelism(BaseModel): scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.") set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.") set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.") + pipeline_save_decoder_layer_input: bool = Field( + True, + description=( + "Whether to save 'decoder_layer_input' activations in the pipeline remat policy. " + "Setting to False reduces temporary memory (tmem) during pipeline execution at the cost " + "of recomputing decoder layer inputs in the backward pass." + ), + ) class RematAndOffload(BaseModel): @@ -2764,7 +2772,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # For AOT compilation and correctness, always prioritize the 'stage' axis for sharding when pipelining. for rule in self.logical_axis_rules: if rule and rule[0] == "activation_embed_and_logits_batch": - rule[1] = ["stage", "data", "fsdp", "fsdp_transpose", "expert"] + rule[1] = [ax for ax in ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if ax in self.mesh_axes] break if "stage" in self.mesh_axes: diff --git a/src/maxtext/kernels/gather_reduce_sc.py b/src/maxtext/kernels/gather_reduce_sc.py index 5b3b8e7597..c858b45bf5 100644 --- a/src/maxtext/kernels/gather_reduce_sc.py +++ b/src/maxtext/kernels/gather_reduce_sc.py @@ -55,6 +55,7 @@ def __getitem__(self, shape): _BF16 = VectorTypeHelper(ir.BF16Type.get) +# fmt: off @jax.jit( static_argnames=[ "reduce_group_size", @@ -69,6 +70,7 @@ def __getitem__(self, shape): "topk_wgt_zero_nan", ], ) +# fmt: on def sc_gather_reduce( op: jax.Array, idx: jax.Array, diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 2f46179dd9..b0cd0d6b11 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1588,13 +1588,22 @@ def _sequence_descriptor(segment_ids): dummy_attn_mask = None mask_type = "causal" else: - # Default case: no packing, no context parallelism - dummy_attn_mask = jnp.zeros( - (1, 1, 1, self.max_target_length, self.max_target_length), - dtype=jnp.uint8, - ) - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) + # Default case: no packing, no context parallelism. + # For synthetic data, segment IDs are always all-ones (one segment per sequence), so + # the segment mask is all-True and the combined mask reduces to pure causal masking. + # Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that + # XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory). + if self.config.dataset_type == "synthetic": + attn_mask = None + dummy_attn_mask = None + mask_type = "causal" + else: + dummy_attn_mask = jnp.zeros( + (1, 1, 1, self.max_target_length, self.max_target_length), + dtype=jnp.uint8, + ) + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -1607,12 +1616,10 @@ def _sequence_descriptor(segment_ids): dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis=self.config.context_sharding, - context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 635506291d..662955cd5a 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -553,6 +553,7 @@ def __init__( mesh=mesh, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding, + skip_trivial_specs=True, ) def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 525fff1ed5..5ee06f73f0 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding +from flax import linen as nn from flax import nnx from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType @@ -156,30 +157,36 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: self.dtype, ) - output_axis_names = ( - ( - "activation_embed_and_logits_batch", - "prefill_activation_length", - "activation_embed", - ) - if model_mode == MODEL_MODE_PREFILL - else ( - "activation_embed_and_logits_batch", - "activation_length", - "activation_embed", - ) - ) - out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)) + output_prefill_axis_names = ("activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed") + output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length", "activation_embed") - out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None + if self.config.shard_mode == ShardMode.EXPLICIT: + output_axis_names = output_prefill_axis_names if model_mode == MODEL_MODE_PREFILL else output_default_axis_names + out_pspec = logical_to_mesh_axes( + output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None) + ) + out_sharding = NamedSharding(self.mesh, out_pspec) + else: + out_sharding = None - if cfg.use_iota_embed: + one_hot_elements = 1 + for d in inputs.shape: + one_hot_elements *= d + one_hot_elements *= self.num_embeddings + one_hot_bytes = one_hot_elements * jnp.dtype(self.dtype).itemsize + use_iota = cfg.use_iota_embed and one_hot_bytes <= 2 * 1024**3 + + if use_iota: iota = lax.iota(jnp.int32, self.num_embeddings) one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) output = jnp.dot(one_hot, embedding, out_sharding=out_sharding) else: output = embedding.at[inputs].get(out_sharding=out_sharding) + if model_mode == MODEL_MODE_PREFILL: + output = nn.with_logical_constraint(output, output_prefill_axis_names) + else: + output = nn.with_logical_constraint(output, output_default_axis_names) return output def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array: diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 645eb05e09..0bc7371f51 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -22,7 +22,7 @@ import jax from jax import lax import jax.numpy as jnp -from jax.sharding import NamedSharding +from jax.sharding import NamedSharding, reshard from maxtext.common.common_types import Array, DType, ShardMode from maxtext.layers import nnx_wrappers from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned @@ -78,7 +78,10 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> if not self.with_scale: if out_sharding is not None: - y = jax.lax.with_sharding_constraint(y, out_sharding) + if self.shard_mode == ShardMode.EXPLICIT: + y = reshard(y, out_sharding) + else: + y = jax.lax.with_sharding_constraint(y, out_sharding) return y scale = self.scale.get_value() @@ -88,8 +91,14 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> scale = jax.device_put(scale, max_utils.device_space()) scale = jnp.asarray(scale, self.dtype) - effective_scale = scale + self.scale_offset - return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding) + effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale + y = y * effective_scale + if out_sharding is not None: + if self.shard_mode == ShardMode.EXPLICIT: + y = reshard(y, out_sharding) + else: + y = jax.lax.with_sharding_constraint(y, out_sharding) + return y class GlobalRMSNorm(RMSNorm): diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 62ea52782b..1196291547 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -118,6 +118,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes): rules=self.config.logical_axis_rules, debug_sharding=self.config.debug_sharding, extra_stack_level=1, + skip_trivial_specs=True, ) def _maybe_shard_with_name(self, inputs, sharding_name): @@ -139,7 +140,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage state_io_slice = state_io[:, state_io_batch_idx] - shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) if self.use_circ_storage: # Setup potential input from circ_storage, which also has a rotating index for microbatch, @@ -154,7 +154,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. # from circ_storage). first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) - first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical) # Note that first_stage_in may correspond to bubble computation during the last few iterations. # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are @@ -164,11 +163,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): def select_state_or_input(first_stage_in, shift): # Selects input for stage 0, shift for other stages - return jnp.where( - jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0, - first_stage_in, - shift, - ) + return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) @@ -180,7 +175,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): non-circular""" # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) - microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage"))) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids @@ -190,7 +184,10 @@ def get_pipeline_remat_policy(self): if self.config.remat_policy == "custom": return self.remat_policy - save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + names_to_save = ["iteration_input"] + if self.config.pipeline_save_decoder_layer_input: + names_to_save.append("decoder_layer_input") + save_input_policy = jax.checkpoint_policies.save_only_these_names(*names_to_save) if self.remat_policy is not None: remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) else: @@ -247,16 +244,6 @@ def get_main_vmap_func_for_iterations(self): def func_to_vmap( body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode ): - weights = meta.remove_axis( - weights, - 0, - { - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) vmap_func = nn.vmap( @@ -490,9 +477,7 @@ def vmap_gather(self, xs, ids, ids_dim): """ def _gather_one(x, i): - idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) - replicated_sharding = NamedSharding(self.mesh, P()) - return x.at[idx].get(out_sharding=replicated_sharding) + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) @@ -516,21 +501,16 @@ def get_new_loop_state(self, output, loop_state): loop_iteration = loop_state["loop_iteration"] old_prev_outputs = loop_state["prev_outputs"] - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _rotate_right(arr): - # we use +1 for right shifting - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return arr + # Use lax.slice to avoid generating a gather. + last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) + except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) + return jnp.concatenate([last, except_last], axis=0) - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _shift_right(arr): - stage_idx = jax.lax.axis_index("stage") - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) # Shift either rotates or shifts depending on if the last stage immediately must send to first or not # For non-circular pipelines, the last stage does not need to send to first @@ -574,29 +554,17 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] - def _rotate_left(arr, stage_size): - # we use -1 for left shifting - perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - def _shift_left(arr, stage_size, output): - stage_idx = jax.lax.axis_index("stage") - arr = _rotate_left(arr, stage_size) - return jnp.where(stage_idx == stage_size - 1, output, arr) - - @jax.shard_map( - mesh=self.mesh, - in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), - out_specs=self.state_io_spec, - ) - def _update_state_io(state_in, stream_slice, output, stream_buf_idx): + def _update_state_io(state_in, stream_slice, output): # Shift the current slice to the left, then fill the last stage with the final output. - stage_size = jax.lax.axis_size("stage") - stream_slice = _shift_left(stream_slice, stage_size, output) + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) stream_slice = jnp.expand_dims(stream_slice, 1) return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) - new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) + new_state = _update_state_io(old_state_io, stream_slice, output) new_loop_state = { "state_io": new_state, diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..c17be0c7c5 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -42,6 +42,7 @@ from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding from maxtext.utils.sharding import maybe_shard_with_logical +from maxtext.utils.sharding import remove_size_one_mesh_axis import transformers @@ -483,15 +484,14 @@ def __call__( return outputs, None # bf16 and fp8 code path for pure-JAX batch-split. - # fp8 code path supports both manual quantization and qwix - # quantization. - input_sharding = jax.typeof(inputs).sharding - activation_pspec = jax.sharding.PartitionSpec( - ("data", "fsdp", "expert"), - None, - None, + activation_pspec = remove_size_one_mesh_axis( + jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert", "context"), + None, + None, + ), + self.mesh, ) - inputs = jax.reshard(inputs, jax.sharding.NamedSharding(self.mesh, activation_pspec)) yarn_freqs = deepseek_batchsplit.initialize_yarn_freqs( decoder_positions, embedding_dims=self.config.qk_rope_head_dim, @@ -563,7 +563,6 @@ def extract_fn(x): in_specs=([activation_pspec] * self.config.batch_split_factor,), out_specs=activation_pspec, )(outputs) - outputs = jax.reshard(outputs, input_sharding) return outputs, None x = self.with_logical_constraint(inputs) diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index faf69273c6..dd6cc8e611 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -18,111 +18,33 @@ from flax import linen as nn -from flax import nnx from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config -from maxtext.layers import initializers, nnx_wrappers +from maxtext.layers import initializers from maxtext.layers import moe from maxtext.layers import quantizations -from maxtext.layers.attentions import Attention -from maxtext.layers.linears import Dropout -from maxtext.layers.normalizations import RMSNorm +from maxtext.layers.attentions import attention_as_linen +from maxtext.layers.normalizations import rms_norm from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils +from maxtext.utils.sharding import maybe_shard_with_logical # ----------------------------------------- # The Decoder Layer for Mixtral # ----------------------------------------- -class MixtralDecoderLayer(nnx.Module): +class MixtralDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" - @nn.compact - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str, - quant: None | Quant = None, - *, - rngs: nnx.Rngs, - ): - self.config = config - self.mesh = mesh - self.model_mode = model_mode - self.quant = quant - self.rngs = rngs - - batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) - dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) - - self.pre_self_attention_layer_norm = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - epsilon=config.normalization_layer_epsilon, - rngs=self.rngs, - ) - - self.self_attention = Attention( - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - attention_kernel=config.attention, - inputs_q_shape=dummy_inputs_shape, - inputs_kv_shape=dummy_inputs_shape, - mesh=mesh, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - dropout_rate=config.dropout_rate, - float32_qk_product=config.float32_qk_product, - float32_logits=config.float32_logits, - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(config), - prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), - compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), - reshape_q=config.reshape_q, - use_ragged_attention=config.use_ragged_attention, - ragged_block_size=config.ragged_block_size, - model_mode=model_mode, - rngs=self.rngs, - ) - - self.post_self_attention_layer_norm = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - epsilon=config.normalization_layer_epsilon, - rngs=self.rngs, - ) - - self.MoeBlock_0 = moe.RoutedMoE( - config=config, - num_experts=config.num_experts, - num_experts_per_tok=config.num_experts_per_tok, - mesh=mesh, - kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), - kernel_axes=("embed", None), - intermediate_dim=config.mlp_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - quant=self.quant, - rngs=self.rngs, - ) - - self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) - - self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + config: Config + mesh: Mesh + model_mode: str + quant: None | Quant = None + @nn.compact def __call__( self, inputs, @@ -139,13 +61,65 @@ def __call__( # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] - inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) - inputs = checkpoint_name(inputs, "decoder_layer_input") - lnx = self.pre_self_attention_layer_norm(inputs) - lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + cfg = self.config + mesh = self.mesh + + activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") - attention_lnx, kv_cache = self.self_attention( + def shard(x): + return maybe_shard_with_logical( + x, + activation_axis_names, + mesh=mesh, + shard_mode=cfg.shard_mode, + rules=cfg.logical_axis_rules, + skip_trivial_specs=True, + ) + + inputs = shard(inputs) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = rms_norm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) + lnx = shard(lnx) + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(cfg, model_mode) + dummy_inputs_shape = (batch_size, seq_len, cfg.emb_dim) + + attention_lnx, kv_cache = attention_as_linen( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + use_ragged_attention=cfg.use_ragged_attention, + ragged_block_size=cfg.ragged_block_size, + model_mode=model_mode, + name="self_attention", + )( lnx, lnx, decoder_positions, @@ -157,28 +131,47 @@ def __call__( attention_metadata=attention_metadata, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) + attention_lnx = shard(attention_lnx) intermediate_inputs = inputs + attention_lnx # Fully Connected - hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + hidden_states = rms_norm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="post_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(intermediate_inputs) + hidden_states = shard(hidden_states) load_balance_loss = None # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. - mlp_lnx, load_balance_loss, _ = self.MoeBlock_0(hidden_states) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + mlp_lnx, load_balance_loss, _ = moe.get_routed_moe( + config=cfg, + num_experts=cfg.num_experts, + num_experts_per_tok=cfg.num_experts_per_tok, + mesh=mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=cfg.mlp_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + quant=self.quant, + name="MoeBlock_0", + )(hidden_states) + mlp_lnx = shard(mlp_lnx) layer_output = mlp_lnx + intermediate_inputs - layer_output = self.dropout(layer_output, deterministic=deterministic) - layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = shard(layer_output) - if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) - if self.config.record_internal_nn_metrics: + if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( @@ -187,13 +180,10 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - if self.config.scan_layers: + if cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache -MixtralDecoderLayerToLinen = nnx_wrappers.to_linen_class( - MixtralDecoderLayer, - base_metadata_fn=initializers.variable_to_logically_partitioned, -) +MixtralDecoderLayerToLinen = MixtralDecoderLayer diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 82460cb020..1d2c3063fe 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,6 +36,13 @@ import jax.numpy as jnp from jax.sharding import NamedSharding + +import flax + +try: + flax.config.update("flax_always_shard_variable", False) +except LookupError: + pass from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning @@ -381,10 +388,11 @@ def diff_wrapper(param, rest, config, data): (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) nnx.update(state.model, new_rest) - raw_grads = jax.tree_util.tree_map( - lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, - raw_grads, - ) + if config.grad_dtype != jnp.float32: + raw_grads = jax.tree_util.tree_map( + lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, + raw_grads, + ) if config.parameter_memory_host_offload: raw_grads = jax.device_put( raw_grads, diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 471abac3f0..9897da5165 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -61,12 +61,9 @@ def validate_config(config): """Validates the config is is setup correctly to compile, returning a useful error message if not.""" assert config.compile_topology != "", ( - "You must pass your desired target hardware in compile_topology, e.g." - " compile_topology=v5e-256" + "You must pass your desired target hardware in compile_topology, e.g." " compile_topology=v5e-256" ) - assert ( - config.compile_topology_num_slices > 0 - ), "You must set compile_topology_num_slices to a positive integer" + assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer" def get_topology_mesh(config): @@ -78,18 +75,12 @@ def get_topology_mesh(config): num_slices=config.compile_topology_num_slices, ).devices else: - target_hardware = accelerator_to_spec_map.get_system_characteristics( - config.compile_topology - ) + target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) if target_hardware.platform == "gpu": # Disable sharded autotuning. This is an optimization to distribute # autotuning across the fleet, but can cause hangs with AoT compilation. - os.environ["XLA_FLAGS"] = ( - os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" - ) - jax.config.update( - "mock_num_gpu_processes", config.compile_topology_num_slices - ) + os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" + jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices) topology_devices = jax.devices() else: topology_devices = get_topology_desc( @@ -104,14 +95,8 @@ def get_topology_mesh(config): "jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type, ) - topology_device_mesh = maxtext_utils.create_device_mesh( - config, topology_devices - ) - mesh_axis_type = ( - AxisType.Explicit - if config.shard_mode == ShardMode.EXPLICIT - else AxisType.Auto - ) + topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) + mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto topology_mesh = Mesh( topology_device_mesh, config.mesh_axes, @@ -129,9 +114,7 @@ def _collect_nnx_activation_shardings(create_model_fn, config, mesh): input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) abstract_input = jax.ShapeDtypeStruct(input_shape, jnp.int32) - def _nnx_forward( - decoder_input_tokens, decoder_positions, decoder_segment_ids - ): + def _nnx_forward(decoder_input_tokens, decoder_positions, decoder_segment_ids): model_instance = create_model_fn() return model_instance( decoder_input_tokens=decoder_input_tokens, @@ -140,9 +123,7 @@ def _nnx_forward( enable_dropout=False, ) - with jax.set_mesh(mesh), nn_partitioning.axis_rules( - config.logical_axis_rules - ): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): jax.eval_shape(_nnx_forward, abstract_input, abstract_input, abstract_input) @@ -151,13 +132,9 @@ def get_shaped_inputs(topology_mesh, config): # Construct the model and optimizer to get shaped versions of the state quant = quantizations.configure_quantization(config) if config.pure_nnx: - _create_model_partial, model = ( - model_creation_utils.create_nnx_abstract_model(config, topology_mesh) - ) + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh) else: - model = Transformer( - config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN - ) + model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) # The learning_rate_schedule is baked into the compiled object. learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon @@ -176,20 +153,14 @@ def create_train_state_fn(): init_state_fn = create_train_state_fn else: - init_state_fn = functools.partial( - maxtext_utils.init_initial_state, model, tx, config, True, example_rng - ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - config, topology_mesh, init_state_fn, True - ) + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True) if config.pure_nnx: # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings. - logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx( - state_mesh_shardings - ) + logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -198,9 +169,7 @@ def create_train_state_fn(): model = graphdef else: # unsharded logical annotations - logical_annotations = maxtext_utils.get_logical_annotations( - config, topology_mesh, init_state_fn - ) + logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) @@ -217,9 +186,7 @@ def create_train_state_fn(): # Collect NNX activation shardings via an abstract forward pass (must run # after get_abstract_state, which only traces __init__). if config.debug_sharding and config.pure_nnx: - _collect_nnx_activation_shardings( - _create_model_partial, config, topology_mesh - ) + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) return ( shaped_train_args, @@ -256,16 +223,18 @@ def jit_and_compile( maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args) lowered = jitted.lower(*func_input_args, **func_input_kwargs) # Import libtpu flags as compiler options. Defaults to empty dict if string is empty. - compiler_options = max_utils.parse_libtpu_flags_to_dict( - config.compile_xla_flags - ) + compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags) compiled = lowered.compile(compiler_options=compiler_options) return compiled def save_compiled(compiled, save_name): """Serialize and save the compiled function.""" - serialized, _, _ = serialize(compiled) + result = serialize(compiled) + # jax.experimental.serialize_executable.serialize() changed its return type: + # older JAX: (bytes, in_tree, out_tree) + # newer JAX: bytes + serialized = result[0] if isinstance(result, tuple) else result with open(save_name, "wb") as f: f.write(serialized) @@ -293,18 +262,12 @@ def is_oom(argv: Sequence[str]) -> bool: ) = get_shaped_inputs(topology_mesh, config) # Update params_shardings when shard_optimizer_over_data is enabled (Zero-1) - params_shardings, state_mesh_shardings = ( - sharding.maybe_update_params_sharding_with_opt( - config, state_mesh_shardings - ) - ) + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) # When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings # but keep the updated state_mesh_shardings for the optimizer state if config.shard_optimizer_over_data: - input_state_mesh_shardings = state_mesh_shardings.replace( - params=params_shardings - ) + input_state_mesh_shardings = state_mesh_shardings.replace(params=params_shardings) else: input_state_mesh_shardings = state_mesh_shardings @@ -355,8 +318,7 @@ def is_oom(argv: Sequence[str]) -> bool: def main(argv: Sequence[str]) -> None: jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") - + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) print("Starting train_compile.py...", flush=True) @@ -381,18 +343,12 @@ def main(argv: Sequence[str]) -> None: ) = get_shaped_inputs(topology_mesh, config) # Update params_shardings when shard_optimizer_over_data is enabled (Zero-1) - params_shardings, state_mesh_shardings = ( - sharding.maybe_update_params_sharding_with_opt( - config, state_mesh_shardings - ) - ) + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) # When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings # but keep the updated state_mesh_shardings for the optimizer state if config.shard_optimizer_over_data: - input_state_mesh_shardings = state_mesh_shardings.replace( - params=params_shardings - ) + input_state_mesh_shardings = state_mesh_shardings.replace(params=params_shardings) else: input_state_mesh_shardings = state_mesh_shardings @@ -401,21 +357,15 @@ def main(argv: Sequence[str]) -> None: if config.enable_diloco: # Build abstract DiLoCo state and shardings for AOT compilation abstract_state = shaped_train_args[0] - diloco_state, state_mesh_shardings, inner_state_shardings = ( - diloco.build_abstract_diloco_state( - config, abstract_state, state_mesh_shardings, topology_mesh - ) + diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( + config, abstract_state, state_mesh_shardings, topology_mesh ) # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. - shaped_rng_arg = ( - shaped_train_args[2] if len(shaped_train_args) > 2 else None - ) + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco - train_step_partial = functools.partial( - train.train_step, model, config, inner_state_shardings, params_shardings - ) + train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, params_shardings) train_step_fn = diloco.build_diloco_train_step(config, train_step_partial) # For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng) @@ -480,10 +430,7 @@ def main(argv: Sequence[str]) -> None: if config.compiled_trainstep_file != "": print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) - print( - "Successfully saved compiled object as" - f" {config.compiled_trainstep_file}" - ) + print("Successfully saved compiled object as" f" {config.compiled_trainstep_file}") print("Finished train_compile.py successfully!", flush=True) print(f"Cost analysis: {compiled.cost_analysis()}") print(f"Memory analysis: {compiled.memory_analysis()}") diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 4a500e2fe1..0902717928 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -132,7 +132,15 @@ def maybe_shard_with_pspec( def maybe_shard_with_logical( - inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" + inputs, + logical_axes, + mesh, + shard_mode, + rules=None, + debug_sharding=False, + extra_stack_level=0, + sharding_desc="", + skip_trivial_specs=False, ): """ A wrapper of maybe_shard_with_name when logical axes are inputs @@ -147,6 +155,9 @@ def maybe_shard_with_logical( named_sharding = create_sharding(mesh, logical_axes, rules=rules) + if skip_trivial_specs and all(ax is None or ax == () for ax in named_sharding.spec): + return inputs + return maybe_shard_with_name( inputs, named_sharding, diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index a469a6fa70..7466333d1c 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -950,6 +950,7 @@ def test_circular_pipeline_ag_per_repeat_ep_ds(self): "use_random_routing=true", "max_target_length=128", "pipeline_fsdp_ag_per_repeat=true", + "pipeline_save_decoder_layer_input=false", ) ) diff --git a/tests/utils/reference_hlo_deepseek3.txt b/tests/utils/reference_hlo_deepseek3.txt index ffff9103ad..1541bedc76 100644 --- a/tests/utils/reference_hlo_deepseek3.txt +++ b/tests/utils/reference_hlo_deepseek3.txt @@ -10,21 +10,21 @@ StackFrames %region_46.56 (top_k.25: bf16[], top_k.26: bf16[], top_k.27: s32[], top_k.28: s32[]) -> pred[] { - %constant.1408 = s32[]{:T(128)} constant(0) - %constant.1409 = s32[]{:T(128)} constant(2147483647) + %constant.1358 = s32[]{:T(128)} constant(0) + %constant.1359 = s32[]{:T(128)} constant(2147483647) %top_k.25 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} %top_k.26 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} %top_k.27 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} %top_k.28 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} - %convert.393 = f32[]{:T(128)S(6)} convert(%top_k.25), metadata={op_name="convert.18"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.39 = s32[]{:T(128)S(6)} bitcast-convert(%convert.393), metadata={op_name="bitcast-convert.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1408), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1409, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.269 = f32[]{:T(128)S(6)} convert(%top_k.25), metadata={op_name="convert.18"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.39 = s32[]{:T(128)S(6)} bitcast-convert(%convert.269), metadata={op_name="bitcast-convert.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1358), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1359, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.127 = s32[]{:T(128)S(6)} select(%compare.144, %xor.40, %bitcast-convert.39), metadata={op_name="select.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} - %convert.394 = f32[]{:T(128)S(6)} convert(%top_k.26), metadata={op_name="convert.19"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.40 = s32[]{:T(128)S(6)} bitcast-convert(%convert.394), metadata={op_name="bitcast-convert.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1408), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1409, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.270 = f32[]{:T(128)S(6)} convert(%top_k.26), metadata={op_name="convert.19"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.40 = s32[]{:T(128)S(6)} bitcast-convert(%convert.270), metadata={op_name="bitcast-convert.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1358), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1359, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.128 = s32[]{:T(128)S(6)} select(%compare.145, %xor.41, %bitcast-convert.40), metadata={op_name="select.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %compare.146 = pred[]{:T(512)S(6)} compare(%select.127, %select.128), direction=GT, metadata={op_name="compare.0"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %compare.147 = pred[]{:T(512)S(6)} compare(%select.128, %select.127), direction=GT, metadata={op_name="compare.117"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -69,28 +69,28 @@ StackFrames ROOT %select.134 = pred[]{:T(512)} select(%compare.156, %compare.157, %lt_to.37), metadata={op_name="select.116"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_119.141 (reduce_sum.157: bf16[], reduce_sum.158: bf16[]) -> bf16[] { - %reduce_sum.157 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} - %reduce_sum.158 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} - ROOT %reduce_sum.159 = bf16[]{:T(256)} add(%reduce_sum.157, %reduce_sum.158), metadata={op_name="checkpoint/moe_layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_119.141 (reduce_sum.151: bf16[], reduce_sum.152: bf16[]) -> bf16[] { + %reduce_sum.151 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + %reduce_sum.152 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + ROOT %reduce_sum.153 = bf16[]{:T(256)} add(%reduce_sum.151, %reduce_sum.152), metadata={op_name="checkpoint/moe_layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_107.126 (psum.6: bf16[], psum.9: bf16[]) -> bf16[] { %psum.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.9 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1407 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1373 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_108.127 (psum.10: bf16[], psum.11: bf16[]) -> bf16[] { %psum.10 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.11 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1408 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1374 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_109.128 (psum.14: bf16[], psum.15: bf16[]) -> bf16[] { %psum.14 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.15 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1409 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1375 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_62.73 (reduce-window.111: s32[], reduce-window.112: s32[]) -> s32[] { @@ -212,11 +212,11 @@ StackFrames %param_1.108 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.13 = s32[1024]{0:T(1024)} custom-call(%param_1.108), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} %slice.920 = s32[512]{0:T(512)} slice(%custom-call.13), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %reshape.3318 = s32[4,128]{1,0:T(4,128)} reshape(%slice.920), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.847 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3318), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %gather.187 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.847), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %transpose.846 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.187), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %reshape.3317 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.846), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.3445 = s32[4,128]{1,0:T(4,128)} reshape(%slice.920), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.604 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3445), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.187 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.604), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.603 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.187), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.3444 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.603), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} } %fused_computation.6 (param_0.20: f32[163840,32], param_1.110: s32[1024]) -> f32[512,32] { @@ -224,11 +224,11 @@ StackFrames %param_1.110 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.15 = s32[1024]{0:T(1024)} custom-call(%param_1.110), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} %slice.922 = s32[512]{0:T(512)} slice(%custom-call.15), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %reshape.3326 = s32[4,128]{1,0:T(4,128)} reshape(%slice.922), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %transpose.853 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3326), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %gather.189 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.853), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %transpose.852 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.189), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - ROOT %reshape.3325 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.852), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3453 = s32[4,128]{1,0:T(4,128)} reshape(%slice.922), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.610 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3453), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.189 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.610), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.609 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.189), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3452 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.609), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} } %fused_computation.7 (param_0.23: f32[163840,32], param_1.112: s32[1024]) -> f32[512,32] { @@ -236,11 +236,11 @@ StackFrames %param_1.112 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.17 = s32[1024]{0:T(1024)} custom-call(%param_1.112), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} %slice.924 = s32[512]{0:T(512)} slice(%custom-call.17), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %reshape.3334 = s32[4,128]{1,0:T(4,128)} reshape(%slice.924), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %transpose.859 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3334), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %gather.191 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.859), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %transpose.858 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.191), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - ROOT %reshape.3333 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.858), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3461 = s32[4,128]{1,0:T(4,128)} reshape(%slice.924), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.616 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3461), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.191 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.616), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.615 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.191), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3460 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.615), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} } %fused_computation.8 (param_0.26: f32[163840,32], param_1.120: s32[1024]) -> f32[512,32] { @@ -248,11 +248,11 @@ StackFrames %param_1.120 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.25 = s32[1024]{0:T(1024)} custom-call(%param_1.120), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} %slice.932 = s32[512]{0:T(512)} slice(%custom-call.25), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %reshape.3342 = s32[4,128]{1,0:T(4,128)} reshape(%slice.932), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %transpose.865 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3342), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %gather.193 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.865), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %transpose.864 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.193), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - ROOT %reshape.3341 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3469 = s32[4,128]{1,0:T(4,128)} reshape(%slice.932), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.622 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3469), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.193 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.622), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.621 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.193), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3468 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.621), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} } %fused_computation.9 (param_0.29: f32[163840,32], param_1.122: s32[1024]) -> f32[512,32] { @@ -260,11 +260,11 @@ StackFrames %param_1.122 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.27 = s32[1024]{0:T(1024)} custom-call(%param_1.122), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} %slice.934 = s32[512]{0:T(512)} slice(%custom-call.27), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %reshape.3350 = s32[4,128]{1,0:T(4,128)} reshape(%slice.934), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %transpose.871 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3350), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %gather.195 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.871), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %transpose.870 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.195), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - ROOT %reshape.3349 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.870), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3477 = s32[4,128]{1,0:T(4,128)} reshape(%slice.934), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.628 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3477), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.195 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.628), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.627 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.195), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3476 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.627), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} } %fused_computation.10 (param_0.32: bf16[4096,512], param_1.126: s32[4096]) -> bf16[4096,512] { @@ -272,11 +272,11 @@ StackFrames %param_1.126 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.31 = s32[4096]{0:T(1024)} custom-call(%param_1.126), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.938 = s32[4096]{0:T(1024)} slice(%custom-call.31), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3358 = s32[4096]{0:T(1024)} reshape(%slice.938), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.877 = s32[4096]{0:T(1024)} transpose(%reshape.3358), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.197 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.877), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.876 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.197), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3357 = bf16[4096,512]{1,0:T(8,128)(2,1)} reshape(%transpose.876), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3485 = s32[4096]{0:T(1024)} reshape(%slice.938), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.634 = s32[4096]{0:T(1024)} transpose(%reshape.3485), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.197 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.634), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.633 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.197), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3484 = bf16[4096,512]{1,0:T(8,128)(2,1)} reshape(%transpose.633), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.11 (param_0.35: bf16[4096,512], param_1.128: s32[4096]) -> bf16[4096,512] { @@ -284,11 +284,11 @@ StackFrames %param_1.128 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.33 = s32[4096]{0:T(1024)} custom-call(%param_1.128), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.940 = s32[4096]{0:T(1024)} slice(%custom-call.33), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3366 = s32[4096]{0:T(1024)} reshape(%slice.940), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.883 = s32[4096]{0:T(1024)} transpose(%reshape.3366), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.199 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.883), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.882 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.199), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3365 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.882), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3493 = s32[4096]{0:T(1024)} reshape(%slice.940), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.640 = s32[4096]{0:T(1024)} transpose(%reshape.3493), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.199 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.640), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.639 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.199), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3492 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.639), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.12 (param_0.38: bf16[4096,512], param_1.130: s32[4096]) -> bf16[4096,512] { @@ -296,11 +296,11 @@ StackFrames %param_1.130 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.35 = s32[4096]{0:T(1024)} custom-call(%param_1.130), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.942 = s32[4096]{0:T(1024)} slice(%custom-call.35), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3374 = s32[4096]{0:T(1024)} reshape(%slice.942), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.889 = s32[4096]{0:T(1024)} transpose(%reshape.3374), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.201 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.889), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.888 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.201), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3373 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3501 = s32[4096]{0:T(1024)} reshape(%slice.942), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.646 = s32[4096]{0:T(1024)} transpose(%reshape.3501), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.201 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.646), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.645 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.201), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3500 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.645), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.13 (param_0.41: bf16[4096,512], param_1.132: s32[4096]) -> bf16[4096,512] { @@ -308,11 +308,11 @@ StackFrames %param_1.132 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.37 = s32[4096]{0:T(1024)} custom-call(%param_1.132), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.944 = s32[4096]{0:T(1024)} slice(%custom-call.37), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3382 = s32[4096]{0:T(1024)} reshape(%slice.944), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.895 = s32[4096]{0:T(1024)} transpose(%reshape.3382), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.203 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.895), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.894 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.203), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3381 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.894), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3509 = s32[4096]{0:T(1024)} reshape(%slice.944), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.652 = s32[4096]{0:T(1024)} transpose(%reshape.3509), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.203 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.652), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.651 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.203), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3508 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.651), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.15 (param_0.47: s32[256], param_1.124: s32[1024]) -> s32[263] { @@ -320,11 +320,11 @@ StackFrames %param_1.124 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.29 = s32[1024]{0:T(1024)} custom-call(%param_1.124), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} %slice.936 = s32[263]{0:T(512)} slice(%custom-call.29), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3413 = s32[263]{0:T(512)} reshape(%slice.936), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.911 = s32[263]{0:T(512)} transpose(%reshape.3413), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.208 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.911), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.910 = s32[263]{0:T(512)} transpose(%gather.208), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3412 = s32[263]{0:T(512)S(1)} reshape(%transpose.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3540 = s32[263]{0:T(512)} reshape(%slice.936), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.668 = s32[263]{0:T(512)} transpose(%reshape.3540), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.208 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.668), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.667 = s32[263]{0:T(512)} transpose(%gather.208), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3539 = s32[263]{0:T(512)S(1)} reshape(%transpose.667), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} } %fused_computation.16 (param_0.50: s32[256], param_1.134: s32[1024]) -> s32[263] { @@ -332,46 +332,46 @@ StackFrames %param_1.134 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.39 = s32[1024]{0:T(1024)} custom-call(%param_1.134), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} %slice.946 = s32[263]{0:T(512)} slice(%custom-call.39), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3436 = s32[263]{0:T(512)} reshape(%slice.946), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.921 = s32[263]{0:T(512)} transpose(%reshape.3436), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.211 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.921), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.920 = s32[263]{0:T(512)} transpose(%gather.211), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3435 = s32[263]{0:T(512)S(1)} reshape(%transpose.920), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3563 = s32[263]{0:T(512)} reshape(%slice.946), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.678 = s32[263]{0:T(512)} transpose(%reshape.3563), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.211 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.678), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.677 = s32[263]{0:T(512)} transpose(%gather.211), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3562 = s32[263]{0:T(512)S(1)} reshape(%transpose.677), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} } %region_173.198.clone (scatter-add.94: bf16[], scatter-add.96: bf16[]) -> bf16[] { %scatter-add.94 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} %scatter-add.96 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} - ROOT %add.1875 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1833 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %fused_computation.21 (param_0.55: bf16[129280,512], param_1.65: s32[512], param_2.24: bf16[512,512]) -> bf16[129280,512] { %param_0.55 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) %param_1.65 = s32[512]{0:T(512)S(1)} parameter(1) - %reshape.3490 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.954 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3490), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %reshape.3617 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.711 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3617), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} %param_2.24 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} parameter(2) - %reshape.3491 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} - %transpose.955 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3491), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} - ROOT %scatter.77 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.954, %transpose.955), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} + %reshape.3618 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + %transpose.712 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3618), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + ROOT %scatter.77 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.711, %transpose.712), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} } -%region_12.18 (top_k.0: bf16[], top_k.6: bf16[], top_k.7: s32[], top_k.8: s32[]) -> pred[] { - %constant.1369 = s32[]{:T(128)} constant(0) - %constant.1370 = s32[]{:T(128)} constant(2147483647) +%region_11.17 (top_k.0: bf16[], top_k.6: bf16[], top_k.7: s32[], top_k.8: s32[]) -> pred[] { + %constant.1320 = s32[]{:T(128)} constant(0) + %constant.1321 = s32[]{:T(128)} constant(2147483647) %top_k.0 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} %top_k.6 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} %top_k.7 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} %top_k.8 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} - %convert.385 = f32[]{:T(128)S(6)} convert(%top_k.0), metadata={op_name="convert.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.35 = s32[]{:T(128)S(6)} bitcast-convert(%convert.385), metadata={op_name="bitcast-convert.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1369), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1370, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.263 = f32[]{:T(128)S(6)} convert(%top_k.0), metadata={op_name="convert.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.35 = s32[]{:T(128)S(6)} bitcast-convert(%convert.263), metadata={op_name="bitcast-convert.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1320), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1321, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.118 = s32[]{:T(128)S(6)} select(%compare.128, %xor.36, %bitcast-convert.35), metadata={op_name="select.14"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} - %convert.386 = f32[]{:T(128)S(6)} convert(%top_k.6), metadata={op_name="convert.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.36 = s32[]{:T(128)S(6)} bitcast-convert(%convert.386), metadata={op_name="bitcast-convert.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1369), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1370, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.264 = f32[]{:T(128)S(6)} convert(%top_k.6), metadata={op_name="convert.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.36 = s32[]{:T(128)S(6)} bitcast-convert(%convert.264), metadata={op_name="bitcast-convert.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1320), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1321, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.119 = s32[]{:T(128)S(6)} select(%compare.129, %xor.37, %bitcast-convert.36), metadata={op_name="select.15"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %compare.130 = pred[]{:T(512)S(6)} compare(%select.118, %select.119), direction=GT, metadata={op_name="compare.1"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %compare.131 = pred[]{:T(512)S(6)} compare(%select.119, %select.118), direction=GT, metadata={op_name="compare.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -380,37 +380,37 @@ StackFrames ROOT %select.120 = pred[]{:T(512)} select(%compare.132, %compare.133, %compare.130), metadata={op_name="select.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_15.21.clone.1 (reduce-window.326: s32[], reduce-window.327: s32[]) -> s32[] { +%region_14.20.clone.1 (reduce-window.326: s32[], reduce-window.327: s32[]) -> s32[] { %reduce-window.326 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.20"} %reduce-window.327 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.20"} ROOT %reduce_window_sum.282 = s32[]{:T(128)} add(%reduce-window.326, %reduce-window.327), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_16.22.clone.1 (reduce-window.330: s32[], reduce-window.331: s32[]) -> s32[] { +%region_15.21.clone.1 (reduce-window.330: s32[], reduce-window.331: s32[]) -> s32[] { %reduce-window.330 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.56"} %reduce-window.331 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.56"} ROOT %reduce_window_sum.284 = s32[]{:T(128)} add(%reduce-window.330, %reduce-window.331), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_18.24.clone.1 (reduce-window.334: s32[], reduce-window.335: s32[]) -> s32[] { +%region_17.23.clone.1 (reduce-window.334: s32[], reduce-window.335: s32[]) -> s32[] { %reduce-window.334 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.22"} %reduce-window.335 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.22"} ROOT %reduce_window_sum.286 = s32[]{:T(128)} add(%reduce-window.334, %reduce-window.335), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_19.25.clone.1 (reduce-window.338: s32[], reduce-window.339: s32[]) -> s32[] { +%region_18.24.clone.1 (reduce-window.338: s32[], reduce-window.339: s32[]) -> s32[] { %reduce-window.338 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.57"} %reduce-window.339 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.57"} ROOT %reduce_window_sum.288 = s32[]{:T(128)} add(%reduce-window.338, %reduce-window.339), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_21.27.clone.1 (reduce-window.342: s32[], reduce-window.343: s32[]) -> s32[] { +%region_20.26.clone.1 (reduce-window.342: s32[], reduce-window.343: s32[]) -> s32[] { %reduce-window.342 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.24"} %reduce-window.343 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.24"} ROOT %reduce_window_sum.290 = s32[]{:T(128)} add(%reduce-window.342, %reduce-window.343), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_22.28.clone.1 (reduce-window.346: s32[], reduce-window.347: s32[]) -> s32[] { +%region_21.27.clone.1 (reduce-window.346: s32[], reduce-window.347: s32[]) -> s32[] { %reduce-window.346 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.58"} %reduce-window.347 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.58"} ROOT %reduce_window_sum.292 = s32[]{:T(128)} add(%reduce-window.346, %reduce-window.347), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -421,32 +421,32 @@ StackFrames %param_1.114 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.19 = s32[1024]{0:T(1024)} custom-call(%param_1.114), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} %slice.926 = s32[263]{0:T(512)} slice(%custom-call.19), slice={[0:263]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3634 = s32[263]{0:T(512)} reshape(%slice.926), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.1037 = s32[263]{0:T(512)} transpose(%reshape.3634), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.213 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.1037), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.1036 = s32[263]{0:T(512)} transpose(%gather.213), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3633 = s32[263]{0:T(512)S(1)} reshape(%transpose.1036), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3761 = s32[263]{0:T(512)} reshape(%slice.926), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.794 = s32[263]{0:T(512)} transpose(%reshape.3761), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.213 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.794), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.793 = s32[263]{0:T(512)} transpose(%gather.213), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3760 = s32[263]{0:T(512)S(1)} reshape(%transpose.793), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} } -%region_27.34.clone.1 (reduce-window.350: s32[], reduce-window.351: s32[]) -> s32[] { +%region_26.33.clone.1 (reduce-window.350: s32[], reduce-window.351: s32[]) -> s32[] { %reduce-window.350 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.26"} %reduce-window.351 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.26"} ROOT %reduce_window_sum.294 = s32[]{:T(128)} add(%reduce-window.350, %reduce-window.351), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_29.36.clone.1 (reduce-window.354: s32[], reduce-window.355: s32[]) -> s32[] { +%region_28.35.clone.1 (reduce-window.354: s32[], reduce-window.355: s32[]) -> s32[] { %reduce-window.354 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.27"} %reduce-window.355 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.27"} ROOT %reduce_window_sum.296 = s32[]{:T(128)} add(%reduce-window.354, %reduce-window.355), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_30.37.clone.1 (reduce-window.358: s32[], reduce-window.359: s32[]) -> s32[] { +%region_29.36.clone.1 (reduce-window.358: s32[], reduce-window.359: s32[]) -> s32[] { %reduce-window.358 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.59"} %reduce-window.359 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.59"} ROOT %reduce_window_sum.298 = s32[]{:T(128)} add(%reduce-window.358, %reduce-window.359), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_13.19 (sort.44: s32[], sort.45: s32[], sort.46: s32[], sort.47: s32[], sort.48: s32[], sort.49: s32[]) -> pred[] { +%region_12.18 (sort.44: s32[], sort.45: s32[], sort.46: s32[], sort.47: s32[], sort.48: s32[], sort.49: s32[]) -> pred[] { %sort.46 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} %sort.47 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} %sort.44 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} @@ -465,14 +465,14 @@ StackFrames %param_1.116 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.21 = s32[4096]{0:T(1024)} custom-call(%param_1.116), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.928 = s32[4096]{0:T(1024)} slice(%custom-call.21), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3657 = s32[4096]{0:T(1024)} reshape(%slice.928), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.1043 = s32[4096]{0:T(1024)} transpose(%reshape.3657), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.214 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.1043), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.1042 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.214), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3656 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1042), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3784 = s32[4096]{0:T(1024)} reshape(%slice.928), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.800 = s32[4096]{0:T(1024)} transpose(%reshape.3784), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.214 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.800), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.799 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.214), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3783 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.799), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } -%region_31.39 (sort.50: s32[], sort.51: s32[], sort.52: s32[], sort.53: s32[], sort.54: s32[], sort.55: s32[]) -> pred[] { +%region_30.38 (sort.50: s32[], sort.51: s32[], sort.52: s32[], sort.53: s32[], sort.54: s32[], sort.55: s32[]) -> pred[] { %sort.52 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} %sort.53 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} %sort.50 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} @@ -491,11 +491,11 @@ StackFrames %param_1.118 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.23 = s32[4096]{0:T(1024)} custom-call(%param_1.118), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.930 = s32[4096]{0:T(1024)} slice(%custom-call.23), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3659 = s32[4096]{0:T(1024)} reshape(%slice.930), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.1045 = s32[4096]{0:T(1024)} transpose(%reshape.3659), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.215 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.1045), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.1044 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.215), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3658 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1044), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3786 = s32[4096]{0:T(1024)} reshape(%slice.930), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.802 = s32[4096]{0:T(1024)} transpose(%reshape.3786), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.215 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.802), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.801 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.215), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3785 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.801), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %compare (name: s32[], name.1: s32[], name.2: bf16[], name.3: bf16[]) -> pred[] { @@ -538,459 +538,459 @@ StackFrames ROOT %compare.389 = pred[] compare(%name.16, %name.17), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%called_computation.13 (param_0.4523: s32[256]) -> s32[256] { - %param_0.4523 = s32[256]{0:T(256)} parameter(0) - ROOT %copy.2073 = s32[256]{0:T(256)} copy(%param_0.4523), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1134","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.13 (param_0.4539: s32[256]) -> s32[256] { + %param_0.4539 = s32[256]{0:T(256)} parameter(0) + ROOT %copy.2071 = s32[256]{0:T(256)} copy(%param_0.4539), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1134","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.13 (param_0.4524: s32[256]) -> s32[256] { - %param_0.4524 = s32[256]{0:T(256)} parameter(0) - ROOT %copy.2074.cloned.1 = s32[256]{0:T(256)} call(%param_0.4524), to_apply=%called_computation.13 +%async_computation.13 (param_0.4540: s32[256]) -> s32[256] { + %param_0.4540 = s32[256]{0:T(256)} parameter(0) + ROOT %copy.2072.cloned.1 = s32[256]{0:T(256)} call(%param_0.4540), to_apply=%called_computation.13 }, execution_thread="sparsecore" %region_49.59 (scatter-add.14: s32[], scatter-add.15: s32[]) -> s32[] { %scatter-add.14 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.15 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1352 = s32[]{:T(128)S(7)} add(%scatter-add.14, %scatter-add.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1318 = s32[]{:T(128)S(7)} add(%scatter-add.14, %scatter-add.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.22.clone.clone (param_0.4525: s32[256], param_1.5325: s32[4096], param_2.4494: s32[4096]) -> s32[256] { - %param_0.4525 = s32[256]{0:T(256)} parameter(0) - %param_1.5325 = s32[4096]{0:T(1024)} parameter(1) - %reshape.3923 = s32[4096]{0:T(1024)} reshape(%param_1.5325), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} - %transpose.1100 = s32[4096]{0:T(1024)} transpose(%reshape.3923), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} - %param_2.4494 = s32[4096]{0:T(1024)} parameter(2) - %reshape.3924 = s32[4096]{0:T(1024)} reshape(%param_2.4494), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - %transpose.1101 = s32[4096]{0:T(1024)} transpose(%reshape.3924), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.237 = s32[256]{0:T(256)} scatter(%param_0.4525, %transpose.1100, %transpose.1101), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_49.59, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%fused_computation.22.clone.clone (param_0.4541: s32[256], param_1.5324: s32[4096], param_2.4484: s32[4096]) -> s32[256] { + %param_0.4541 = s32[256]{0:T(256)} parameter(0) + %param_1.5324 = s32[4096]{0:T(1024)} parameter(1) + %reshape.4051 = s32[4096]{0:T(1024)} reshape(%param_1.5324), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %transpose.857 = s32[4096]{0:T(1024)} transpose(%reshape.4051), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %param_2.4484 = s32[4096]{0:T(1024)} parameter(2) + %reshape.4052 = s32[4096]{0:T(1024)} reshape(%param_2.4484), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.858 = s32[4096]{0:T(1024)} transpose(%reshape.4052), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.237 = s32[256]{0:T(256)} scatter(%param_0.4541, %transpose.857, %transpose.858), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_49.59, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.14 (param_0.4526: s32[256], param_1.5326: s32[4096], param_2.4495: s32[4096]) -> s32[256] { - %param_0.4526 = s32[256]{0:T(256)} parameter(0) - %param_1.5326 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4495 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.39 = s32[256]{0:T(256)} fusion(%param_0.4526, %param_1.5326, %param_2.4495), kind=kCustom, calls=%fused_computation.22.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.14 (param_0.4542: s32[256], param_1.5325: s32[4096], param_2.4485: s32[4096]) -> s32[256] { + %param_0.4542 = s32[256]{0:T(256)} parameter(0) + %param_1.5325 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4485 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.39 = s32[256]{0:T(256)} fusion(%param_0.4542, %param_1.5325, %param_2.4485), kind=kCustom, calls=%fused_computation.22.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.14 (param_0.4527: s32[256], param_1.5327: s32[4096], param_2.4496: s32[4096]) -> s32[256] { - %param_0.4527 = s32[256]{0:T(256)} parameter(0) - %param_1.5327 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4496 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.40.cloned.1 = s32[256]{0:T(256)} call(%param_0.4527, %param_1.5327, %param_2.4496), to_apply=%called_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%async_computation.14 (param_0.4543: s32[256], param_1.5326: s32[4096], param_2.4486: s32[4096]) -> s32[256] { + %param_0.4543 = s32[256]{0:T(256)} parameter(0) + %param_1.5326 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4486 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.40.cloned.1 = s32[256]{0:T(256)} call(%param_0.4543, %param_1.5326, %param_2.4486), to_apply=%called_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation (param_0.84: s32[256], param_1.136: s32[4096], param_2.80: s32[4096], param_3.3085: token[]) -> s32[256] { - %param_3.3085 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation (param_0.84: s32[256], param_1.136: s32[4096], param_2.80: s32[4096], param_3.3090: token[]) -> s32[256] { + %param_3.3090 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.84 = s32[256]{0:T(256)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.136 = s32[4096]{0:T(1024)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.80 = s32[4096]{0:T(1024)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2074.cloned.1.call-start = ((s32[256]{0:T(256)}), s32[256]{0:T(256)}, u32[]{:S(8)}) async-start(%param_0.84), async_execution_thread="sparsecore", calls=%async_computation.13 - %copy.2074.cloned.1.call-done = s32[256]{0:T(256)} async-done(%copy.2074.cloned.1.call-start) - %scatter_offload_custom_fusion.40.cloned.1.call-start = ((s32[256]{0:T(256)}, s32[4096]{0:T(1024)}, s32[4096]{0:T(1024)}), s32[256]{0:T(256)}, u32[]{:S(8)}) async-start(%copy.2074.cloned.1.call-done, %param_1.136, %param_2.80), async_execution_thread="sparsecore", calls=%async_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} + %copy.2072.cloned.1.call-start = ((s32[256]{0:T(256)}), s32[256]{0:T(256)}, u32[]{:S(8)}) async-start(%param_0.84), async_execution_thread="sparsecore", calls=%async_computation.13 + %copy.2072.cloned.1.call-done = s32[256]{0:T(256)} async-done(%copy.2072.cloned.1.call-start) + %scatter_offload_custom_fusion.40.cloned.1.call-start = ((s32[256]{0:T(256)}, s32[4096]{0:T(1024)}, s32[4096]{0:T(1024)}), s32[256]{0:T(256)}, u32[]{:S(8)}) async-start(%copy.2072.cloned.1.call-done, %param_1.136, %param_2.80), async_execution_thread="sparsecore", calls=%async_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.40.cloned.1.call-done = s32[256]{0:T(256)} async-done(%scatter_offload_custom_fusion.40.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation (param_0.85: s32[256], param_1.137: s32[4096], param_2.81: s32[4096], param_3.3084: token[]) -> s32[256] { - %param_3.3084 = token[] parameter(3) +%async_computation (param_0.85: s32[256], param_1.137: s32[4096], param_2.81: s32[4096], param_3.3089: token[]) -> s32[256] { + %param_3.3089 = token[] parameter(3) %param_0.85 = s32[256]{0:T(256)} parameter(0) %param_1.137 = s32[4096]{0:T(1024)} parameter(1) %param_2.81 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.2.cloned.1 = s32[256]{0:T(256)} call(%param_0.85, %param_1.137, %param_2.81, %param_3.3084), to_apply=%called_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.2.cloned.1 = s32[256]{0:T(256)} call(%param_0.85, %param_1.137, %param_2.81, %param_3.3089), to_apply=%called_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.15 (param_0.4528: f32[9]) -> f32[9] { - %param_0.4528 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2075 = f32[9]{0:T(128)} copy(%param_0.4528), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.15 (param_0.4544: f32[9]) -> f32[9] { + %param_0.4544 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2073 = f32[9]{0:T(128)} copy(%param_0.4544), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.15 (param_0.4529: f32[9]) -> f32[9] { - %param_0.4529 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2076.cloned.1 = f32[9]{0:T(128)} call(%param_0.4529), to_apply=%called_computation.15 +%async_computation.15 (param_0.4545: f32[9]) -> f32[9] { + %param_0.4545 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2074.cloned.1 = f32[9]{0:T(128)} call(%param_0.4545), to_apply=%called_computation.15 }, execution_thread="sparsecore" %region_61.72 (scatter-add.24: f32[], scatter-add.25: f32[]) -> f32[] { %scatter-add.24 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.25 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1358 = f32[]{:T(128)S(7)} add(%scatter-add.24, %scatter-add.25), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1324 = f32[]{:T(128)S(7)} add(%scatter-add.24, %scatter-add.25), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.24.clone.clone (param_0.4530: f32[9], param_1.5328: s32[256], param_2.4497: f32[256]) -> f32[9] { - %param_0.4530 = f32[9]{0:T(128)} parameter(0) - %param_1.5328 = s32[256]{0:T(256)} parameter(1) - %reshape.3925 = s32[256]{0:T(256)} reshape(%param_1.5328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1102 = s32[256]{0:T(256)} transpose(%reshape.3925), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4497 = f32[256]{0:T(256)} parameter(2) - %reshape.3926 = f32[256]{0:T(256)} reshape(%param_2.4497), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1103 = f32[256]{0:T(256)} transpose(%reshape.3926), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.238 = f32[9]{0:T(128)} scatter(%param_0.4530, %transpose.1102, %transpose.1103), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_61.72, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.24.clone.clone (param_0.4546: f32[9], param_1.5327: s32[256], param_2.4487: f32[256]) -> f32[9] { + %param_0.4546 = f32[9]{0:T(128)} parameter(0) + %param_1.5327 = s32[256]{0:T(256)} parameter(1) + %reshape.4053 = s32[256]{0:T(256)} reshape(%param_1.5327), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.859 = s32[256]{0:T(256)} transpose(%reshape.4053), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4487 = f32[256]{0:T(256)} parameter(2) + %reshape.4054 = f32[256]{0:T(256)} reshape(%param_2.4487), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.860 = f32[256]{0:T(256)} transpose(%reshape.4054), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.238 = f32[9]{0:T(128)} scatter(%param_0.4546, %transpose.859, %transpose.860), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_61.72, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.16 (param_0.4531: f32[9], param_1.5329: s32[256], param_2.4498: f32[256]) -> f32[9] { - %param_0.4531 = f32[9]{0:T(128)} parameter(0) - %param_1.5329 = s32[256]{0:T(256)} parameter(1) - %param_2.4498 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.41 = f32[9]{0:T(128)} fusion(%param_0.4531, %param_1.5329, %param_2.4498), kind=kCustom, calls=%fused_computation.24.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.16 (param_0.4547: f32[9], param_1.5328: s32[256], param_2.4488: f32[256]) -> f32[9] { + %param_0.4547 = f32[9]{0:T(128)} parameter(0) + %param_1.5328 = s32[256]{0:T(256)} parameter(1) + %param_2.4488 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.41 = f32[9]{0:T(128)} fusion(%param_0.4547, %param_1.5328, %param_2.4488), kind=kCustom, calls=%fused_computation.24.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.16 (param_0.4532: f32[9], param_1.5330: s32[256], param_2.4499: f32[256]) -> f32[9] { - %param_0.4532 = f32[9]{0:T(128)} parameter(0) - %param_1.5330 = s32[256]{0:T(256)} parameter(1) - %param_2.4499 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.42.cloned.1 = f32[9]{0:T(128)} call(%param_0.4532, %param_1.5330, %param_2.4499), to_apply=%called_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.16 (param_0.4548: f32[9], param_1.5329: s32[256], param_2.4489: f32[256]) -> f32[9] { + %param_0.4548 = f32[9]{0:T(128)} parameter(0) + %param_1.5329 = s32[256]{0:T(256)} parameter(1) + %param_2.4489 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.42.cloned.1 = f32[9]{0:T(128)} call(%param_0.4548, %param_1.5329, %param_2.4489), to_apply=%called_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.1 (param_0.87: f32[9], param_1.139: s32[256], param_2.83: f32[256], param_3.3099: token[]) -> f32[9] { - %param_3.3099 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.1 (param_0.87: f32[9], param_1.139: s32[256], param_2.83: f32[256], param_3.3104: token[]) -> f32[9] { + %param_3.3104 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.87 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.139 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.83 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2076.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.87), async_execution_thread="sparsecore", calls=%async_computation.15 - %copy.2076.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2076.cloned.1.call-start) - %scatter_offload_custom_fusion.42.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2076.cloned.1.call-done, %param_1.139, %param_2.83), async_execution_thread="sparsecore", calls=%async_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2074.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.87), async_execution_thread="sparsecore", calls=%async_computation.15 + %copy.2074.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2074.cloned.1.call-start) + %scatter_offload_custom_fusion.42.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2074.cloned.1.call-done, %param_1.139, %param_2.83), async_execution_thread="sparsecore", calls=%async_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.42.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.42.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.1 (param_0.88: f32[9], param_1.140: s32[256], param_2.84: f32[256], param_3.3098: token[]) -> f32[9] { - %param_3.3098 = token[] parameter(3) +%async_computation.1 (param_0.88: f32[9], param_1.140: s32[256], param_2.84: f32[256], param_3.3103: token[]) -> f32[9] { + %param_3.3103 = token[] parameter(3) %param_0.88 = f32[9]{0:T(128)} parameter(0) %param_1.140 = s32[256]{0:T(256)} parameter(1) %param_2.84 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.5.cloned.1 = f32[9]{0:T(128)} call(%param_0.88, %param_1.140, %param_2.84, %param_3.3098), to_apply=%called_computation.1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.5.cloned.1 = f32[9]{0:T(128)} call(%param_0.88, %param_1.140, %param_2.84, %param_3.3103), to_apply=%called_computation.1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.17 (param_0.4533: s32[263]) -> s32[263] { - %param_0.4533 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2077 = s32[263]{0:T(512)} copy(%param_0.4533), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.17 (param_0.4549: s32[263]) -> s32[263] { + %param_0.4549 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2075 = s32[263]{0:T(512)} copy(%param_0.4549), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.17 (param_0.4534: s32[263]) -> s32[263] { - %param_0.4534 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2078.cloned.1 = s32[263]{0:T(512)} call(%param_0.4534), to_apply=%called_computation.17 +%async_computation.17 (param_0.4550: s32[263]) -> s32[263] { + %param_0.4550 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2076.cloned.1 = s32[263]{0:T(512)} call(%param_0.4550), to_apply=%called_computation.17 }, execution_thread="sparsecore" %region_63.74 (scatter-add.28: s32[], scatter-add.29: s32[]) -> s32[] { %scatter-add.28 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.29 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1359 = s32[]{:T(128)S(7)} add(%scatter-add.28, %scatter-add.29), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1325 = s32[]{:T(128)S(7)} add(%scatter-add.28, %scatter-add.29), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.25.clone.clone (param_0.4535: s32[263], param_1.5331: s32[8], param_2.4500: s32[8]) -> s32[263] { - %param_0.4535 = s32[263]{0:T(512)} parameter(0) - %param_1.5331 = s32[8]{0:T(128)} parameter(1) - %reshape.3927 = s32[8]{0:T(128)} reshape(%param_1.5331), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1104 = s32[8]{0:T(128)} transpose(%reshape.3927), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4500 = s32[8]{0:T(128)} parameter(2) - %reshape.3928 = s32[8]{0:T(128)} reshape(%param_2.4500), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1105 = s32[8]{0:T(128)} transpose(%reshape.3928), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.239 = s32[263]{0:T(512)} scatter(%param_0.4535, %transpose.1104, %transpose.1105), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_63.74, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.25.clone.clone (param_0.4551: s32[263], param_1.5330: s32[8], param_2.4490: s32[8]) -> s32[263] { + %param_0.4551 = s32[263]{0:T(512)} parameter(0) + %param_1.5330 = s32[8]{0:T(128)} parameter(1) + %reshape.4055 = s32[8]{0:T(128)} reshape(%param_1.5330), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.861 = s32[8]{0:T(128)} transpose(%reshape.4055), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4490 = s32[8]{0:T(128)} parameter(2) + %reshape.4056 = s32[8]{0:T(128)} reshape(%param_2.4490), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.862 = s32[8]{0:T(128)} transpose(%reshape.4056), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.239 = s32[263]{0:T(512)} scatter(%param_0.4551, %transpose.861, %transpose.862), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_63.74, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.18 (param_0.4536: s32[263], param_1.5332: s32[8], param_2.4501: s32[8]) -> s32[263] { - %param_0.4536 = s32[263]{0:T(512)} parameter(0) - %param_1.5332 = s32[8]{0:T(128)} parameter(1) - %param_2.4501 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.43 = s32[263]{0:T(512)} fusion(%param_0.4536, %param_1.5332, %param_2.4501), kind=kCustom, calls=%fused_computation.25.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.18 (param_0.4552: s32[263], param_1.5331: s32[8], param_2.4491: s32[8]) -> s32[263] { + %param_0.4552 = s32[263]{0:T(512)} parameter(0) + %param_1.5331 = s32[8]{0:T(128)} parameter(1) + %param_2.4491 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.43 = s32[263]{0:T(512)} fusion(%param_0.4552, %param_1.5331, %param_2.4491), kind=kCustom, calls=%fused_computation.25.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.18 (param_0.4537: s32[263], param_1.5333: s32[8], param_2.4502: s32[8]) -> s32[263] { - %param_0.4537 = s32[263]{0:T(512)} parameter(0) - %param_1.5333 = s32[8]{0:T(128)} parameter(1) - %param_2.4502 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.44.cloned.1 = s32[263]{0:T(512)} call(%param_0.4537, %param_1.5333, %param_2.4502), to_apply=%called_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.18 (param_0.4553: s32[263], param_1.5332: s32[8], param_2.4492: s32[8]) -> s32[263] { + %param_0.4553 = s32[263]{0:T(512)} parameter(0) + %param_1.5332 = s32[8]{0:T(128)} parameter(1) + %param_2.4492 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.44.cloned.1 = s32[263]{0:T(512)} call(%param_0.4553, %param_1.5332, %param_2.4492), to_apply=%called_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.2 (param_0.90: s32[263], param_1.142: s32[8], param_2.86: s32[8], param_3.3105: token[]) -> s32[263] { - %param_3.3105 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.2 (param_0.90: s32[263], param_1.142: s32[8], param_2.86: s32[8], param_3.3110: token[]) -> s32[263] { + %param_3.3110 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.90 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.142 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.86 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2078.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.90), async_execution_thread="sparsecore", calls=%async_computation.17 - %copy.2078.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2078.cloned.1.call-start) - %scatter_offload_custom_fusion.44.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2078.cloned.1.call-done, %param_1.142, %param_2.86), async_execution_thread="sparsecore", calls=%async_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2076.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.90), async_execution_thread="sparsecore", calls=%async_computation.17 + %copy.2076.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2076.cloned.1.call-start) + %scatter_offload_custom_fusion.44.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2076.cloned.1.call-done, %param_1.142, %param_2.86), async_execution_thread="sparsecore", calls=%async_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.44.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.44.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.2 (param_0.91: s32[263], param_1.143: s32[8], param_2.87: s32[8], param_3.3104: token[]) -> s32[263] { - %param_3.3104 = token[] parameter(3) +%async_computation.2 (param_0.91: s32[263], param_1.143: s32[8], param_2.87: s32[8], param_3.3109: token[]) -> s32[263] { + %param_3.3109 = token[] parameter(3) %param_0.91 = s32[263]{0:T(512)} parameter(0) %param_1.143 = s32[8]{0:T(128)} parameter(1) %param_2.87 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.8.cloned.1 = s32[263]{0:T(512)} call(%param_0.91, %param_1.143, %param_2.87, %param_3.3104), to_apply=%called_computation.2, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.8.cloned.1 = s32[263]{0:T(512)} call(%param_0.91, %param_1.143, %param_2.87, %param_3.3109), to_apply=%called_computation.2, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.19 (param_0.4538: s32[263]) -> s32[263] { - %param_0.4538 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2079 = s32[263]{0:T(512)} copy(%param_0.4538), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.19 (param_0.4554: s32[263]) -> s32[263] { + %param_0.4554 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2077 = s32[263]{0:T(512)} copy(%param_0.4554), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.19 (param_0.4539: s32[263]) -> s32[263] { - %param_0.4539 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2080.cloned.1 = s32[263]{0:T(512)} call(%param_0.4539), to_apply=%called_computation.19 +%async_computation.19 (param_0.4555: s32[263]) -> s32[263] { + %param_0.4555 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2078.cloned.1 = s32[263]{0:T(512)} call(%param_0.4555), to_apply=%called_computation.19 }, execution_thread="sparsecore" %region_73.86.clone (scatter-add.163: s32[], scatter-add.164: s32[]) -> s32[] { %scatter-add.163 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.164 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2474 = s32[]{:T(128)S(7)} add(%scatter-add.163, %scatter-add.164), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2432 = s32[]{:T(128)S(7)} add(%scatter-add.163, %scatter-add.164), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.26.clone.clone (param_0.4540: s32[263], param_1.5334: s32[256], param_2.4503: s32[256]) -> s32[263] { - %param_0.4540 = s32[263]{0:T(512)} parameter(0) - %param_1.5334 = s32[256]{0:T(256)} parameter(1) - %reshape.3929 = s32[256]{0:T(256)} reshape(%param_1.5334), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1106 = s32[256]{0:T(256)} transpose(%reshape.3929), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4503 = s32[256]{0:T(256)} parameter(2) - %reshape.3930 = s32[256]{0:T(256)} reshape(%param_2.4503), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1107 = s32[256]{0:T(256)} transpose(%reshape.3930), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.240 = s32[263]{0:T(512)} scatter(%param_0.4540, %transpose.1106, %transpose.1107), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_73.86.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.26.clone.clone (param_0.4556: s32[263], param_1.5333: s32[256], param_2.4493: s32[256]) -> s32[263] { + %param_0.4556 = s32[263]{0:T(512)} parameter(0) + %param_1.5333 = s32[256]{0:T(256)} parameter(1) + %reshape.4057 = s32[256]{0:T(256)} reshape(%param_1.5333), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.863 = s32[256]{0:T(256)} transpose(%reshape.4057), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4493 = s32[256]{0:T(256)} parameter(2) + %reshape.4058 = s32[256]{0:T(256)} reshape(%param_2.4493), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.864 = s32[256]{0:T(256)} transpose(%reshape.4058), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.240 = s32[263]{0:T(512)} scatter(%param_0.4556, %transpose.863, %transpose.864), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_73.86.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.20 (param_0.4541: s32[263], param_1.5335: s32[256], param_2.4504: s32[256]) -> s32[263] { - %param_0.4541 = s32[263]{0:T(512)} parameter(0) - %param_1.5335 = s32[256]{0:T(256)} parameter(1) - %param_2.4504 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.45 = s32[263]{0:T(512)} fusion(%param_0.4541, %param_1.5335, %param_2.4504), kind=kCustom, calls=%fused_computation.26.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.20 (param_0.4557: s32[263], param_1.5334: s32[256], param_2.4494: s32[256]) -> s32[263] { + %param_0.4557 = s32[263]{0:T(512)} parameter(0) + %param_1.5334 = s32[256]{0:T(256)} parameter(1) + %param_2.4494 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.45 = s32[263]{0:T(512)} fusion(%param_0.4557, %param_1.5334, %param_2.4494), kind=kCustom, calls=%fused_computation.26.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.20 (param_0.4542: s32[263], param_1.5336: s32[256], param_2.4505: s32[256]) -> s32[263] { - %param_0.4542 = s32[263]{0:T(512)} parameter(0) - %param_1.5336 = s32[256]{0:T(256)} parameter(1) - %param_2.4505 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.46.cloned.1 = s32[263]{0:T(512)} call(%param_0.4542, %param_1.5336, %param_2.4505), to_apply=%called_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.20 (param_0.4558: s32[263], param_1.5335: s32[256], param_2.4495: s32[256]) -> s32[263] { + %param_0.4558 = s32[263]{0:T(512)} parameter(0) + %param_1.5335 = s32[256]{0:T(256)} parameter(1) + %param_2.4495 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.46.cloned.1 = s32[263]{0:T(512)} call(%param_0.4558, %param_1.5335, %param_2.4495), to_apply=%called_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.3 (param_0.93: s32[263], param_1.145: s32[256], param_2.89: s32[256], param_3.3091: token[]) -> s32[263] { - %param_3.3091 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.3 (param_0.93: s32[263], param_1.145: s32[256], param_2.89: s32[256], param_3.3096: token[]) -> s32[263] { + %param_3.3096 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.93 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.145 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.89 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2080.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.93), async_execution_thread="sparsecore", calls=%async_computation.19 - %copy.2080.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2080.cloned.1.call-start) - %scatter_offload_custom_fusion.46.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2080.cloned.1.call-done, %param_1.145, %param_2.89), async_execution_thread="sparsecore", calls=%async_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2078.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.93), async_execution_thread="sparsecore", calls=%async_computation.19 + %copy.2078.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2078.cloned.1.call-start) + %scatter_offload_custom_fusion.46.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2078.cloned.1.call-done, %param_1.145, %param_2.89), async_execution_thread="sparsecore", calls=%async_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.46.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.46.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.3 (param_0.94: s32[263], param_1.146: s32[256], param_2.90: s32[256], param_3.3090: token[]) -> s32[263] { - %param_3.3090 = token[] parameter(3) +%async_computation.3 (param_0.94: s32[263], param_1.146: s32[256], param_2.90: s32[256], param_3.3095: token[]) -> s32[263] { + %param_3.3095 = token[] parameter(3) %param_0.94 = s32[263]{0:T(512)} parameter(0) %param_1.146 = s32[256]{0:T(256)} parameter(1) %param_2.90 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.11.cloned.1 = s32[263]{0:T(512)} call(%param_0.94, %param_1.146, %param_2.90, %param_3.3090), to_apply=%called_computation.3, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.11.cloned.1 = s32[263]{0:T(512)} call(%param_0.94, %param_1.146, %param_2.90, %param_3.3095), to_apply=%called_computation.3, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.21 (param_0.4543: f32[9]) -> f32[9] { - %param_0.4543 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2081 = f32[9]{0:T(128)} copy(%param_0.4543), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.21 (param_0.4559: f32[9]) -> f32[9] { + %param_0.4559 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2079 = f32[9]{0:T(128)} copy(%param_0.4559), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.21 (param_0.4544: f32[9]) -> f32[9] { - %param_0.4544 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2082.cloned.1 = f32[9]{0:T(128)} call(%param_0.4544), to_apply=%called_computation.21 +%async_computation.21 (param_0.4560: f32[9]) -> f32[9] { + %param_0.4560 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2080.cloned.1 = f32[9]{0:T(128)} call(%param_0.4560), to_apply=%called_computation.21 }, execution_thread="sparsecore" %region_79.95.clone (scatter-add.167: f32[], scatter-add.168: f32[]) -> f32[] { %scatter-add.167 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.168 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2476 = f32[]{:T(128)S(7)} add(%scatter-add.167, %scatter-add.168), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2434 = f32[]{:T(128)S(7)} add(%scatter-add.167, %scatter-add.168), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.27.clone.clone (param_0.4545: f32[9], param_1.5337: s32[256], param_2.4506: f32[256]) -> f32[9] { - %param_0.4545 = f32[9]{0:T(128)} parameter(0) - %param_1.5337 = s32[256]{0:T(256)} parameter(1) - %reshape.3931 = s32[256]{0:T(256)} reshape(%param_1.5337), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1108 = s32[256]{0:T(256)} transpose(%reshape.3931), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4506 = f32[256]{0:T(256)} parameter(2) - %reshape.3932 = f32[256]{0:T(256)} reshape(%param_2.4506), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1109 = f32[256]{0:T(256)} transpose(%reshape.3932), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.241 = f32[9]{0:T(128)} scatter(%param_0.4545, %transpose.1108, %transpose.1109), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_79.95.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.27.clone.clone (param_0.4561: f32[9], param_1.5336: s32[256], param_2.4496: f32[256]) -> f32[9] { + %param_0.4561 = f32[9]{0:T(128)} parameter(0) + %param_1.5336 = s32[256]{0:T(256)} parameter(1) + %reshape.4059 = s32[256]{0:T(256)} reshape(%param_1.5336), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.865 = s32[256]{0:T(256)} transpose(%reshape.4059), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4496 = f32[256]{0:T(256)} parameter(2) + %reshape.4060 = f32[256]{0:T(256)} reshape(%param_2.4496), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.866 = f32[256]{0:T(256)} transpose(%reshape.4060), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.241 = f32[9]{0:T(128)} scatter(%param_0.4561, %transpose.865, %transpose.866), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_79.95.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.22 (param_0.4546: f32[9], param_1.5338: s32[256], param_2.4507: f32[256]) -> f32[9] { - %param_0.4546 = f32[9]{0:T(128)} parameter(0) - %param_1.5338 = s32[256]{0:T(256)} parameter(1) - %param_2.4507 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.47 = f32[9]{0:T(128)} fusion(%param_0.4546, %param_1.5338, %param_2.4507), kind=kCustom, calls=%fused_computation.27.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.22 (param_0.4562: f32[9], param_1.5337: s32[256], param_2.4497: f32[256]) -> f32[9] { + %param_0.4562 = f32[9]{0:T(128)} parameter(0) + %param_1.5337 = s32[256]{0:T(256)} parameter(1) + %param_2.4497 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.47 = f32[9]{0:T(128)} fusion(%param_0.4562, %param_1.5337, %param_2.4497), kind=kCustom, calls=%fused_computation.27.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.22 (param_0.4547: f32[9], param_1.5339: s32[256], param_2.4508: f32[256]) -> f32[9] { - %param_0.4547 = f32[9]{0:T(128)} parameter(0) - %param_1.5339 = s32[256]{0:T(256)} parameter(1) - %param_2.4508 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.48.cloned.1 = f32[9]{0:T(128)} call(%param_0.4547, %param_1.5339, %param_2.4508), to_apply=%called_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.22 (param_0.4563: f32[9], param_1.5338: s32[256], param_2.4498: f32[256]) -> f32[9] { + %param_0.4563 = f32[9]{0:T(128)} parameter(0) + %param_1.5338 = s32[256]{0:T(256)} parameter(1) + %param_2.4498 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.48.cloned.1 = f32[9]{0:T(128)} call(%param_0.4563, %param_1.5338, %param_2.4498), to_apply=%called_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.4 (param_0.96: f32[9], param_1.148: s32[256], param_2.92: f32[256], param_3.3097: token[]) -> f32[9] { - %param_3.3097 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.4 (param_0.96: f32[9], param_1.148: s32[256], param_2.92: f32[256], param_3.3102: token[]) -> f32[9] { + %param_3.3102 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.96 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.148 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.92 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2082.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.96), async_execution_thread="sparsecore", calls=%async_computation.21 - %copy.2082.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2082.cloned.1.call-start) - %scatter_offload_custom_fusion.48.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2082.cloned.1.call-done, %param_1.148, %param_2.92), async_execution_thread="sparsecore", calls=%async_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2080.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.96), async_execution_thread="sparsecore", calls=%async_computation.21 + %copy.2080.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2080.cloned.1.call-start) + %scatter_offload_custom_fusion.48.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2080.cloned.1.call-done, %param_1.148, %param_2.92), async_execution_thread="sparsecore", calls=%async_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.48.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.48.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.4 (param_0.97: f32[9], param_1.149: s32[256], param_2.93: f32[256], param_3.3096: token[]) -> f32[9] { - %param_3.3096 = token[] parameter(3) +%async_computation.4 (param_0.97: f32[9], param_1.149: s32[256], param_2.93: f32[256], param_3.3101: token[]) -> f32[9] { + %param_3.3101 = token[] parameter(3) %param_0.97 = f32[9]{0:T(128)} parameter(0) %param_1.149 = s32[256]{0:T(256)} parameter(1) %param_2.93 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.14.cloned.1 = f32[9]{0:T(128)} call(%param_0.97, %param_1.149, %param_2.93, %param_3.3096), to_apply=%called_computation.4, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.14.cloned.1 = f32[9]{0:T(128)} call(%param_0.97, %param_1.149, %param_2.93, %param_3.3101), to_apply=%called_computation.4, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.23 (param_0.4548: s32[263]) -> s32[263] { - %param_0.4548 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2083 = s32[263]{0:T(512)} copy(%param_0.4548), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.23 (param_0.4564: s32[263]) -> s32[263] { + %param_0.4564 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2081 = s32[263]{0:T(512)} copy(%param_0.4564), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.23 (param_0.4549: s32[263]) -> s32[263] { - %param_0.4549 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2084.cloned.1 = s32[263]{0:T(512)} call(%param_0.4549), to_apply=%called_computation.23 +%async_computation.23 (param_0.4565: s32[263]) -> s32[263] { + %param_0.4565 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2082.cloned.1 = s32[263]{0:T(512)} call(%param_0.4565), to_apply=%called_computation.23 }, execution_thread="sparsecore" %region_81.97.clone (scatter-add.171: s32[], scatter-add.172: s32[]) -> s32[] { %scatter-add.171 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.172 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2478 = s32[]{:T(128)S(7)} add(%scatter-add.171, %scatter-add.172), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2436 = s32[]{:T(128)S(7)} add(%scatter-add.171, %scatter-add.172), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.28.clone.clone (param_0.4550: s32[263], param_1.5340: s32[8], param_2.4509: s32[8]) -> s32[263] { - %param_0.4550 = s32[263]{0:T(512)} parameter(0) - %param_1.5340 = s32[8]{0:T(128)} parameter(1) - %reshape.3933 = s32[8]{0:T(128)} reshape(%param_1.5340), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1110 = s32[8]{0:T(128)} transpose(%reshape.3933), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4509 = s32[8]{0:T(128)} parameter(2) - %reshape.3934 = s32[8]{0:T(128)} reshape(%param_2.4509), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1111 = s32[8]{0:T(128)} transpose(%reshape.3934), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.242 = s32[263]{0:T(512)} scatter(%param_0.4550, %transpose.1110, %transpose.1111), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_81.97.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.28.clone.clone (param_0.4566: s32[263], param_1.5339: s32[8], param_2.4499: s32[8]) -> s32[263] { + %param_0.4566 = s32[263]{0:T(512)} parameter(0) + %param_1.5339 = s32[8]{0:T(128)} parameter(1) + %reshape.4061 = s32[8]{0:T(128)} reshape(%param_1.5339), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.867 = s32[8]{0:T(128)} transpose(%reshape.4061), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4499 = s32[8]{0:T(128)} parameter(2) + %reshape.4062 = s32[8]{0:T(128)} reshape(%param_2.4499), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.868 = s32[8]{0:T(128)} transpose(%reshape.4062), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.242 = s32[263]{0:T(512)} scatter(%param_0.4566, %transpose.867, %transpose.868), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_81.97.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.24 (param_0.4551: s32[263], param_1.5341: s32[8], param_2.4510: s32[8]) -> s32[263] { - %param_0.4551 = s32[263]{0:T(512)} parameter(0) - %param_1.5341 = s32[8]{0:T(128)} parameter(1) - %param_2.4510 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.49 = s32[263]{0:T(512)} fusion(%param_0.4551, %param_1.5341, %param_2.4510), kind=kCustom, calls=%fused_computation.28.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.24 (param_0.4567: s32[263], param_1.5340: s32[8], param_2.4500: s32[8]) -> s32[263] { + %param_0.4567 = s32[263]{0:T(512)} parameter(0) + %param_1.5340 = s32[8]{0:T(128)} parameter(1) + %param_2.4500 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.49 = s32[263]{0:T(512)} fusion(%param_0.4567, %param_1.5340, %param_2.4500), kind=kCustom, calls=%fused_computation.28.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.24 (param_0.4552: s32[263], param_1.5342: s32[8], param_2.4511: s32[8]) -> s32[263] { - %param_0.4552 = s32[263]{0:T(512)} parameter(0) - %param_1.5342 = s32[8]{0:T(128)} parameter(1) - %param_2.4511 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.50.cloned.1 = s32[263]{0:T(512)} call(%param_0.4552, %param_1.5342, %param_2.4511), to_apply=%called_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.24 (param_0.4568: s32[263], param_1.5341: s32[8], param_2.4501: s32[8]) -> s32[263] { + %param_0.4568 = s32[263]{0:T(512)} parameter(0) + %param_1.5341 = s32[8]{0:T(128)} parameter(1) + %param_2.4501 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.50.cloned.1 = s32[263]{0:T(512)} call(%param_0.4568, %param_1.5341, %param_2.4501), to_apply=%called_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.5 (param_0.99: s32[263], param_1.151: s32[8], param_2.95: s32[8], param_3.3107: token[]) -> s32[263] { - %param_3.3107 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.5 (param_0.99: s32[263], param_1.151: s32[8], param_2.95: s32[8], param_3.3112: token[]) -> s32[263] { + %param_3.3112 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.99 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.151 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.95 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2084.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.99), async_execution_thread="sparsecore", calls=%async_computation.23 - %copy.2084.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2084.cloned.1.call-start) - %scatter_offload_custom_fusion.50.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2084.cloned.1.call-done, %param_1.151, %param_2.95), async_execution_thread="sparsecore", calls=%async_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2082.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.99), async_execution_thread="sparsecore", calls=%async_computation.23 + %copy.2082.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2082.cloned.1.call-start) + %scatter_offload_custom_fusion.50.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2082.cloned.1.call-done, %param_1.151, %param_2.95), async_execution_thread="sparsecore", calls=%async_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.50.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.50.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.5 (param_0.100: s32[263], param_1.152: s32[8], param_2.96: s32[8], param_3.3106: token[]) -> s32[263] { - %param_3.3106 = token[] parameter(3) +%async_computation.5 (param_0.100: s32[263], param_1.152: s32[8], param_2.96: s32[8], param_3.3111: token[]) -> s32[263] { + %param_3.3111 = token[] parameter(3) %param_0.100 = s32[263]{0:T(512)} parameter(0) %param_1.152 = s32[8]{0:T(128)} parameter(1) %param_2.96 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.17.cloned.1 = s32[263]{0:T(512)} call(%param_0.100, %param_1.152, %param_2.96, %param_3.3106), to_apply=%called_computation.5, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.17.cloned.1 = s32[263]{0:T(512)} call(%param_0.100, %param_1.152, %param_2.96, %param_3.3111), to_apply=%called_computation.5, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.25 (param_0.4553: s32[263]) -> s32[263] { - %param_0.4553 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2085 = s32[263]{0:T(512)} copy(%param_0.4553), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.25 (param_0.4569: s32[263]) -> s32[263] { + %param_0.4569 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2083 = s32[263]{0:T(512)} copy(%param_0.4569), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.25 (param_0.4554: s32[263]) -> s32[263] { - %param_0.4554 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2086.cloned.1 = s32[263]{0:T(512)} call(%param_0.4554), to_apply=%called_computation.25 +%async_computation.25 (param_0.4570: s32[263]) -> s32[263] { + %param_0.4570 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2084.cloned.1 = s32[263]{0:T(512)} call(%param_0.4570), to_apply=%called_computation.25 }, execution_thread="sparsecore" %region_96.114 (scatter-add.48: s32[], scatter-add.49: s32[]) -> s32[] { %scatter-add.48 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.49 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1396 = s32[]{:T(128)S(7)} add(%scatter-add.48, %scatter-add.49), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1362 = s32[]{:T(128)S(7)} add(%scatter-add.48, %scatter-add.49), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.29.clone.clone (param_0.4555: s32[263], param_1.5343: s32[256], param_2.4512: s32[256]) -> s32[263] { - %param_0.4555 = s32[263]{0:T(512)} parameter(0) - %param_1.5343 = s32[256]{0:T(256)} parameter(1) - %reshape.3935 = s32[256]{0:T(256)} reshape(%param_1.5343), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %transpose.1112 = s32[256]{0:T(256)} transpose(%reshape.3935), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %param_2.4512 = s32[256]{0:T(256)} parameter(2) - %reshape.3936 = s32[256]{0:T(256)} reshape(%param_2.4512), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1113 = s32[256]{0:T(256)} transpose(%reshape.3936), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.243 = s32[263]{0:T(512)} scatter(%param_0.4555, %transpose.1112, %transpose.1113), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_96.114, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%fused_computation.29.clone.clone (param_0.4571: s32[263], param_1.5342: s32[256], param_2.4502: s32[256]) -> s32[263] { + %param_0.4571 = s32[263]{0:T(512)} parameter(0) + %param_1.5342 = s32[256]{0:T(256)} parameter(1) + %reshape.4063 = s32[256]{0:T(256)} reshape(%param_1.5342), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.869 = s32[256]{0:T(256)} transpose(%reshape.4063), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %param_2.4502 = s32[256]{0:T(256)} parameter(2) + %reshape.4064 = s32[256]{0:T(256)} reshape(%param_2.4502), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.870 = s32[256]{0:T(256)} transpose(%reshape.4064), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.243 = s32[263]{0:T(512)} scatter(%param_0.4571, %transpose.869, %transpose.870), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_96.114, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.26 (param_0.4556: s32[263], param_1.5344: s32[256], param_2.4513: s32[256]) -> s32[263] { - %param_0.4556 = s32[263]{0:T(512)} parameter(0) - %param_1.5344 = s32[256]{0:T(256)} parameter(1) - %param_2.4513 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.51 = s32[263]{0:T(512)} fusion(%param_0.4556, %param_1.5344, %param_2.4513), kind=kCustom, calls=%fused_computation.29.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.26 (param_0.4572: s32[263], param_1.5343: s32[256], param_2.4503: s32[256]) -> s32[263] { + %param_0.4572 = s32[263]{0:T(512)} parameter(0) + %param_1.5343 = s32[256]{0:T(256)} parameter(1) + %param_2.4503 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.51 = s32[263]{0:T(512)} fusion(%param_0.4572, %param_1.5343, %param_2.4503), kind=kCustom, calls=%fused_computation.29.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.26 (param_0.4557: s32[263], param_1.5345: s32[256], param_2.4514: s32[256]) -> s32[263] { - %param_0.4557 = s32[263]{0:T(512)} parameter(0) - %param_1.5345 = s32[256]{0:T(256)} parameter(1) - %param_2.4514 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.52.cloned.1 = s32[263]{0:T(512)} call(%param_0.4557, %param_1.5345, %param_2.4514), to_apply=%called_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%async_computation.26 (param_0.4573: s32[263], param_1.5344: s32[256], param_2.4504: s32[256]) -> s32[263] { + %param_0.4573 = s32[263]{0:T(512)} parameter(0) + %param_1.5344 = s32[256]{0:T(256)} parameter(1) + %param_2.4504 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.52.cloned.1 = s32[263]{0:T(512)} call(%param_0.4573, %param_1.5344, %param_2.4504), to_apply=%called_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.6 (param_0.102: s32[263], param_1.154: s32[256], param_2.98: s32[256], param_3.3093: token[]) -> s32[263] { - %param_3.3093 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.6 (param_0.102: s32[263], param_1.154: s32[256], param_2.98: s32[256], param_3.3098: token[]) -> s32[263] { + %param_3.3098 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.102 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.154 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.98 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2086.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.102), async_execution_thread="sparsecore", calls=%async_computation.25 - %copy.2086.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2086.cloned.1.call-start) - %scatter_offload_custom_fusion.52.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2086.cloned.1.call-done, %param_1.154, %param_2.98), async_execution_thread="sparsecore", calls=%async_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + %copy.2084.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.102), async_execution_thread="sparsecore", calls=%async_computation.25 + %copy.2084.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2084.cloned.1.call-start) + %scatter_offload_custom_fusion.52.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2084.cloned.1.call-done, %param_1.154, %param_2.98), async_execution_thread="sparsecore", calls=%async_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.52.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.52.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.6 (param_0.103: s32[263], param_1.155: s32[256], param_2.99: s32[256], param_3.3092: token[]) -> s32[263] { - %param_3.3092 = token[] parameter(3) +%async_computation.6 (param_0.103: s32[263], param_1.155: s32[256], param_2.99: s32[256], param_3.3097: token[]) -> s32[263] { + %param_3.3097 = token[] parameter(3) %param_0.103 = s32[263]{0:T(512)} parameter(0) %param_1.155 = s32[256]{0:T(256)} parameter(1) %param_2.99 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.20.cloned.1 = s32[263]{0:T(512)} call(%param_0.103, %param_1.155, %param_2.99, %param_3.3092), to_apply=%called_computation.6, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.20.cloned.1 = s32[263]{0:T(512)} call(%param_0.103, %param_1.155, %param_2.99, %param_3.3097), to_apply=%called_computation.6, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %region_102.120 (scatter-add.52: f32[], scatter-add.53: f32[]) -> f32[] { %scatter-add.52 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.53 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1399 = f32[]{:T(128)S(7)} add(%scatter-add.52, %scatter-add.53), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1365 = f32[]{:T(128)S(7)} add(%scatter-add.52, %scatter-add.53), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.30.clone.clone (param_0.4560: f32[9], param_1.5346: s32[256], param_2.4515: f32[256]) -> f32[9] { - %param_0.4560 = f32[9]{0:T(128)} parameter(0) - %param_1.5346 = s32[256]{0:T(256)} parameter(1) - %reshape.3937 = s32[256]{0:T(256)} reshape(%param_1.5346), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1114 = s32[256]{0:T(256)} transpose(%reshape.3937), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4515 = f32[256]{0:T(256)} parameter(2) - %reshape.3938 = f32[256]{0:T(256)} reshape(%param_2.4515), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1115 = f32[256]{0:T(256)} transpose(%reshape.3938), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.244 = f32[9]{0:T(128)} scatter(%param_0.4560, %transpose.1114, %transpose.1115), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_102.120, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%fused_computation.30.clone.clone (param_0.4576: f32[9], param_1.5345: s32[256], param_2.4505: f32[256]) -> f32[9] { + %param_0.4576 = f32[9]{0:T(128)} parameter(0) + %param_1.5345 = s32[256]{0:T(256)} parameter(1) + %reshape.4065 = s32[256]{0:T(256)} reshape(%param_1.5345), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.871 = s32[256]{0:T(256)} transpose(%reshape.4065), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4505 = f32[256]{0:T(256)} parameter(2) + %reshape.4066 = f32[256]{0:T(256)} reshape(%param_2.4505), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.872 = f32[256]{0:T(256)} transpose(%reshape.4066), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.244 = f32[9]{0:T(128)} scatter(%param_0.4576, %transpose.871, %transpose.872), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_102.120, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.28 (param_0.4561: f32[9], param_1.5347: s32[256], param_2.4516: f32[256]) -> f32[9] { - %param_0.4561 = f32[9]{0:T(128)} parameter(0) - %param_1.5347 = s32[256]{0:T(256)} parameter(1) - %param_2.4516 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.53 = f32[9]{0:T(128)} fusion(%param_0.4561, %param_1.5347, %param_2.4516), kind=kCustom, calls=%fused_computation.30.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.28 (param_0.4577: f32[9], param_1.5346: s32[256], param_2.4506: f32[256]) -> f32[9] { + %param_0.4577 = f32[9]{0:T(128)} parameter(0) + %param_1.5346 = s32[256]{0:T(256)} parameter(1) + %param_2.4506 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.53 = f32[9]{0:T(128)} fusion(%param_0.4577, %param_1.5346, %param_2.4506), kind=kCustom, calls=%fused_computation.30.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.28 (param_0.4562: f32[9], param_1.5348: s32[256], param_2.4517: f32[256]) -> f32[9] { - %param_0.4562 = f32[9]{0:T(128)} parameter(0) - %param_1.5348 = s32[256]{0:T(256)} parameter(1) - %param_2.4517 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.54.cloned.1 = f32[9]{0:T(128)} call(%param_0.4562, %param_1.5348, %param_2.4517), to_apply=%called_computation.28, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%async_computation.28 (param_0.4578: f32[9], param_1.5347: s32[256], param_2.4507: f32[256]) -> f32[9] { + %param_0.4578 = f32[9]{0:T(128)} parameter(0) + %param_1.5347 = s32[256]{0:T(256)} parameter(1) + %param_2.4507 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.54.cloned.1 = f32[9]{0:T(128)} call(%param_0.4578, %param_1.5347, %param_2.4507), to_apply=%called_computation.28, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.7 (param_0.105: f32[9], param_1.157: s32[256], param_2.101: f32[256], param_3.3101: token[]) -> f32[9] { - %param_3.3101 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.7 (param_0.105: f32[9], param_1.157: s32[256], param_2.101: f32[256], param_3.3106: token[]) -> f32[9] { + %param_3.3106 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.105 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.157 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.101 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -998,47 +998,47 @@ StackFrames ROOT %scatter_offload_custom_fusion.54.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.54.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.7 (param_0.106: f32[9], param_1.158: s32[256], param_2.102: f32[256], param_3.3100: token[]) -> f32[9] { - %param_3.3100 = token[] parameter(3) +%async_computation.7 (param_0.106: f32[9], param_1.158: s32[256], param_2.102: f32[256], param_3.3105: token[]) -> f32[9] { + %param_3.3105 = token[] parameter(3) %param_0.106 = f32[9]{0:T(128)} parameter(0) %param_1.158 = s32[256]{0:T(256)} parameter(1) %param_2.102 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.23.cloned.1 = f32[9]{0:T(128)} call(%param_0.106, %param_1.158, %param_2.102, %param_3.3100), to_apply=%called_computation.7, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.23.cloned.1 = f32[9]{0:T(128)} call(%param_0.106, %param_1.158, %param_2.102, %param_3.3105), to_apply=%called_computation.7, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %region_104.122 (scatter-add.83: s32[], scatter-add.84: s32[]) -> s32[] { %scatter-add.83 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.84 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1400 = s32[]{:T(128)S(7)} add(%scatter-add.83, %scatter-add.84), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1366 = s32[]{:T(128)S(7)} add(%scatter-add.83, %scatter-add.84), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.31.clone.clone (param_0.4565: s32[263], param_1.5349: s32[8], param_2.4518: s32[8]) -> s32[263] { - %param_0.4565 = s32[263]{0:T(512)} parameter(0) - %param_1.5349 = s32[8]{0:T(128)} parameter(1) - %reshape.3939 = s32[8]{0:T(128)} reshape(%param_1.5349), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %transpose.1116 = s32[8]{0:T(128)} transpose(%reshape.3939), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %param_2.4518 = s32[8]{0:T(128)} parameter(2) - %reshape.3940 = s32[8]{0:T(128)} reshape(%param_2.4518), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1117 = s32[8]{0:T(128)} transpose(%reshape.3940), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.245 = s32[263]{0:T(512)} scatter(%param_0.4565, %transpose.1116, %transpose.1117), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_104.122, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%fused_computation.31.clone.clone (param_0.4581: s32[263], param_1.5348: s32[8], param_2.4508: s32[8]) -> s32[263] { + %param_0.4581 = s32[263]{0:T(512)} parameter(0) + %param_1.5348 = s32[8]{0:T(128)} parameter(1) + %reshape.4067 = s32[8]{0:T(128)} reshape(%param_1.5348), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.873 = s32[8]{0:T(128)} transpose(%reshape.4067), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %param_2.4508 = s32[8]{0:T(128)} parameter(2) + %reshape.4068 = s32[8]{0:T(128)} reshape(%param_2.4508), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.874 = s32[8]{0:T(128)} transpose(%reshape.4068), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.245 = s32[263]{0:T(512)} scatter(%param_0.4581, %transpose.873, %transpose.874), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_104.122, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.30 (param_0.4566: s32[263], param_1.5350: s32[8], param_2.4519: s32[8]) -> s32[263] { - %param_0.4566 = s32[263]{0:T(512)} parameter(0) - %param_1.5350 = s32[8]{0:T(128)} parameter(1) - %param_2.4519 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.55 = s32[263]{0:T(512)} fusion(%param_0.4566, %param_1.5350, %param_2.4519), kind=kCustom, calls=%fused_computation.31.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.30 (param_0.4582: s32[263], param_1.5349: s32[8], param_2.4509: s32[8]) -> s32[263] { + %param_0.4582 = s32[263]{0:T(512)} parameter(0) + %param_1.5349 = s32[8]{0:T(128)} parameter(1) + %param_2.4509 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.55 = s32[263]{0:T(512)} fusion(%param_0.4582, %param_1.5349, %param_2.4509), kind=kCustom, calls=%fused_computation.31.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.30 (param_0.4567: s32[263], param_1.5351: s32[8], param_2.4520: s32[8]) -> s32[263] { - %param_0.4567 = s32[263]{0:T(512)} parameter(0) - %param_1.5351 = s32[8]{0:T(128)} parameter(1) - %param_2.4520 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.56.cloned.1 = s32[263]{0:T(512)} call(%param_0.4567, %param_1.5351, %param_2.4520), to_apply=%called_computation.30, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%async_computation.30 (param_0.4583: s32[263], param_1.5350: s32[8], param_2.4510: s32[8]) -> s32[263] { + %param_0.4583 = s32[263]{0:T(512)} parameter(0) + %param_1.5350 = s32[8]{0:T(128)} parameter(1) + %param_2.4510 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.56.cloned.1 = s32[263]{0:T(512)} call(%param_0.4583, %param_1.5350, %param_2.4510), to_apply=%called_computation.30, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.8 (param_0.108: s32[263], param_1.160: s32[8], param_2.104: s32[8], param_3.3109: token[]) -> s32[263] { - %param_3.3109 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.8 (param_0.108: s32[263], param_1.160: s32[8], param_2.104: s32[8], param_3.3114: token[]) -> s32[263] { + %param_3.3114 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.108 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.160 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.104 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -1046,47 +1046,47 @@ StackFrames ROOT %scatter_offload_custom_fusion.56.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.56.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.8 (param_0.109: s32[263], param_1.161: s32[8], param_2.105: s32[8], param_3.3108: token[]) -> s32[263] { - %param_3.3108 = token[] parameter(3) +%async_computation.8 (param_0.109: s32[263], param_1.161: s32[8], param_2.105: s32[8], param_3.3113: token[]) -> s32[263] { + %param_3.3113 = token[] parameter(3) %param_0.109 = s32[263]{0:T(512)} parameter(0) %param_1.161 = s32[8]{0:T(128)} parameter(1) %param_2.105 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.26.cloned.1 = s32[263]{0:T(512)} call(%param_0.109, %param_1.161, %param_2.105, %param_3.3108), to_apply=%called_computation.8, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.26.cloned.1 = s32[263]{0:T(512)} call(%param_0.109, %param_1.161, %param_2.105, %param_3.3113), to_apply=%called_computation.8, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%region_14.20 (scatter-add.0: s32[], scatter-add.1: s32[]) -> s32[] { +%region_13.19 (scatter-add.0: s32[], scatter-add.1: s32[]) -> s32[] { %scatter-add.0 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.1 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1312 = s32[]{:T(128)S(7)} add(%scatter-add.0, %scatter-add.1), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1279 = s32[]{:T(128)S(7)} add(%scatter-add.0, %scatter-add.1), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.17.clone.clone.clone (param_0.4570: s32[256], param_1.5352: s32[4096], param_2.4521: s32[4096]) -> s32[256] { - %param_0.4570 = s32[256]{0:T(256)} parameter(0) - %param_1.5352 = s32[4096]{0:T(1024)} parameter(1) - %reshape.3941 = s32[4096]{0:T(1024)} reshape(%param_1.5352), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} - %transpose.1118 = s32[4096]{0:T(1024)} transpose(%reshape.3941), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} - %param_2.4521 = s32[4096]{0:T(1024)} parameter(2) - %reshape.3942 = s32[4096]{0:T(1024)} reshape(%param_2.4521), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - %transpose.1119 = s32[4096]{0:T(1024)} transpose(%reshape.3942), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.246 = s32[256]{0:T(256)} scatter(%param_0.4570, %transpose.1118, %transpose.1119), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_14.20, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%fused_computation.17.clone.clone.clone (param_0.4586: s32[256], param_1.5351: s32[4096], param_2.4511: s32[4096]) -> s32[256] { + %param_0.4586 = s32[256]{0:T(256)} parameter(0) + %param_1.5351 = s32[4096]{0:T(1024)} parameter(1) + %reshape.4069 = s32[4096]{0:T(1024)} reshape(%param_1.5351), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %transpose.875 = s32[4096]{0:T(1024)} transpose(%reshape.4069), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %param_2.4511 = s32[4096]{0:T(1024)} parameter(2) + %reshape.4070 = s32[4096]{0:T(1024)} reshape(%param_2.4511), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.876 = s32[4096]{0:T(1024)} transpose(%reshape.4070), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.246 = s32[256]{0:T(256)} scatter(%param_0.4586, %transpose.875, %transpose.876), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_13.19, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.32 (param_0.4571: s32[256], param_1.5353: s32[4096], param_2.4522: s32[4096]) -> s32[256] { - %param_0.4571 = s32[256]{0:T(256)} parameter(0) - %param_1.5353 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4522 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.57 = s32[256]{0:T(256)} fusion(%param_0.4571, %param_1.5353, %param_2.4522), kind=kCustom, calls=%fused_computation.17.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.32 (param_0.4587: s32[256], param_1.5352: s32[4096], param_2.4512: s32[4096]) -> s32[256] { + %param_0.4587 = s32[256]{0:T(256)} parameter(0) + %param_1.5352 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4512 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.57 = s32[256]{0:T(256)} fusion(%param_0.4587, %param_1.5352, %param_2.4512), kind=kCustom, calls=%fused_computation.17.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.32 (param_0.4572: s32[256], param_1.5354: s32[4096], param_2.4523: s32[4096]) -> s32[256] { - %param_0.4572 = s32[256]{0:T(256)} parameter(0) - %param_1.5354 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4523 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.58.cloned.1 = s32[256]{0:T(256)} call(%param_0.4572, %param_1.5354, %param_2.4523), to_apply=%called_computation.32, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%async_computation.32 (param_0.4588: s32[256], param_1.5353: s32[4096], param_2.4513: s32[4096]) -> s32[256] { + %param_0.4588 = s32[256]{0:T(256)} parameter(0) + %param_1.5353 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4513 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.58.cloned.1 = s32[256]{0:T(256)} call(%param_0.4588, %param_1.5353, %param_2.4513), to_apply=%called_computation.32, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.9 (param_0.111: s32[256], param_1.163: s32[4096], param_2.107: s32[4096], param_3.3087: token[]) -> s32[256] { - %param_3.3087 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.9 (param_0.111: s32[256], param_1.163: s32[4096], param_2.107: s32[4096], param_3.3092: token[]) -> s32[256] { + %param_3.3092 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.111 = s32[256]{0:T(256)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.163 = s32[4096]{0:T(1024)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.107 = s32[4096]{0:T(1024)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -1094,907 +1094,907 @@ StackFrames ROOT %scatter_offload_custom_fusion.58.cloned.1.call-done = s32[256]{0:T(256)} async-done(%scatter_offload_custom_fusion.58.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.9 (param_0.112: s32[256], param_1.164: s32[4096], param_2.108: s32[4096], param_3.3086: token[]) -> s32[256] { - %param_3.3086 = token[] parameter(3) +%async_computation.9 (param_0.112: s32[256], param_1.164: s32[4096], param_2.108: s32[4096], param_3.3091: token[]) -> s32[256] { + %param_3.3091 = token[] parameter(3) %param_0.112 = s32[256]{0:T(256)} parameter(0) %param_1.164 = s32[4096]{0:T(1024)} parameter(1) %param_2.108 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.29.cloned.1 = s32[256]{0:T(256)} call(%param_0.112, %param_1.164, %param_2.108, %param_3.3086), to_apply=%called_computation.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.29.cloned.1 = s32[256]{0:T(256)} call(%param_0.112, %param_1.164, %param_2.108, %param_3.3091), to_apply=%called_computation.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.33 (param_0.4573: s32[263]) -> s32[263] { - %param_0.4573 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2093 = s32[263]{0:T(512)} copy(%param_0.4573), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.33 (param_0.4589: s32[263]) -> s32[263] { + %param_0.4589 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2091 = s32[263]{0:T(512)} copy(%param_0.4589), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.33 (param_0.4574: s32[263]) -> s32[263] { - %param_0.4574 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2094.cloned.1 = s32[263]{0:T(512)} call(%param_0.4574), to_apply=%called_computation.33 +%async_computation.33 (param_0.4590: s32[263]) -> s32[263] { + %param_0.4590 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2092.cloned.1 = s32[263]{0:T(512)} call(%param_0.4590), to_apply=%called_computation.33 }, execution_thread="sparsecore" -%region_20.26.clone.1 (scatter-add.141: s32[], scatter-add.142: s32[]) -> s32[] { +%region_19.25.clone.1 (scatter-add.141: s32[], scatter-add.142: s32[]) -> s32[] { %scatter-add.141 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.142 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2463 = s32[]{:T(128)S(7)} add(%scatter-add.141, %scatter-add.142), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2421 = s32[]{:T(128)S(7)} add(%scatter-add.141, %scatter-add.142), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.18.clone.clone.clone (param_0.4575: s32[263], param_1.5355: s32[256], param_2.4524: s32[256]) -> s32[263] { - %param_0.4575 = s32[263]{0:T(512)} parameter(0) - %param_1.5355 = s32[256]{0:T(256)} parameter(1) - %reshape.3943 = s32[256]{0:T(256)} reshape(%param_1.5355), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1120 = s32[256]{0:T(256)} transpose(%reshape.3943), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4524 = s32[256]{0:T(256)} parameter(2) - %reshape.3944 = s32[256]{0:T(256)} reshape(%param_2.4524) - %transpose.1121 = s32[256]{0:T(256)} transpose(%reshape.3944), dimensions={0} - ROOT %scatter-add.247 = s32[263]{0:T(512)} scatter(%param_0.4575, %transpose.1120, %transpose.1121), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_20.26.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.18.clone.clone.clone (param_0.4591: s32[263], param_1.5354: s32[256], param_2.4514: s32[256]) -> s32[263] { + %param_0.4591 = s32[263]{0:T(512)} parameter(0) + %param_1.5354 = s32[256]{0:T(256)} parameter(1) + %reshape.4071 = s32[256]{0:T(256)} reshape(%param_1.5354), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.877 = s32[256]{0:T(256)} transpose(%reshape.4071), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4514 = s32[256]{0:T(256)} parameter(2) + %reshape.4072 = s32[256]{0:T(256)} reshape(%param_2.4514) + %transpose.878 = s32[256]{0:T(256)} transpose(%reshape.4072), dimensions={0} + ROOT %scatter-add.247 = s32[263]{0:T(512)} scatter(%param_0.4591, %transpose.877, %transpose.878), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_19.25.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.34 (param_0.4576: s32[263], param_1.5356: s32[256], param_2.4525: s32[256]) -> s32[263] { - %param_0.4576 = s32[263]{0:T(512)} parameter(0) - %param_1.5356 = s32[256]{0:T(256)} parameter(1) - %param_2.4525 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.59 = s32[263]{0:T(512)} fusion(%param_0.4576, %param_1.5356, %param_2.4525), kind=kCustom, calls=%fused_computation.18.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.34 (param_0.4592: s32[263], param_1.5355: s32[256], param_2.4515: s32[256]) -> s32[263] { + %param_0.4592 = s32[263]{0:T(512)} parameter(0) + %param_1.5355 = s32[256]{0:T(256)} parameter(1) + %param_2.4515 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.59 = s32[263]{0:T(512)} fusion(%param_0.4592, %param_1.5355, %param_2.4515), kind=kCustom, calls=%fused_computation.18.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.34 (param_0.4577: s32[263], param_1.5357: s32[256], param_2.4526: s32[256]) -> s32[263] { - %param_0.4577 = s32[263]{0:T(512)} parameter(0) - %param_1.5357 = s32[256]{0:T(256)} parameter(1) - %param_2.4526 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.60.cloned.1 = s32[263]{0:T(512)} call(%param_0.4577, %param_1.5357, %param_2.4526), to_apply=%called_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.34 (param_0.4593: s32[263], param_1.5356: s32[256], param_2.4516: s32[256]) -> s32[263] { + %param_0.4593 = s32[263]{0:T(512)} parameter(0) + %param_1.5356 = s32[256]{0:T(256)} parameter(1) + %param_2.4516 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.60.cloned.1 = s32[263]{0:T(512)} call(%param_0.4593, %param_1.5356, %param_2.4516), to_apply=%called_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.10 (param_0.114: s32[263], param_1.166: s32[256], param_2.110: s32[256], param_3.3089: token[]) -> s32[263] { - %param_3.3089 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.10 (param_0.114: s32[263], param_1.166: s32[256], param_2.110: s32[256], param_3.3094: token[]) -> s32[263] { + %param_3.3094 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.114 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.166 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.110 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2094.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.114), async_execution_thread="sparsecore", calls=%async_computation.33 - %copy.2094.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2094.cloned.1.call-start) - %scatter_offload_custom_fusion.60.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2094.cloned.1.call-done, %param_1.166, %param_2.110), async_execution_thread="sparsecore", calls=%async_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2092.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.114), async_execution_thread="sparsecore", calls=%async_computation.33 + %copy.2092.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2092.cloned.1.call-start) + %scatter_offload_custom_fusion.60.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2092.cloned.1.call-done, %param_1.166, %param_2.110), async_execution_thread="sparsecore", calls=%async_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.60.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.60.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.10 (param_0.115: s32[263], param_1.167: s32[256], param_2.111: s32[256], param_3.3088: token[]) -> s32[263] { - %param_3.3088 = token[] parameter(3) +%async_computation.10 (param_0.115: s32[263], param_1.167: s32[256], param_2.111: s32[256], param_3.3093: token[]) -> s32[263] { + %param_3.3093 = token[] parameter(3) %param_0.115 = s32[263]{0:T(512)} parameter(0) %param_1.167 = s32[256]{0:T(256)} parameter(1) %param_2.111 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.32.cloned.1 = s32[263]{0:T(512)} call(%param_0.115, %param_1.167, %param_2.111, %param_3.3088), to_apply=%called_computation.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.32.cloned.1 = s32[263]{0:T(512)} call(%param_0.115, %param_1.167, %param_2.111, %param_3.3093), to_apply=%called_computation.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.35 (param_0.4578: f32[9]) -> f32[9] { - %param_0.4578 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2095 = f32[9]{0:T(128)} copy(%param_0.4578), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.35 (param_0.4594: f32[9]) -> f32[9] { + %param_0.4594 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2093 = f32[9]{0:T(128)} copy(%param_0.4594), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.35 (param_0.4579: f32[9]) -> f32[9] { - %param_0.4579 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2096.cloned.1 = f32[9]{0:T(128)} call(%param_0.4579), to_apply=%called_computation.35 +%async_computation.35 (param_0.4595: f32[9]) -> f32[9] { + %param_0.4595 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2094.cloned.1 = f32[9]{0:T(128)} call(%param_0.4595), to_apply=%called_computation.35 }, execution_thread="sparsecore" -%region_26.33.clone.1 (scatter-add.145: f32[], scatter-add.146: f32[]) -> f32[] { +%region_25.32.clone.1 (scatter-add.145: f32[], scatter-add.146: f32[]) -> f32[] { %scatter-add.145 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.146 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2465 = f32[]{:T(128)S(7)} add(%scatter-add.145, %scatter-add.146), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2423 = f32[]{:T(128)S(7)} add(%scatter-add.145, %scatter-add.146), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.19.clone.clone.clone (param_0.4580: f32[9], param_1.5358: s32[256], param_2.4527: f32[256]) -> f32[9] { - %param_0.4580 = f32[9]{0:T(128)} parameter(0) - %param_1.5358 = s32[256]{0:T(256)} parameter(1) - %reshape.3945 = s32[256]{0:T(256)} reshape(%param_1.5358), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1122 = s32[256]{0:T(256)} transpose(%reshape.3945), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4527 = f32[256]{0:T(256)} parameter(2) - %reshape.3946 = f32[256]{0:T(256)} reshape(%param_2.4527) - %transpose.1123 = f32[256]{0:T(256)} transpose(%reshape.3946), dimensions={0} - ROOT %scatter-add.248 = f32[9]{0:T(128)} scatter(%param_0.4580, %transpose.1122, %transpose.1123), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_26.33.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.19.clone.clone.clone (param_0.4596: f32[9], param_1.5357: s32[256], param_2.4517: f32[256]) -> f32[9] { + %param_0.4596 = f32[9]{0:T(128)} parameter(0) + %param_1.5357 = s32[256]{0:T(256)} parameter(1) + %reshape.4073 = s32[256]{0:T(256)} reshape(%param_1.5357), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.879 = s32[256]{0:T(256)} transpose(%reshape.4073), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4517 = f32[256]{0:T(256)} parameter(2) + %reshape.4074 = f32[256]{0:T(256)} reshape(%param_2.4517) + %transpose.880 = f32[256]{0:T(256)} transpose(%reshape.4074), dimensions={0} + ROOT %scatter-add.248 = f32[9]{0:T(128)} scatter(%param_0.4596, %transpose.879, %transpose.880), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_25.32.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.36 (param_0.4581: f32[9], param_1.5359: s32[256], param_2.4528: f32[256]) -> f32[9] { - %param_0.4581 = f32[9]{0:T(128)} parameter(0) - %param_1.5359 = s32[256]{0:T(256)} parameter(1) - %param_2.4528 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.61 = f32[9]{0:T(128)} fusion(%param_0.4581, %param_1.5359, %param_2.4528), kind=kCustom, calls=%fused_computation.19.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.36 (param_0.4597: f32[9], param_1.5358: s32[256], param_2.4518: f32[256]) -> f32[9] { + %param_0.4597 = f32[9]{0:T(128)} parameter(0) + %param_1.5358 = s32[256]{0:T(256)} parameter(1) + %param_2.4518 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.61 = f32[9]{0:T(128)} fusion(%param_0.4597, %param_1.5358, %param_2.4518), kind=kCustom, calls=%fused_computation.19.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.36 (param_0.4582: f32[9], param_1.5360: s32[256], param_2.4529: f32[256]) -> f32[9] { - %param_0.4582 = f32[9]{0:T(128)} parameter(0) - %param_1.5360 = s32[256]{0:T(256)} parameter(1) - %param_2.4529 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.62.cloned.1 = f32[9]{0:T(128)} call(%param_0.4582, %param_1.5360, %param_2.4529), to_apply=%called_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.36 (param_0.4598: f32[9], param_1.5359: s32[256], param_2.4519: f32[256]) -> f32[9] { + %param_0.4598 = f32[9]{0:T(128)} parameter(0) + %param_1.5359 = s32[256]{0:T(256)} parameter(1) + %param_2.4519 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.62.cloned.1 = f32[9]{0:T(128)} call(%param_0.4598, %param_1.5359, %param_2.4519), to_apply=%called_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.11 (param_0.117: f32[9], param_1.169: s32[256], param_2.113: f32[256], param_3.3095: token[]) -> f32[9] { - %param_3.3095 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.11 (param_0.117: f32[9], param_1.169: s32[256], param_2.113: f32[256], param_3.3100: token[]) -> f32[9] { + %param_3.3100 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.117 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.169 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.113 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2096.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.117), async_execution_thread="sparsecore", calls=%async_computation.35 - %copy.2096.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2096.cloned.1.call-start) - %scatter_offload_custom_fusion.62.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2096.cloned.1.call-done, %param_1.169, %param_2.113), async_execution_thread="sparsecore", calls=%async_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2094.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.117), async_execution_thread="sparsecore", calls=%async_computation.35 + %copy.2094.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2094.cloned.1.call-start) + %scatter_offload_custom_fusion.62.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2094.cloned.1.call-done, %param_1.169, %param_2.113), async_execution_thread="sparsecore", calls=%async_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.62.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.62.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.11 (param_0.118: f32[9], param_1.170: s32[256], param_2.114: f32[256], param_3.3094: token[]) -> f32[9] { - %param_3.3094 = token[] parameter(3) +%async_computation.11 (param_0.118: f32[9], param_1.170: s32[256], param_2.114: f32[256], param_3.3099: token[]) -> f32[9] { + %param_3.3099 = token[] parameter(3) %param_0.118 = f32[9]{0:T(128)} parameter(0) %param_1.170 = s32[256]{0:T(256)} parameter(1) %param_2.114 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.35.cloned.1 = f32[9]{0:T(128)} call(%param_0.118, %param_1.170, %param_2.114, %param_3.3094), to_apply=%called_computation.11, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.35.cloned.1 = f32[9]{0:T(128)} call(%param_0.118, %param_1.170, %param_2.114, %param_3.3099), to_apply=%called_computation.11, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.37 (param_0.4583: s32[263]) -> s32[263] { - %param_0.4583 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2097 = s32[263]{0:T(512)} copy(%param_0.4583), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.37 (param_0.4599: s32[263]) -> s32[263] { + %param_0.4599 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2095 = s32[263]{0:T(512)} copy(%param_0.4599), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.37 (param_0.4584: s32[263]) -> s32[263] { - %param_0.4584 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2098.cloned.1 = s32[263]{0:T(512)} call(%param_0.4584), to_apply=%called_computation.37 +%async_computation.37 (param_0.4600: s32[263]) -> s32[263] { + %param_0.4600 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2096.cloned.1 = s32[263]{0:T(512)} call(%param_0.4600), to_apply=%called_computation.37 }, execution_thread="sparsecore" -%region_28.35.clone.1 (scatter-add.149: s32[], scatter-add.150: s32[]) -> s32[] { +%region_27.34.clone.1 (scatter-add.149: s32[], scatter-add.150: s32[]) -> s32[] { %scatter-add.149 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.150 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2467 = s32[]{:T(128)S(7)} add(%scatter-add.149, %scatter-add.150), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2425 = s32[]{:T(128)S(7)} add(%scatter-add.149, %scatter-add.150), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.20.clone.clone.clone (param_0.4585: s32[263], param_1.5361: s32[8], param_2.4530: s32[8]) -> s32[263] { - %param_0.4585 = s32[263]{0:T(512)} parameter(0) - %param_1.5361 = s32[8]{0:T(128)} parameter(1) - %reshape.3947 = s32[8]{0:T(128)} reshape(%param_1.5361), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1124 = s32[8]{0:T(128)} transpose(%reshape.3947), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4530 = s32[8]{0:T(128)} parameter(2) - %reshape.3948 = s32[8]{0:T(128)} reshape(%param_2.4530) - %transpose.1125 = s32[8]{0:T(128)} transpose(%reshape.3948), dimensions={0} - ROOT %scatter-add.249 = s32[263]{0:T(512)} scatter(%param_0.4585, %transpose.1124, %transpose.1125), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_28.35.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.20.clone.clone.clone (param_0.4601: s32[263], param_1.5360: s32[8], param_2.4520: s32[8]) -> s32[263] { + %param_0.4601 = s32[263]{0:T(512)} parameter(0) + %param_1.5360 = s32[8]{0:T(128)} parameter(1) + %reshape.4075 = s32[8]{0:T(128)} reshape(%param_1.5360), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.881 = s32[8]{0:T(128)} transpose(%reshape.4075), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4520 = s32[8]{0:T(128)} parameter(2) + %reshape.4076 = s32[8]{0:T(128)} reshape(%param_2.4520) + %transpose.882 = s32[8]{0:T(128)} transpose(%reshape.4076), dimensions={0} + ROOT %scatter-add.249 = s32[263]{0:T(512)} scatter(%param_0.4601, %transpose.881, %transpose.882), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_27.34.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.38 (param_0.4586: s32[263], param_1.5362: s32[8], param_2.4531: s32[8]) -> s32[263] { - %param_0.4586 = s32[263]{0:T(512)} parameter(0) - %param_1.5362 = s32[8]{0:T(128)} parameter(1) - %param_2.4531 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.63 = s32[263]{0:T(512)} fusion(%param_0.4586, %param_1.5362, %param_2.4531), kind=kCustom, calls=%fused_computation.20.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.38 (param_0.4602: s32[263], param_1.5361: s32[8], param_2.4521: s32[8]) -> s32[263] { + %param_0.4602 = s32[263]{0:T(512)} parameter(0) + %param_1.5361 = s32[8]{0:T(128)} parameter(1) + %param_2.4521 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.63 = s32[263]{0:T(512)} fusion(%param_0.4602, %param_1.5361, %param_2.4521), kind=kCustom, calls=%fused_computation.20.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.38 (param_0.4587: s32[263], param_1.5363: s32[8], param_2.4532: s32[8]) -> s32[263] { - %param_0.4587 = s32[263]{0:T(512)} parameter(0) - %param_1.5363 = s32[8]{0:T(128)} parameter(1) - %param_2.4532 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.64.cloned.1 = s32[263]{0:T(512)} call(%param_0.4587, %param_1.5363, %param_2.4532), to_apply=%called_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.38 (param_0.4603: s32[263], param_1.5362: s32[8], param_2.4522: s32[8]) -> s32[263] { + %param_0.4603 = s32[263]{0:T(512)} parameter(0) + %param_1.5362 = s32[8]{0:T(128)} parameter(1) + %param_2.4522 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.64.cloned.1 = s32[263]{0:T(512)} call(%param_0.4603, %param_1.5362, %param_2.4522), to_apply=%called_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.12 (param_0.120: s32[263], param_1.172: s32[8], param_2.116: s32[8], param_3.3103: token[]) -> s32[263] { - %param_3.3103 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.12 (param_0.120: s32[263], param_1.172: s32[8], param_2.116: s32[8], param_3.3108: token[]) -> s32[263] { + %param_3.3108 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.120 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.172 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.116 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} - %copy.2098.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.120), async_execution_thread="sparsecore", calls=%async_computation.37 - %copy.2098.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2098.cloned.1.call-start) - %scatter_offload_custom_fusion.64.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2098.cloned.1.call-done, %param_1.172, %param_2.116), async_execution_thread="sparsecore", calls=%async_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %copy.2096.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.120), async_execution_thread="sparsecore", calls=%async_computation.37 + %copy.2096.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2096.cloned.1.call-start) + %scatter_offload_custom_fusion.64.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2096.cloned.1.call-done, %param_1.172, %param_2.116), async_execution_thread="sparsecore", calls=%async_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} ROOT %scatter_offload_custom_fusion.64.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.64.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.12 (param_0.121: s32[263], param_1.173: s32[8], param_2.117: s32[8], param_3.3102: token[]) -> s32[263] { - %param_3.3102 = token[] parameter(3) +%async_computation.12 (param_0.121: s32[263], param_1.173: s32[8], param_2.117: s32[8], param_3.3107: token[]) -> s32[263] { + %param_3.3107 = token[] parameter(3) %param_0.121 = s32[263]{0:T(512)} parameter(0) %param_1.173 = s32[8]{0:T(128)} parameter(1) %param_2.117 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.38.cloned.1 = s32[263]{0:T(512)} call(%param_0.121, %param_1.173, %param_2.117, %param_3.3102), to_apply=%called_computation.12, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.38.cloned.1 = s32[263]{0:T(512)} call(%param_0.121, %param_1.173, %param_2.117, %param_3.3107), to_apply=%called_computation.12, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%region_154.179 (reduce_sum.431: f32[], reduce_sum.254: f32[]) -> f32[] { - %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.254 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.258 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.254), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_154.179 (reduce_sum.502: f32[], reduce_sum.336: f32[]) -> f32[] { + %reduce_sum.502 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.336 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.337 = f32[]{:T(128)} add(%reduce_sum.502, %reduce_sum.336), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.467 (param_0.4170: f32[3,1536,128,192]) -> f32[] { - %param_0.4170 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(0) - %bitcast.672 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_0.4170), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.564 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%bitcast.672, %bitcast.672), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5105 = f32[]{:T(128)} constant(0) - ROOT %reduce.669 = f32[]{:T(128)} reduce(%square.564, %constant.5105), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.466 (param_0.4188: f32[3,1536,128,192]) -> f32[] { + %param_0.4188 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(0) + %bitcast.654 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_0.4188), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.564 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%bitcast.654, %bitcast.654), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5049 = f32[]{:T(128)} constant(0) + ROOT %reduce.612 = f32[]{:T(128)} reduce(%square.564, %constant.5049), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.468 (param_0.1421: f32[1536,3,128,192]) -> bf16[3,1536,128,192] { - %param_0.1421 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) - %copy.1550 = bf16[1536,3,128,192]{2,0,3,1:T(8,128)(2,1)} copy(%param_0.1421), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} - ROOT %bitcast.673 = bf16[3,1536,128,192]{2,1,3,0:T(8,128)(2,1)} bitcast(%copy.1550), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} +%fused_computation.467 (param_0.1439: f32[1536,3,128,192]) -> bf16[3,1536,128,192] { + %param_0.1439 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %copy.1548 = bf16[1536,3,128,192]{2,0,3,1:T(8,128)(2,1)} copy(%param_0.1439), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} + ROOT %bitcast.655 = bf16[3,1536,128,192]{2,1,3,0:T(8,128)(2,1)} bitcast(%copy.1548), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} } -%region_221.246 (reduce_sum.893: f32[], reduce_sum.603: f32[]) -> f32[] { - %reduce_sum.893 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.603 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.604 = f32[]{:T(128)} add(%reduce_sum.893, %reduce_sum.603), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_221.246 (reduce_sum.964: f32[], reduce_sum.965: f32[]) -> f32[] { + %reduce_sum.964 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.965 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.534 = f32[]{:T(128)} add(%reduce_sum.964, %reduce_sum.965), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_187.212 (reduce_sum.655: f32[], reduce_sum.449: f32[]) -> f32[] { - %reduce_sum.655 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.449 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.450 = f32[]{:T(128)} add(%reduce_sum.655, %reduce_sum.449), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_187.212 (reduce_sum.726: f32[], reduce_sum.727: f32[]) -> f32[] { + %reduce_sum.726 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.727 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.452 = f32[]{:T(128)} add(%reduce_sum.726, %reduce_sum.727), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.469 (param_0.4140: f32[1536,3,128,192], param_1.5025: f32[], param_2.4298: f32[], param_3.2951: f32[], param_4.2203: f32[1536,3,128,192], param_5.2006: f32[], param_6.1443: f32[3,1536,128,192], param_7.1124: pred[], param_8.889: f32[1536,3,128,192]) -> (f32[], f32[1536,3,128,192], f32[1536,3,128,192], f32[1536,3,128,192], f32[]) { - %param_0.4140 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) - %param_3.2951 = f32[]{:T(128)S(6)} parameter(3) - %mul.4727.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.2951), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.468 (param_0.4158: f32[1536,3,128,192], param_1.5026: f32[], param_2.4288: f32[], param_3.2956: f32[], param_4.2203: f32[1536,3,128,192], param_5.2003: f32[], param_6.1444: f32[3,1536,128,192], param_7.1124: pred[], param_8.889: f32[1536,3,128,192]) -> (f32[], f32[1536,3,128,192], f32[1536,3,128,192], f32[1536,3,128,192], f32[]) { + %param_0.4158 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %param_3.2956 = f32[]{:T(128)S(6)} parameter(3) + %mul.5451.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.2956), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1124 = pred[]{:T(512)S(6)} parameter(7) %select_n.2165.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} broadcast(%param_7.1124), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1443 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(6) - %bitcast.1374.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_6.1443), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %param_5.2006 = f32[]{:T(128)} parameter(5) - %div.2575.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.2006), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2574.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1374.clone.1, %div.2575.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2164.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.2165.clone.1, %bitcast.1374.clone.1, %div.2574.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4864.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4279.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4864.clone.1), dimensions={}, metadata={op_name="broadcast.334"} - %mul.4733.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2164.clone.1, %broadcast.4279.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1444 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(6) + %bitcast.1356.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_6.1444), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_5.2003 = f32[]{:T(128)} parameter(5) + %div.2573.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.2003), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2572.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1356.clone.1, %div.2573.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2164.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.2165.clone.1, %bitcast.1356.clone.1, %div.2572.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4808.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4140.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4808.clone.1), dimensions={}, metadata={op_name="broadcast.334"} + %mul.5457.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2164.clone.1, %broadcast.4140.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.889 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(8) - %constant.4868.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4734.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4868.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4732.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.889, %mul.4734.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3443.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4733.clone.1, %mul.4732.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4298 = f32[]{:T(128)S(6)} parameter(2) - %div.2571.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.4298), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4812.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5458.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4812.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5456.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.889, %mul.5458.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3401.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.5457.clone.1, %mul.5456.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4288 = f32[]{:T(128)S(6)} parameter(2) + %div.2569.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.4288), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.399.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2164.clone.1, %select_n.2164.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4867.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4731.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4867.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4729.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.4731.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4811.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5455.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4811.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5453.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.5455.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2203 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(4) - %constant.4866.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4730.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4866.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4728.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.2203, %mul.4730.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3442.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4729.clone.1, %mul.4728.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5025 = f32[]{:T(128)S(6)} parameter(1) - %div.2570.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5025), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2569.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3442.clone.1, %div.2570.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.157.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} sqrt(%div.2569.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4865.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3441.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4865.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3440.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3441.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1293.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2571.clone.1, %add.3440.clone.1), metadata={op_name="multiply.290"} - %div.2568.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3443.clone.1, %multiply.1293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4726.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4140, %broadcast.4279.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3439.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2568.clone.1, %mul.4726.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4725.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.4727.clone.1, %add.3439.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3438.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4140, %mul.4725.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3438.clone.1, %add.3438.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5075 = f32[]{:T(128)} constant(0) - %reduce.670 = f32[]{:T(128)} reduce(%square.565, %constant.5075), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.671.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.5075), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.660 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.670, %add.3438.clone.1, %add.3442.clone.1, %add.3443.clone.1, %reduce.671.clone.1) -} - -%region_160.185 (reduce_sum.473: f32[], reduce_sum.293: f32[]) -> f32[] { - %reduce_sum.473 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.293 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.300 = f32[]{:T(128)} add(%reduce_sum.473, %reduce_sum.293), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_158.183 (reduce_sum.459: f32[], reduce_sum.460: f32[]) -> f32[] { - %reduce_sum.459 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.460 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.461 = f32[]{:T(128)} add(%reduce_sum.459, %reduce_sum.460), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.495 (param_0.4166: bf16[256,512,512], param_1.5047: bf16[256,512,512]) -> (f32[], f32[]) { - %param_0.4166 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) - %broadcast_in_dim.1358 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4166), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.695 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.570 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.695, %bitcast.695), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5101 = f32[]{:T(128)} constant(0) - %reduce.672 = f32[]{:T(128)} reduce(%square.570, %constant.5101), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.5047 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(1) - %broadcast_in_dim.1366.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5047), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.703.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1366.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.576.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.703.clone.1, %bitcast.703.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.674.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.5101), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.767 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.672, %reduce.674.clone.1) -} - -%region_159.184 (reduce_sum.466: f32[], reduce_sum.279: f32[]) -> f32[] { - %reduce_sum.466 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.279 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.286 = f32[]{:T(128)} add(%reduce_sum.466, %reduce_sum.279), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.497 (param_0.4165: bf16[256,512,512]) -> f32[] { - %param_0.4165 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) - %broadcast_in_dim.1362 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4165), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.699 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1362), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.573 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.699, %bitcast.699), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5100 = f32[]{:T(128)} constant(0) - ROOT %reduce.673 = f32[]{:T(128)} reduce(%square.573, %constant.5100), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} -} - -%region_227.252 (reduce_sum.935: f32[], reduce_sum.631: f32[]) -> f32[] { - %reduce_sum.935 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.631 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.632 = f32[]{:T(128)} add(%reduce_sum.935, %reduce_sum.631), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_193.218 (reduce_sum.697: f32[], reduce_sum.471: f32[]) -> f32[] { - %reduce_sum.697 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.697, %reduce_sum.471), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.515 (param_0.4134: f32[], param_1.5019: f32[256,1,512,512], param_2.4292: f32[], param_3.2945: f32[256,1,512,512], param_4.2197: f32[], param_5.2000: bf16[256,512,512], param_6.1437: pred[], param_7.1118: f32[], param_8.883: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %constant.4810.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5454.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4810.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5452.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.2203, %mul.5454.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3400.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.5453.clone.1, %mul.5452.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5026 = f32[]{:T(128)S(6)} parameter(1) + %div.2568.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5026), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2567.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3400.clone.1, %div.2568.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.157.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} sqrt(%div.2567.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4809.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3399.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4809.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3398.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3399.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1086.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2569.clone.1, %add.3398.clone.1), metadata={op_name="multiply.263"} + %div.2566.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3401.clone.1, %multiply.1086.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5450.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4158, %broadcast.4140.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3397.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2566.clone.1, %mul.5450.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5449.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.5451.clone.1, %add.3397.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3396.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4158, %mul.5449.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3396.clone.1, %add.3396.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5019 = f32[]{:T(128)} constant(0) + %reduce.613 = f32[]{:T(128)} reduce(%square.565, %constant.5019), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.614.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.5019), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.660 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.613, %add.3396.clone.1, %add.3400.clone.1, %add.3401.clone.1, %reduce.614.clone.1) +} + +%region_160.185 (reduce_sum.544: f32[], reduce_sum.364: f32[]) -> f32[] { + %reduce_sum.544 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.364 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.365 = f32[]{:T(128)} add(%reduce_sum.544, %reduce_sum.364), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_158.183 (reduce_sum.530: f32[], reduce_sum.352: f32[]) -> f32[] { + %reduce_sum.530 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.352 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.357 = f32[]{:T(128)} add(%reduce_sum.530, %reduce_sum.352), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.494 (param_0.4184: bf16[256,512,512], param_1.5048: bf16[256,512,512]) -> (f32[], f32[]) { + %param_0.4184 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1383 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4184), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.677 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1383), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.570 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.677, %bitcast.677), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5045 = f32[]{:T(128)} constant(0) + %reduce.615 = f32[]{:T(128)} reduce(%square.570, %constant.5045), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.5048 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(1) + %broadcast_in_dim.1391.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5048), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.685.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1391.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.576.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.685.clone.1, %bitcast.685.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.617.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.5045), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.767 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.615, %reduce.617.clone.1) +} + +%region_159.184 (reduce_sum.537: f32[], reduce_sum.358: f32[]) -> f32[] { + %reduce_sum.537 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.358 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.359 = f32[]{:T(128)} add(%reduce_sum.537, %reduce_sum.358), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.496 (param_0.4183: bf16[256,512,512]) -> f32[] { + %param_0.4183 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1387 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4183), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.681 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1387), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.573 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.681, %bitcast.681), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5044 = f32[]{:T(128)} constant(0) + ROOT %reduce.616 = f32[]{:T(128)} reduce(%square.573, %constant.5044), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_227.252 (reduce_sum.1006: f32[], reduce_sum.1007: f32[]) -> f32[] { + %reduce_sum.1006 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1007 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.548 = f32[]{:T(128)} add(%reduce_sum.1006, %reduce_sum.1007), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_193.218 (reduce_sum.768: f32[], reduce_sum.769: f32[]) -> f32[] { + %reduce_sum.768 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.769 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.768, %reduce_sum.769), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.514 (param_0.4152: f32[], param_1.5020: f32[256,1,512,512], param_2.4282: f32[], param_3.2950: f32[256,1,512,512], param_4.2197: f32[], param_5.1997: bf16[256,512,512], param_6.1438: pred[], param_7.1118: f32[], param_8.883: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { %param_8.883 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) - %bitcast.1359.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.883), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %bitcast.1341.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.883), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} %param_7.1118 = f32[]{:T(128)S(6)} parameter(7) - %mul.4676.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1118), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_6.1437 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2147.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1437), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_5.2000 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1572.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2000), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1361.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1572.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %mul.5400.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1118), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1438 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2147.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1438), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1997 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1597.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1343.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1597.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2197 = f32[]{:T(128)} parameter(4) - %div.2533.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2197), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2532.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1361.clone.1, %div.2533.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2146.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2147.clone.1, %bitcast.1361.clone.1, %div.2532.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4834.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4259.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4834.clone.1), dimensions={}, metadata={op_name="broadcast.2345"} - %mul.4678.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2146.clone.1, %broadcast.4259.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_3.2945 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) - %bitcast.1360.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2945), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} - %constant.4833.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4258.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4833.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4677.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1360.clone.1, %broadcast.4258.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3408.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4678.clone.1, %mul.4677.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4292 = f32[]{:T(128)S(6)} parameter(2) - %div.2531.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4292), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2531.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2197), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1343.clone.1, %div.2531.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2146.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2147.clone.1, %bitcast.1343.clone.1, %div.2530.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4778.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4120.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4778.clone.1), dimensions={}, metadata={op_name="broadcast.2222"} + %mul.5402.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2146.clone.1, %broadcast.4120.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2950 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1342.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2950), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.4777.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4119.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4777.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5401.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1342.clone.1, %broadcast.4119.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3366.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5402.clone.1, %mul.5401.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4282 = f32[]{:T(128)S(6)} parameter(2) + %div.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4282), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2146.clone.1, %select_n.2146.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4832.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4261.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4832.clone.1), dimensions={}, metadata={op_name="broadcast.2348"} - %mul.4680.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.4261.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_1.5019 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) - %bitcast.1362.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5019), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} - %constant.4831.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4260.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4831.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4679.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1362.clone.1, %broadcast.4260.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3409.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4680.clone.1, %mul.4679.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_0.4134 = f32[]{:T(128)S(6)} parameter(0) - %div.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4134), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3409.clone.1, %div.2530.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.151.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2529.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4835.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4257.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4835.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3407.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.4257.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1287.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2531.clone.1, %add.3407.clone.1), metadata={op_name="multiply.296"} - %div.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3408.clone.1, %multiply.1287.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4675.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1359.clone.1, %broadcast.4259.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3406.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2528.clone.1, %mul.4675.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4674.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4676.clone.1, %add.3406.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3405.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1359.clone.1, %mul.4674.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3405.clone.1, %add.3405.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5069 = f32[]{:T(128)} constant(0) - %reduce.675 = f32[]{:T(128)} reduce(%square.577, %constant.5069), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.849.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3409.clone.1) - %bitcast.822.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3408.clone.1) - %reduce.684.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.5069), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.670 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.675, %add.3405.clone.1, %bitcast.849.clone.1, %bitcast.822.clone.1, %reduce.684.clone.1) -} - -%region_226.251 (reduce_sum.928: f32[], reduce_sum.625: f32[]) -> f32[] { - %reduce_sum.928 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.625 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.626 = f32[]{:T(128)} add(%reduce_sum.928, %reduce_sum.625), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_192.217 (reduce_sum.690: f32[], reduce_sum.465: f32[]) -> f32[] { - %reduce_sum.690 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.470 = f32[]{:T(128)} add(%reduce_sum.690, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.516 (param_0.4135: f32[], param_1.5020: f32[256,1,512,512], param_2.4293: f32[], param_3.2946: f32[256,1,512,512], param_4.2198: f32[], param_5.2001: bf16[256,512,512], param_6.1438: pred[], param_7.1119: f32[], param_8.884: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %constant.4776.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4122.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4776.clone.1), dimensions={}, metadata={op_name="broadcast.2225"} + %mul.5404.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.4122.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5020 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1344.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5020), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.4775.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4121.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4775.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5403.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1344.clone.1, %broadcast.4121.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3367.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5404.clone.1, %mul.5403.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4152 = f32[]{:T(128)S(6)} parameter(0) + %div.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4152), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2527.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3367.clone.1, %div.2528.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.151.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2527.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4779.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4118.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4779.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3365.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.4118.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1080.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2529.clone.1, %add.3365.clone.1), metadata={op_name="multiply.269"} + %div.2526.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3366.clone.1, %multiply.1080.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5399.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1341.clone.1, %broadcast.4120.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3364.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2526.clone.1, %mul.5399.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5398.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5400.clone.1, %add.3364.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3363.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1341.clone.1, %mul.5398.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3363.clone.1, %add.3363.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5013 = f32[]{:T(128)} constant(0) + %reduce.618 = f32[]{:T(128)} reduce(%square.577, %constant.5013), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.831.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3367.clone.1) + %bitcast.804.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3366.clone.1) + %reduce.627.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.5013), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.670 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.618, %add.3363.clone.1, %bitcast.831.clone.1, %bitcast.804.clone.1, %reduce.627.clone.1) +} + +%region_226.251 (reduce_sum.999: f32[], reduce_sum.1000: f32[]) -> f32[] { + %reduce_sum.999 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1000 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.543 = f32[]{:T(128)} add(%reduce_sum.999, %reduce_sum.1000), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_192.217 (reduce_sum.761: f32[], reduce_sum.762: f32[]) -> f32[] { + %reduce_sum.761 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.762 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.465 = f32[]{:T(128)} add(%reduce_sum.761, %reduce_sum.762), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.515 (param_0.4153: f32[], param_1.5021: f32[256,1,512,512], param_2.4283: f32[], param_3.2951: f32[256,1,512,512], param_4.2198: f32[], param_5.1998: bf16[256,512,512], param_6.1439: pred[], param_7.1119: f32[], param_8.884: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { %param_8.884 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) - %bitcast.1363.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.884), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %bitcast.1345.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.884), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} %param_7.1119 = f32[]{:T(128)S(6)} parameter(7) - %mul.4683.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1119), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_6.1438 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2149.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1438), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_5.2001 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1573.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2001), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1365.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1573.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %mul.5407.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1119), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1439 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2149.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1439), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1998 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1598.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1998), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1347.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1598.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2198 = f32[]{:T(128)} parameter(4) - %div.2539.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2198), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2538.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1365.clone.1, %div.2539.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2148.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2149.clone.1, %bitcast.1365.clone.1, %div.2538.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4839.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4264.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4839.clone.1), dimensions={}, metadata={op_name="broadcast.2345"} - %mul.4685.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2148.clone.1, %broadcast.4264.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_3.2946 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) - %bitcast.1364.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2946), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} - %constant.4838.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4263.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4838.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4684.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1364.clone.1, %broadcast.4263.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3413.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4685.clone.1, %mul.4684.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4293 = f32[]{:T(128)S(6)} parameter(2) - %div.2537.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4293), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2537.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2198), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2536.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1347.clone.1, %div.2537.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2148.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2149.clone.1, %bitcast.1347.clone.1, %div.2536.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4783.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4125.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4783.clone.1), dimensions={}, metadata={op_name="broadcast.2222"} + %mul.5409.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2148.clone.1, %broadcast.4125.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2951 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1346.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2951), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.4782.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4124.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4782.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5408.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1346.clone.1, %broadcast.4124.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3371.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5409.clone.1, %mul.5408.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4283 = f32[]{:T(128)S(6)} parameter(2) + %div.2535.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4283), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2148.clone.1, %select_n.2148.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4837.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4266.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4837.clone.1), dimensions={}, metadata={op_name="broadcast.2348"} - %mul.4687.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.4266.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_1.5020 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) - %bitcast.1366.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5020), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} - %constant.4836.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4265.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4836.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4686.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1366.clone.1, %broadcast.4265.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3414.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4687.clone.1, %mul.4686.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_0.4135 = f32[]{:T(128)S(6)} parameter(0) - %div.2536.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4135), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2535.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3414.clone.1, %div.2536.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.152.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2535.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4840.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4262.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4840.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3412.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.4262.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1288.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2537.clone.1, %add.3412.clone.1), metadata={op_name="multiply.295"} - %div.2534.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3413.clone.1, %multiply.1288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4682.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1363.clone.1, %broadcast.4264.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3411.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2534.clone.1, %mul.4682.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4681.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4683.clone.1, %add.3411.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3410.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1363.clone.1, %mul.4681.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3410.clone.1, %add.3410.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5070 = f32[]{:T(128)} constant(0) - %reduce.676 = f32[]{:T(128)} reduce(%square.578, %constant.5070), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.840.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3414.clone.1) - %bitcast.813.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3413.clone.1) - %reduce.685.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.5070), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.669 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.676, %add.3410.clone.1, %bitcast.840.clone.1, %bitcast.813.clone.1, %reduce.685.clone.1) -} - -%region_225.250 (reduce_sum.921: f32[], reduce_sum.619: f32[]) -> f32[] { - %reduce_sum.921 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.619 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.624 = f32[]{:T(128)} add(%reduce_sum.921, %reduce_sum.619), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_191.216 (reduce_sum.683: f32[], reduce_sum.463: f32[]) -> f32[] { - %reduce_sum.683 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.463 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.464 = f32[]{:T(128)} add(%reduce_sum.683, %reduce_sum.463), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.517 (param_0.4136: f32[], param_1.5021: f32[256,1,512,512], param_2.4294: f32[], param_3.2947: f32[256,1,512,512], param_4.2199: f32[], param_5.2002: bf16[256,512,512], param_6.1439: pred[], param_7.1120: f32[], param_8.885: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %constant.4781.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4127.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4781.clone.1), dimensions={}, metadata={op_name="broadcast.2225"} + %mul.5411.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.4127.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5021 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1348.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5021), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.4780.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4126.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4780.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5410.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1348.clone.1, %broadcast.4126.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3372.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5411.clone.1, %mul.5410.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4153 = f32[]{:T(128)S(6)} parameter(0) + %div.2534.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4153), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2533.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3372.clone.1, %div.2534.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.152.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2533.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4784.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4123.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4784.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3370.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.4123.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1081.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2535.clone.1, %add.3370.clone.1), metadata={op_name="multiply.268"} + %div.2532.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3371.clone.1, %multiply.1081.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5406.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1345.clone.1, %broadcast.4125.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3369.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2532.clone.1, %mul.5406.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5405.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5407.clone.1, %add.3369.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3368.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1345.clone.1, %mul.5405.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3368.clone.1, %add.3368.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5014 = f32[]{:T(128)} constant(0) + %reduce.619 = f32[]{:T(128)} reduce(%square.578, %constant.5014), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.822.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3372.clone.1) + %bitcast.795.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3371.clone.1) + %reduce.628.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.5014), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.669 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.619, %add.3368.clone.1, %bitcast.822.clone.1, %bitcast.795.clone.1, %reduce.628.clone.1) +} + +%region_225.250 (reduce_sum.992: f32[], reduce_sum.993: f32[]) -> f32[] { + %reduce_sum.992 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.993 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.542 = f32[]{:T(128)} add(%reduce_sum.992, %reduce_sum.993), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_191.216 (reduce_sum.754: f32[], reduce_sum.755: f32[]) -> f32[] { + %reduce_sum.754 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.755 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.464 = f32[]{:T(128)} add(%reduce_sum.754, %reduce_sum.755), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.516 (param_0.4154: f32[], param_1.5022: f32[256,1,512,512], param_2.4284: f32[], param_3.2952: f32[256,1,512,512], param_4.2199: f32[], param_5.1999: bf16[256,512,512], param_6.1440: pred[], param_7.1120: f32[], param_8.885: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { %param_8.885 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) - %bitcast.1367.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.885), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %bitcast.1349.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.885), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} %param_7.1120 = f32[]{:T(128)S(6)} parameter(7) - %mul.4690.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_6.1439 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2151.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1439), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_5.2002 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1574.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2002), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1369.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1574.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %mul.5414.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1440 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2151.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1440), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1999 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1599.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1999), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1351.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1599.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2199 = f32[]{:T(128)} parameter(4) - %div.2545.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2199), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2544.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1369.clone.1, %div.2545.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2150.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2151.clone.1, %bitcast.1369.clone.1, %div.2544.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4844.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4269.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4844.clone.1), dimensions={}, metadata={op_name="broadcast.2345"} - %mul.4692.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2150.clone.1, %broadcast.4269.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_3.2947 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) - %bitcast.1368.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2947), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} - %constant.4843.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4268.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4843.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4691.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1368.clone.1, %broadcast.4268.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3418.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4692.clone.1, %mul.4691.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4294 = f32[]{:T(128)S(6)} parameter(2) - %div.2543.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4294), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2543.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2199), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2542.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1351.clone.1, %div.2543.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2150.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2151.clone.1, %bitcast.1351.clone.1, %div.2542.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4788.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4130.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4788.clone.1), dimensions={}, metadata={op_name="broadcast.2222"} + %mul.5416.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2150.clone.1, %broadcast.4130.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2952 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1350.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2952), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.4787.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4129.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4787.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5415.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1350.clone.1, %broadcast.4129.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3376.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5416.clone.1, %mul.5415.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4284 = f32[]{:T(128)S(6)} parameter(2) + %div.2541.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4284), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2150.clone.1, %select_n.2150.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4842.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4271.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4842.clone.1), dimensions={}, metadata={op_name="broadcast.2348"} - %mul.4694.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.4271.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_1.5021 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) - %bitcast.1370.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5021), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} - %constant.4841.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4270.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4841.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4693.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1370.clone.1, %broadcast.4270.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3419.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4694.clone.1, %mul.4693.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_0.4136 = f32[]{:T(128)S(6)} parameter(0) - %div.2542.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4136), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2541.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3419.clone.1, %div.2542.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.153.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2541.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4845.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4267.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4845.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3417.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.4267.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1289.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2543.clone.1, %add.3417.clone.1), metadata={op_name="multiply.294"} - %div.2540.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3418.clone.1, %multiply.1289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4689.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1367.clone.1, %broadcast.4269.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3416.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2540.clone.1, %mul.4689.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4688.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4690.clone.1, %add.3416.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3415.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1367.clone.1, %mul.4688.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3415.clone.1, %add.3415.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5071 = f32[]{:T(128)} constant(0) - %reduce.677 = f32[]{:T(128)} reduce(%square.579, %constant.5071), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.831.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3419.clone.1) - %bitcast.804.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3418.clone.1) - %reduce.686.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.5071), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.668 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.677, %add.3415.clone.1, %bitcast.831.clone.1, %bitcast.804.clone.1, %reduce.686.clone.1) -} - -%region_155.180 (reduce_sum.438: f32[], reduce_sum.259: f32[]) -> f32[] { - %reduce_sum.438 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.259 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.260 = f32[]{:T(128)} add(%reduce_sum.438, %reduce_sum.259), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.529.clone.clone.clone (param_0.4079: bf16[4,128,129280], param_1.4953: s32[4,128], param_2.4225: f32[4,128], param_3.2913: f32[4,128], param_4.2170: bf16[4,128], param_5.1978: f32[4,128]) -> bf16[4,128,129280] { - %param_5.1978 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.4903 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.1978), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.2913 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.4902 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2913), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.4079 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.3163 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4079), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %constant.4786.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4132.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4786.clone.1), dimensions={}, metadata={op_name="broadcast.2225"} + %mul.5418.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.4132.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5022 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1352.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5022), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.4785.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4131.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4785.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5417.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1352.clone.1, %broadcast.4131.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3377.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5418.clone.1, %mul.5417.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4154 = f32[]{:T(128)S(6)} parameter(0) + %div.2540.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4154), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2539.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3377.clone.1, %div.2540.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.153.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2539.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4789.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4128.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4789.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3375.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.4128.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1082.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2541.clone.1, %add.3375.clone.1), metadata={op_name="multiply.267"} + %div.2538.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3376.clone.1, %multiply.1082.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5413.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1349.clone.1, %broadcast.4130.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3374.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2538.clone.1, %mul.5413.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5412.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5414.clone.1, %add.3374.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3373.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1349.clone.1, %mul.5412.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3373.clone.1, %add.3373.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5015 = f32[]{:T(128)} constant(0) + %reduce.620 = f32[]{:T(128)} reduce(%square.579, %constant.5015), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.813.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3377.clone.1) + %bitcast.786.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3376.clone.1) + %reduce.629.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.5015), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.668 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.620, %add.3373.clone.1, %bitcast.813.clone.1, %bitcast.786.clone.1, %reduce.629.clone.1) +} + +%region_155.180 (reduce_sum.509: f32[], reduce_sum.338: f32[]) -> f32[] { + %reduce_sum.509 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.338 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.343 = f32[]{:T(128)} add(%reduce_sum.509, %reduce_sum.338), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.528.clone.clone.clone (param_0.4097: bf16[4,128,129280], param_1.4954: s32[4,128], param_2.4215: f32[4,128], param_3.2918: f32[4,128], param_4.2170: bf16[4,128], param_5.1975: f32[4,128]) -> bf16[4,128,129280] { + %param_5.1975 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.5651 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.1975), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.2918 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.5650 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2918), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.4097 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.3111 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4097), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %param_4.2170 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) %sub.804 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_4.2170), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.803 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3163, %sub.804), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.803 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3111, %sub.804), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.534 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.803), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.4901 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4902, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.4225 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.2698 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4225), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.2697 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.4901, %div.2698), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.4953 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.371 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.4953), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %mul.5649 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.5650, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.4215 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.2696 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4215), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.2695 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.5649, %div.2696), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.4954 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.371 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.4954), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.370 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.369 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.371, %eq.370), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.3162 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.369), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.802 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2697, %convert_element_type.3162), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.4900 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4903, %sub.802), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.3161 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.4900), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.939.clone.clone (param_0.4080: f32[4,128], param_1.4954: bf16[4,128,512], param_2.4227: bf16[512]) -> bf16[4,128,512] { - %param_2.4227 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) - %dot_general.831 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4227), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.4954 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.3165 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.4954), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.4080 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.4905 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4080), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.4904 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3165, %mul.4905), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.3164 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.4904), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.830 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.831, %convert_element_type.3164), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.518 (param_0.4169: bf16[4,128,129280], param_1.5049: s32[4,128], param_2.4319: f32[4,128], param_3.2969: f32[4,128], param_4.2219: bf16[4,128], param_5.2020: f32[4,128], param_6.1457: f32[4,128], param_7.1138: bf16[4,128,512], param_8.902: bf16[512]) -> (f32[], bf16[512,129280,1]) { - %param_6.1457 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %convert_element_type.3110 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.369), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.802 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2695, %convert_element_type.3110), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.5648 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.5651, %sub.802), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.3109 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.5648), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.935.clone.clone (param_0.4098: f32[4,128], param_1.4955: bf16[4,128,512], param_2.4217: bf16[512]) -> bf16[4,128,512] { + %param_1.4955 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3113 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.4955), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.4098 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5654 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4098), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.5653 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3113, %mul.5654), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.3112 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.5653), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.4217 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) + %mul.5655 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4217), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.5652 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.3112, %mul.5655), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.517 (param_0.4187: bf16[4,128,129280], param_1.5050: s32[4,128], param_2.4309: f32[4,128], param_3.2974: f32[4,128], param_4.2219: bf16[4,128], param_5.2017: f32[4,128], param_6.1458: f32[4,128], param_7.1138: bf16[4,128,512], param_8.902: bf16[512]) -> (f32[], bf16[512,129280,1]) { + %param_6.1458 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) %param_7.1138 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) %param_8.902 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(8) - %fusion.577.clone.1 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} fusion(%param_6.1457, %param_7.1138, %param_8.902), kind=kLoop, calls=%fused_computation.939.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.4169 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.5049 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.4319 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.2969 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %fusion.574.clone.1 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} fusion(%param_6.1458, %param_7.1138, %param_8.902), kind=kLoop, calls=%fused_computation.935.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.4187 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.5050 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.4309 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.2974 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) %param_4.2219 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %param_5.2020 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %multiply_convert_fusion.1.clone.1 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} fusion(%param_0.4169, %param_1.5049, %param_2.4319, %param_3.2969, %param_4.2219, /*index=5*/%param_5.2020), kind=kLoop, calls=%fused_computation.529.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %convolution.141.clone.1 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.577.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %bitcast.776 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%convolution.141.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.2665 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.776), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2665, %convert_element_type.2665), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5104 = f32[]{:T(128)} constant(0) - %reduce.678 = f32[]{:T(128)} reduce(%square.581, %constant.5104), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.757 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.678, %convolution.141.clone.1) + %param_5.2017 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %multiply_convert_fusion.1.clone.1 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} fusion(%param_0.4187, %param_1.5050, %param_2.4309, %param_3.2974, %param_4.2219, /*index=5*/%param_5.2017), kind=kLoop, calls=%fused_computation.528.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convolution.141.clone.1 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.574.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %bitcast.758 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%convolution.141.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.2618 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.758), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2618, %convert_element_type.2618), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5048 = f32[]{:T(128)} constant(0) + %reduce.621 = f32[]{:T(128)} reduce(%square.581, %constant.5048), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.757 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.621, %convolution.141.clone.1) } -%region_174.199 (reduce_sum.564: f32[], reduce_sum.387: f32[]) -> f32[] { - %reduce_sum.564 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.564, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_174.199 (reduce_sum.635: f32[], reduce_sum.636: f32[]) -> f32[] { + %reduce_sum.635 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.636 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.635, %reduce_sum.636), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.519 (param_0.4153: bf16[129280,512]) -> f32[] { - %param_0.4153 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2667 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4153), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2667, %convert_element_type.2667), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5088 = f32[]{:T(128)} constant(0) - ROOT %reduce.679 = f32[]{:T(128)} reduce(%square.583, %constant.5088), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.518 (param_0.4171: bf16[129280,512]) -> f32[] { + %param_0.4171 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2620 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4171), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2620, %convert_element_type.2620), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5032 = f32[]{:T(128)} constant(0) + ROOT %reduce.622 = f32[]{:T(128)} reduce(%square.583, %constant.5032), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_240.265 (reduce_sum.1026: f32[], reduce_sum.689: f32[]) -> f32[] { - %reduce_sum.1026 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.689 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.694 = f32[]{:T(128)} add(%reduce_sum.1026, %reduce_sum.689), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_240.265 (reduce_sum.1097: f32[], reduce_sum.1098: f32[]) -> f32[] { + %reduce_sum.1097 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1098 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.577 = f32[]{:T(128)} add(%reduce_sum.1097, %reduce_sum.1098), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_206.231 (reduce_sum.788: f32[], reduce_sum.533: f32[]) -> f32[] { - %reduce_sum.788 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.533 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.534 = f32[]{:T(128)} add(%reduce_sum.788, %reduce_sum.533), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_206.231 (reduce_sum.859: f32[], reduce_sum.860: f32[]) -> f32[] { + %reduce_sum.859 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.860 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.499 = f32[]{:T(128)} add(%reduce_sum.859, %reduce_sum.860), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.520 (param_0.4121: f32[129280,512], param_1.5006: f32[], param_2.4279: f32[], param_3.2932: f32[], param_4.2184: f32[129280,512], param_5.1987: f32[], param_6.1424: bf16[129280,512], param_7.1105: pred[], param_8.870: f32[129280,512]) -> (f32[], f32[129280,512], f32[129280,512], f32[129280,512], f32[]) { - %param_0.4121 = f32[129280,512]{1,0:T(8,128)} parameter(0) - %param_3.2932 = f32[]{:T(128)S(6)} parameter(3) - %mul.4564.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.2932), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.519 (param_0.4139: f32[129280,512], param_1.5007: f32[], param_2.4269: f32[], param_3.2937: f32[], param_4.2184: f32[129280,512], param_5.1984: f32[], param_6.1425: bf16[129280,512], param_7.1105: pred[], param_8.870: f32[129280,512]) -> (f32[], f32[129280,512], f32[129280,512], f32[129280,512], f32[]) { + %param_0.4139 = f32[129280,512]{1,0:T(8,128)} parameter(0) + %param_3.2937 = f32[]{:T(128)S(6)} parameter(3) + %mul.5288.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.2937), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1105 = pred[]{:T(512)S(6)} parameter(7) %select_n.2105.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} broadcast(%param_7.1105), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1424 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(6) - %convert_element_type.3106.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.1424), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %param_5.1987 = f32[]{:T(128)} parameter(5) - %div.2439.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.1987), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2438.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3106.clone.1, %div.2439.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2104.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.2105.clone.1, %convert_element_type.3106.clone.1, %div.2438.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4754.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4209.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4754.clone.1), dimensions={}, metadata={op_name="broadcast.318"} - %mul.4570.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2104.clone.1, %broadcast.4209.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1425 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.3054.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.1425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %param_5.1984 = f32[]{:T(128)} parameter(5) + %div.2437.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.1984), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2436.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3054.clone.1, %div.2437.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2104.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.2105.clone.1, %convert_element_type.3054.clone.1, %div.2436.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4698.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4070.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4698.clone.1), dimensions={}, metadata={op_name="broadcast.318"} + %mul.5294.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2104.clone.1, %broadcast.4070.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.870 = f32[129280,512]{1,0:T(8,128)} parameter(8) - %constant.4758.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4571.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4758.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4569.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.870, %mul.4571.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3338.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4570.clone.1, %mul.4569.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4279 = f32[]{:T(128)S(6)} parameter(2) - %div.2435.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.4279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4702.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5295.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4702.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5293.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.870, %mul.5295.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3296.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.5294.clone.1, %mul.5293.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4269 = f32[]{:T(128)S(6)} parameter(2) + %div.2433.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.4269), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.380.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2104.clone.1, %select_n.2104.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4757.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4568.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4757.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4566.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.4568.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4701.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5292.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4701.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5290.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.5292.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2184 = f32[129280,512]{1,0:T(8,128)} parameter(4) - %constant.4756.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4567.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4756.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4565.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.2184, %mul.4567.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3337.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4566.clone.1, %mul.4565.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5006 = f32[]{:T(128)S(6)} parameter(1) - %div.2434.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5006), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2433.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3337.clone.1, %div.2434.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.138.clone.1 = f32[129280,512]{1,0:T(8,128)} sqrt(%div.2433.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4755.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3336.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4755.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3335.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3336.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1274.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2435.clone.1, %add.3335.clone.1), metadata={op_name="multiply.309"} - %div.2432.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3338.clone.1, %multiply.1274.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4563.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4121, %broadcast.4209.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3334.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2432.clone.1, %mul.4563.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4562.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.4564.clone.1, %add.3334.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3333.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4121, %mul.4562.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3333.clone.1, %add.3333.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5056 = f32[]{:T(128)} constant(0) - %reduce.680 = f32[]{:T(128)} reduce(%square.584, %constant.5056), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.687.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.5056), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.671 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.680, %add.3333.clone.1, %add.3337.clone.1, %add.3338.clone.1, %reduce.687.clone.1) -} - -%region_222.247 (reduce_sum.900: f32[], reduce_sum.605: f32[]) -> f32[] { - %reduce_sum.900 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.605 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.610 = f32[]{:T(128)} add(%reduce_sum.900, %reduce_sum.605), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_188.213 (reduce_sum.662: f32[], reduce_sum.451: f32[]) -> f32[] { - %reduce_sum.662 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.451 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.455 = f32[]{:T(128)} add(%reduce_sum.662, %reduce_sum.451), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.521 (param_0.4139: f32[512,129280], param_1.5024: f32[], param_2.4297: f32[], param_3.2950: f32[], param_4.2202: f32[512,129280], param_5.2005: f32[], param_6.1442: bf16[512,129280,1], param_7.1123: pred[], param_8.888: f32[512,129280]) -> (f32[], f32[512,129280], f32[512,129280], f32[512,129280], f32[]) { - %param_0.4139 = f32[512,129280]{1,0:T(8,128)} parameter(0) - %param_3.2950 = f32[]{:T(128)S(6)} parameter(3) - %mul.4717.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.2950), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4700.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5291.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4700.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5289.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.2184, %mul.5291.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3295.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.5290.clone.1, %mul.5289.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5007 = f32[]{:T(128)S(6)} parameter(1) + %div.2432.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5007), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2431.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3295.clone.1, %div.2432.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.138.clone.1 = f32[129280,512]{1,0:T(8,128)} sqrt(%div.2431.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4699.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3294.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4699.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3293.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3294.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1067.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2433.clone.1, %add.3293.clone.1), metadata={op_name="multiply.282"} + %div.2430.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3296.clone.1, %multiply.1067.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5287.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4139, %broadcast.4070.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3292.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2430.clone.1, %mul.5287.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5286.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.5288.clone.1, %add.3292.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3291.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4139, %mul.5286.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3291.clone.1, %add.3291.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5000 = f32[]{:T(128)} constant(0) + %reduce.623 = f32[]{:T(128)} reduce(%square.584, %constant.5000), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.630.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.5000), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.671 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.623, %add.3291.clone.1, %add.3295.clone.1, %add.3296.clone.1, %reduce.630.clone.1) +} + +%region_222.247 (reduce_sum.971: f32[], reduce_sum.972: f32[]) -> f32[] { + %reduce_sum.971 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.972 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.535 = f32[]{:T(128)} add(%reduce_sum.971, %reduce_sum.972), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_188.213 (reduce_sum.733: f32[], reduce_sum.734: f32[]) -> f32[] { + %reduce_sum.733 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.734 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.733, %reduce_sum.734), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.520 (param_0.4157: f32[512,129280], param_1.5025: f32[], param_2.4287: f32[], param_3.2955: f32[], param_4.2202: f32[512,129280], param_5.2002: f32[], param_6.1443: bf16[512,129280,1], param_7.1123: pred[], param_8.888: f32[512,129280]) -> (f32[], f32[512,129280], f32[512,129280], f32[512,129280], f32[]) { + %param_0.4157 = f32[512,129280]{1,0:T(8,128)} parameter(0) + %param_3.2955 = f32[]{:T(128)S(6)} parameter(3) + %mul.5441.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.2955), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1123 = pred[]{:T(512)S(6)} parameter(7) %select_n.2161.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} broadcast(%param_7.1123), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1442 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} parameter(6) - %bitcast.1372.clone.1 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%param_6.1442), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.3108.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1372.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - %param_5.2005 = f32[]{:T(128)} parameter(5) - %div.2567.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.2005), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2566.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3108.clone.1, %div.2567.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2160.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.2161.clone.1, %convert_element_type.3108.clone.1, %div.2566.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4858.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4277.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4858.clone.1), dimensions={}, metadata={op_name="broadcast.333"} - %mul.4723.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2160.clone.1, %broadcast.4277.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1443 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.1354.clone.1 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%param_6.1443), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.3056.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1354.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %param_5.2002 = f32[]{:T(128)} parameter(5) + %div.2565.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.2002), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2564.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3056.clone.1, %div.2565.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2160.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.2161.clone.1, %convert_element_type.3056.clone.1, %div.2564.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4802.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4138.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4802.clone.1), dimensions={}, metadata={op_name="broadcast.333"} + %mul.5447.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2160.clone.1, %broadcast.4138.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.888 = f32[512,129280]{1,0:T(8,128)} parameter(8) - %constant.4862.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4724.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4862.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4722.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.888, %mul.4724.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3437.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4723.clone.1, %mul.4722.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4297 = f32[]{:T(128)S(6)} parameter(2) - %div.2563.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.4297), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4806.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5448.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4806.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5446.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.888, %mul.5448.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3395.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.5447.clone.1, %mul.5446.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4287 = f32[]{:T(128)S(6)} parameter(2) + %div.2561.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.4287), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.398.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2160.clone.1, %select_n.2160.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4861.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4721.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4861.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4719.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.4721.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4805.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5445.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4805.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5443.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.5445.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2202 = f32[512,129280]{1,0:T(8,128)} parameter(4) - %constant.4860.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4720.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4860.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4718.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.2202, %mul.4720.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3436.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4719.clone.1, %mul.4718.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5024 = f32[]{:T(128)S(6)} parameter(1) - %div.2562.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5024), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2561.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3436.clone.1, %div.2562.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.156.clone.1 = f32[512,129280]{1,0:T(8,128)} sqrt(%div.2561.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4859.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3435.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4859.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3434.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3435.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1292.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2563.clone.1, %add.3434.clone.1), metadata={op_name="multiply.291"} - %div.2560.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3437.clone.1, %multiply.1292.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4716.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4139, %broadcast.4277.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3433.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2560.clone.1, %mul.4716.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4715.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.4717.clone.1, %add.3433.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3432.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4139, %mul.4715.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3432.clone.1, %add.3432.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5074 = f32[]{:T(128)} constant(0) - %reduce.681 = f32[]{:T(128)} reduce(%square.585, %constant.5074), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.688.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.5074), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.672 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.681, %add.3432.clone.1, %add.3436.clone.1, %add.3437.clone.1, %reduce.688.clone.1) -} - -%region_207.232 (reduce_sum.795: f32[], reduce_sum.535: f32[]) -> f32[] { - %reduce_sum.795 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.535 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.540 = f32[]{:T(128)} add(%reduce_sum.795, %reduce_sum.535), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.522 (param_0.4190: bf16[4,128,129280], param_1.5063: f32[4,128], param_2.4329: s32[4,128], param_3.2977: bf16[4,128]) -> f32[4,128] { - %param_2.4329 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %eq.307 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4329), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %constant.4804.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5444.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4804.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5442.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.2202, %mul.5444.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3394.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.5443.clone.1, %mul.5442.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5025 = f32[]{:T(128)S(6)} parameter(1) + %div.2560.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5025), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2559.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3394.clone.1, %div.2560.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.156.clone.1 = f32[512,129280]{1,0:T(8,128)} sqrt(%div.2559.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4803.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3393.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4803.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3392.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3393.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1085.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2561.clone.1, %add.3392.clone.1), metadata={op_name="multiply.264"} + %div.2558.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3395.clone.1, %multiply.1085.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5440.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4157, %broadcast.4138.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3391.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2558.clone.1, %mul.5440.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5439.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.5441.clone.1, %add.3391.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3390.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4157, %mul.5439.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3390.clone.1, %add.3390.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5018 = f32[]{:T(128)} constant(0) + %reduce.624 = f32[]{:T(128)} reduce(%square.585, %constant.5018), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.631.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.5018), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.672 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.624, %add.3390.clone.1, %add.3394.clone.1, %add.3395.clone.1, %reduce.631.clone.1) +} + +%region_207.232 (reduce_sum.866: f32[], reduce_sum.867: f32[]) -> f32[] { + %reduce_sum.866 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.867 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.500 = f32[]{:T(128)} add(%reduce_sum.866, %reduce_sum.867), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.521 (param_0.4207: bf16[4,128,129280], param_1.5063: f32[4,128], param_2.4319: s32[4,128], param_3.2982: bf16[4,128]) -> f32[4,128] { + %param_2.4319 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.307 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4319), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.302 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.301 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.307, %eq.302), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %param_0.4190 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2672 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_3.2977 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) - %sub.665 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2977), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.656 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2672, %sub.665), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_0.4207 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2625 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4207), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.2982 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.665 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2982), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.656 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2625, %sub.665), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %param_1.5063 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) %sub.663 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5063), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %sub.652 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%sub.656, %sub.663), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %constant.5128 = f32[]{:T(128)} constant(0) - %broadcast.3784 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.5128), dimensions={}, metadata={op_name="broadcast.518"} - %mul.3624 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.301, %sub.652, %broadcast.3784), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - ROOT %reduce.682 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.3624, %constant.5128), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.5071 = f32[]{:T(128)} constant(0) + %broadcast.3645 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.5071), dimensions={}, metadata={op_name="broadcast.500"} + %mul.4239 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.301, %sub.652, %broadcast.3645), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.625 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.4239, %constant.5071), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_37.47 (reduce_sum.76: f32[], reduce_sum.80: f32[]) -> f32[] { +%region_37.47 (reduce_sum.76: f32[], reduce_sum.82: f32[]) -> f32[] { %reduce_sum.76 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.80 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.83 = f32[]{:T(128)} add(%reduce_sum.76, %reduce_sum.80), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %reduce_sum.82 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.88 = f32[]{:T(128)} add(%reduce_sum.76, %reduce_sum.82), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.533 (param_0.4191: bf16[4,128,129280], param_1.5064: bf16[4,128]) -> f32[4,128] { - %param_0.4191 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2678 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4191), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +%fused_computation.532 (param_0.4208: bf16[4,128,129280], param_1.5064: bf16[4,128]) -> f32[4,128] { + %param_0.4208 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2631 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4208), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %param_1.5064 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) %sub.666 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5064), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.662 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2678, %sub.666), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.662 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2631, %sub.666), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.448 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.662), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %constant.5129 = f32[]{:T(128)} constant(0) - ROOT %reduce.683 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.5129), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.5072 = f32[]{:T(128)} constant(0) + ROOT %reduce.626 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.5072), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_152.177 (reduce_sum.417: f32[], reduce_sum.244: f32[]) -> f32[] { - %reduce_sum.417 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.244 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.251 = f32[]{:T(128)} add(%reduce_sum.417, %reduce_sum.244), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_152.177 (reduce_sum.488: f32[], reduce_sum.324: f32[]) -> f32[] { + %reduce_sum.488 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.324 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.329 = f32[]{:T(128)} add(%reduce_sum.488, %reduce_sum.324), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.541 (param_0.4172: f32[3,512,128,256]) -> f32[] { - %param_0.4172 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.752 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_0.4172), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.588 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%bitcast.752, %bitcast.752), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5107 = f32[]{:T(128)} constant(0) - ROOT %reduce.689 = f32[]{:T(128)} reduce(%square.588, %constant.5107), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.540 (param_0.4190: f32[3,512,128,256]) -> f32[] { + %param_0.4190 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.734 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_0.4190), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.588 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%bitcast.734, %bitcast.734), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5051 = f32[]{:T(128)} constant(0) + ROOT %reduce.632 = f32[]{:T(128)} reduce(%square.588, %constant.5051), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.542 (param_0.1602: f32[512,3,128,256]) -> bf16[3,512,128,256] { - %param_0.1602 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) - %copy.1551 = bf16[512,3,128,256]{3,0,2,1:T(8,128)(2,1)} copy(%param_0.1602), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} - ROOT %bitcast.753 = bf16[3,512,128,256]{3,1,2,0:T(8,128)(2,1)} bitcast(%copy.1551), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} +%fused_computation.541 (param_0.1620: f32[512,3,128,256]) -> bf16[3,512,128,256] { + %param_0.1620 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %copy.1549 = bf16[512,3,128,256]{3,0,2,1:T(8,128)(2,1)} copy(%param_0.1620), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} + ROOT %bitcast.735 = bf16[3,512,128,256]{3,1,2,0:T(8,128)(2,1)} bitcast(%copy.1549), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} } -%region_219.244 (reduce_sum.879: f32[], reduce_sum.591: f32[]) -> f32[] { - %reduce_sum.879 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.591 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.596 = f32[]{:T(128)} add(%reduce_sum.879, %reduce_sum.591), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_219.244 (reduce_sum.950: f32[], reduce_sum.951: f32[]) -> f32[] { + %reduce_sum.950 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.951 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.528 = f32[]{:T(128)} add(%reduce_sum.950, %reduce_sum.951), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_185.210 (reduce_sum.641: f32[], reduce_sum.437: f32[]) -> f32[] { - %reduce_sum.641 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.437 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.442 = f32[]{:T(128)} add(%reduce_sum.641, %reduce_sum.437), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_185.210 (reduce_sum.712: f32[], reduce_sum.713: f32[]) -> f32[] { + %reduce_sum.712 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.713 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.450 = f32[]{:T(128)} add(%reduce_sum.712, %reduce_sum.713), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.543 (param_0.4142: f32[512,3,128,256], param_1.5027: f32[], param_2.4300: f32[], param_3.2953: f32[], param_4.2205: f32[512,3,128,256], param_5.2008: f32[], param_6.1445: f32[3,512,128,256], param_7.1126: pred[], param_8.891: f32[512,3,128,256]) -> (f32[], f32[512,3,128,256], f32[512,3,128,256], f32[512,3,128,256], f32[]) { - %param_0.4142 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) - %param_3.2953 = f32[]{:T(128)S(6)} parameter(3) - %mul.4747.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.2953), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.542 (param_0.4160: f32[512,3,128,256], param_1.5028: f32[], param_2.4290: f32[], param_3.2958: f32[], param_4.2205: f32[512,3,128,256], param_5.2005: f32[], param_6.1446: f32[3,512,128,256], param_7.1126: pred[], param_8.891: f32[512,3,128,256]) -> (f32[], f32[512,3,128,256], f32[512,3,128,256], f32[512,3,128,256], f32[]) { + %param_0.4160 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %param_3.2958 = f32[]{:T(128)S(6)} parameter(3) + %mul.5471.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.2958), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1126 = pred[]{:T(512)S(6)} parameter(7) %select_n.2173.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.1126), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1445 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.1378.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_6.1445), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %param_5.2008 = f32[]{:T(128)} parameter(5) - %div.2591.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.2008), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2590.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1378.clone.1, %div.2591.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2172.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.2173.clone.1, %bitcast.1378.clone.1, %div.2590.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4876.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4283.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4876.clone.1), dimensions={}, metadata={op_name="broadcast.336"} - %mul.4753.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2172.clone.1, %broadcast.4283.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1446 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.1360.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_6.1446), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_5.2005 = f32[]{:T(128)} parameter(5) + %div.2589.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.2005), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2588.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1360.clone.1, %div.2589.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2172.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.2173.clone.1, %bitcast.1360.clone.1, %div.2588.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4820.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4144.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4820.clone.1), dimensions={}, metadata={op_name="broadcast.336"} + %mul.5477.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2172.clone.1, %broadcast.4144.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.891 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(8) - %constant.4880.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4754.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4880.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4752.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.891, %mul.4754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3455.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4753.clone.1, %mul.4752.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4300 = f32[]{:T(128)S(6)} parameter(2) - %div.2587.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.4300), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4824.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5478.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4824.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5476.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.891, %mul.5478.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3413.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.5477.clone.1, %mul.5476.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4290 = f32[]{:T(128)S(6)} parameter(2) + %div.2585.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.4290), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.401.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2172.clone.1, %select_n.2172.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4879.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4751.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4879.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4749.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.4751.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4823.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5475.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4823.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5473.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.5475.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2205 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(4) - %constant.4878.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4750.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4878.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4748.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.2205, %mul.4750.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3454.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4749.clone.1, %mul.4748.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5027 = f32[]{:T(128)S(6)} parameter(1) - %div.2586.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5027), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2585.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3454.clone.1, %div.2586.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.159.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} sqrt(%div.2585.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4877.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3453.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4877.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3452.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3453.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1295.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2587.clone.1, %add.3452.clone.1), metadata={op_name="multiply.288"} - %div.2584.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3455.clone.1, %multiply.1295.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4746.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4142, %broadcast.4283.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3451.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2584.clone.1, %mul.4746.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4745.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.4747.clone.1, %add.3451.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3450.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4142, %mul.4745.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3450.clone.1, %add.3450.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5077 = f32[]{:T(128)} constant(0) - %reduce.690 = f32[]{:T(128)} reduce(%square.589, %constant.5077), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.691.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.5077), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.667 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.690, %add.3450.clone.1, %add.3454.clone.1, %add.3455.clone.1, %reduce.691.clone.1) -} - -%region_172.197 (reduce_sum.557: f32[], reduce_sum.381: f32[]) -> f32[] { - %reduce_sum.557 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.386 = f32[]{:T(128)} add(%reduce_sum.557, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.783.clone.clone (param_0.4106: f32[4,128], param_1.4998: bf16[4,128,1536], param_2.4261: bf16[1536]) -> bf16[4,128,1536,1] { - %param_2.4261 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.851 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4261), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.4998 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.3187 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.4998), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} - %param_0.4106 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.4951 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4106), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} - %mul.4950 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3187, %mul.4951), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} - %convert_element_type.3186 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.4950), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} - %dot_general.850 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.851, %convert_element_type.3186), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.1466 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dot_general.850), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} + %constant.4822.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5474.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4822.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5472.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.2205, %mul.5474.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3412.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.5473.clone.1, %mul.5472.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5028 = f32[]{:T(128)S(6)} parameter(1) + %div.2584.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5028), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2583.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3412.clone.1, %div.2584.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.159.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} sqrt(%div.2583.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4821.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3411.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4821.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3410.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3411.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1088.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2585.clone.1, %add.3410.clone.1), metadata={op_name="multiply.261"} + %div.2582.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3413.clone.1, %multiply.1088.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5470.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4160, %broadcast.4144.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3409.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2582.clone.1, %mul.5470.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5469.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.5471.clone.1, %add.3409.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3408.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4160, %mul.5469.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3408.clone.1, %add.3408.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5021 = f32[]{:T(128)} constant(0) + %reduce.633 = f32[]{:T(128)} reduce(%square.589, %constant.5021), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.634.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.5021), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.667 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.633, %add.3408.clone.1, %add.3412.clone.1, %add.3413.clone.1, %reduce.634.clone.1) +} + +%region_172.197 (reduce_sum.628: f32[], reduce_sum.629: f32[]) -> f32[] { + %reduce_sum.628 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.629 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.423 = f32[]{:T(128)} add(%reduce_sum.628, %reduce_sum.629), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.781.clone.clone (param_0.4124: f32[4,128], param_1.4999: bf16[4,128,1536], param_2.4251: bf16[1536]) -> bf16[4,128,1536,1] { + %param_1.4999 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3135 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.4999), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %param_0.4124 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5720 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4124), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %mul.5719 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3135, %mul.5720), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %convert_element_type.3134 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.5719), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %param_2.4251 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.5721 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4251), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %mul.5718 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.3134, %mul.5721), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + ROOT %bitcast.1448 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%mul.5718), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} } %bitcast_fusion.12 (bitcast_input.12: bf16[4,128,128,192]) -> bf16[4,128,128,192] { %bitcast_input.12 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.1488 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} bitcast(%bitcast_input.12) + ROOT %bitcast.1470 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} bitcast(%bitcast_input.12) } -%fused_computation.552 (param_0.4154: bf16[4,128,128,192], param_1.5038: f32[4,128], param_2.4311: bf16[4,128,1536], param_3.2964: bf16[1536]) -> (f32[], bf16[1536,128,192,1]) { - %param_1.5038 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.4311 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.2964 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.460.clone.1 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.5038, %param_2.4311, %param_3.2964), kind=kLoop, calls=%fused_computation.783.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.4154 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) - %fusion.751 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.4154), kind=kLoop, calls=%bitcast_fusion.12 - %convolution.146.clone.1 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} convolution(%fusion.460.clone.1, %fusion.751), window={size=192x4 pad=191_191x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1io0->bf01, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} - %bitcast.861 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%convolution.146.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} - %broadcast_in_dim.1388 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.861), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.763 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1388), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.592 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%bitcast.763, %bitcast.763), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5089 = f32[]{:T(128)} constant(0) - %reduce.692 = f32[]{:T(128)} reduce(%square.592, %constant.5089), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.766 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.692, %convolution.146.clone.1) +%fused_computation.551 (param_0.4172: bf16[4,128,128,192], param_1.5039: f32[4,128], param_2.4301: bf16[4,128,1536], param_3.2969: bf16[1536]) -> (f32[], bf16[1536,128,192,1]) { + %param_1.5039 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.4301 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.2969 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.458.clone.1 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.5039, %param_2.4301, %param_3.2969), kind=kLoop, calls=%fused_computation.781.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %param_0.4172 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) + %fusion.748 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.4172), kind=kLoop, calls=%bitcast_fusion.12 + %convolution.146.clone.1 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} convolution(%fusion.458.clone.1, %fusion.748), window={size=192x4 pad=191_191x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1io0->bf01, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} + %bitcast.843 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%convolution.146.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} + %broadcast_in_dim.1413 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.843), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.745 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1413), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.592 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%bitcast.745, %bitcast.745), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5033 = f32[]{:T(128)} constant(0) + %reduce.635 = f32[]{:T(128)} reduce(%square.592, %constant.5033), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.766 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.635, %convolution.146.clone.1) } -%region_239.264 (reduce_sum.1019: f32[], reduce_sum.687: f32[]) -> f32[] { - %reduce_sum.1019 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.687 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.688 = f32[]{:T(128)} add(%reduce_sum.1019, %reduce_sum.687), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_239.264 (reduce_sum.1090: f32[], reduce_sum.1091: f32[]) -> f32[] { + %reduce_sum.1090 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1091 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.576 = f32[]{:T(128)} add(%reduce_sum.1090, %reduce_sum.1091), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_205.230 (reduce_sum.781: f32[], reduce_sum.527: f32[]) -> f32[] { - %reduce_sum.781 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.527 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.528 = f32[]{:T(128)} add(%reduce_sum.781, %reduce_sum.527), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_205.230 (reduce_sum.852: f32[], reduce_sum.853: f32[]) -> f32[] { + %reduce_sum.852 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.853 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.494 = f32[]{:T(128)} add(%reduce_sum.852, %reduce_sum.853), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.557 (param_0.4122: f32[], param_1.5007: f32[], param_2.4280: f32[], param_3.2933: f32[1536,1,128,192], param_4.2185: f32[1536,1,128,192], param_5.1988: f32[], param_6.1425: bf16[1536,128,192,1], param_7.1106: pred[], param_8.871: f32[1536,1,128,192]) -> (f32[], f32[1536,1,128,192], f32[1536,1,128,192], f32[1536,1,128,192], f32[]) { +%fused_computation.556 (param_0.4140: f32[], param_1.5008: f32[], param_2.4270: f32[], param_3.2938: f32[1536,1,128,192], param_4.2185: f32[1536,1,128,192], param_5.1985: f32[], param_6.1426: bf16[1536,128,192,1], param_7.1106: pred[], param_8.871: f32[1536,1,128,192]) -> (f32[], f32[1536,1,128,192], f32[1536,1,128,192], f32[1536,1,128,192], f32[]) { diff --git a/tests/utils/reference_hlo_llama3_8b.txt b/tests/utils/reference_hlo_llama3_8b.txt index 27c6529df2..19a83b1a84 100644 --- a/tests/utils/reference_hlo_llama3_8b.txt +++ b/tests/utils/reference_hlo_llama3_8b.txt @@ -14,1355 +14,1355 @@ StackFrames %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %reshape.342 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.326 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.342), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %gather.4 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.326), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,4096}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %transpose.325 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %reshape.341 = bf16[512,4096]{1,0:T(8,128)(2,1)} reshape(%transpose.325), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.380 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.241 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.380), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.4 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.241), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,4096}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.240 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.379 = bf16[512,4096]{1,0:T(8,128)(2,1)} reshape(%transpose.240), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} } %region_33.38.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} - ROOT %add.476 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.462 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %fused_computation.1 (param_0.3: bf16[128256,4096], param_1.5: s32[512], param_2.4: bf16[512,4096]) -> bf16[128256,4096] { %param_0.3 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) - %reshape.349 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.331 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.349), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %reshape.387 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.246 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.387), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} %param_2.4 = bf16[512,4096]{1,0:T(8,128)(2,1)S(1)} parameter(2) - %reshape.350 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - %transpose.332 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%reshape.350), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - ROOT %scatter.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.331, %transpose.332), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_33.38.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} + %reshape.388 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + %transpose.247 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%reshape.388), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + ROOT %scatter.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.246, %transpose.247), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_33.38.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} } -%region_32.37 (reduce_sum.190: f32[], reduce_sum.191: f32[]) -> f32[] { - %reduce_sum.190 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.191 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.192 = f32[]{:T(128)} add(%reduce_sum.190, %reduce_sum.191), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_32.37 (reduce_sum.244: f32[], reduce_sum.245: f32[]) -> f32[] { + %reduce_sum.244 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.245 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.249 = f32[]{:T(128)} add(%reduce_sum.244, %reduce_sum.245), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.280.clone.clone.clone (param_0.1099: bf16[4,128,128256], param_1.1265: s32[4,128], param_2.1086: f32[4,128], param_3.785: f32[4,128], param_4.487: bf16[4,128], param_5.412: f32[4,128]) -> bf16[4,128,128256] { +%fused_computation.280.clone.clone.clone (param_0.1106: bf16[4,128,128256], param_1.1269: s32[4,128], param_2.1080: f32[4,128], param_3.773: f32[4,128], param_4.481: bf16[4,128], param_5.412: f32[4,128]) -> bf16[4,128,128256] { %param_5.412 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.1613 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.785 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.1612 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.785), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1099 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1044 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1099), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.487 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.94 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.487), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.93 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1044, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %mul.1937 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.773 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1936 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.773), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1106 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1060 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1106), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.481 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.94 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.481), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.93 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1060, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.62 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.93), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.1611 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1612, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1086 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.823 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1086), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.822 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1611, %div.823), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1265 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.49 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1265), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %mul.1935 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1936, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1080 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.823 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1080), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.822 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1935, %div.823), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1269 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.49 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1269), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.48 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.47 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.49, %eq.48), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1043 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.92 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.822, %convert_element_type.1043), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.1610 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1613, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1042 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1610), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.316.clone.clone (param_0.1100: f32[4,128], param_1.1266: bf16[4,128,4096], param_2.1088: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1088 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.387 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1088), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1266 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1046 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1266), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1100 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1615 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1100), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1614 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1046, %mul.1615), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1045 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1614), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.386 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.387, %convert_element_type.1045), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.219 (param_0.1119: bf16[4,128,128256], param_1.1281: s32[4,128], param_2.1112: f32[4,128], param_3.801: f32[4,128], param_4.502: bf16[4,128], param_5.427: f32[4,128], param_6.299: f32[4,128], param_7.198: bf16[4,128,4096], param_8.116: bf16[4096]) -> (f32[], bf16[4096,128256,1]) { - %param_6.299 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %convert_element_type.1059 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.92 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.822, %convert_element_type.1059), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.1934 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1937, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1058 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1934), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.316.clone.clone (param_0.1107: f32[4,128], param_1.1270: bf16[4,128,4096], param_2.1082: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1270 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1062 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1270), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1107 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1940 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1107), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1939 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1062, %mul.1940), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1061 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1939), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1082 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1941 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1082), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1938 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1061, %mul.1941), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.219 (param_0.1126: bf16[4,128,128256], param_1.1285: s32[4,128], param_2.1106: f32[4,128], param_3.789: f32[4,128], param_4.496: bf16[4,128], param_5.427: f32[4,128], param_6.300: f32[4,128], param_7.198: bf16[4,128,4096], param_8.116: bf16[4096]) -> (f32[], bf16[4096,128256,1]) { + %param_6.300 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) %param_7.198 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) %param_8.116 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(8) - %fusion.239.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_6.299, %param_7.198, %param_8.116), kind=kLoop, calls=%fused_computation.316.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1119 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1281 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1112 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.801 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %param_4.502 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %fusion.239.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_6.300, %param_7.198, %param_8.116), kind=kLoop, calls=%fused_computation.316.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1126 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1285 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1106 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.789 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %param_4.496 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) %param_5.427 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %multiply_convert_fusion.1.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1119, %param_1.1281, %param_2.1112, %param_3.801, %param_4.502, /*index=5*/%param_5.427), kind=kLoop, calls=%fused_computation.280.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %multiply_convert_fusion.1.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1126, %param_1.1285, %param_2.1106, %param_3.789, %param_4.496, /*index=5*/%param_5.427), kind=kLoop, calls=%fused_computation.280.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %convolution.88.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.239.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} %bitcast.306 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%convolution.88.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.923 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.306), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - %square.157 = f32[4096,128256]{1,0:T(8,128)} multiply(%convert_element_type.923, %convert_element_type.923), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1006 = f32[]{:T(128)} constant(0) - %reduce.118 = f32[]{:T(128)} reduce(%square.157, %constant.1006), dimensions={0,1}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.154 = (f32[]{:T(128)}, bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.118, %convolution.88.clone.1) + %convert_element_type.939 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.306), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %square.157 = f32[4096,128256]{1,0:T(8,128)} multiply(%convert_element_type.939, %convert_element_type.939), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.998 = f32[]{:T(128)} constant(0) + %reduce.79 = f32[]{:T(128)} reduce(%square.157, %constant.998), dimensions={0,1}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.154 = (f32[]{:T(128)}, bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.79, %convolution.88.clone.1) } -%region_34.39 (reduce_sum.196: f32[], reduce_sum.197: f32[]) -> f32[] { - %reduce_sum.196 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.197 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.198 = f32[]{:T(128)} add(%reduce_sum.196, %reduce_sum.197), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_34.39 (reduce_sum.250: f32[], reduce_sum.251: f32[]) -> f32[] { + %reduce_sum.250 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.251 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.252 = f32[]{:T(128)} add(%reduce_sum.250, %reduce_sum.251), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.220 (param_0.1118: bf16[128256,4096]) -> f32[] { - %param_0.1118 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.925 = f32[128256,4096]{1,0:T(8,128)} convert(%param_0.1118), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %square.159 = f32[128256,4096]{1,0:T(8,128)} multiply(%convert_element_type.925, %convert_element_type.925), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1005 = f32[]{:T(128)} constant(0) - ROOT %reduce.119 = f32[]{:T(128)} reduce(%square.159, %constant.1005), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.220 (param_0.1125: bf16[128256,4096]) -> f32[] { + %param_0.1125 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.941 = f32[128256,4096]{1,0:T(8,128)} convert(%param_0.1125), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %square.159 = f32[128256,4096]{1,0:T(8,128)} multiply(%convert_element_type.941, %convert_element_type.941), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.997 = f32[]{:T(128)} constant(0) + ROOT %reduce.80 = f32[]{:T(128)} reduce(%square.159, %constant.997), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_60.65 (reduce_sum.338: f32[], reduce_sum.339: f32[]) -> f32[] { - %reduce_sum.338 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.339 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.329 = f32[]{:T(128)} add(%reduce_sum.338, %reduce_sum.339), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_60.65 (reduce_sum.385: f32[], reduce_sum.389: f32[]) -> f32[] { + %reduce_sum.385 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.389 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.390 = f32[]{:T(128)} add(%reduce_sum.385, %reduce_sum.389), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_46.51 (reduce_sum.259: f32[], reduce_sum.260: f32[]) -> f32[] { - %reduce_sum.259 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.260 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.261 = f32[]{:T(128)} add(%reduce_sum.259, %reduce_sum.260), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_46.51 (reduce_sum.313: f32[], reduce_sum.314: f32[]) -> f32[] { + %reduce_sum.313 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.314 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.315 = f32[]{:T(128)} add(%reduce_sum.313, %reduce_sum.314), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.221 (param_0.1106: f32[128256,4096], param_1.1269: f32[], param_2.1100: f32[], param_3.789: f32[], param_4.490: f32[128256,4096], param_5.415: f32[], param_6.287: bf16[128256,4096], param_7.186: pred[], param_8.104: f32[128256,4096]) -> (f32[], f32[128256,4096], f32[128256,4096], f32[128256,4096], f32[]) { - %param_0.1106 = f32[128256,4096]{1,0:T(8,128)} parameter(0) - %param_3.789 = f32[]{:T(128)S(6)} parameter(3) - %mul.1482.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_3.789), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.221 (param_0.1113: f32[128256,4096], param_1.1273: f32[], param_2.1094: f32[], param_3.777: f32[], param_4.484: f32[128256,4096], param_5.415: f32[], param_6.288: bf16[128256,4096], param_7.186: pred[], param_8.104: f32[128256,4096]) -> (f32[], f32[128256,4096], f32[128256,4096], f32[128256,4096], f32[]) { + %param_0.1113 = f32[128256,4096]{1,0:T(8,128)} parameter(0) + %param_3.777 = f32[]{:T(128)S(6)} parameter(3) + %mul.1800.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_3.777), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.186 = pred[]{:T(512)S(6)} parameter(7) %select_n.242.clone.1 = pred[128256,4096]{1,0:T(8,128)(4,1)} broadcast(%param_7.186), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.287 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(6) - %convert_element_type.1017.clone.1 = f32[128256,4096]{1,0:T(8,128)} convert(%param_6.287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %param_6.288 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1033.clone.1 = f32[128256,4096]{1,0:T(8,128)} convert(%param_6.288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} %param_5.415 = f32[]{:T(128)} parameter(5) %div.725.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_5.415), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.724.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%convert_element_type.1017.clone.1, %div.725.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.241.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%select_n.242.clone.1, %convert_element_type.1017.clone.1, %div.724.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.907.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.554.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.907.clone.1), dimensions={}, metadata={op_name="broadcast.61"} - %mul.1488.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.724.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%convert_element_type.1033.clone.1, %div.725.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.241.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%select_n.242.clone.1, %convert_element_type.1033.clone.1, %div.724.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.899.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.515.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.899.clone.1), dimensions={}, metadata={op_name="broadcast.61"} + %mul.1806.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %broadcast.515.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.104 = f32[128256,4096]{1,0:T(8,128)} parameter(8) - %constant.911.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1489.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.911.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1487.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_8.104, %mul.1489.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.776.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1488.clone.1, %mul.1487.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1100 = f32[]{:T(128)S(6)} parameter(2) - %div.721.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_2.1100), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.903.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1807.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.903.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1805.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_8.104, %mul.1807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.762.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1806.clone.1, %mul.1805.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1094 = f32[]{:T(128)S(6)} parameter(2) + %div.721.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_2.1094), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.60.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %select_n.241.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.910.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1486.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.910.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1484.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%integer_pow.60.clone.1, %mul.1486.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.490 = f32[128256,4096]{1,0:T(8,128)} parameter(4) - %constant.909.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1485.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.909.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1483.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_4.490, %mul.1485.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.775.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1484.clone.1, %mul.1483.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1269 = f32[]{:T(128)S(6)} parameter(1) - %div.720.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_1.1269), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.719.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.775.clone.1, %div.720.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.902.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1804.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.902.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1802.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%integer_pow.60.clone.1, %mul.1804.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.484 = f32[128256,4096]{1,0:T(8,128)} parameter(4) + %constant.901.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1803.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.901.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1801.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_4.484, %mul.1803.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.761.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1802.clone.1, %mul.1801.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1273 = f32[]{:T(128)S(6)} parameter(1) + %div.720.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_1.1273), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.719.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.761.clone.1, %div.720.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.58.clone.1 = f32[128256,4096]{1,0:T(8,128)} sqrt(%div.719.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.908.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.774.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.908.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.773.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%sqrt.58.clone.1, %add.774.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.256.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%div.721.clone.1, %add.773.clone.1), metadata={op_name="multiply.42"} - %div.718.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.776.clone.1, %multiply.256.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1481.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_0.1106, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.772.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%div.718.clone.1, %mul.1481.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1480.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%mul.1482.clone.1, %add.772.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.771.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%param_0.1106, %mul.1480.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.160 = f32[128256,4096]{1,0:T(8,128)} multiply(%add.771.clone.1, %add.771.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.993 = f32[]{:T(128)} constant(0) - %reduce.120 = f32[]{:T(128)} reduce(%square.160, %constant.993), dimensions={0,1}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.122.clone.1 = f32[]{:T(128)} reduce(%integer_pow.60.clone.1, %constant.993), dimensions={0,1}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.135 = (f32[]{:T(128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.120, %add.771.clone.1, %add.775.clone.1, %add.776.clone.1, %reduce.122.clone.1) -} - -%region_59.64 (reduce_sum.331: f32[], reduce_sum.332: f32[]) -> f32[] { - %reduce_sum.331 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.332 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.323 = f32[]{:T(128)} add(%reduce_sum.331, %reduce_sum.332), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_45.50 (reduce_sum.253: f32[], reduce_sum.254: f32[]) -> f32[] { - %reduce_sum.253 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.254 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.255 = f32[]{:T(128)} add(%reduce_sum.253, %reduce_sum.254), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.900.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.760.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.900.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.759.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%sqrt.58.clone.1, %add.760.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.183.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%div.721.clone.1, %add.759.clone.1), metadata={op_name="multiply.33"} + %div.718.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.762.clone.1, %multiply.183.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1799.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_0.1113, %broadcast.515.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.758.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%div.718.clone.1, %mul.1799.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1798.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%mul.1800.clone.1, %add.758.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.757.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%param_0.1113, %mul.1798.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.160 = f32[128256,4096]{1,0:T(8,128)} multiply(%add.757.clone.1, %add.757.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.985 = f32[]{:T(128)} constant(0) + %reduce.81 = f32[]{:T(128)} reduce(%square.160, %constant.985), dimensions={0,1}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.83.clone.1 = f32[]{:T(128)} reduce(%integer_pow.60.clone.1, %constant.985), dimensions={0,1}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.135 = (f32[]{:T(128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.81, %add.757.clone.1, %add.761.clone.1, %add.762.clone.1, %reduce.83.clone.1) +} + +%region_59.64 (reduce_sum.382: f32[], reduce_sum.383: f32[]) -> f32[] { + %reduce_sum.382 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.383 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.384 = f32[]{:T(128)} add(%reduce_sum.382, %reduce_sum.383), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.307: f32[], reduce_sum.308: f32[]) -> f32[] { + %reduce_sum.307 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.308 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.312 = f32[]{:T(128)} add(%reduce_sum.307, %reduce_sum.308), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.222 (param_0.1107: f32[4096,128256], param_1.1270: f32[], param_2.1101: f32[], param_3.790: f32[], param_4.491: f32[4096,128256], param_5.416: f32[], param_6.288: bf16[4096,128256,1], param_7.187: pred[], param_8.105: f32[4096,128256]) -> (f32[], f32[4096,128256], f32[4096,128256], f32[4096,128256], f32[]) { - %param_0.1107 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %param_3.790 = f32[]{:T(128)S(6)} parameter(3) - %mul.1492.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_3.790), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.222 (param_0.1114: f32[4096,128256], param_1.1274: f32[], param_2.1095: f32[], param_3.778: f32[], param_4.485: f32[4096,128256], param_5.416: f32[], param_6.289: bf16[4096,128256,1], param_7.187: pred[], param_8.105: f32[4096,128256]) -> (f32[], f32[4096,128256], f32[4096,128256], f32[4096,128256], f32[]) { + %param_0.1114 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %param_3.778 = f32[]{:T(128)S(6)} parameter(3) + %mul.1810.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_3.778), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.187 = pred[]{:T(512)S(6)} parameter(7) %select_n.246.clone.1 = pred[4096,128256]{1,0:T(8,128)(4,1)} broadcast(%param_7.187), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.288 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} parameter(6) - %bitcast.409.clone.1 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%param_6.288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.1019.clone.1 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.409.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %param_6.289 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.409.clone.1 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%param_6.289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.1035.clone.1 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.409.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} %param_5.416 = f32[]{:T(128)} parameter(5) %div.733.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_5.416), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.732.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%convert_element_type.1019.clone.1, %div.733.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.245.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%select_n.246.clone.1, %convert_element_type.1019.clone.1, %div.732.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.913.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.556.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.913.clone.1), dimensions={}, metadata={op_name="broadcast.62"} - %mul.1498.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.732.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%convert_element_type.1035.clone.1, %div.733.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.245.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%select_n.246.clone.1, %convert_element_type.1035.clone.1, %div.732.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.905.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.517.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.905.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %mul.1816.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %broadcast.517.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.105 = f32[4096,128256]{1,0:T(8,128)} parameter(8) - %constant.917.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1499.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.917.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1497.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_8.105, %mul.1499.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.782.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1498.clone.1, %mul.1497.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1101 = f32[]{:T(128)S(6)} parameter(2) - %div.729.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_2.1101), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.909.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1817.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.909.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1815.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_8.105, %mul.1817.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.768.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1816.clone.1, %mul.1815.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1095 = f32[]{:T(128)S(6)} parameter(2) + %div.729.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_2.1095), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.61.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %select_n.245.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.916.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1496.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.916.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1494.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%integer_pow.61.clone.1, %mul.1496.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.491 = f32[4096,128256]{1,0:T(8,128)} parameter(4) - %constant.915.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1495.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.915.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1493.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_4.491, %mul.1495.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.781.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1494.clone.1, %mul.1493.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1270 = f32[]{:T(128)S(6)} parameter(1) - %div.728.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_1.1270), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.727.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.781.clone.1, %div.728.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.908.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1814.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.908.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1812.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%integer_pow.61.clone.1, %mul.1814.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.485 = f32[4096,128256]{1,0:T(8,128)} parameter(4) + %constant.907.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1813.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.907.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1811.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_4.485, %mul.1813.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.767.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1812.clone.1, %mul.1811.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1274 = f32[]{:T(128)S(6)} parameter(1) + %div.728.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_1.1274), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.727.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.767.clone.1, %div.728.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.59.clone.1 = f32[4096,128256]{1,0:T(8,128)} sqrt(%div.727.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.914.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.780.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.914.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.779.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%sqrt.59.clone.1, %add.780.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.257.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%div.729.clone.1, %add.779.clone.1), metadata={op_name="multiply.41"} - %div.726.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.782.clone.1, %multiply.257.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1491.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_0.1107, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.778.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%div.726.clone.1, %mul.1491.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1490.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%mul.1492.clone.1, %add.778.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.777.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%param_0.1107, %mul.1490.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.161 = f32[4096,128256]{1,0:T(8,128)} multiply(%add.777.clone.1, %add.777.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.994 = f32[]{:T(128)} constant(0) - %reduce.121 = f32[]{:T(128)} reduce(%square.161, %constant.994), dimensions={0,1}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.123.clone.1 = f32[]{:T(128)} reduce(%integer_pow.61.clone.1, %constant.994), dimensions={0,1}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.136 = (f32[]{:T(128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.121, %add.777.clone.1, %add.781.clone.1, %add.782.clone.1, %reduce.123.clone.1) -} - -%region_25.30 (reduce_sum.154: f32[], reduce_sum.155: f32[]) -> f32[] { - %reduce_sum.154 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.155 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.156 = f32[]{:T(128)} add(%reduce_sum.154, %reduce_sum.155), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.239 (param_0.1124: f32[4,14336,4096]) -> f32[] { - %param_0.1124 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(0) - %bitcast.314 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_0.1124), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.906.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.766.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.906.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.765.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%sqrt.59.clone.1, %add.766.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.184.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%div.729.clone.1, %add.765.clone.1), metadata={op_name="multiply.32"} + %div.726.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.768.clone.1, %multiply.184.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1809.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_0.1114, %broadcast.517.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.764.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%div.726.clone.1, %mul.1809.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1808.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%mul.1810.clone.1, %add.764.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.763.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%param_0.1114, %mul.1808.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.161 = f32[4096,128256]{1,0:T(8,128)} multiply(%add.763.clone.1, %add.763.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.986 = f32[]{:T(128)} constant(0) + %reduce.82 = f32[]{:T(128)} reduce(%square.161, %constant.986), dimensions={0,1}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.84.clone.1 = f32[]{:T(128)} reduce(%integer_pow.61.clone.1, %constant.986), dimensions={0,1}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.136 = (f32[]{:T(128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.82, %add.763.clone.1, %add.767.clone.1, %add.768.clone.1, %reduce.84.clone.1) +} + +%region_25.30 (reduce_sum.208: f32[], reduce_sum.209: f32[]) -> f32[] { + %reduce_sum.208 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.209 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.210 = f32[]{:T(128)} add(%reduce_sum.208, %reduce_sum.209), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.239 (param_0.1131: f32[4,14336,4096]) -> f32[] { + %param_0.1131 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(0) + %bitcast.314 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_0.1131), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.164 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%bitcast.314, %bitcast.314), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1011 = f32[]{:T(128)} constant(0) - ROOT %reduce.124 = f32[]{:T(128)} reduce(%square.164, %constant.1011), dimensions={0,1,2}, to_apply=%region_25.30, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.1003 = f32[]{:T(128)} constant(0) + ROOT %reduce.85 = f32[]{:T(128)} reduce(%square.164, %constant.1003), dimensions={0,1,2}, to_apply=%region_25.30, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_24.29 (reduce_sum.148: f32[], reduce_sum.149: f32[]) -> f32[] { - %reduce_sum.148 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.149 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.150 = f32[]{:T(128)} add(%reduce_sum.148, %reduce_sum.149), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_24.29 (reduce_sum.202: f32[], reduce_sum.203: f32[]) -> f32[] { + %reduce_sum.202 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.203 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.207 = f32[]{:T(128)} add(%reduce_sum.202, %reduce_sum.203), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_23.28 (reduce_sum.142: f32[], reduce_sum.143: f32[]) -> f32[] { - %reduce_sum.142 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.143 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.147 = f32[]{:T(128)} add(%reduce_sum.142, %reduce_sum.143), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_23.28 (reduce_sum.196: f32[], reduce_sum.200: f32[]) -> f32[] { + %reduce_sum.196 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.200 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.201 = f32[]{:T(128)} add(%reduce_sum.196, %reduce_sum.200), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.241 (param_0.1125: f32[4,4096,14336], param_1.1284: f32[4,4096,14336]) -> (f32[], f32[]) { - %param_0.1125 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(0) - %bitcast.318 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_0.1125), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.241 (param_0.1132: f32[4,4096,14336], param_1.1288: f32[4,4096,14336]) -> (f32[], f32[]) { + %param_0.1132 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(0) + %bitcast.318 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_0.1132), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.167 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.318, %bitcast.318), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1012 = f32[]{:T(128)} constant(0) - %reduce.125 = f32[]{:T(128)} reduce(%square.167, %constant.1012), dimensions={0,1,2}, to_apply=%region_24.29, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1284 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(1) - %bitcast.322.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_1.1284), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.1004 = f32[]{:T(128)} constant(0) + %reduce.86 = f32[]{:T(128)} reduce(%square.167, %constant.1004), dimensions={0,1,2}, to_apply=%region_24.29, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1288 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(1) + %bitcast.322.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_1.1288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.170.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.322.clone.1, %bitcast.322.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.126.clone.1 = f32[]{:T(128)} reduce(%square.170.clone.1, %constant.1012), dimensions={0,1,2}, to_apply=%region_23.28, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.155 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.125, %reduce.126.clone.1) + %reduce.87.clone.1 = f32[]{:T(128)} reduce(%square.170.clone.1, %constant.1004), dimensions={0,1,2}, to_apply=%region_23.28, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.155 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.86, %reduce.87.clone.1) } -%fused_computation.244 (param_0.694: f32[14336,4,4096]) -> bf16[4,14336,4096] { - %param_0.694 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) - %copy.234 = bf16[14336,4,4096]{2,0,1:T(8,128)(2,1)} copy(%param_0.694), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} +%fused_computation.244 (param_0.699: f32[14336,4,4096]) -> bf16[4,14336,4096] { + %param_0.699 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %copy.234 = bf16[14336,4,4096]{2,0,1:T(8,128)(2,1)} copy(%param_0.699), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} ROOT %bitcast.323 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} bitcast(%copy.234), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.245 (param_0.696: f32[4096,4,14336]) -> bf16[4,4096,14336] { - %param_0.696 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %copy.235 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.696), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} +%fused_computation.245 (param_0.701: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.701 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.235 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.701), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} ROOT %bitcast.324 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.235), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.246 (param_0.698: f32[4096,4,14336]) -> bf16[4,4096,14336] { - %param_0.698 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %copy.236 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.698), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} +%fused_computation.246 (param_0.703: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.703 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.236 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.703), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} ROOT %bitcast.325 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.236), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_52.57 (reduce_sum.289: f32[], reduce_sum.290: f32[]) -> f32[] { - %reduce_sum.289 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.290 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.294 = f32[]{:T(128)} add(%reduce_sum.289, %reduce_sum.290), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_52.57 (reduce_sum.343: f32[], reduce_sum.347: f32[]) -> f32[] { + %reduce_sum.343 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.347 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.348 = f32[]{:T(128)} add(%reduce_sum.343, %reduce_sum.347), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_38.43 (reduce_sum.217: f32[], reduce_sum.218: f32[]) -> f32[] { - %reduce_sum.217 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.218 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.219 = f32[]{:T(128)} add(%reduce_sum.217, %reduce_sum.218), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_38.43 (reduce_sum.271: f32[], reduce_sum.272: f32[]) -> f32[] { + %reduce_sum.271 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.272 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.273 = f32[]{:T(128)} add(%reduce_sum.271, %reduce_sum.272), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.247 (param_0.1114: f32[14336,4,4096], param_1.1277: f32[], param_2.1108: f32[], param_3.797: f32[], param_4.498: f32[14336,4,4096], param_5.423: f32[], param_6.295: f32[4,14336,4096], param_7.194: pred[], param_8.112: f32[14336,4,4096]) -> (f32[], f32[14336,4,4096], f32[14336,4,4096], f32[14336,4,4096], f32[]) { - %param_0.1114 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) - %param_3.797 = f32[]{:T(128)S(6)} parameter(3) - %mul.1550.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_3.797), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.247 (param_0.1121: f32[14336,4,4096], param_1.1281: f32[], param_2.1102: f32[], param_3.785: f32[], param_4.492: f32[14336,4,4096], param_5.423: f32[], param_6.296: f32[4,14336,4096], param_7.194: pred[], param_8.112: f32[14336,4,4096]) -> (f32[], f32[14336,4,4096], f32[14336,4,4096], f32[14336,4,4096], f32[]) { + %param_0.1121 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %param_3.785 = f32[]{:T(128)S(6)} parameter(3) + %mul.1868.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_3.785), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.194 = pred[]{:T(512)S(6)} parameter(7) %select_n.274.clone.1 = pred[14336,4,4096]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.194), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.295 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(6) - %bitcast.423.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_6.295), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.296 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(6) + %bitcast.423.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_6.296), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.423 = f32[]{:T(128)} parameter(5) %div.789.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_5.423), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.788.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%bitcast.423.clone.1, %div.789.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.273.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} select(%select_n.274.clone.1, %bitcast.423.clone.1, %div.788.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.955.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.586.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.955.clone.1), dimensions={}, metadata={op_name="broadcast.69"} - %mul.1556.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %broadcast.586.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.947.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.547.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.947.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.1874.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %broadcast.547.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.112 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(8) - %constant.959.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1557.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.959.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1555.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_8.112, %mul.1557.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.820.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1556.clone.1, %mul.1555.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1108 = f32[]{:T(128)S(6)} parameter(2) - %div.785.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_2.1108), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.951.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1875.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.951.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1873.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_8.112, %mul.1875.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.806.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1874.clone.1, %mul.1873.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1102 = f32[]{:T(128)S(6)} parameter(2) + %div.785.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_2.1102), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.68.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %select_n.273.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.958.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1554.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.958.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1552.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%integer_pow.68.clone.1, %mul.1554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.498 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(4) - %constant.957.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1553.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.957.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1551.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_4.498, %mul.1553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.819.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1552.clone.1, %mul.1551.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1277 = f32[]{:T(128)S(6)} parameter(1) - %div.784.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_1.1277), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.783.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.819.clone.1, %div.784.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.950.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1872.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.950.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1870.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%integer_pow.68.clone.1, %mul.1872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.492 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(4) + %constant.949.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1871.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.949.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1869.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_4.492, %mul.1871.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.805.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1870.clone.1, %mul.1869.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1281 = f32[]{:T(128)S(6)} parameter(1) + %div.784.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_1.1281), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.783.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.805.clone.1, %div.784.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.66.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} sqrt(%div.783.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.956.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.818.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.956.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.817.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%sqrt.66.clone.1, %add.818.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.264.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%div.785.clone.1, %add.817.clone.1), metadata={op_name="multiply.34"} - %div.782.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.820.clone.1, %multiply.264.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1549.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_0.1114, %broadcast.586.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.816.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%div.782.clone.1, %mul.1549.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1548.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%mul.1550.clone.1, %add.816.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.815.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%param_0.1114, %mul.1548.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.171 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%add.815.clone.1, %add.815.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1001 = f32[]{:T(128)} constant(0) - %reduce.127 = f32[]{:T(128)} reduce(%square.171, %constant.1001), dimensions={0,1,2}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.130.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1001), dimensions={0,1,2}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.137 = (f32[]{:T(128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.127, %add.815.clone.1, %add.819.clone.1, %add.820.clone.1, %reduce.130.clone.1) + %constant.948.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.804.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.948.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.803.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%sqrt.66.clone.1, %add.804.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.191.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%div.785.clone.1, %add.803.clone.1), metadata={op_name="multiply.25"} + %div.782.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.806.clone.1, %multiply.191.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1867.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_0.1121, %broadcast.547.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.802.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%div.782.clone.1, %mul.1867.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1866.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%mul.1868.clone.1, %add.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.801.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%param_0.1121, %mul.1866.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.171 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%add.801.clone.1, %add.801.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.993 = f32[]{:T(128)} constant(0) + %reduce.88 = f32[]{:T(128)} reduce(%square.171, %constant.993), dimensions={0,1,2}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.91.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.993), dimensions={0,1,2}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.137 = (f32[]{:T(128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.88, %add.801.clone.1, %add.805.clone.1, %add.806.clone.1, %reduce.91.clone.1) } -%region_51.56 (reduce_sum.283: f32[], reduce_sum.287: f32[]) -> f32[] { - %reduce_sum.283 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.287 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.288 = f32[]{:T(128)} add(%reduce_sum.283, %reduce_sum.287), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_51.56 (reduce_sum.340: f32[], reduce_sum.341: f32[]) -> f32[] { + %reduce_sum.340 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.341 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.342 = f32[]{:T(128)} add(%reduce_sum.340, %reduce_sum.341), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_37.42 (reduce_sum.211: f32[], reduce_sum.212: f32[]) -> f32[] { - %reduce_sum.211 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.212 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.213 = f32[]{:T(128)} add(%reduce_sum.211, %reduce_sum.212), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_37.42 (reduce_sum.265: f32[], reduce_sum.266: f32[]) -> f32[] { + %reduce_sum.265 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.266 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.270 = f32[]{:T(128)} add(%reduce_sum.265, %reduce_sum.266), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.248 (param_0.1115: f32[4096,4,14336], param_1.1278: f32[], param_2.1109: f32[], param_3.798: f32[], param_4.499: f32[4096,4,14336], param_5.424: f32[], param_6.296: f32[4,4096,14336], param_7.195: pred[], param_8.113: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { - %param_0.1115 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %param_3.798 = f32[]{:T(128)S(6)} parameter(3) - %mul.1560.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.798), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.248 (param_0.1122: f32[4096,4,14336], param_1.1282: f32[], param_2.1103: f32[], param_3.786: f32[], param_4.493: f32[4096,4,14336], param_5.424: f32[], param_6.297: f32[4,4096,14336], param_7.195: pred[], param_8.113: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1122 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.786 = f32[]{:T(128)S(6)} parameter(3) + %mul.1878.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.786), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.195 = pred[]{:T(512)S(6)} parameter(7) %select_n.278.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.195), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.296 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) - %bitcast.425.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.296), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.297 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.425.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.297), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.424 = f32[]{:T(128)} parameter(5) %div.797.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.424), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.796.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.425.clone.1, %div.797.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.277.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.278.clone.1, %bitcast.425.clone.1, %div.796.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.961.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.592.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.961.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1564.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %broadcast.592.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.953.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.553.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.953.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1882.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %broadcast.553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.113 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) - %constant.965.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.591.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.965.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1563.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.113, %broadcast.591.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.825.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1564.clone.1, %mul.1563.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1109 = f32[]{:T(128)S(6)} parameter(2) - %div.793.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1109), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.957.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.552.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.957.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1881.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.113, %broadcast.552.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.811.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1882.clone.1, %mul.1881.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1103 = f32[]{:T(128)S(6)} parameter(2) + %div.793.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1103), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.69.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %select_n.277.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.964.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.590.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.964.clone.1), dimensions={}, metadata={op_name="broadcast.60"} - %mul.1562.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.69.clone.1, %broadcast.590.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.499 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) - %constant.963.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.589.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.963.clone.1), dimensions={}, metadata={op_name="broadcast.59"} - %mul.1561.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.499, %broadcast.589.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.824.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1562.clone.1, %mul.1561.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1278 = f32[]{:T(128)S(6)} parameter(1) - %div.792.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1278), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.791.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.824.clone.1, %div.792.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.956.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.551.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.956.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1880.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.69.clone.1, %broadcast.551.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.493 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.955.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.550.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.955.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1879.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.493, %broadcast.550.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.810.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1880.clone.1, %mul.1879.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1282 = f32[]{:T(128)S(6)} parameter(1) + %div.792.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1282), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.791.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.810.clone.1, %div.792.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.67.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.791.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.962.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.587.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.962.clone.1), dimensions={}, metadata={op_name="broadcast.54"} - %add.823.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.67.clone.1, %broadcast.587.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.265.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.793.clone.1, %add.823.clone.1), metadata={op_name="multiply.33"} - %div.790.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.825.clone.1, %multiply.265.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1559.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1115, %broadcast.592.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.822.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.790.clone.1, %mul.1559.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1558.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1560.clone.1, %add.822.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.821.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1115, %mul.1558.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.172 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.821.clone.1, %add.821.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1002 = f32[]{:T(128)} constant(0) - %reduce.128 = f32[]{:T(128)} reduce(%square.172, %constant.1002), dimensions={0,1,2}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.131.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1002), dimensions={0,1,2}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.138 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.128, %add.821.clone.1, %add.824.clone.1, %add.825.clone.1, %reduce.131.clone.1) + %constant.954.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.548.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.954.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.809.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.67.clone.1, %broadcast.548.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.192.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.793.clone.1, %add.809.clone.1), metadata={op_name="multiply.24"} + %div.790.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.811.clone.1, %multiply.192.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1877.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1122, %broadcast.553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.808.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.790.clone.1, %mul.1877.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1876.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1878.clone.1, %add.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.807.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1122, %mul.1876.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.172 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.807.clone.1, %add.807.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.994 = f32[]{:T(128)} constant(0) + %reduce.89 = f32[]{:T(128)} reduce(%square.172, %constant.994), dimensions={0,1,2}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.92.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.994), dimensions={0,1,2}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.138 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.89, %add.807.clone.1, %add.810.clone.1, %add.811.clone.1, %reduce.92.clone.1) } -%region_50.55 (reduce_sum.280: f32[], reduce_sum.281: f32[]) -> f32[] { - %reduce_sum.280 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.281 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.282 = f32[]{:T(128)} add(%reduce_sum.280, %reduce_sum.281), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_50.55 (reduce_sum.334: f32[], reduce_sum.335: f32[]) -> f32[] { + %reduce_sum.334 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.335 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.336 = f32[]{:T(128)} add(%reduce_sum.334, %reduce_sum.335), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_36.41 (reduce_sum.205: f32[], reduce_sum.206: f32[]) -> f32[] { - %reduce_sum.205 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.206 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.210 = f32[]{:T(128)} add(%reduce_sum.205, %reduce_sum.206), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_36.41 (reduce_sum.259: f32[], reduce_sum.263: f32[]) -> f32[] { + %reduce_sum.259 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.263 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.264 = f32[]{:T(128)} add(%reduce_sum.259, %reduce_sum.263), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.249 (param_0.1116: f32[4096,4,14336], param_1.1279: f32[], param_2.1110: f32[], param_3.799: f32[], param_4.500: f32[4096,4,14336], param_5.425: f32[], param_6.297: f32[4,4096,14336], param_7.196: pred[], param_8.114: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { - %param_0.1116 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %param_3.799 = f32[]{:T(128)S(6)} parameter(3) - %mul.1567.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.799), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.249 (param_0.1123: f32[4096,4,14336], param_1.1283: f32[], param_2.1104: f32[], param_3.787: f32[], param_4.494: f32[4096,4,14336], param_5.425: f32[], param_6.298: f32[4,4096,14336], param_7.196: pred[], param_8.114: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1123 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.787 = f32[]{:T(128)S(6)} parameter(3) + %mul.1885.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.787), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.196 = pred[]{:T(512)S(6)} parameter(7) %select_n.282.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.196), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.297 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) - %bitcast.427.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.297), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.298 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.427.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.298), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.425 = f32[]{:T(128)} parameter(5) %div.805.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.425), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.804.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.427.clone.1, %div.805.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.281.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.282.clone.1, %bitcast.427.clone.1, %div.804.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.967.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.598.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.967.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1571.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %broadcast.598.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.959.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.559.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.959.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1889.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.114 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) - %constant.971.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.597.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.971.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1570.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.114, %broadcast.597.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.830.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1571.clone.1, %mul.1570.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1110 = f32[]{:T(128)S(6)} parameter(2) - %div.801.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1110), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.963.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.558.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.963.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1888.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.114, %broadcast.558.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.816.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1889.clone.1, %mul.1888.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1104 = f32[]{:T(128)S(6)} parameter(2) + %div.801.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1104), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.70.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %select_n.281.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.970.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.596.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.970.clone.1), dimensions={}, metadata={op_name="broadcast.60"} - %mul.1569.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.596.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.500 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) - %constant.969.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.595.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.969.clone.1), dimensions={}, metadata={op_name="broadcast.59"} - %mul.1568.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.500, %broadcast.595.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.829.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1569.clone.1, %mul.1568.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1279 = f32[]{:T(128)S(6)} parameter(1) - %div.800.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.799.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.829.clone.1, %div.800.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.962.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.557.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.962.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1887.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.557.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.494 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.961.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.556.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.961.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1886.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.494, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.815.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1887.clone.1, %mul.1886.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1283 = f32[]{:T(128)S(6)} parameter(1) + %div.800.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1283), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.799.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.815.clone.1, %div.800.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.68.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.799.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.968.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.593.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.968.clone.1), dimensions={}, metadata={op_name="broadcast.54"} - %add.828.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.68.clone.1, %broadcast.593.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.266.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.801.clone.1, %add.828.clone.1), metadata={op_name="multiply.32"} - %div.798.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.830.clone.1, %multiply.266.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1566.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1116, %broadcast.598.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.827.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.798.clone.1, %mul.1566.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1565.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1567.clone.1, %add.827.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.826.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1116, %mul.1565.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.173 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.826.clone.1, %add.826.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1003 = f32[]{:T(128)} constant(0) - %reduce.129 = f32[]{:T(128)} reduce(%square.173, %constant.1003), dimensions={0,1,2}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.132.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1003), dimensions={0,1,2}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.139 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.129, %add.826.clone.1, %add.829.clone.1, %add.830.clone.1, %reduce.132.clone.1) + %constant.960.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.554.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.960.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.814.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.68.clone.1, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.193.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.801.clone.1, %add.814.clone.1), metadata={op_name="multiply.23"} + %div.798.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.816.clone.1, %multiply.193.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1884.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1123, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.813.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.798.clone.1, %mul.1884.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1883.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1885.clone.1, %add.813.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.812.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1123, %mul.1883.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.173 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.812.clone.1, %add.812.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.995 = f32[]{:T(128)} constant(0) + %reduce.90 = f32[]{:T(128)} reduce(%square.173, %constant.995), dimensions={0,1,2}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.93.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.995), dimensions={0,1,2}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.139 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.90, %add.812.clone.1, %add.815.clone.1, %add.816.clone.1, %reduce.93.clone.1) } -%region_30.35 (reduce_sum.178: f32[], reduce_sum.182: f32[]) -> f32[] { - %reduce_sum.178 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.182 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.183 = f32[]{:T(128)} add(%reduce_sum.178, %reduce_sum.182), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_30.35 (reduce_sum.235: f32[], reduce_sum.236: f32[]) -> f32[] { + %reduce_sum.235 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.236 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.237 = f32[]{:T(128)} add(%reduce_sum.235, %reduce_sum.236), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.267 (param_0.1120: f32[4,4096,32,128]) -> f32[] { - %param_0.1120 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.329 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1120), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.267 (param_0.1127: f32[4,4096,32,128]) -> f32[] { + %param_0.1127 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.329 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1127), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.176 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%bitcast.329, %bitcast.329), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1007 = f32[]{:T(128)} constant(0) - ROOT %reduce.133 = f32[]{:T(128)} reduce(%square.176, %constant.1007), dimensions={0,1,2,3}, to_apply=%region_30.35, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.999 = f32[]{:T(128)} constant(0) + ROOT %reduce.94 = f32[]{:T(128)} reduce(%square.176, %constant.999), dimensions={0,1,2,3}, to_apply=%region_30.35, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_29.34 (reduce_sum.175: f32[], reduce_sum.176: f32[]) -> f32[] { - %reduce_sum.175 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.176 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.177 = f32[]{:T(128)} add(%reduce_sum.175, %reduce_sum.176), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_29.34 (reduce_sum.229: f32[], reduce_sum.230: f32[]) -> f32[] { + %reduce_sum.229 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.230 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.231 = f32[]{:T(128)} add(%reduce_sum.229, %reduce_sum.230), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.269 (param_0.1121: f32[4,32,128,4096]) -> f32[] { - %param_0.1121 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.333 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_0.1121), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.269 (param_0.1128: f32[4,32,128,4096]) -> f32[] { + %param_0.1128 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.333 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_0.1128), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.179 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%bitcast.333, %bitcast.333), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1008 = f32[]{:T(128)} constant(0) - ROOT %reduce.134 = f32[]{:T(128)} reduce(%square.179, %constant.1008), dimensions={0,1,2,3}, to_apply=%region_29.34, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.1000 = f32[]{:T(128)} constant(0) + ROOT %reduce.95 = f32[]{:T(128)} reduce(%square.179, %constant.1000), dimensions={0,1,2,3}, to_apply=%region_29.34, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.270 (param_0.748: f32[32,4,128,4096]) -> bf16[4,32,128,4096] { - %param_0.748 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) - %copy.237 = bf16[32,4,128,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.748), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} +%fused_computation.270 (param_0.753: f32[32,4,128,4096]) -> bf16[4,32,128,4096] { + %param_0.753 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %copy.237 = bf16[32,4,128,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.753), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} ROOT %bitcast.334 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.237), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_57.62 (reduce_sum.317: f32[], reduce_sum.318: f32[]) -> f32[] { - %reduce_sum.317 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.318 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.316 = f32[]{:T(128)} add(%reduce_sum.317, %reduce_sum.318), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_57.62 (reduce_sum.370: f32[], reduce_sum.371: f32[]) -> f32[] { + %reduce_sum.370 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.371 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.375 = f32[]{:T(128)} add(%reduce_sum.370, %reduce_sum.371), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_43.48 (reduce_sum.241: f32[], reduce_sum.245: f32[]) -> f32[] { - %reduce_sum.241 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.245 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.246 = f32[]{:T(128)} add(%reduce_sum.241, %reduce_sum.245), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_43.48 (reduce_sum.298: f32[], reduce_sum.299: f32[]) -> f32[] { + %reduce_sum.298 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.299 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.300 = f32[]{:T(128)} add(%reduce_sum.298, %reduce_sum.299), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.271 (param_0.1109: f32[4096,4,32,128], param_1.1272: f32[], param_2.1103: f32[], param_3.792: f32[], param_4.493: f32[4096,4,32,128], param_5.418: f32[], param_6.290: f32[4,4096,32,128], param_7.189: pred[], param_8.107: f32[4096,4,32,128]) -> (f32[], f32[4096,4,32,128], f32[4096,4,32,128], f32[4096,4,32,128], f32[]) { - %param_0.1109 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.792 = f32[]{:T(128)S(6)} parameter(3) - %mul.1509.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_3.792), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.271 (param_0.1116: f32[4096,4,32,128], param_1.1276: f32[], param_2.1097: f32[], param_3.780: f32[], param_4.487: f32[4096,4,32,128], param_5.418: f32[], param_6.291: f32[4,4096,32,128], param_7.189: pred[], param_8.107: f32[4096,4,32,128]) -> (f32[], f32[4096,4,32,128], f32[4096,4,32,128], f32[4096,4,32,128], f32[]) { + %param_0.1116 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.780 = f32[]{:T(128)S(6)} parameter(3) + %mul.1827.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_3.780), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.189 = pred[]{:T(512)S(6)} parameter(7) %select_n.254.clone.1 = pred[4096,4,32,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.189), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.290 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.413.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_6.290), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.291 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.413.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_6.291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.418 = f32[]{:T(128)} parameter(5) %div.749.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_5.418), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.748.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%bitcast.413.clone.1, %div.749.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.253.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} select(%select_n.254.clone.1, %bitcast.413.clone.1, %div.748.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.925.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.564.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.925.clone.1), dimensions={}, metadata={op_name="broadcast.63"} - %mul.1515.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %broadcast.564.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.917.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.525.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.917.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %mul.1833.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %broadcast.525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.107 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.929.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1516.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.929.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1514.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_8.107, %mul.1516.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.793.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1515.clone.1, %mul.1514.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1103 = f32[]{:T(128)S(6)} parameter(2) - %div.745.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1103), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.921.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1834.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.921.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1832.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_8.107, %mul.1834.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.779.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1833.clone.1, %mul.1832.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1097 = f32[]{:T(128)S(6)} parameter(2) + %div.745.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1097), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.63.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %select_n.253.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.928.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1513.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.928.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1511.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.63.clone.1, %mul.1513.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.493 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.927.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1512.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.927.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1510.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_4.493, %mul.1512.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.792.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1511.clone.1, %mul.1510.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1272 = f32[]{:T(128)S(6)} parameter(1) - %div.744.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1272), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.743.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.792.clone.1, %div.744.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.920.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1831.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.920.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1829.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.63.clone.1, %mul.1831.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.487 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.919.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1830.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.919.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1828.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_4.487, %mul.1830.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.778.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1829.clone.1, %mul.1828.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1276 = f32[]{:T(128)S(6)} parameter(1) + %div.744.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1276), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.743.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.778.clone.1, %div.744.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.61.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} sqrt(%div.743.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.926.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.791.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.926.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.790.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%sqrt.61.clone.1, %add.791.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.259.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%div.745.clone.1, %add.790.clone.1), metadata={op_name="multiply.39"} - %div.742.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.793.clone.1, %multiply.259.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1508.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_0.1109, %broadcast.564.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.789.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%div.742.clone.1, %mul.1508.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1507.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%mul.1509.clone.1, %add.789.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.788.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%param_0.1109, %mul.1507.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.180 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%add.788.clone.1, %add.788.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.996 = f32[]{:T(128)} constant(0) - %reduce.135 = f32[]{:T(128)} reduce(%square.180, %constant.996), dimensions={0,1,2,3}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.139.clone.1 = f32[]{:T(128)} reduce(%integer_pow.63.clone.1, %constant.996), dimensions={0,1,2,3}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.140 = (f32[]{:T(128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.135, %add.788.clone.1, %add.792.clone.1, %add.793.clone.1, %reduce.139.clone.1) -} - -%region_56.61 (reduce_sum.310: f32[], reduce_sum.311: f32[]) -> f32[] { - %reduce_sum.310 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.311 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.315 = f32[]{:T(128)} add(%reduce_sum.310, %reduce_sum.311), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_42.47 (reduce_sum.238: f32[], reduce_sum.239: f32[]) -> f32[] { - %reduce_sum.238 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.239 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.240 = f32[]{:T(128)} add(%reduce_sum.238, %reduce_sum.239), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.272 (param_0.1110: f32[32,4,128,4096], param_1.1273: f32[], param_2.1104: f32[], param_3.793: f32[], param_4.494: f32[32,4,128,4096], param_5.419: f32[], param_6.291: f32[4,32,128,4096], param_7.190: pred[], param_8.108: f32[32,4,128,4096]) -> (f32[], f32[32,4,128,4096], f32[32,4,128,4096], f32[32,4,128,4096], f32[]) { - %param_0.1110 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) - %param_3.793 = f32[]{:T(128)S(6)} parameter(3) - %mul.1519.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_3.793), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.918.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.777.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.918.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.776.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%sqrt.61.clone.1, %add.777.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.186.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%div.745.clone.1, %add.776.clone.1), metadata={op_name="multiply.30"} + %div.742.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.779.clone.1, %multiply.186.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1826.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_0.1116, %broadcast.525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.775.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%div.742.clone.1, %mul.1826.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1825.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%mul.1827.clone.1, %add.775.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.774.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%param_0.1116, %mul.1825.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.180 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%add.774.clone.1, %add.774.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.988 = f32[]{:T(128)} constant(0) + %reduce.96 = f32[]{:T(128)} reduce(%square.180, %constant.988), dimensions={0,1,2,3}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.100.clone.1 = f32[]{:T(128)} reduce(%integer_pow.63.clone.1, %constant.988), dimensions={0,1,2,3}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.140 = (f32[]{:T(128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.96, %add.774.clone.1, %add.778.clone.1, %add.779.clone.1, %reduce.100.clone.1) +} + +%region_56.61 (reduce_sum.364: f32[], reduce_sum.368: f32[]) -> f32[] { + %reduce_sum.364 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.368 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.369 = f32[]{:T(128)} add(%reduce_sum.364, %reduce_sum.368), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_42.47 (reduce_sum.292: f32[], reduce_sum.293: f32[]) -> f32[] { + %reduce_sum.292 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.293 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.294 = f32[]{:T(128)} add(%reduce_sum.292, %reduce_sum.293), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.272 (param_0.1117: f32[32,4,128,4096], param_1.1277: f32[], param_2.1098: f32[], param_3.781: f32[], param_4.488: f32[32,4,128,4096], param_5.419: f32[], param_6.292: f32[4,32,128,4096], param_7.190: pred[], param_8.108: f32[32,4,128,4096]) -> (f32[], f32[32,4,128,4096], f32[32,4,128,4096], f32[32,4,128,4096], f32[]) { + %param_0.1117 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %param_3.781 = f32[]{:T(128)S(6)} parameter(3) + %mul.1837.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_3.781), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.190 = pred[]{:T(512)S(6)} parameter(7) %select_n.258.clone.1 = pred[32,4,128,4096]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.190), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.291 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.415.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_6.291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.292 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.415.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_6.292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.419 = f32[]{:T(128)} parameter(5) %div.757.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_5.419), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.756.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%bitcast.415.clone.1, %div.757.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.257.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} select(%select_n.258.clone.1, %bitcast.415.clone.1, %div.756.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.931.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.566.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.931.clone.1), dimensions={}, metadata={op_name="broadcast.64"} - %mul.1525.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %broadcast.566.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.923.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.527.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.923.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %mul.1843.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %broadcast.527.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.108 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(8) - %constant.935.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1526.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.935.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1524.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_8.108, %mul.1526.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.799.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1525.clone.1, %mul.1524.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1104 = f32[]{:T(128)S(6)} parameter(2) - %div.753.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_2.1104), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.927.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1844.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.927.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1842.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_8.108, %mul.1844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.785.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1843.clone.1, %mul.1842.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1098 = f32[]{:T(128)S(6)} parameter(2) + %div.753.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_2.1098), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.64.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %select_n.257.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.934.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1523.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.934.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1521.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%integer_pow.64.clone.1, %mul.1523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.494 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(4) - %constant.933.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1522.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.933.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1520.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_4.494, %mul.1522.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.798.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1521.clone.1, %mul.1520.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1273 = f32[]{:T(128)S(6)} parameter(1) - %div.752.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_1.1273), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.751.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.798.clone.1, %div.752.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.926.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1841.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.926.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1839.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%integer_pow.64.clone.1, %mul.1841.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.488 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(4) + %constant.925.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1840.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.925.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1838.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_4.488, %mul.1840.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.784.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1839.clone.1, %mul.1838.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1277 = f32[]{:T(128)S(6)} parameter(1) + %div.752.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_1.1277), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.751.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.784.clone.1, %div.752.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.62.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} sqrt(%div.751.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.932.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.797.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.932.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.796.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%sqrt.62.clone.1, %add.797.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.260.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%div.753.clone.1, %add.796.clone.1), metadata={op_name="multiply.38"} - %div.750.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.799.clone.1, %multiply.260.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1518.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_0.1110, %broadcast.566.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.795.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%div.750.clone.1, %mul.1518.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1517.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%mul.1519.clone.1, %add.795.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.794.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%param_0.1110, %mul.1517.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.181 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%add.794.clone.1, %add.794.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.997 = f32[]{:T(128)} constant(0) - %reduce.136 = f32[]{:T(128)} reduce(%square.181, %constant.997), dimensions={0,1,2,3}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.140.clone.1 = f32[]{:T(128)} reduce(%integer_pow.64.clone.1, %constant.997), dimensions={0,1,2,3}, to_apply=%region_42.47, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.141 = (f32[]{:T(128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.136, %add.794.clone.1, %add.798.clone.1, %add.799.clone.1, %reduce.140.clone.1) -} - -%region_47.52 (reduce_sum.262: f32[], reduce_sum.266: f32[]) -> f32[] { - %reduce_sum.262 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.266 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.267 = f32[]{:T(128)} add(%reduce_sum.262, %reduce_sum.266), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.279 (param_0.1129: bf16[4,128,128256], param_1.1288: f32[4,128], param_2.1115: s32[4,128], param_3.803: bf16[4,128]) -> f32[4,128] { - %param_2.1115 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %eq.30 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1115), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %constant.924.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.783.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.924.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.782.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%sqrt.62.clone.1, %add.783.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.187.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%div.753.clone.1, %add.782.clone.1), metadata={op_name="multiply.29"} + %div.750.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.785.clone.1, %multiply.187.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1836.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_0.1117, %broadcast.527.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.781.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%div.750.clone.1, %mul.1836.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1835.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%mul.1837.clone.1, %add.781.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.780.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%param_0.1117, %mul.1835.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.181 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%add.780.clone.1, %add.780.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.989 = f32[]{:T(128)} constant(0) + %reduce.97 = f32[]{:T(128)} reduce(%square.181, %constant.989), dimensions={0,1,2,3}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.101.clone.1 = f32[]{:T(128)} reduce(%integer_pow.64.clone.1, %constant.989), dimensions={0,1,2,3}, to_apply=%region_42.47, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.141 = (f32[]{:T(128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.97, %add.780.clone.1, %add.784.clone.1, %add.785.clone.1, %reduce.101.clone.1) +} + +%region_47.52 (reduce_sum.319: f32[], reduce_sum.320: f32[]) -> f32[] { + %reduce_sum.319 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.320 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.321 = f32[]{:T(128)} add(%reduce_sum.319, %reduce_sum.320), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.279 (param_0.1136: bf16[4,128,128256], param_1.1292: f32[4,128], param_2.1109: s32[4,128], param_3.791: bf16[4,128]) -> f32[4,128] { + %param_2.1109 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.30 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1109), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.25 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.24 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.30, %eq.25), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %param_0.1129 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.950 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1129), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_3.803 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) - %sub.73 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.803), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.64 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.950, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %param_1.1288 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %sub.71 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1288), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_0.1136 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.966 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1136), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.791 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.791), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.64 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.966, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.1292 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1292), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %sub.60 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %constant.1017 = f32[]{:T(128)} constant(0) - %broadcast.511 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%constant.1017), dimensions={}, metadata={op_name="broadcast.83"} - %mul.1373 = f32[4,128,128256]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.511), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - ROOT %reduce.137 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1373, %constant.1017), dimensions={2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1009 = f32[]{:T(128)} constant(0) + %broadcast.472 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%constant.1009), dimensions={}, metadata={op_name="broadcast.39"} + %mul.1674 = f32[4,128,128256]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.472), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.98 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1674, %constant.1009), dimensions={2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_7.10 (reduce_sum.93: f32[], reduce_sum.94: f32[]) -> f32[] { - %reduce_sum.93 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.94 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.95 = f32[]{:T(128)} add(%reduce_sum.93, %reduce_sum.94), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_7.10 (reduce_sum.123: f32[], reduce_sum.127: f32[]) -> f32[] { + %reduce_sum.123 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.127 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.128 = f32[]{:T(128)} add(%reduce_sum.123, %reduce_sum.127), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.284 (param_0.1130: bf16[4,128,128256], param_1.1289: bf16[4,128]) -> f32[4,128] { - %param_0.1130 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.956 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1130), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_1.1289 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) - %sub.74 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1289), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.70 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.956, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} +%fused_computation.284 (param_0.1137: bf16[4,128,128256], param_1.1293: bf16[4,128]) -> f32[4,128] { + %param_0.1137 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.972 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1137), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1293 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1293), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.70 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.972, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.54 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %constant.1018 = f32[]{:T(128)} constant(0) - ROOT %reduce.138 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1018), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1010 = f32[]{:T(128)} constant(0) + ROOT %reduce.99 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1010), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_31.36 (reduce_sum.184: f32[], reduce_sum.185: f32[]) -> f32[] { - %reduce_sum.184 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.185 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.189 = f32[]{:T(128)} add(%reduce_sum.184, %reduce_sum.185), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_31.36 (reduce_sum.238: f32[], reduce_sum.242: f32[]) -> f32[] { + %reduce_sum.238 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.242 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.243 = f32[]{:T(128)} add(%reduce_sum.238, %reduce_sum.242), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_28.33 (reduce_sum.169: f32[], reduce_sum.170: f32[]) -> f32[] { - %reduce_sum.169 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.170 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.171 = f32[]{:T(128)} add(%reduce_sum.169, %reduce_sum.170), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_28.33 (reduce_sum.223: f32[], reduce_sum.224: f32[]) -> f32[] { + %reduce_sum.223 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.224 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.228 = f32[]{:T(128)} add(%reduce_sum.223, %reduce_sum.224), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.290 (param_0.1122: f32[4,4096,8,128], param_1.1282: f32[4,4096,8,128]) -> (f32[], f32[]) { - %param_0.1122 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.350 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1122), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.290 (param_0.1129: f32[4,4096,8,128], param_1.1286: f32[4,4096,8,128]) -> (f32[], f32[]) { + %param_0.1129 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.350 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1129), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.184 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.350, %bitcast.350), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1009 = f32[]{:T(128)} constant(0) - %reduce.141 = f32[]{:T(128)} reduce(%square.184, %constant.1009), dimensions={0,1,2,3}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1282 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(1) - %bitcast.354.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1282), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.1001 = f32[]{:T(128)} constant(0) + %reduce.102 = f32[]{:T(128)} reduce(%square.184, %constant.1001), dimensions={0,1,2,3}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1286 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(1) + %bitcast.354.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1286), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.187.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.354.clone.1, %bitcast.354.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.142.clone.1 = f32[]{:T(128)} reduce(%square.187.clone.1, %constant.1009), dimensions={0,1,2,3}, to_apply=%region_28.33, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.156 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.141, %reduce.142.clone.1) + %reduce.103.clone.1 = f32[]{:T(128)} reduce(%square.187.clone.1, %constant.1001), dimensions={0,1,2,3}, to_apply=%region_28.33, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.156 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.102, %reduce.103.clone.1) } -%fused_computation.293 (param_0.807: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { - %param_0.807 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %copy.238 = bf16[4096,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.807), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} +%fused_computation.293 (param_0.812: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.812 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %copy.238 = bf16[4096,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.812), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} ROOT %bitcast.355 = bf16[4,4096,8,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%copy.238), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_58.63 (reduce_sum.324: f32[], reduce_sum.325: f32[]) -> f32[] { - %reduce_sum.324 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.325 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.322 = f32[]{:T(128)} add(%reduce_sum.324, %reduce_sum.325), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_58.63 (reduce_sum.376: f32[], reduce_sum.377: f32[]) -> f32[] { + %reduce_sum.376 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.377 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.378 = f32[]{:T(128)} add(%reduce_sum.376, %reduce_sum.377), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_44.49 (reduce_sum.247: f32[], reduce_sum.248: f32[]) -> f32[] { - %reduce_sum.247 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.248 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.252 = f32[]{:T(128)} add(%reduce_sum.247, %reduce_sum.248), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_44.49 (reduce_sum.301: f32[], reduce_sum.305: f32[]) -> f32[] { + %reduce_sum.301 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.305 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.306 = f32[]{:T(128)} add(%reduce_sum.301, %reduce_sum.305), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.294 (param_0.1108: f32[4096,4,8,128], param_1.1271: f32[], param_2.1102: f32[], param_3.791: f32[], param_4.492: f32[4096,4,8,128], param_5.417: f32[], param_6.289: f32[4,4096,8,128], param_7.188: pred[], param_8.106: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { - %param_0.1108 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.791 = f32[]{:T(128)S(6)} parameter(3) - %mul.1502.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.791), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.294 (param_0.1115: f32[4096,4,8,128], param_1.1275: f32[], param_2.1096: f32[], param_3.779: f32[], param_4.486: f32[4096,4,8,128], param_5.417: f32[], param_6.290: f32[4,4096,8,128], param_7.188: pred[], param_8.106: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1115 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.779 = f32[]{:T(128)S(6)} parameter(3) + %mul.1820.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.779), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.188 = pred[]{:T(512)S(6)} parameter(7) %select_n.250.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.188), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.289 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.411.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.290 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.411.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.290), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.417 = f32[]{:T(128)} parameter(5) %div.741.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.417), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.740.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.411.clone.1, %div.741.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.249.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.250.clone.1, %bitcast.411.clone.1, %div.740.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.919.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.562.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.919.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1506.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %broadcast.562.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.911.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.523.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.911.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1824.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %broadcast.523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.106 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.923.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.561.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.923.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %mul.1505.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.106, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.787.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1506.clone.1, %mul.1505.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1102 = f32[]{:T(128)S(6)} parameter(2) - %div.737.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1102), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.915.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.522.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.915.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1823.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.106, %broadcast.522.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.773.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1824.clone.1, %mul.1823.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1096 = f32[]{:T(128)S(6)} parameter(2) + %div.737.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1096), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.62.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %select_n.249.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.922.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.560.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.922.clone.1), dimensions={}, metadata={op_name="broadcast.56"} - %mul.1504.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.62.clone.1, %broadcast.560.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.492 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.921.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.559.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.921.clone.1), dimensions={}, metadata={op_name="broadcast.55"} - %mul.1503.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.492, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.786.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1504.clone.1, %mul.1503.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1271 = f32[]{:T(128)S(6)} parameter(1) - %div.736.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1271), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.735.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.786.clone.1, %div.736.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.914.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.521.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.914.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1822.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.62.clone.1, %broadcast.521.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.486 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.913.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.520.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.913.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1821.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.486, %broadcast.520.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.772.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1822.clone.1, %mul.1821.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1275 = f32[]{:T(128)S(6)} parameter(1) + %div.736.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1275), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.735.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.772.clone.1, %div.736.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.60.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.735.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.920.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.557.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.920.clone.1), dimensions={}, metadata={op_name="broadcast.52"} - %add.785.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.60.clone.1, %broadcast.557.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.258.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.737.clone.1, %add.785.clone.1), metadata={op_name="multiply.40"} - %div.734.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.787.clone.1, %multiply.258.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1501.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1108, %broadcast.562.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.784.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.734.clone.1, %mul.1501.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1500.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1502.clone.1, %add.784.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.783.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1108, %mul.1500.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.188 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.783.clone.1, %add.783.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.995 = f32[]{:T(128)} constant(0) - %reduce.143 = f32[]{:T(128)} reduce(%square.188, %constant.995), dimensions={0,1,2,3}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.145.clone.1 = f32[]{:T(128)} reduce(%integer_pow.62.clone.1, %constant.995), dimensions={0,1,2,3}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.142 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.143, %add.783.clone.1, %add.786.clone.1, %add.787.clone.1, %reduce.145.clone.1) -} - -%region_55.60 (reduce_sum.304: f32[], reduce_sum.308: f32[]) -> f32[] { - %reduce_sum.304 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.308 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.309 = f32[]{:T(128)} add(%reduce_sum.304, %reduce_sum.308), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_41.46 (reduce_sum.232: f32[], reduce_sum.233: f32[]) -> f32[] { - %reduce_sum.232 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.233 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.234 = f32[]{:T(128)} add(%reduce_sum.232, %reduce_sum.233), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.912.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.518.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.912.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.771.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.60.clone.1, %broadcast.518.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.185.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.737.clone.1, %add.771.clone.1), metadata={op_name="multiply.31"} + %div.734.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.773.clone.1, %multiply.185.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1819.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1115, %broadcast.523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.770.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.734.clone.1, %mul.1819.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1818.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1820.clone.1, %add.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.769.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1115, %mul.1818.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.188 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.769.clone.1, %add.769.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.987 = f32[]{:T(128)} constant(0) + %reduce.104 = f32[]{:T(128)} reduce(%square.188, %constant.987), dimensions={0,1,2,3}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.106.clone.1 = f32[]{:T(128)} reduce(%integer_pow.62.clone.1, %constant.987), dimensions={0,1,2,3}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.142 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.104, %add.769.clone.1, %add.772.clone.1, %add.773.clone.1, %reduce.106.clone.1) +} + +%region_55.60 (reduce_sum.361: f32[], reduce_sum.362: f32[]) -> f32[] { + %reduce_sum.361 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.362 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.363 = f32[]{:T(128)} add(%reduce_sum.361, %reduce_sum.362), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_41.46 (reduce_sum.286: f32[], reduce_sum.287: f32[]) -> f32[] { + %reduce_sum.286 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.287 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.291 = f32[]{:T(128)} add(%reduce_sum.286, %reduce_sum.287), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.295 (param_0.1111: f32[4096,4,8,128], param_1.1274: f32[], param_2.1105: f32[], param_3.794: f32[], param_4.495: f32[4096,4,8,128], param_5.420: f32[], param_6.292: f32[4,4096,8,128], param_7.191: pred[], param_8.109: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { - %param_0.1111 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.794 = f32[]{:T(128)S(6)} parameter(3) - %mul.1529.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.794), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.295 (param_0.1118: f32[4096,4,8,128], param_1.1278: f32[], param_2.1099: f32[], param_3.782: f32[], param_4.489: f32[4096,4,8,128], param_5.420: f32[], param_6.293: f32[4,4096,8,128], param_7.191: pred[], param_8.109: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1118 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.782 = f32[]{:T(128)S(6)} parameter(3) + %mul.1847.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.782), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.191 = pred[]{:T(512)S(6)} parameter(7) %select_n.262.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.191), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.292 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.417.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.293 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.417.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.293), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.420 = f32[]{:T(128)} parameter(5) %div.765.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.420), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.764.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.417.clone.1, %div.765.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.261.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.262.clone.1, %bitcast.417.clone.1, %div.764.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.937.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.572.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.937.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1533.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %broadcast.572.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.929.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.533.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.929.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1851.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %broadcast.533.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.109 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.941.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.571.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.941.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %mul.1532.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.109, %broadcast.571.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.804.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1533.clone.1, %mul.1532.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1105 = f32[]{:T(128)S(6)} parameter(2) - %div.761.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1105), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.933.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.532.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.933.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1850.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.109, %broadcast.532.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.790.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1851.clone.1, %mul.1850.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1099 = f32[]{:T(128)S(6)} parameter(2) + %div.761.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1099), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.65.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %select_n.261.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.940.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.570.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.940.clone.1), dimensions={}, metadata={op_name="broadcast.56"} - %mul.1531.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %broadcast.570.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.495 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.939.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.569.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.939.clone.1), dimensions={}, metadata={op_name="broadcast.55"} - %mul.1530.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.495, %broadcast.569.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.803.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1531.clone.1, %mul.1530.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1274 = f32[]{:T(128)S(6)} parameter(1) - %div.760.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1274), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.759.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.803.clone.1, %div.760.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.932.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.531.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.932.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1849.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %broadcast.531.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.489 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.931.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.530.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.931.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1848.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.489, %broadcast.530.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.789.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1849.clone.1, %mul.1848.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1278 = f32[]{:T(128)S(6)} parameter(1) + %div.760.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1278), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.759.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.789.clone.1, %div.760.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.63.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.759.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.938.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.567.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.938.clone.1), dimensions={}, metadata={op_name="broadcast.52"} - %add.802.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.567.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.261.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.761.clone.1, %add.802.clone.1), metadata={op_name="multiply.37"} - %div.758.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.804.clone.1, %multiply.261.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1528.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1111, %broadcast.572.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.801.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.758.clone.1, %mul.1528.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1527.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1529.clone.1, %add.801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.800.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1111, %mul.1527.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.189 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.800.clone.1, %add.800.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.998 = f32[]{:T(128)} constant(0) - %reduce.144 = f32[]{:T(128)} reduce(%square.189, %constant.998), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.146.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.998), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.143 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.144, %add.800.clone.1, %add.803.clone.1, %add.804.clone.1, %reduce.146.clone.1) -} - -%fused_computation.311 (param_0.872: bf16[4,128,4096], param_1.941: f32[4,128], param_2.726: f32[4,128], param_3.452: bf16[4,128,4096], param_4.271: bf16[4096]) -> bf16[4,128,4096] { - %param_3.452 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.271 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %dot_general.375 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.271), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.365 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_3.452, %dot_general.375), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.973 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_2.726 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %mul.1423 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_2.726), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1415 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.973, %mul.1423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %param_0.872 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.984 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.872), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.941 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %mul.1422 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.941), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1421 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.984, %mul.1422), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %add_any.138 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1415, %mul.1421), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} - ROOT %convert_element_type.971 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.138), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} -} - -%region_5.8 (reduce_sum.87: f32[], reduce_sum.88: f32[]) -> f32[] { - %reduce_sum.87 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.88 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.92 = f32[]{:T(128)} add(%reduce_sum.87, %reduce_sum.88), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.312 (param_0.1131: bf16[4,128,4096]) -> f32[4,128] { - %param_0.1131 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.975 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1131), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.192 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.975, %convert_element_type.975), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} - %constant.1019 = f32[]{:T(128)} constant(0) - ROOT %reduce.147 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.192, %constant.1019), dimensions={2}, to_apply=%region_5.8, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_10.13 (reduce_sum.102: f32[], reduce_sum.106: f32[]) -> f32[] { - %reduce_sum.102 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.106 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.107 = f32[]{:T(128)} add(%reduce_sum.102, %reduce_sum.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.314 (param_0.1126: bf16[4,128,4096], param_1.1285: bf16[4,128,4096], param_2.1113: bf16[4096]) -> f32[4,128] { - %param_0.1126 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.982 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1126), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.1285 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %param_2.1113 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.374 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1113), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.364 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1285, %dot_general.374), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.981 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %mul.1419 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.982, %convert_element_type.981), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.1013 = f32[]{:T(128)} constant(0) - ROOT %reduce.148 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1419, %constant.1013), dimensions={2}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_8.11 (dot_general.182: bf16[], dot_general.183: bf16[]) -> bf16[] { - %dot_general.182 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - %dot_general.183 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - ROOT %add.168 = bf16[]{:T(256)} add(%dot_general.182, %dot_general.183), metadata={op_name="add.54"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.235.clone.clone (param_0.1095: f32[4096,128256]) -> bf16[4096,128256,1] { - %param_0.1095 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %convert_element_type.1033 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1095), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - ROOT %bitcast.449 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1033), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} -} - -%fused_computation.280.clone.1.clone.clone (param_0.1096: bf16[4,128,128256], param_1.1261: s32[4,128], param_2.1081: f32[4,128], param_3.782: f32[4,128], param_4.484: bf16[4,128], param_5.409: f32[4,128]) -> bf16[4,128,128256] { + %constant.930.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.528.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.930.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.788.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.528.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.188.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.761.clone.1, %add.788.clone.1), metadata={op_name="multiply.28"} + %div.758.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.790.clone.1, %multiply.188.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1846.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1118, %broadcast.533.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.787.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.758.clone.1, %mul.1846.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1845.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1847.clone.1, %add.787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.786.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1118, %mul.1845.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.189 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.786.clone.1, %add.786.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.990 = f32[]{:T(128)} constant(0) + %reduce.105 = f32[]{:T(128)} reduce(%square.189, %constant.990), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.107.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.990), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.143 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.105, %add.786.clone.1, %add.789.clone.1, %add.790.clone.1, %reduce.107.clone.1) +} + +%fused_computation.311 (param_0.877: bf16[4,128,4096], param_1.943: f32[4,128], param_2.720: f32[4,128], param_3.440: bf16[4,128,4096], param_4.265: bf16[4096]) -> bf16[4,128,4096] { + %param_3.440 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.265 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %mul.1754 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.265), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1728 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_3.440, %mul.1754), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.989 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1728), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.720 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.1725 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_2.720), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1716 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.989, %mul.1725), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.877 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1000 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.877), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.943 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.1723 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.943), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1722 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1000, %mul.1723), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %add_any.138 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1716, %mul.1722), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} + ROOT %convert_element_type.987 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.138), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} +} + +%region_4.7 (reduce_sum.114: f32[], reduce_sum.115: f32[]) -> f32[] { + %reduce_sum.114 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.115 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.116 = f32[]{:T(128)} add(%reduce_sum.114, %reduce_sum.115), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.312 (param_0.1138: bf16[4,128,4096]) -> f32[4,128] { + %param_0.1138 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.991 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1138), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.192 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.991, %convert_element_type.991), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} + %constant.1011 = f32[]{:T(128)} constant(0) + ROOT %reduce.108 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.192, %constant.1011), dimensions={2}, to_apply=%region_4.7, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_10.13 (reduce_sum.141: f32[], reduce_sum.142: f32[]) -> f32[] { + %reduce_sum.141 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.142 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.143 = f32[]{:T(128)} add(%reduce_sum.141, %reduce_sum.142), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.314 (param_0.1133: bf16[4,128,4096], param_1.1289: bf16[4,128,4096], param_2.1107: bf16[4096]) -> f32[4,128] { + %param_0.1133 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.998 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1133), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1289 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1107 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1753 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1107), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1727 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1289, %mul.1753), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.997 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1727), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %mul.1720 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.998, %convert_element_type.997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1005 = f32[]{:T(128)} constant(0) + ROOT %reduce.109 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1720, %constant.1005), dimensions={2}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_8.11 (reduce_sum.129: bf16[], reduce_sum.130: bf16[]) -> bf16[] { + %reduce_sum.129 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.130 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.134 = bf16[]{:T(256)} add(%reduce_sum.129, %reduce_sum.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.235.clone.clone (param_0.1102: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1102 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1049 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1102), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + ROOT %bitcast.449 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1049), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +} + +%fused_computation.280.clone.1.clone.clone (param_0.1103: bf16[4,128,128256], param_1.1265: s32[4,128], param_2.1075: f32[4,128], param_3.770: f32[4,128], param_4.478: bf16[4,128], param_5.409: f32[4,128]) -> bf16[4,128,128256] { %param_5.409 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.1603 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.409), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.782 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.1602 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.782), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1096 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1036 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1096), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.484 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.88 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.484), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.87 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1036, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %mul.1925 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.409), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.770 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1924 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.770), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1103 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1052 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1103), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.478 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.88 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.478), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.87 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1052, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.60 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.87), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.1601 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1602, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1081 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.819 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1081), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.818 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1601, %div.819), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1261 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.43 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1261), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %mul.1923 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1924, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1075 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.819 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1075), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.818 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1923, %div.819), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1265 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.43 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1265), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.42 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.41 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.43, %eq.42), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1035 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.86 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.818, %convert_element_type.1035), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.1600 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1603, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1034 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1600), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.315 (param_0.1094: f32[4,128], param_1.1260: bf16[4,128,4096], param_2.1082: f32[4096,128256], param_3.783: bf16[4,128,128256], param_4.485: s32[4,128], param_5.410: f32[4,128], param_6.284: f32[4,128], param_7.183: bf16[4,128], param_8.102: f32[4,128]) -> (bf16[4096], bf16[4,128,4096]) { - %param_3.783 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(3) - %param_4.485 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %convert_element_type.1051 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.86 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.818, %convert_element_type.1051), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.1922 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1925, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1050 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1922), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.315 (param_0.1101: f32[4,128], param_1.1264: bf16[4,128,4096], param_2.1076: f32[4096,128256], param_3.771: bf16[4,128,128256], param_4.479: s32[4,128], param_5.410: f32[4,128], param_6.285: f32[4,128], param_7.183: bf16[4,128], param_8.102: f32[4,128]) -> (bf16[4096], bf16[4,128,4096]) { + %param_1.1264 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1010 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1264), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1101 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1742 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1101), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1741 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1010, %mul.1742), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1009 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1741), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_3.771 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.479 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) %param_5.410 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %param_6.284 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_6.285 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) %param_7.183 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) %param_8.102 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) - %multiply_convert_fusion.2.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_3.783, %param_4.485, %param_5.410, %param_6.284, %param_7.183, /*index=5*/%param_8.102), kind=kLoop, calls=%fused_computation.280.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_2.1082 = f32[4096,128256]{1,0:T(8,128)} parameter(2) - %fusion.219.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1082), kind=kLoop, calls=%fused_computation.235.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %multiply_convert_fusion.2.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_3.771, %param_4.479, %param_5.410, %param_6.285, %param_7.183, /*index=5*/%param_8.102), kind=kLoop, calls=%fused_computation.280.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_2.1076 = f32[4096,128256]{1,0:T(8,128)} parameter(2) + %fusion.219.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1076), kind=kLoop, calls=%fused_computation.235.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} %convolution.86.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.219.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %param_1.1260 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.994 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1260), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1094 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1434 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1094), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1433 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.994, %mul.1434), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.993 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1433), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %multiply.252 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convolution.86.clone.1, %convert_element_type.993), metadata={op_name="multiply.206"} - %constant.874 = bf16[]{:T(256)} constant(0) - %reduce.149 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%multiply.252, %constant.874), dimensions={0,1}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %tuple.153 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.149, %convolution.86.clone.1) -} - -%fused_computation.323 (param_0.904: f32[64], param_1.974: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.974 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %div.621 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.974), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - %param_0.904 = f32[64]{0:T(128)S(1)} parameter(0) - %div.619 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.904), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %mul.1724 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1009, %convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.878 = bf16[]{:T(256)} constant(0) + %reduce.110 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%mul.1724, %constant.878), dimensions={0,1}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} + ROOT %tuple.153 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.110, %convolution.86.clone.1) +} + +%fused_computation.323 (param_0.911: f32[64], param_1.978: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.978 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.621 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.978), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %param_0.911 = f32[64]{0:T(128)S(1)} parameter(0) + %div.619 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.911), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %div.618 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.621, %div.619), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.618), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=0} - %convert_element_type.1002 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1018 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.618), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=0} - %convert_element_type.1001.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.150 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1002, %convert_element_type.1001.clone.1) + %convert_element_type.1017.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.150 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1018, %convert_element_type.1017.clone.1) } -%fused_computation.324 (param_0.901: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.901 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.866 = bf16[]{:T(256)} constant(-inf) - %pad.38 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.901, %constant.866), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.37 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.901, %constant.866), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.324 (param_0.908: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.908 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.858 = bf16[]{:T(256)} constant(-inf) + %pad.38 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.908, %constant.858), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.37 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.908, %constant.858), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.34 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.38, %pad.37), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%fused_computation.325 (param_0.903: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.903 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.865 = bf16[]{:T(256)} constant(-inf) - %pad.40 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.903, %constant.865), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.39 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.903, %constant.865), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.325 (param_0.910: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.910 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.857 = bf16[]{:T(256)} constant(-inf) + %pad.40 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.910, %constant.857), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.39 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.910, %constant.857), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.35 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.40, %pad.39), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%region_27.32 (reduce_sum.163: f32[], reduce_sum.164: f32[]) -> f32[] { - %reduce_sum.163 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.164 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.168 = f32[]{:T(128)} add(%reduce_sum.163, %reduce_sum.164), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_27.32 (reduce_sum.217: f32[], reduce_sum.221: f32[]) -> f32[] { + %reduce_sum.217 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.221 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.222 = f32[]{:T(128)} add(%reduce_sum.217, %reduce_sum.221), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_26.31 (reduce_sum.157: f32[], reduce_sum.161: f32[]) -> f32[] { - %reduce_sum.157 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.161 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.162 = f32[]{:T(128)} add(%reduce_sum.157, %reduce_sum.161), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_26.31 (reduce_sum.214: f32[], reduce_sum.215: f32[]) -> f32[] { + %reduce_sum.214 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.215 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.216 = f32[]{:T(128)} add(%reduce_sum.214, %reduce_sum.215), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.329 (param_0.1123: f32[4,4096], param_1.1283: f32[4,4096]) -> (f32[], f32[]) { - %param_0.1123 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(0) - %bitcast.371 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_0.1123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.329 (param_0.1130: f32[4,4096], param_1.1287: f32[4,4096]) -> (f32[], f32[]) { + %param_0.1130 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.371 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_0.1130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.195 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.371, %bitcast.371), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1010 = f32[]{:T(128)} constant(0) - %reduce.150 = f32[]{:T(128)} reduce(%square.195, %constant.1010), dimensions={0,1}, to_apply=%region_27.32, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1283 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(1) - %bitcast.375.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_1.1283), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.1002 = f32[]{:T(128)} constant(0) + %reduce.111 = f32[]{:T(128)} reduce(%square.195, %constant.1002), dimensions={0,1}, to_apply=%region_27.32, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1287 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(1) + %bitcast.375.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_1.1287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.198.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.375.clone.1, %bitcast.375.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.151.clone.1 = f32[]{:T(128)} reduce(%square.198.clone.1, %constant.1010), dimensions={0,1}, to_apply=%region_26.31, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.150, %reduce.151.clone.1) + %reduce.112.clone.1 = f32[]{:T(128)} reduce(%square.198.clone.1, %constant.1002), dimensions={0,1}, to_apply=%region_26.31, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.111, %reduce.112.clone.1) } -%region_54.59 (reduce_sum.301: f32[], reduce_sum.302: f32[]) -> f32[] { - %reduce_sum.301 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.302 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.303 = f32[]{:T(128)} add(%reduce_sum.301, %reduce_sum.302), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_54.59 (reduce_sum.355: f32[], reduce_sum.356: f32[]) -> f32[] { + %reduce_sum.355 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.356 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.357 = f32[]{:T(128)} add(%reduce_sum.355, %reduce_sum.356), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_40.45 (reduce_sum.226: f32[], reduce_sum.227: f32[]) -> f32[] { - %reduce_sum.226 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.227 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.231 = f32[]{:T(128)} add(%reduce_sum.226, %reduce_sum.227), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_40.45 (reduce_sum.280: f32[], reduce_sum.284: f32[]) -> f32[] { + %reduce_sum.280 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.284 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.285 = f32[]{:T(128)} add(%reduce_sum.280, %reduce_sum.284), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.332 (param_0.1112: f32[4096,4], param_1.1275: f32[], param_2.1106: f32[], param_3.795: f32[], param_4.496: f32[4096,4], param_5.421: f32[], param_6.293: f32[4,4096], param_7.192: pred[], param_8.110: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { - %param_0.1112 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.795 = f32[]{:T(128)S(6)} parameter(3) - %mul.1536.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.795), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.332 (param_0.1119: f32[4096,4], param_1.1279: f32[], param_2.1100: f32[], param_3.783: f32[], param_4.490: f32[4096,4], param_5.421: f32[], param_6.294: f32[4,4096], param_7.192: pred[], param_8.110: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1119 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.783 = f32[]{:T(128)S(6)} parameter(3) + %mul.1854.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.783), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.192 = pred[]{:T(512)S(6)} parameter(7) %select_n.266.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.192), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.293 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) - %bitcast.419.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.293), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.294 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.419.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.294), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.421 = f32[]{:T(128)} parameter(5) %div.773.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.421), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.772.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.419.clone.1, %div.773.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.265.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.266.clone.1, %bitcast.419.clone.1, %div.772.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.943.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.578.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.943.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1540.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %broadcast.578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.935.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.539.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.935.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1858.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %broadcast.539.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.110 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.947.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.577.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.947.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1539.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.110, %broadcast.577.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.809.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1540.clone.1, %mul.1539.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1106 = f32[]{:T(128)S(6)} parameter(2) - %div.769.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1106), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.939.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.538.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.939.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1857.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.110, %broadcast.538.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.795.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1858.clone.1, %mul.1857.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1100 = f32[]{:T(128)S(6)} parameter(2) + %div.769.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1100), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.66.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %select_n.265.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.946.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.576.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.946.clone.1), dimensions={}, metadata={op_name="broadcast.58"} - %mul.1538.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.66.clone.1, %broadcast.576.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.496 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.945.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.575.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.945.clone.1), dimensions={}, metadata={op_name="broadcast.57"} - %mul.1537.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.496, %broadcast.575.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.808.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1538.clone.1, %mul.1537.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1275 = f32[]{:T(128)S(6)} parameter(1) - %div.768.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1275), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.767.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.808.clone.1, %div.768.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.938.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.537.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.938.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1856.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.66.clone.1, %broadcast.537.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.490 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.937.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.536.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.937.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1855.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.490, %broadcast.536.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.794.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1856.clone.1, %mul.1855.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1279 = f32[]{:T(128)S(6)} parameter(1) + %div.768.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.767.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.794.clone.1, %div.768.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.64.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.767.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.944.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.573.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.944.clone.1), dimensions={}, metadata={op_name="broadcast.53"} - %add.807.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.573.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.262.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.769.clone.1, %add.807.clone.1), metadata={op_name="multiply.36"} - %div.766.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.809.clone.1, %multiply.262.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1535.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1112, %broadcast.578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.806.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.766.clone.1, %mul.1535.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1534.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1536.clone.1, %add.806.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.805.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1112, %mul.1534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.199 = f32[4096,4]{0,1:T(4,128)} multiply(%add.805.clone.1, %add.805.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.999 = f32[]{:T(128)} constant(0) - %reduce.152 = f32[]{:T(128)} reduce(%square.199, %constant.999), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.154.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.999), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.144 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.152, %add.805.clone.1, %add.808.clone.1, %add.809.clone.1, %reduce.154.clone.1) -} - -%region_53.58 (reduce_sum.295: f32[], reduce_sum.296: f32[]) -> f32[] { - %reduce_sum.295 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.296 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.297 = f32[]{:T(128)} add(%reduce_sum.295, %reduce_sum.296), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_39.44 (reduce_sum.220: f32[], reduce_sum.224: f32[]) -> f32[] { - %reduce_sum.220 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.224 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.225 = f32[]{:T(128)} add(%reduce_sum.220, %reduce_sum.224), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.333 (param_0.1113: f32[4096,4], param_1.1276: f32[], param_2.1107: f32[], param_3.796: f32[], param_4.497: f32[4096,4], param_5.422: f32[], param_6.294: f32[4,4096], param_7.193: pred[], param_8.111: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { - %param_0.1113 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.796 = f32[]{:T(128)S(6)} parameter(3) - %mul.1543.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.796), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.936.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.534.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.936.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.793.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.189.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.769.clone.1, %add.793.clone.1), metadata={op_name="multiply.27"} + %div.766.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.795.clone.1, %multiply.189.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1853.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1119, %broadcast.539.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.792.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.766.clone.1, %mul.1853.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1852.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1854.clone.1, %add.792.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.791.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1119, %mul.1852.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.199 = f32[4096,4]{0,1:T(4,128)} multiply(%add.791.clone.1, %add.791.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.991 = f32[]{:T(128)} constant(0) + %reduce.113 = f32[]{:T(128)} reduce(%square.199, %constant.991), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.115.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.991), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.144 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.113, %add.791.clone.1, %add.794.clone.1, %add.795.clone.1, %reduce.115.clone.1) +} + +%region_53.58 (reduce_sum.349: f32[], reduce_sum.350: f32[]) -> f32[] { + %reduce_sum.349 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.350 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.354 = f32[]{:T(128)} add(%reduce_sum.349, %reduce_sum.350), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_39.44 (reduce_sum.277: f32[], reduce_sum.278: f32[]) -> f32[] { + %reduce_sum.277 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.278 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.279 = f32[]{:T(128)} add(%reduce_sum.277, %reduce_sum.278), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.333 (param_0.1120: f32[4096,4], param_1.1280: f32[], param_2.1101: f32[], param_3.784: f32[], param_4.491: f32[4096,4], param_5.422: f32[], param_6.295: f32[4,4096], param_7.193: pred[], param_8.111: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1120 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.784 = f32[]{:T(128)S(6)} parameter(3) + %mul.1861.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.784), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.193 = pred[]{:T(512)S(6)} parameter(7) %select_n.270.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.193), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.294 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) - %bitcast.421.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.294), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.295 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.421.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.295), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.422 = f32[]{:T(128)} parameter(5) %div.781.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.422), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.780.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.421.clone.1, %div.781.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.269.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.270.clone.1, %bitcast.421.clone.1, %div.780.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.949.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.584.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.949.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1547.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %broadcast.584.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.941.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.545.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.941.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1865.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %broadcast.545.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.111 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.953.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.583.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.953.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1546.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.111, %broadcast.583.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.814.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1547.clone.1, %mul.1546.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1107 = f32[]{:T(128)S(6)} parameter(2) - %div.777.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1107), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.945.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.544.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.945.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1864.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.111, %broadcast.544.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.800.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1865.clone.1, %mul.1864.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1101 = f32[]{:T(128)S(6)} parameter(2) + %div.777.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1101), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.67.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %select_n.269.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.952.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.582.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.952.clone.1), dimensions={}, metadata={op_name="broadcast.58"} - %mul.1545.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.582.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.497 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.951.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.581.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.951.clone.1), dimensions={}, metadata={op_name="broadcast.57"} - %mul.1544.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.497, %broadcast.581.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.813.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1545.clone.1, %mul.1544.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1276 = f32[]{:T(128)S(6)} parameter(1) - %div.776.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1276), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.775.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.813.clone.1, %div.776.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.944.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.543.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.944.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1863.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.543.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.491 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.943.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.542.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.943.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1862.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.491, %broadcast.542.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.799.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1863.clone.1, %mul.1862.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1280 = f32[]{:T(128)S(6)} parameter(1) + %div.776.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1280), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.775.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.799.clone.1, %div.776.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.65.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.775.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.950.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.579.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.950.clone.1), dimensions={}, metadata={op_name="broadcast.53"} - %add.812.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.65.clone.1, %broadcast.579.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.263.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.777.clone.1, %add.812.clone.1), metadata={op_name="multiply.35"} - %div.774.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.814.clone.1, %multiply.263.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1542.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1113, %broadcast.584.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.811.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.774.clone.1, %mul.1542.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1541.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1543.clone.1, %add.811.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.810.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1113, %mul.1541.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.200 = f32[4096,4]{0,1:T(4,128)} multiply(%add.810.clone.1, %add.810.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1000 = f32[]{:T(128)} constant(0) - %reduce.153 = f32[]{:T(128)} reduce(%square.200, %constant.1000), dimensions={0,1}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.155.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1000), dimensions={0,1}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.145 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.153, %add.810.clone.1, %add.813.clone.1, %add.814.clone.1, %reduce.155.clone.1) -} - -%region_9.12 (reduce_sum.99: f32[], reduce_sum.100: f32[]) -> f32[] { - %reduce_sum.100 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.99 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.101 = f32[]{:T(128)} add(%reduce_sum.99, %reduce_sum.100), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.344 (param_0.1127: bf16[4096]) -> f32[] { - %param_0.1127 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) - %convert_element_type.1006 = f32[4096]{0:T(1024)} convert(%param_0.1127), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.203 = f32[4096]{0:T(1024)} multiply(%convert_element_type.1006, %convert_element_type.1006), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1014 = f32[]{:T(128)} constant(0) - ROOT %reduce.156 = f32[]{:T(128)} reduce(%square.203, %constant.1014), dimensions={0}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.942.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.540.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.942.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.798.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.65.clone.1, %broadcast.540.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.190.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.777.clone.1, %add.798.clone.1), metadata={op_name="multiply.26"} + %div.774.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.800.clone.1, %multiply.190.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1860.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1120, %broadcast.545.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.797.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.774.clone.1, %mul.1860.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1859.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1861.clone.1, %add.797.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.796.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1120, %mul.1859.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.200 = f32[4096,4]{0,1:T(4,128)} multiply(%add.796.clone.1, %add.796.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.992 = f32[]{:T(128)} constant(0) + %reduce.114 = f32[]{:T(128)} reduce(%square.200, %constant.992), dimensions={0,1}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.116.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.992), dimensions={0,1}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.114, %add.796.clone.1, %add.799.clone.1, %add.800.clone.1, %reduce.116.clone.1) +} + +%region_9.12 (reduce_sum.135: f32[], reduce_sum.136: f32[]) -> f32[] { + %reduce_sum.135 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.136 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.137 = f32[]{:T(128)} add(%reduce_sum.135, %reduce_sum.136), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.344 (param_0.1134: bf16[4096]) -> f32[] { + %param_0.1134 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1022 = f32[4096]{0:T(1024)} convert(%param_0.1134), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.203 = f32[4096]{0:T(1024)} multiply(%convert_element_type.1022, %convert_element_type.1022), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1006 = f32[]{:T(128)} constant(0) + ROOT %reduce.117 = f32[]{:T(128)} reduce(%square.203, %constant.1006), dimensions={0}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_49.54 (reduce_sum.274: f32[], reduce_sum.275: f32[]) -> f32[] { - %reduce_sum.274 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.275 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.276 = f32[]{:T(128)} add(%reduce_sum.274, %reduce_sum.275), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_49.54 (reduce_sum.328: f32[], reduce_sum.329: f32[]) -> f32[] { + %reduce_sum.328 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.329 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.333 = f32[]{:T(128)} add(%reduce_sum.328, %reduce_sum.329), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_35.40 (reduce_sum.199: f32[], reduce_sum.203: f32[]) -> f32[] { - %reduce_sum.199 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.203 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.204 = f32[]{:T(128)} add(%reduce_sum.199, %reduce_sum.203), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_35.40 (reduce_sum.256: f32[], reduce_sum.257: f32[]) -> f32[] { + %reduce_sum.256 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.257 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.258 = f32[]{:T(128)} add(%reduce_sum.256, %reduce_sum.257), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.345 (param_0.1117: f32[4096], param_1.1280: f32[], param_2.1111: f32[], param_3.800: f32[], param_4.501: f32[4096], param_5.426: f32[], param_6.298: bf16[4096], param_7.197: pred[], param_8.115: f32[4096]) -> (f32[], f32[4096], f32[4096], f32[4096], f32[]) { - %param_0.1117 = f32[4096]{0:T(1024)S(1)} parameter(0) - %param_3.800 = f32[]{:T(128)S(6)} parameter(3) - %mul.1574.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_3.800), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.345 (param_0.1124: f32[4096], param_1.1284: f32[], param_2.1105: f32[], param_3.788: f32[], param_4.495: f32[4096], param_5.426: f32[], param_6.299: bf16[4096], param_7.197: pred[], param_8.115: f32[4096]) -> (f32[], f32[4096], f32[4096], f32[4096], f32[]) { + %param_0.1124 = f32[4096]{0:T(1024)S(1)} parameter(0) + %param_3.788 = f32[]{:T(128)S(6)} parameter(3) + %mul.1892.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_3.788), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.197 = pred[]{:T(512)S(6)} parameter(7) %select_n.286.clone.1 = pred[4096]{0:T(1024)(128)(4,1)} broadcast(%param_7.197), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.298 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(6) - %convert_element_type.1021.clone.1 = f32[4096]{0:T(1024)} convert(%param_6.298), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_6.299 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1037.clone.1 = f32[4096]{0:T(1024)} convert(%param_6.299), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} %param_5.426 = f32[]{:T(128)} parameter(5) %div.813.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_5.426), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.812.clone.1 = f32[4096]{0:T(1024)} divide(%convert_element_type.1021.clone.1, %div.813.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.285.clone.1 = f32[4096]{0:T(1024)} select(%select_n.286.clone.1, %convert_element_type.1021.clone.1, %div.812.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.973.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.600.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.973.clone.1), dimensions={}, metadata={op_name="broadcast.72"} - %mul.1580.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %broadcast.600.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.812.clone.1 = f32[4096]{0:T(1024)} divide(%convert_element_type.1037.clone.1, %div.813.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.285.clone.1 = f32[4096]{0:T(1024)} select(%select_n.286.clone.1, %convert_element_type.1037.clone.1, %div.812.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.965.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.561.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.965.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.1898.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.115 = f32[4096]{0:T(1024)S(1)} parameter(8) - %constant.977.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1581.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.977.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1579.clone.1 = f32[4096]{0:T(1024)} multiply(%param_8.115, %mul.1581.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.836.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1580.clone.1, %mul.1579.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1111 = f32[]{:T(128)S(6)} parameter(2) - %div.809.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_2.1111), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.969.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1899.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.969.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1897.clone.1 = f32[4096]{0:T(1024)} multiply(%param_8.115, %mul.1899.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.822.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1898.clone.1, %mul.1897.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1105 = f32[]{:T(128)S(6)} parameter(2) + %div.809.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_2.1105), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.71.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %select_n.285.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.976.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1578.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.976.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1576.clone.1 = f32[4096]{0:T(1024)} multiply(%integer_pow.71.clone.1, %mul.1578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.501 = f32[4096]{0:T(1024)S(1)} parameter(4) - %constant.975.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1577.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.975.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1575.clone.1 = f32[4096]{0:T(1024)} multiply(%param_4.501, %mul.1577.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.835.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1576.clone.1, %mul.1575.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1280 = f32[]{:T(128)S(6)} parameter(1) - %div.808.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_1.1280), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.807.clone.1 = f32[4096]{0:T(1024)} divide(%add.835.clone.1, %div.808.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.968.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1896.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.968.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1894.clone.1 = f32[4096]{0:T(1024)} multiply(%integer_pow.71.clone.1, %mul.1896.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.495 = f32[4096]{0:T(1024)S(1)} parameter(4) + %constant.967.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1895.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.967.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1893.clone.1 = f32[4096]{0:T(1024)} multiply(%param_4.495, %mul.1895.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.821.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1894.clone.1, %mul.1893.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1284 = f32[]{:T(128)S(6)} parameter(1) + %div.808.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_1.1284), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.807.clone.1 = f32[4096]{0:T(1024)} divide(%add.821.clone.1, %div.808.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.69.clone.1 = f32[4096]{0:T(1024)} sqrt(%div.807.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.974.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.834.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.974.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.833.clone.1 = f32[4096]{0:T(1024)} add(%sqrt.69.clone.1, %add.834.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.267.clone.1 = f32[4096]{0:T(1024)} multiply(%div.809.clone.1, %add.833.clone.1), metadata={op_name="multiply.31"} - %div.806.clone.1 = f32[4096]{0:T(1024)} divide(%add.836.clone.1, %multiply.267.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1573.clone.1 = f32[4096]{0:T(1024)} multiply(%param_0.1117, %broadcast.600.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.832.clone.1 = f32[4096]{0:T(1024)} add(%div.806.clone.1, %mul.1573.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1572.clone.1 = f32[4096]{0:T(1024)} multiply(%mul.1574.clone.1, %add.832.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.831.clone.1 = f32[4096]{0:T(1024)S(1)} add(%param_0.1117, %mul.1572.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.204 = f32[4096]{0:T(1024)} multiply(%add.831.clone.1, %add.831.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1004 = f32[]{:T(128)} constant(0) - %reduce.157 = f32[]{:T(128)} reduce(%square.204, %constant.1004), dimensions={0}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.158.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1004), dimensions={0}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.148 = (f32[]{:T(128)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.157, %add.831.clone.1, %add.835.clone.1, %add.836.clone.1, %reduce.158.clone.1) -} - -%fused_computation.351 (param_0.964: s32[512]) -> s32[1024] { - %constant.801 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.539 = s32[1024]{0:T(1024)} broadcast(%constant.801), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %param_0.964 = s32[512]{0:T(512)S(1)} parameter(0) - %constant.802 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %pad.41 = s32[1024]{0:T(1024)} pad(%param_0.964, %constant.802), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %constant.800 = s32[] constant(128255), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.538 = s32[1024]{0:T(1024)} broadcast(%constant.800), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.539, %pad.41, %broadcast.538), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} -} - -%fused_computation.352 (param_0.963: s32[4,128]) -> s32[512] { - %param_0.963 = s32[4,128]{1,0:T(4,128)} parameter(0) - %constant.888 = s32[]{:T(128)} constant(0) - %broadcast.546 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.888), dimensions={}, metadata={op_name="broadcast.81"} - %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.963, %broadcast.546), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} - %constant.875 = s32[]{:T(128)} constant(128256) - %add.760 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.875), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %add.748 = s32[4,128]{1,0:T(4,128)} add(%param_0.963, %add.760), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.748, %param_0.963), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} + %constant.966.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.820.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.966.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.819.clone.1 = f32[4096]{0:T(1024)} add(%sqrt.69.clone.1, %add.820.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.194.clone.1 = f32[4096]{0:T(1024)} multiply(%div.809.clone.1, %add.819.clone.1), metadata={op_name="multiply.22"} + %div.806.clone.1 = f32[4096]{0:T(1024)} divide(%add.822.clone.1, %multiply.194.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1891.clone.1 = f32[4096]{0:T(1024)} multiply(%param_0.1124, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.818.clone.1 = f32[4096]{0:T(1024)} add(%div.806.clone.1, %mul.1891.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1890.clone.1 = f32[4096]{0:T(1024)} multiply(%mul.1892.clone.1, %add.818.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.817.clone.1 = f32[4096]{0:T(1024)S(1)} add(%param_0.1124, %mul.1890.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.204 = f32[4096]{0:T(1024)} multiply(%add.817.clone.1, %add.817.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.996 = f32[]{:T(128)} constant(0) + %reduce.118 = f32[]{:T(128)} reduce(%square.204, %constant.996), dimensions={0}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.119.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.996), dimensions={0}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.118, %add.817.clone.1, %add.821.clone.1, %add.822.clone.1, %reduce.119.clone.1) +} + +%fused_computation.351 (param_0.971: s32[512]) -> s32[1024] { + %constant.793 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.500 = s32[1024]{0:T(1024)} broadcast(%constant.793), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %param_0.971 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.794 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %pad.41 = s32[1024]{0:T(1024)} pad(%param_0.971, %constant.794), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %constant.792 = s32[] constant(128255), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.499 = s32[1024]{0:T(1024)} broadcast(%constant.792), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.500, %pad.41, %broadcast.499), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%fused_computation.352 (param_0.970: s32[4,128]) -> s32[512] { + %param_0.970 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.882 = s32[]{:T(128)} constant(0) + %broadcast.508 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.882), dimensions={}, metadata={op_name="broadcast.81"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.970, %broadcast.508), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} + %constant.879 = s32[]{:T(128)} constant(128256) + %add.746 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.879), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %add.734 = s32[4,128]{1,0:T(4,128)} add(%param_0.970, %add.746), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.734, %param_0.970), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} ROOT %bitcast.376 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} } -%region_61.66 (reduce_sum.345: f32[], reduce_sum.346: f32[]) -> f32[] { - %reduce_sum.345 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.346 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.330 = f32[]{:T(128)} add(%reduce_sum.345, %reduce_sum.346), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_48.53 (reduce_sum.268: f32[], reduce_sum.269: f32[]) -> f32[] { - %reduce_sum.268 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.269 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.273 = f32[]{:T(128)} add(%reduce_sum.268, %reduce_sum.269), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.353 (param_0.1128: bf16[4,128], param_1.1287: f32[4,128], param_2.1114: f32[4,128], param_3.802: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { - %param_3.802 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %constant.979.clone.1 = s32[]{:T(128)} constant(0) - %broadcast.601.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.979.clone.1), dimensions={}, metadata={op_name="broadcast.81"} - %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.802, %broadcast.601.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} - %param_1.1287 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1287), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} - %param_0.1128 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) - %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1128), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - %add.762 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %square.207 = f32[4,128]{1,0:T(4,128)} multiply(%add.762, %add.762), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} - %constant.1016 = f32[]{:T(128)} constant(0) - %broadcast.543 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1016), dimensions={}, metadata={op_name="broadcast.32"} - %mul.1473 = f32[4,128]{1,0:T(4,128)} multiply(%square.207, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %mul.1465 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.1473, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.159 = f32[]{:T(128)} reduce(%mul.1465, %constant.1016), dimensions={0,1}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %param_2.1114 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1114), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} - %add.749.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1473), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %mul.1466.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.749.clone.1, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.160.clone.1 = f32[]{:T(128)} reduce(%mul.1466.clone.1, %constant.1016), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %mul.1471.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.762, %broadcast.543), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %constant.891.clone.1 = f32[]{:T(128)} constant(1) - %add.757.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.891.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - %add.750.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1471.clone.1, %add.757.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - ROOT %tuple.149 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.159, %reduce.160.clone.1, %ne.6.clone.1, %add.750.clone.1) -} - -%fused_computation.356 (param_0.987: f32[4,128], param_1.1101: f32[4,128]) -> f32[4,128] { - %param_0.987 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %param_1.1101 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %constant.869 = f32[]{:T(128)} constant(0.000244140625) - %broadcast.549 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.869), dimensions={}, metadata={op_name="broadcast.264"} - %div.656 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1101, %broadcast.549), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.867 = f32[]{:T(128)} constant(1e-05) - %add.770 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.867), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.769 = f32[4,128]{1,0:T(4,128)} add(%div.656, %add.770), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %rsqrt.90 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.769), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} - %div.649 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.90, %add.769), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.864 = f32[]{:T(128)} constant(-0.5) - %mul.1477 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.864), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1470 = f32[4,128]{1,0:T(4,128)} multiply(%div.649, %mul.1477), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1469 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.987, %mul.1470), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.863 = f32[]{:T(128)} constant(0.00048828125) - %mul.1476 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.863), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - ROOT %mul.1468 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1469, %mul.1476), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} -} - -%region_0.1 (reduce_sum.67: s32[], reduce_sum.71: s32[]) -> s32[] { - %reduce_sum.67 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.71 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.72 = s32[]{:T(128)} add(%reduce_sum.67, %reduce_sum.71), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} -} - -%fused_computation.360 (param_0.1004: pred[4,128]) -> s32[] { - %param_0.1004 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %convert_element_type.1013 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1004), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} - %constant.889 = s32[]{:T(128)} constant(0) - ROOT %reduce.161 = s32[]{:T(128)} reduce(%convert_element_type.1013, %constant.889), dimensions={0,1}, to_apply=%region_0.1, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} -} - -%fused_computation.361 (param_0.989: f32[4,128]) -> f32[4,128] { - %param_0.989 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.870 = f32[]{:T(128)} constant(0.000244140625) - %broadcast.541 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.870), dimensions={}, metadata={op_name="broadcast.264"} - %div.654 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.989, %broadcast.541), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.868 = f32[]{:T(128)} constant(1e-05) - %add.759 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.868), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.756 = f32[4,128]{1,0:T(4,128)} add(%div.654, %add.759), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - ROOT %rsqrt.88 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.756), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} -} - -%fused_computation.362 (param_0.990: pred[4,128], param_1.1286: f32[]) -> f32[4,128] { - %param_0.990 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %param_1.1286 = f32[]{:T(128)S(6)} parameter(1) - %broadcast_in_dim.272 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1286), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} - %constant.1015 = f32[]{:T(128)} constant(0) - %broadcast.545 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1015), dimensions={}, metadata={op_name="broadcast.32"} - ROOT %mul.1478 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.990, %broadcast_in_dim.272, %broadcast.545), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} +%region_61.66 (reduce_sum.391: f32[], reduce_sum.392: f32[]) -> f32[] { + %reduce_sum.391 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.392 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.396 = f32[]{:T(128)} add(%reduce_sum.391, %reduce_sum.392), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.322: f32[], reduce_sum.326: f32[]) -> f32[] { + %reduce_sum.322 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.326 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.327 = f32[]{:T(128)} add(%reduce_sum.322, %reduce_sum.326), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.353 (param_0.1135: bf16[4,128], param_1.1291: f32[4,128], param_2.1108: f32[4,128], param_3.790: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.790 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.971.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.562.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.971.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.790, %broadcast.562.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} + %param_1.1291 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1291), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} + %param_0.1135 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1135), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + %add.748 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %square.207 = f32[4,128]{1,0:T(4,128)} multiply(%add.748, %add.748), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} + %constant.1008 = f32[]{:T(128)} constant(0) + %broadcast.502 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1008), dimensions={}, metadata={op_name="broadcast.32"} + %mul.1791 = f32[4,128]{1,0:T(4,128)} multiply(%square.207, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %mul.1783 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.1791, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.120 = f32[]{:T(128)} reduce(%mul.1783, %constant.1008), dimensions={0,1}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %param_2.1108 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1108), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} + %add.735.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1791), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %mul.1784.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.735.clone.1, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.121.clone.1 = f32[]{:T(128)} reduce(%mul.1784.clone.1, %constant.1008), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %mul.1789.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.748, %broadcast.502), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %constant.883.clone.1 = f32[]{:T(128)} constant(1) + %add.743.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.883.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + %add.736.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1789.clone.1, %add.743.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.120, %reduce.121.clone.1, %ne.6.clone.1, %add.736.clone.1) +} + +%fused_computation.356 (param_0.994: f32[4,128], param_1.1105: f32[4,128]) -> f32[4,128] { + %param_0.994 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1105 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.873 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.510 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.873), dimensions={}, metadata={op_name="broadcast.245"} + %div.656 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1105, %broadcast.510), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.871 = f32[]{:T(128)} constant(1e-05) + %add.756 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.871), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.755 = f32[4,128]{1,0:T(4,128)} add(%div.656, %add.756), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %rsqrt.90 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.755), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} + %div.649 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.90, %add.755), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.856 = f32[]{:T(128)} constant(-0.5) + %mul.1795 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.856), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1788 = f32[4,128]{1,0:T(4,128)} multiply(%div.649, %mul.1795), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1787 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.994, %mul.1788), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.855 = f32[]{:T(128)} constant(0.00048828125) + %mul.1794 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.855), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1786 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1787, %mul.1794), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%region_5.8 (reduce_sum.120: s32[], reduce_sum.121: s32[]) -> s32[] { + %reduce_sum.120 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.121 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.122 = s32[]{:T(128)} add(%reduce_sum.120, %reduce_sum.121), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.359 (param_0.1011: pred[4,128]) -> s32[] { + %param_0.1011 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1029 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1011), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} + %constant.881 = s32[]{:T(128)} constant(0) + ROOT %reduce.122 = s32[]{:T(128)} reduce(%convert_element_type.1029, %constant.881), dimensions={0,1}, to_apply=%region_5.8, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%fused_computation.361 (param_0.997: f32[4,128]) -> f32[4,128] { + %param_0.997 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.874 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.506 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.874), dimensions={}, metadata={op_name="broadcast.245"} + %div.654 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.997, %broadcast.506), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.872 = f32[]{:T(128)} constant(1e-05) + %add.745 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.872), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.742 = f32[4,128]{1,0:T(4,128)} add(%div.654, %add.745), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + ROOT %rsqrt.88 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.742), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} +} + +%fused_computation.362 (param_0.996: pred[4,128], param_1.1290: f32[]) -> f32[4,128] { + %param_0.996 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.1290 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.283 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1290), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} + %constant.1007 = f32[]{:T(128)} constant(0) + %broadcast.504 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1007), dimensions={}, metadata={op_name="broadcast.32"} + ROOT %mul.1796 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.996, %broadcast_in_dim.283, %broadcast.504), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} } %fused_computation.364 () -> f32[64] { - %constant.873 = f32[]{:T(128)} constant(500000) - %broadcast.552 = f32[64]{0:T(128)} broadcast(%constant.873), dimensions={}, metadata={op_name="broadcast.255"} + %constant.877 = f32[]{:T(128)} constant(500000) + %broadcast.513 = f32[64]{0:T(128)} broadcast(%constant.877), dimensions={}, metadata={op_name="broadcast.236"} %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=0} - %constant.872 = s32[]{:T(128)} constant(2) - %broadcast.551 = s32[64]{0:T(128)} broadcast(%constant.872), dimensions={}, metadata={op_name="broadcast.256"} - %mul.1479 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.551), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} - %convert_element_type.1014 = f32[64]{0:T(128)} convert(%mul.1479), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %constant.871 = f32[]{:T(128)} constant(0.0078125) - %broadcast.550 = f32[64]{0:T(128)} broadcast(%constant.871), dimensions={}, metadata={op_name="broadcast.257"} - %div.657 = f32[64]{0:T(128)} multiply(%convert_element_type.1014, %broadcast.550), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.552, %div.657), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} + %constant.876 = s32[]{:T(128)} constant(2) + %broadcast.512 = s32[64]{0:T(128)} broadcast(%constant.876), dimensions={}, metadata={op_name="broadcast.237"} + %mul.1797 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.512), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} + %convert_element_type.1030 = f32[64]{0:T(128)} convert(%mul.1797), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %constant.875 = f32[]{:T(128)} constant(0.0078125) + %broadcast.511 = f32[64]{0:T(128)} broadcast(%constant.875), dimensions={}, metadata={op_name="broadcast.238"} + %div.657 = f32[64]{0:T(128)} multiply(%convert_element_type.1030, %broadcast.511), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.513, %div.657), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} } -%fused_computation.365 (param_0.1002: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { - %param_0.1002 = s32[4,128]{1,0:T(4,128)} parameter(0) - %convert_element_type.1015 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1002), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %bitcast.377 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1015), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.151 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.377, %convert_element_type.1015) +%fused_computation.365 (param_0.1009: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1009 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1031 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1009), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %bitcast.377 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1031), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.151 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.377, %convert_element_type.1031) } -%fused_computation.369 (param_0.1103: f32[4096,4]) -> bf16[4,4096] { - %param_0.1103 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %bitcast.451 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1103), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.106 = bf16[4,4096]{1,0:T(4,128)(2,1)} convert(%bitcast.451) +%fused_computation.369 (param_0.1110: f32[4096,4]) -> bf16[4,4096] { + %param_0.1110 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %bitcast.451 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1110), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.69 = bf16[4,4096]{1,0:T(4,128)(2,1)S(1)} convert(%bitcast.451) } -%fused_computation.370 (param_0.1104: f32[4096,4]) -> bf16[4,4096] { - %param_0.1104 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %bitcast.452 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1104), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.108 = bf16[4,4096]{1,0:T(4,128)(2,1)S(1)} convert(%bitcast.452) +%fused_computation.370 (param_0.1111: f32[4096,4]) -> bf16[4,4096] { + %param_0.1111 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %bitcast.452 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1111), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.71 = bf16[4,4096]{1,0:T(4,128)(2,1)} convert(%bitcast.452) } %region_6.9 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { @@ -1371,364 +1371,364 @@ StackFrames ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.237.clone.clone (param_0.1090: f32[4096,128256]) -> bf16[4096,128256,1] { - %param_0.1090 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %convert_element_type.1026 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1090), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - ROOT %bitcast.447 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1026), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +%fused_computation.237.clone.clone (param_0.1097: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1097 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1042 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1097), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + ROOT %bitcast.447 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1042), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} } -%fused_computation.317.clone.clone (param_0.1091: f32[4,128], param_1.1257: bf16[4,128,4096], param_2.1077: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1077 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.383 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1077), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1257 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1028 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1257), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1091 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1595 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1091), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1594 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1028, %mul.1595), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1027 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1594), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.382 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.383, %convert_element_type.1027), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} +%fused_computation.317.clone.clone (param_0.1098: f32[4,128], param_1.1261: bf16[4,128,4096], param_2.1071: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1261 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1044 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1261), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1098 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1916 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1098), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1915 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1044, %mul.1916), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1043 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1915), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1071 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1917 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1071), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1914 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1043, %mul.1917), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} } -%fused_computation.371 (param_0.1105: f32[4096,128256], param_1.1268: f32[4,128], param_2.1099: bf16[4,128,4096], param_3.788: bf16[4096]) -> (bf16[4,128], bf16[4,128,128256]) { - %param_1.1268 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1099 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.788 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.240.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1268, %param_2.1099, %param_3.788), kind=kLoop, calls=%fused_computation.317.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1105 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %fusion.221.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1105), kind=kLoop, calls=%fused_computation.237.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +%fused_computation.371 (param_0.1112: f32[4096,128256], param_1.1272: f32[4,128], param_2.1093: bf16[4,128,4096], param_3.776: bf16[4096]) -> (bf16[4,128], bf16[4,128,128256]) { + %param_1.1272 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1093 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.776 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.240.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1272, %param_2.1093, %param_3.776), kind=kLoop, calls=%fused_computation.317.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1112 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %fusion.221.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1112), kind=kLoop, calls=%fused_computation.237.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} %convolution.87.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convolution(%fusion.240.clone.1, %fusion.221.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %constant.992 = bf16[]{:T(256)} constant(-inf) - %reduce.162 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.87.clone.1, %constant.992), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - ROOT %tuple.152 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,128256]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.162, %convolution.87.clone.1) + %constant.984 = bf16[]{:T(256)} constant(-inf) + %reduce.123 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.87.clone.1, %constant.984), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + ROOT %tuple.152 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,128256]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.123, %convolution.87.clone.1) } -%fused_computation.372 (param_0.1102: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { - %param_0.1102 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %bitcast.450 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1102), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.110 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.450) +%fused_computation.372 (param_0.1109: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.1109 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %bitcast.450 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1109), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.73 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.450) } -%convert_element_type.525.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { - %lhs.1 = bf16[] parameter(0) - %rhs.1 = bf16[] parameter(1) - ROOT %add.624 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%convert_element_type.541.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { + %lhs = bf16[] parameter(0) + %rhs = bf16[] parameter(1) + ROOT %add.609 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.121.clone.clone (param_0.1242: bf16[4,4096], param_1.1376: s32[]) -> bf16[4096] { - %param_0.1242 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1376 = s32[]{:T(128)S(6)} parameter(1) - %constant.1116 = s32[]{:T(128)} constant(0) - %dynamic_slice.316 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1242, %param_1.1376, %constant.1116), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1117 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.174 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.316, %constant.1117), dimensions={0}, to_apply=%convert_element_type.525.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.122.clone.clone (param_0.1249: bf16[4,4096], param_1.1380: s32[]) -> bf16[4096] { + %param_0.1249 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1380 = s32[]{:T(128)S(6)} parameter(1) + %constant.1108 = s32[]{:T(128)} constant(0) + %dynamic_slice.316 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1249, %param_1.1380, %constant.1108), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1109 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.135 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.316, %constant.1109), dimensions={0}, to_apply=%convert_element_type.541.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%region_12.14 (reduce_sum.108: f32[], reduce_sum.109: f32[]) -> f32[] { - %reduce_sum.108 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.109 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.113 = f32[]{:T(128)} add(%reduce_sum.108, %reduce_sum.109), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_12.14 (reduce_sum.144: f32[], reduce_sum.148: f32[]) -> f32[] { + %reduce_sum.144 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.148 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.149 = f32[]{:T(128)} add(%reduce_sum.144, %reduce_sum.148), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.58.clone.clone (param_0.1243: bf16[4,4,128,4096], param_1.1377: s32[]) -> f32[4,128] { - %param_0.1243 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1377 = s32[]{:T(128)S(6)} parameter(1) - %constant.1118 = s32[]{:T(128)} constant(0) - %dynamic_slice.317 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1243, %param_1.1377, %constant.1118, %constant.1118, %constant.1118), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.58.clone.clone (param_0.1250: bf16[4,4,128,4096], param_1.1381: s32[]) -> f32[4,128] { + %param_0.1250 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1381 = s32[]{:T(128)S(6)} parameter(1) + %constant.1110 = s32[]{:T(128)} constant(0) + %dynamic_slice.317 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1250, %param_1.1381, %constant.1110, %constant.1110, %constant.1110), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.548 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.317), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1093 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.548), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.214 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1093, %convert_element_type.1093), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1119 = f32[]{:T(128)} constant(0) - ROOT %reduce.175 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.214, %constant.1119), dimensions={2}, to_apply=%region_12.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + %convert_element_type.1109 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.548), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.214 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1109, %convert_element_type.1109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1111 = f32[]{:T(128)} constant(0) + ROOT %reduce.136 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.214, %constant.1111), dimensions={2}, to_apply=%region_12.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} } -%fused_computation.143.clone.1.clone (param_0.1244: f32[4,128]) -> f32[4,128] { - %param_0.1244 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1121 = f32[]{:T(128)} constant(0.000244140625) - %closed_call.81 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1121), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.842 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1244, %closed_call.81), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1120 = f32[]{:T(128)} constant(1e-05) - %closed_call.80 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1120), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.858 = f32[4,128]{1,0:T(4,128)} add(%div.842, %closed_call.80), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.97 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.858), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +%fused_computation.143.clone.1.clone (param_0.1251: f32[4,128]) -> f32[4,128] { + %param_0.1251 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1113 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.81 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1113), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.842 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1251, %closed_call.81), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1112 = f32[]{:T(128)} constant(1e-05) + %closed_call.80 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1112), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.844 = f32[4,128]{1,0:T(4,128)} add(%div.842, %closed_call.80), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.97 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} } -%fused_computation.24.clone.1.clone.clone (param_0.1258: bf16[4,4096,32,128], param_1.1387: s32[]) -> bf16[4096,32,128,1] { - %param_0.1258 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1387 = s32[]{:T(128)S(6)} parameter(1) - %constant.1134 = s32[]{:T(128)} constant(0) - %dynamic_slice.323 = bf16[1,4096,32,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1258, %param_1.1387, %constant.1134, %constant.1134, %constant.1134), dynamic_slice_sizes={1,4096,32,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.24.clone.1.clone.clone (param_0.1265: bf16[4,4096,32,128], param_1.1391: s32[]) -> bf16[4096,32,128,1] { + %param_0.1265 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1391 = s32[]{:T(128)S(6)} parameter(1) + %constant.1126 = s32[]{:T(128)} constant(0) + %dynamic_slice.323 = bf16[1,4096,32,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1265, %param_1.1391, %constant.1126, %constant.1126, %constant.1126), dynamic_slice_sizes={1,4096,32,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.559 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.323), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.91.clone.clone (param_0.1259: f32[4,128], param_1.1388: bf16[4,4,128,4096], param_2.1176: s32[], param_3.847: bf16[4096]) -> bf16[4,128,4096,1] { - %param_3.847 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.428 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.847), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1388 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1176 = s32[]{:T(128)S(6)} parameter(2) - %constant.1135 = s32[]{:T(128)} constant(0) - %dynamic_slice.324 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1388, %param_2.1176, %constant.1135, %constant.1135, %constant.1135), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.91.clone.clone (param_0.1266: bf16[4096], param_1.1392: f32[4,128], param_2.1170: bf16[4,4,128,4096], param_3.835: s32[]) -> bf16[4,128,4096,1] { + %param_2.1170 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.835 = s32[]{:T(128)S(6)} parameter(3) + %constant.1127 = s32[]{:T(128)} constant(0) + %dynamic_slice.324 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1170, %param_3.835, %constant.1127, %constant.1127, %constant.1127), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.561 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.324), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1101 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1259 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1709 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1259), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1708 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1101, %mul.1709), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1100 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1708), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.427 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.428, %convert_element_type.1100), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.560 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.427), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.36.clone.clone (param_0.1260: bf16[4,4096,32,128], param_1.1389: s32[], param_2.1177: f32[4,128], param_3.848: bf16[4,4,128,4096], param_4.530: bf16[4096]) -> bf16[4,128,32,128] { - %param_2.1177 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.848 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1389 = s32[]{:T(128)S(6)} parameter(1) - %param_4.530 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.343 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1177, %param_3.848, %param_1.1389, %param_4.530), kind=kLoop, calls=%fused_computation.91.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1260 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.342 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1260, %param_1.1389), kind=kLoop, calls=%fused_computation.24.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1117 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1392 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2081 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1392), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2080 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1117, %mul.2081), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1116 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2080), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1266 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2079 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1266), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2078 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1116, %mul.2079), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.560 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2078), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.36.clone.clone (param_0.1267: bf16[4,4096,32,128], param_1.1393: s32[], param_2.1171: f32[4,128], param_3.836: bf16[4,4,128,4096], param_4.524: bf16[4096]) -> bf16[4,128,32,128] { + %param_4.524 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_2.1171 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.836 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1393 = s32[]{:T(128)S(6)} parameter(1) + %fusion.343 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.524, %param_2.1171, %param_3.836, %param_1.1393), kind=kLoop, calls=%fused_computation.91.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1267 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.342 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1267, %param_1.1393), kind=kLoop, calls=%fused_computation.24.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.113 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.343, %fusion.342), window={size=1x32 pad=0_0x31_31 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%fused_computation.70.clone.clone (param_0.1261: bf16[4,128,32,128]) -> (bf16[4,128,32,64], bf16[4,128,32,64]) { - %param_0.1261 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.160 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1261), slice={[0:4], [0:128], [0:32], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} +%fused_computation.70.clone.clone (param_0.1268: bf16[4,128,32,128]) -> (bf16[4,128,32,64], bf16[4,128,32,64]) { + %param_0.1268 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1268), slice={[0:4], [0:128], [0:32], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.129 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.161 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1261), slice={[0:4], [0:128], [0:32], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %split.161 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1268), slice={[0:4], [0:128], [0:32], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} ROOT %tuple.187 = (bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) } %fused_computation.145.clone.clone () -> f32[64] { - %constant.1124 = f32[]{:T(128)} constant(500000) - %closed_call.84 = f32[64]{0:T(128)} broadcast(%constant.1124), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %constant.1116 = f32[]{:T(128)} constant(500000) + %closed_call.84 = f32[64]{0:T(128)} broadcast(%constant.1116), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=0} - %constant.1123 = s32[]{:T(128)} constant(2) - %closed_call.83 = s32[64]{0:T(128)} broadcast(%constant.1123), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.1699 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.83), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1094 = f32[64]{0:T(128)} convert(%mul.1699), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %constant.1122 = f32[]{:T(128)} constant(0.0078125) - %closed_call.82 = f32[64]{0:T(128)} broadcast(%constant.1122), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.843 = f32[64]{0:T(128)} multiply(%convert_element_type.1094, %closed_call.82), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1115 = s32[]{:T(128)} constant(2) + %closed_call.83 = s32[64]{0:T(128)} broadcast(%constant.1115), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2065 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.83), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1110 = f32[64]{0:T(128)} convert(%mul.2065), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %constant.1114 = f32[]{:T(128)} constant(0.0078125) + %closed_call.82 = f32[64]{0:T(128)} broadcast(%constant.1114), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.843 = f32[64]{0:T(128)} multiply(%convert_element_type.1110, %closed_call.82), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.84, %div.843), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=0} } -%fused_computation.117.clone.clone (param_0.1245: f32[64], param_1.1378: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.1378 = f32[4,128]{1,0:T(4,128)} parameter(1) - %div.846 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1378), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %param_0.1245 = f32[64]{0:T(128)S(1)} parameter(0) - %div.845 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1245), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} +%fused_computation.117.clone.clone (param_0.1252: f32[64], param_1.1382: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1382 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.846 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1382), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %param_0.1252 = f32[64]{0:T(128)S(1)} parameter(0) + %div.845 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1252), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %div.844 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.846, %div.845), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=0} - %convert_element_type.1095 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1111 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=0} - %convert_element_type.829.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.185 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1095, %convert_element_type.829.clone.3) + %convert_element_type.845.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.185 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1111, %convert_element_type.845.clone.3) } -%fused_computation.120.clone.clone (param_0.1252: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1252 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1130 = bf16[]{:T(256)} constant(-inf) - %pad.61 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1252, %constant.1130), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.60 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1252, %constant.1130), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.120.clone.clone (param_0.1259: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1259 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1122 = bf16[]{:T(256)} constant(-inf) + %pad.61 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1259, %constant.1122), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.60 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1259, %constant.1122), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.61, %pad.60), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} ROOT %bitcast.554 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.119.clone.clone (param_0.1246: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1246 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1125 = bf16[]{:T(256)} constant(-inf) - %pad.59 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1246, %constant.1125), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.58 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1246, %constant.1125), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.119.clone.clone (param_0.1253: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1253 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1117 = bf16[]{:T(256)} constant(-inf) + %pad.59 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1253, %constant.1117), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.58 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1253, %constant.1117), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.44 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.59, %pad.58), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} ROOT %bitcast.549 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.73.clone.clone (param_0.1262: bf16[4,128,32,64], param_1.1390: bf16[4,128,32,64], param_2.1178: bf16[4,128,32,128], param_3.849: bf16[4,128,128], param_4.531: bf16[4,128,128]) -> bf16[4,32,128,128] { - %param_2.1178 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) - %param_4.531 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) - %mul.1713 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.531), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1711 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1178, %mul.1713), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1390 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1136 = bf16[]{:T(256)} constant(-inf) - %pad.65 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1390, %constant.1136), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1262 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.64 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1262, %constant.1136), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.73.clone.clone (param_0.1269: bf16[4,128,32,64], param_1.1394: bf16[4,128,32,64], param_2.1172: bf16[4,128,32,128], param_3.837: bf16[4,128,128], param_4.525: bf16[4,128,128]) -> bf16[4,32,128,128] { + %param_2.1172 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.525 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.2085 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.525), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2083 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1172, %mul.2085), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1394 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1128 = bf16[]{:T(256)} constant(-inf) + %pad.65 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1394, %constant.1128), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1269 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.64 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1269, %constant.1128), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.47 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.65, %pad.64), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_3.849 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.1712 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.849), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1710 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.47, %mul.1712), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.860 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.1711, %mul.1710), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %bitcast.562 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.860), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.90.clone.clone (param_0.1254: f32[4,128], param_1.1384: bf16[4,4,128,4096], param_2.1173: s32[], param_3.844: bf16[4096]) -> bf16[4,128,4096,1] { - %param_3.844 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.426 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.844), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1384 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1173 = s32[]{:T(128)S(6)} parameter(2) - %constant.1132 = s32[]{:T(128)} constant(0) - %dynamic_slice.322 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1384, %param_2.1173, %constant.1132, %constant.1132, %constant.1132), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %param_3.837 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2084 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.837), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2082 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.47, %mul.2084), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.846 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2083, %mul.2082), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.562 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.846), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.90.clone.clone (param_0.1261: bf16[4096], param_1.1388: f32[4,128], param_2.1167: bf16[4,4,128,4096], param_3.832: s32[]) -> bf16[4,128,4096,1] { + %param_2.1167 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.832 = s32[]{:T(128)S(6)} parameter(3) + %constant.1124 = s32[]{:T(128)} constant(0) + %dynamic_slice.322 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1167, %param_3.832, %constant.1124, %constant.1124, %constant.1124), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.557 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.322), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1099 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.557), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1254 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1703 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1254), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1702 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1099, %mul.1703), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1098 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1702), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.425 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.426, %convert_element_type.1098), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.556 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.64.clone.1.clone.clone (param_0.1253: bf16[4,4096,8,128], param_1.1383: s32[]) -> bf16[4096,8,128,1] { - %param_0.1253 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1383 = s32[]{:T(128)S(6)} parameter(1) - %constant.1131 = s32[]{:T(128)} constant(0) - %dynamic_slice.321 = bf16[1,4096,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1253, %param_1.1383, %constant.1131, %constant.1131, %constant.1131), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1115 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.557), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1388 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2073 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1388), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2072 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1115, %mul.2073), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1114 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2072), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1261 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2071 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1261), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2070 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1114, %mul.2071), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.556 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2070), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.64.clone.1.clone.clone (param_0.1260: bf16[4,4096,8,128], param_1.1387: s32[]) -> bf16[4096,8,128,1] { + %param_0.1260 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1387 = s32[]{:T(128)S(6)} parameter(1) + %constant.1123 = s32[]{:T(128)} constant(0) + %dynamic_slice.321 = bf16[1,4096,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1260, %param_1.1387, %constant.1123, %constant.1123, %constant.1123), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.555 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.321), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.89.clone.clone (param_0.1255: bf16[4,4096,8,128], param_1.1385: s32[], param_2.1174: f32[4,128], param_3.845: bf16[4,4,128,4096], param_4.528: bf16[4096]) -> bf16[4,128,8,128] { - %param_2.1174 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.845 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1385 = s32[]{:T(128)S(6)} parameter(1) - %param_4.528 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.340 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1174, %param_3.845, %param_1.1385, %param_4.528), kind=kLoop, calls=%fused_computation.90.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1255 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.341 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1255, %param_1.1385), kind=kLoop, calls=%fused_computation.64.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.89.clone.clone (param_0.1262: bf16[4,4096,8,128], param_1.1389: s32[], param_2.1168: f32[4,128], param_3.833: bf16[4,4,128,4096], param_4.522: bf16[4096]) -> bf16[4,128,8,128] { + %param_4.522 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_2.1168 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.833 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1389 = s32[]{:T(128)S(6)} parameter(1) + %fusion.340 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.522, %param_2.1168, %param_3.833, %param_1.1389), kind=kLoop, calls=%fused_computation.90.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1262 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.341 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1262, %param_1.1389), kind=kLoop, calls=%fused_computation.64.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.112 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.340, %fusion.341), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%fused_computation.106.clone.clone (param_0.1256: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { - %param_0.1256 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1256), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} +%fused_computation.106.clone.clone (param_0.1263: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1263 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1263), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1256), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1263), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} ROOT %tuple.186 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) } -%fused_computation.109.clone.clone (param_0.1257: bf16[4,128,8,64], param_1.1386: bf16[4,128,8,64], param_2.1175: bf16[4,128,8,128], param_3.846: bf16[4,128,128], param_4.529: bf16[4,128,128]) -> bf16[4,8,128,128] { - %param_2.1175 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) - %param_4.529 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) - %mul.1707 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.529), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1705 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1175, %mul.1707), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1386 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1133 = bf16[]{:T(256)} constant(-inf) - %pad.63 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1386, %constant.1133), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1257 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.62 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1257, %constant.1133), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.109.clone.clone (param_0.1264: bf16[4,128,8,64], param_1.1390: bf16[4,128,8,64], param_2.1169: bf16[4,128,8,128], param_3.834: bf16[4,128,128], param_4.523: bf16[4,128,128]) -> bf16[4,8,128,128] { + %param_2.1169 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.523 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.2077 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.523), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2075 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1169, %mul.2077), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1390 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1125 = bf16[]{:T(256)} constant(-inf) + %pad.63 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1390, %constant.1125), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1264 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.62 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1264, %constant.1125), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.46 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.63, %pad.62), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_3.846 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.1706 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.846), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1704 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.46, %mul.1706), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.859 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.1705, %mul.1704), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %bitcast.558 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.859), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_3.834 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2076 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.834), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2074 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.46, %mul.2076), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.845 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2075, %mul.2074), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.558 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.845), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} } -%fused_computation.135.clone.clone (param_0.1248: bf16[4,4096,8,128], param_1.1380: s32[]) -> bf16[1,4096,8,128] { - %param_0.1248 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) - %param_1.1380 = s32[]{:T(128)S(6)} parameter(1) - %constant.1128 = s32[]{:T(128)} constant(0) - ROOT %dynamic_slice.319 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1248, %param_1.1380, %constant.1128, %constant.1128, %constant.1128), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.135.clone.clone (param_0.1255: bf16[4,4096,8,128], param_1.1384: s32[]) -> bf16[1,4096,8,128] { + %param_0.1255 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) + %param_1.1384 = s32[]{:T(128)S(6)} parameter(1) + %constant.1120 = s32[]{:T(128)} constant(0) + ROOT %dynamic_slice.319 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1255, %param_1.1384, %constant.1120, %constant.1120, %constant.1120), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} } -%fused_computation.65.clone.1.clone.clone.clone.clone (param_0.1249: bf16[1,4096,8,128]) -> bf16[4096,8,128,1] { - %param_0.1249 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %copy.248 = bf16[1,4096,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1249), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} +%fused_computation.65.clone.1.clone.clone.clone.clone (param_0.1256: bf16[1,4096,8,128]) -> bf16[4096,8,128,1] { + %param_0.1256 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %copy.248 = bf16[1,4096,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1256), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} ROOT %bitcast.550 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.248), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.88.clone.clone.clone.clone (param_0.1250: f32[4,128], param_1.1381: bf16[4,4,128,4096], param_2.1171: s32[], param_3.842: bf16[4096]) -> bf16[4,128,4096,1] { - %param_3.842 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.424 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.842), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1381 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1171 = s32[]{:T(128)S(6)} parameter(2) - %constant.1129 = s32[]{:T(128)} constant(0) - %dynamic_slice.320 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1381, %param_2.1171, %constant.1129, %constant.1129, %constant.1129), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.88.clone.clone.clone.clone (param_0.1257: bf16[4096], param_1.1385: f32[4,128], param_2.1165: bf16[4,4,128,4096], param_3.830: s32[]) -> bf16[4,128,4096,1] { + %param_2.1165 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.830 = s32[]{:T(128)S(6)} parameter(3) + %constant.1121 = s32[]{:T(128)} constant(0) + %dynamic_slice.320 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1165, %param_3.830, %constant.1121, %constant.1121, %constant.1121), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.552 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.320), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1097 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.552), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1250 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1701 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1250), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1700 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1097, %mul.1701), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1096 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1700), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.423 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.424, %convert_element_type.1096), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.551 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.114.clone.clone (param_0.1251: bf16[1,4096,8,128], param_1.1382: f32[4,128], param_2.1172: bf16[4,4,128,4096], param_3.843: s32[], param_4.527: bf16[4096]) -> bf16[4,8,128,128] { - %param_1.1382 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1172 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) - %param_3.843 = s32[]{:T(128)S(6)} parameter(3) - %param_4.527 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.339 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1382, %param_2.1172, %param_3.843, %param_4.527), kind=kLoop, calls=%fused_computation.88.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1251 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %fusion.338 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1251), kind=kLoop, calls=%fused_computation.65.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1113 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.552), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1385 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2069 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1385), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2068 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1113, %mul.2069), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1112 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2068), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1257 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2067 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1257), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2066 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1112, %mul.2067), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.551 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2066), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.114.clone.clone (param_0.1258: bf16[1,4096,8,128], param_1.1386: f32[4,128], param_2.1166: bf16[4,4,128,4096], param_3.831: s32[], param_4.521: bf16[4096]) -> bf16[4,8,128,128] { + %param_4.521 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_1.1386 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1166 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.831 = s32[]{:T(128)S(6)} parameter(3) + %fusion.339 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.521, %param_1.1386, %param_2.1166, %param_3.831), kind=kLoop, calls=%fused_computation.88.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1258 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %fusion.338 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1258), kind=kLoop, calls=%fused_computation.65.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.111 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.339, %fusion.338), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} ROOT %bitcast.553 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} } -%fused_computation.366.clone.clone (param_0.1286: f32[4,32,128,128]) -> (f32[4,32,128,1], f32[4,32,128]) { - %param_0.1286 = f32[4,32,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) - %slice.11 = f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1286), slice={[0:4], [0:32], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} +%fused_computation.366.clone.clone (param_0.1293: f32[4,32,128,128]) -> (f32[4,32,128,1], f32[4,32,128]) { + %param_0.1293 = f32[4,32,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) + %slice.11 = f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1293), slice={[0:4], [0:32], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} %bitcast.262.clone.3 = f32[4,32,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} ROOT %tuple.192 = (f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)}, f32[4,32,128]{2,1,0:T(8,128)S(1)}) tuple(%slice.11, %bitcast.262.clone.3) } -%region_13.16 (reduce_sum.120: f32[], reduce_sum.121: f32[]) -> f32[] { - %reduce_sum.120 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.121 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.122 = f32[]{:T(128)} add(%reduce_sum.120, %reduce_sum.121), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_13.16 (reduce_sum.150: f32[], reduce_sum.151: f32[]) -> f32[] { + %reduce_sum.150 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.151 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.155 = f32[]{:T(128)} add(%reduce_sum.150, %reduce_sum.151), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone (param_0.1263: bf16[4,32,128,4096], param_1.1391: s32[]) -> bf16[32,128,4096,1] { - %param_0.1263 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1391 = s32[]{:T(128)S(6)} parameter(1) - %constant.1137 = s32[]{:T(128)} constant(0) - %dynamic_slice.325 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1263, %param_1.1391, %constant.1137, %constant.1137, %constant.1137), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone (param_0.1270: bf16[4,32,128,4096], param_1.1395: s32[]) -> bf16[32,128,4096,1] { + %param_0.1270 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1395 = s32[]{:T(128)S(6)} parameter(1) + %constant.1129 = s32[]{:T(128)} constant(0) + %dynamic_slice.325 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1270, %param_1.1395, %constant.1129, %constant.1129, %constant.1129), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.563 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.325), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.80.clone.clone.clone.clone.clone.clone (param_0.1264: bf16[4,32,128,128]) -> bf16[4,128,32,128] { - %param_0.1264 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.564 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +%fused_computation.81.clone.clone.clone.clone.clone.clone (param_0.1271: bf16[4,32,128,128]) -> bf16[4,128,32,128] { + %param_0.1271 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.564 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1271), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} } -%fused_computation.61.clone.clone (param_0.1265: bf16[4,32,128,4096], param_1.1392: s32[], param_2.1179: bf16[4,32,128,128], param_3.850: bf16[4,4,128,4096]) -> (f32[4,128], bf16[4,128,4096]) { - %param_3.850 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1392 = s32[]{:T(128)S(6)} parameter(1) - %constant.365.clone.1.clone.3 = s32[]{:T(128)} constant(0) - %dynamic_slice.208.clone.3 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.850, %param_1.1392, %constant.365.clone.1.clone.3, %constant.365.clone.1.clone.3, %constant.365.clone.1.clone.3), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.61.clone.clone (param_0.1272: bf16[4,32,128,4096], param_1.1396: s32[], param_2.1173: bf16[4,32,128,128], param_3.838: bf16[4,4,128,4096]) -> (f32[4,128], bf16[4,128,4096]) { + %param_3.838 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1396 = s32[]{:T(128)S(6)} parameter(1) + %constant.357.clone.1.clone.3 = s32[]{:T(128)} constant(0) + %dynamic_slice.208.clone.3 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.838, %param_1.1396, %constant.357.clone.1.clone.3, %constant.357.clone.1.clone.3, %constant.357.clone.1.clone.3), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.207.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.208.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %param_2.1179 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %fusion.83.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1179), kind=kLoop, calls=%fused_computation.80.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} - %param_0.1265 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %fusion.82.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1265, %param_1.1392), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1173 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %fusion.83.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1173), kind=kLoop, calls=%fused_computation.81.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_0.1272 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %fusion.82.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1272, %param_1.1396), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.62.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.83.clone.3, %fusion.82.clone.3), window={size=1x32}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} %bitcast.182.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %add.635.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.207.clone.3, %bitcast.182.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %convert_element_type.1102 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%add.635.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.215 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1102, %convert_element_type.1102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1138 = f32[]{:T(128)} constant(0) - %reduce.177 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.215, %constant.1138), dimensions={2}, to_apply=%region_13.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.188 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.177, %add.635.clone.3) + %add.621.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.207.clone.3, %bitcast.182.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %convert_element_type.1118 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%add.621.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.215 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1118, %convert_element_type.1118), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1130 = f32[]{:T(128)} constant(0) + %reduce.138 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.215, %constant.1130), dimensions={2}, to_apply=%region_13.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.188 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.138, %add.621.clone.3) } -%convert_element_type.523.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { - %lhs = bf16[] parameter(0) - %rhs = bf16[] parameter(1) - ROOT %add.623 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%convert_element_type.556.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %lhs.1 = bf16[] parameter(0) + %rhs.1 = bf16[] parameter(1) + ROOT %add.610 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.122.clone.clone (param_0.1247: bf16[4,4096], param_1.1379: s32[]) -> bf16[4096] { - %param_0.1247 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1379 = s32[]{:T(128)S(6)} parameter(1) - %constant.1126 = s32[]{:T(128)} constant(0) - %dynamic_slice.318 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1247, %param_1.1379, %constant.1126), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1127 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.176 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.318, %constant.1127), dimensions={0}, to_apply=%convert_element_type.523.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.121.clone.clone (param_0.1254: bf16[4,4096], param_1.1383: s32[]) -> bf16[4096] { + %param_0.1254 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1383 = s32[]{:T(128)S(6)} parameter(1) + %constant.1118 = s32[]{:T(128)} constant(0) + %dynamic_slice.318 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1254, %param_1.1383, %constant.1118), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1119 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.137 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.318, %constant.1119), dimensions={0}, to_apply=%convert_element_type.556.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.12.clone.clone.clone (param_0.1266: bf16[4,14336,4096], param_1.1393: s32[]) -> bf16[14336,4096,1] { - %param_0.1266 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1393 = s32[]{:T(128)S(6)} parameter(1) - %constant.1139 = s32[]{:T(128)} constant(0) - %dynamic_slice.326 = bf16[1,14336,4096]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1266, %param_1.1393, %constant.1139, %constant.1139), dynamic_slice_sizes={1,14336,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.12.clone.clone.clone (param_0.1273: bf16[4,14336,4096], param_1.1397: s32[]) -> bf16[14336,4096,1] { + %param_0.1273 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1397 = s32[]{:T(128)S(6)} parameter(1) + %constant.1131 = s32[]{:T(128)} constant(0) + %dynamic_slice.326 = bf16[1,14336,4096]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1273, %param_1.1397, %constant.1131, %constant.1131), dynamic_slice_sizes={1,14336,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.566 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.326), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } @@ -1737,264 +1737,264 @@ StackFrames ROOT %bitcast.565 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.12) } -%fused_computation.13.clone.clone (param_0.1267: bf16[4,128,4096], param_1.1394: bf16[4,14336,4096], param_2.1180: s32[]) -> bf16[14336,4,128] { - %param_1.1394 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1180 = s32[]{:T(128)S(6)} parameter(2) - %fusion.344 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1394, %param_2.1180), kind=kLoop, calls=%fused_computation.12.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1267 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %fusion.345 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1267), kind=kLoop, calls=%bitcast_fusion.3.clone.clone +%fused_computation.13.clone.clone (param_0.1274: bf16[4,128,4096], param_1.1398: bf16[4,14336,4096], param_2.1174: s32[]) -> bf16[14336,4,128] { + %param_1.1398 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1174 = s32[]{:T(128)S(6)} parameter(2) + %fusion.344 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1398, %param_2.1174), kind=kLoop, calls=%fused_computation.12.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1274 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %fusion.345 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1274), kind=kLoop, calls=%bitcast_fusion.3.clone.clone ROOT %convolution.114 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} convolution(%fusion.344, %fusion.345), window={size=4 pad=3_3 rhs_reversal=1}, dim_labels=bf0_0oi->b0f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} } -%fused_computation.144.clone.1.clone (param_0.1268: f32[4,128]) -> f32[4,128] { - %param_0.1268 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1141 = f32[]{:T(128)} constant(0.000244140625) - %closed_call.86 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1141), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.847 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1268, %closed_call.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1140 = f32[]{:T(128)} constant(1e-05) - %closed_call.85 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1140), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.861 = f32[4,128]{1,0:T(4,128)} add(%div.847, %closed_call.85), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.98 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.861), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +%fused_computation.144.clone.1.clone (param_0.1275: f32[4,128]) -> f32[4,128] { + %param_0.1275 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1133 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.86 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1133), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.847 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1275, %closed_call.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1132 = f32[]{:T(128)} constant(1e-05) + %closed_call.85 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1132), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.847 = f32[4,128]{1,0:T(4,128)} add(%div.847, %closed_call.85), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.98 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.847), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} } -%fused_computation.11.clone.1.clone.clone (param_0.1272: bf16[4,4096,14336], param_1.1398: s32[]) -> bf16[4096,14336,1] { - %param_0.1272 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1398 = s32[]{:T(128)S(6)} parameter(1) - %constant.1143 = s32[]{:T(128)} constant(0) - %dynamic_slice.328 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1272, %param_1.1398, %constant.1143, %constant.1143), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.11.clone.1.clone.clone (param_0.1279: bf16[4,4096,14336], param_1.1402: s32[]) -> bf16[4096,14336,1] { + %param_0.1279 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1402 = s32[]{:T(128)S(6)} parameter(1) + %constant.1135 = s32[]{:T(128)} constant(0) + %dynamic_slice.328 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1279, %param_1.1402, %constant.1135, %constant.1135), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.568 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.96.clone.2.clone.clone (param_0.1273: f32[4,128], param_1.1399: bf16[4,128,4096], param_2.1183: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1183 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.432 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1183), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1399 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1106 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1273 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1717 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1273), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1716 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1106, %mul.1717), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1105 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1716), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.431 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.432, %convert_element_type.1105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.23.clone.clone (param_0.1274: bf16[4,4096,14336], param_1.1400: s32[], param_2.1184: f32[4,128], param_3.852: bf16[4,128,4096], param_4.533: bf16[4096]) -> bf16[4,128,14336] { - %param_2.1184 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.852 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.533 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.349 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1184, %param_3.852, %param_4.533), kind=kLoop, calls=%fused_computation.96.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1274 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1400 = s32[]{:T(128)S(6)} parameter(1) - %fusion.348 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1274, %param_1.1400), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.96.clone.2.clone.clone (param_0.1280: f32[4,128], param_1.1403: bf16[4,128,4096], param_2.1177: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1403 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1122 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1403), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1280 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2092 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1280), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2091 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1122, %mul.2092), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1121 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2091), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1177 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2093 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1177), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2090 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1121, %mul.2093), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.23.clone.clone (param_0.1281: bf16[4,4096,14336], param_1.1404: s32[], param_2.1178: f32[4,128], param_3.840: bf16[4,128,4096], param_4.527: bf16[4096]) -> bf16[4,128,14336] { + %param_2.1178 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.840 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.527 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.349 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1178, %param_3.840, %param_4.527), kind=kLoop, calls=%fused_computation.96.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1281 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1404 = s32[]{:T(128)S(6)} parameter(1) + %fusion.348 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1281, %param_1.1404), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.116 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.349, %fusion.348), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%fused_computation.14.clone.1.clone.clone (param_0.1275: bf16[4,4096,14336], param_1.1401: s32[]) -> bf16[4096,14336,1] { - %param_0.1275 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1401 = s32[]{:T(128)S(6)} parameter(1) - %constant.1144 = s32[]{:T(128)} constant(0) - %dynamic_slice.329 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1275, %param_1.1401, %constant.1144, %constant.1144), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.14.clone.1.clone.clone (param_0.1282: bf16[4,4096,14336], param_1.1405: s32[]) -> bf16[4096,14336,1] { + %param_0.1282 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1405 = s32[]{:T(128)S(6)} parameter(1) + %constant.1136 = s32[]{:T(128)} constant(0) + %dynamic_slice.329 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1282, %param_1.1405, %constant.1136, %constant.1136), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.569 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.329), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.39.clone.1.clone.clone (param_0.1276: bf16[14336,4,128], param_1.1402: bf16[4,128,14336]) -> bf16[4,128,14336] { - %param_1.1402 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1145 = bf16[]{:T(256)} constant(1) - %jit_silu_.44 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1145), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} - %neg.130 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_1.1402), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} +%fused_computation.39.clone.1.clone.clone (param_0.1283: bf16[14336,4,128], param_1.1406: bf16[4,128,14336]) -> bf16[4,128,14336] { + %param_1.1406 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1137 = bf16[]{:T(256)} constant(1) + %jit_silu_.44 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1137), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} + %neg.130 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_1.1406), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} %exp.69 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} exponential(%neg.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} - %add.862 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} - %div.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.862), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} - %mul.1719 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1402, %div.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} - %param_0.1276 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(0) - %bitcast.570 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_0.1276), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - ROOT %mul.1718 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.1719, %bitcast.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} -} - -%fused_computation.21.clone.clone (param_0.1277: bf16[4,4096,14336], param_1.1403: s32[], param_2.1185: bf16[14336,4,128], param_3.853: bf16[4,128,14336]) -> bf16[4,128,4096] { - %param_2.1185 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) - %param_3.853 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %bitcast_multiply_fusion.15 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1185, %param_3.853), kind=kLoop, calls=%fused_computation.39.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %param_0.1277 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1403 = s32[]{:T(128)S(6)} parameter(1) - %fusion.350 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1277, %param_1.1403), kind=kLoop, calls=%fused_computation.14.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %add.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} + %div.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} + %mul.2095 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1406, %div.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %param_0.1283 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(0) + %bitcast.570 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_0.1283), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + ROOT %mul.2094 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2095, %bitcast.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} +} + +%fused_computation.21.clone.clone (param_0.1284: bf16[4,4096,14336], param_1.1407: s32[], param_2.1179: bf16[14336,4,128], param_3.841: bf16[4,128,14336]) -> bf16[4,128,4096] { + %param_2.1179 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) + %param_3.841 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %bitcast_multiply_fusion.15 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1179, %param_3.841), kind=kLoop, calls=%fused_computation.39.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %param_0.1284 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1407 = s32[]{:T(128)S(6)} parameter(1) + %fusion.350 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1284, %param_1.1407), kind=kLoop, calls=%fused_computation.14.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.117 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%bitcast_multiply_fusion.15, %fusion.350), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} } -%fused_computation.14.clone.clone.clone (param_0.1269: bf16[4,4096,14336], param_1.1395: s32[]) -> bf16[4096,14336,1] { - %param_0.1269 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1395 = s32[]{:T(128)S(6)} parameter(1) - %constant.1142 = s32[]{:T(128)} constant(0) - %dynamic_slice.327 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1269, %param_1.1395, %constant.1142, %constant.1142), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.14.clone.clone.clone (param_0.1276: bf16[4,4096,14336], param_1.1399: s32[]) -> bf16[4096,14336,1] { + %param_0.1276 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1399 = s32[]{:T(128)S(6)} parameter(1) + %constant.1134 = s32[]{:T(128)} constant(0) + %dynamic_slice.327 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1276, %param_1.1399, %constant.1134, %constant.1134), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.567 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.327), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.96.clone.1.clone.clone (param_0.1270: f32[4,128], param_1.1396: bf16[4,128,4096], param_2.1181: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1181 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.430 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1181), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1396 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1104 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1396), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1270 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1715 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1270), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1714 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1104, %mul.1715), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1103 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1714), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.429 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.430, %convert_element_type.1103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.20.clone.clone (param_0.1271: bf16[4,4096,14336], param_1.1397: s32[], param_2.1182: f32[4,128], param_3.851: bf16[4,128,4096], param_4.532: bf16[4096]) -> bf16[4,128,14336] { - %param_2.1182 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.851 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.532 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.347 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1182, %param_3.851, %param_4.532), kind=kLoop, calls=%fused_computation.96.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1271 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1397 = s32[]{:T(128)S(6)} parameter(1) - %fusion.346 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1271, %param_1.1397), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.96.clone.1.clone.clone (param_0.1277: f32[4,128], param_1.1400: bf16[4,128,4096], param_2.1175: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1400 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1120 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1277 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2088 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1277), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2087 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1120, %mul.2088), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1119 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2087), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1175 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2089 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1175), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2086 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1119, %mul.2089), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.20.clone.clone (param_0.1278: bf16[4,4096,14336], param_1.1401: s32[], param_2.1176: f32[4,128], param_3.839: bf16[4,128,4096], param_4.526: bf16[4096]) -> bf16[4,128,14336] { + %param_2.1176 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.839 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.526 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.347 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1176, %param_3.839, %param_4.526), kind=kLoop, calls=%fused_computation.96.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1278 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1401 = s32[]{:T(128)S(6)} parameter(1) + %fusion.346 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1278, %param_1.1401), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.115 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.347, %fusion.346), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%region_14.17 (reduce_sum.126: f32[], reduce_sum.127: f32[]) -> f32[] { - %reduce_sum.126 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} - %reduce_sum.127 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} - ROOT %reduce_sum.128 = f32[]{:T(128)} add(%reduce_sum.126, %reduce_sum.127), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_14.17 (reduce_sum.166: f32[], reduce_sum.167: f32[]) -> f32[] { + %reduce_sum.166 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + %reduce_sum.167 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + ROOT %reduce_sum.168 = f32[]{:T(128)} add(%reduce_sum.166, %reduce_sum.167), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.11.clone.clone.clone.clone.clone.clone.clone (param_0.1278: bf16[4,4096,14336], param_1.1404: s32[]) -> bf16[4096,14336,1] { - %param_0.1278 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1404 = s32[]{:T(128)S(6)} parameter(1) - %constant.1146 = s32[]{:T(128)} constant(0) - %dynamic_slice.330 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1278, %param_1.1404, %constant.1146, %constant.1146), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.11.clone.clone.clone.clone.clone.clone.clone (param_0.1285: bf16[4,4096,14336], param_1.1408: s32[]) -> bf16[4096,14336,1] { + %param_0.1285 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1408 = s32[]{:T(128)S(6)} parameter(1) + %constant.1138 = s32[]{:T(128)} constant(0) + %dynamic_slice.330 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1285, %param_1.1408, %constant.1138, %constant.1138), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.571 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.330), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.38.clone.1.clone.clone.clone.clone (param_0.1279: bf16[4,128,14336], param_1.1405: bf16[4,128,14336], param_2.1186: bf16[14336,4,128]) -> bf16[4,128,14336] { - %param_2.1186 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) - %bitcast.572 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_2.1186), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - %param_1.1405 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %mul.1724 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%bitcast.572, %param_1.1405), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %constant.1147 = bf16[]{:T(256)} constant(1) - %jit_silu_.45 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1147), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} - %param_0.1279 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %neg.131 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_0.1279), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} +%fused_computation.38.clone.1.clone.clone.clone.clone (param_0.1286: bf16[4,128,14336], param_1.1409: bf16[4,128,14336], param_2.1180: bf16[14336,4,128]) -> bf16[4,128,14336] { + %param_2.1180 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) + %bitcast.572 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_2.1180), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + %param_1.1409 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %mul.2100 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%bitcast.572, %param_1.1409), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1139 = bf16[]{:T(256)} constant(1) + %jit_silu_.45 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1139), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} + %param_0.1286 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %neg.131 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_0.1286), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} %exp.70 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} exponential(%neg.131), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} - %add.863 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.70, %jit_silu_.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} - %div.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.45, %add.863), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} - %mul.1723 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.1724, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} - %mul.1722 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1279, %mul.1724), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + %add.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.70, %jit_silu_.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} + %div.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.45, %add.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} + %mul.2099 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2100, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + %mul.2098 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1286, %mul.2100), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} %sub.98 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} subtract(%jit_silu_.45, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/sub" stack_frame_id=0} - %mul.1721 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%div.849, %sub.98), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} - %mul.1720 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.1722, %mul.1721), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} - ROOT %add_any.145 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%mul.1723, %mul.1720), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} + %mul.2097 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%div.849, %sub.98), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %mul.2096 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2098, %mul.2097), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + ROOT %add_any.145 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%mul.2099, %mul.2096), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} } -%fused_computation.63.clone.clone (param_0.1280: bf16[4,128,4096], param_1.1406: bf16[4096], param_2.1187: bf16[4,128,4096], param_3.854: bf16[4,4096,14336], param_4.534: s32[], param_5.435: bf16[4,128,14336], param_6.304: bf16[4,128,14336], param_7.200: bf16[14336,4,128]) -> (f32[4,128], bf16[4,128,4096]) { - %param_0.1280 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1108 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1280), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_2.1187 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) +%fused_computation.63.clone.clone (param_0.1287: bf16[4,128,4096], param_1.1410: bf16[4096], param_2.1181: bf16[4,128,4096], param_3.842: bf16[4,4096,14336], param_4.528: s32[], param_5.435: bf16[4,128,14336], param_6.305: bf16[4,128,14336], param_7.200: bf16[14336,4,128]) -> (f32[4,128], bf16[4,128,4096]) { + %param_0.1287 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1124 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1181 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) %param_5.435 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(5) - %param_6.304 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(6) + %param_6.305 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(6) %param_7.200 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(7) - %fusion.134.clone.3 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_5.435, %param_6.304, %param_7.200), kind=kLoop, calls=%fused_computation.38.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} - %param_3.854 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(3) - %param_4.534 = s32[]{:T(128)S(6)} parameter(4) - %fusion.79.clone.3 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_3.854, %param_4.534), kind=kLoop, calls=%fused_computation.11.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %fusion.134.clone.3 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_5.435, %param_6.305, %param_7.200), kind=kLoop, calls=%fused_computation.38.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} + %param_3.842 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.528 = s32[]{:T(128)S(6)} parameter(4) + %fusion.79.clone.3 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_3.842, %param_4.528), kind=kLoop, calls=%fused_computation.11.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.60.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convolution(%fusion.134.clone.3, %fusion.79.clone.3), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - %add_any.132.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_2.1187, %convolution.60.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - %param_1.1406 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(1) - %dot_general.434 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_1.1406), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.433 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%add_any.132.clone.3, %dot_general.434), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.1107 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.433), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} - %mul.1725 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1108, %convert_element_type.1107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %constant.1148 = f32[]{:T(128)} constant(0) - %reduce.178 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1725, %constant.1148), dimensions={2}, to_apply=%region_14.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.189 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.178, %add_any.132.clone.3) -} - -%fused_computation.140.clone.clone (param_0.1281: f32[4,128], param_1.1407: f32[4,128]) -> f32[4,128] { - %param_0.1281 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %param_1.1407 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %constant.1152 = f32[]{:T(128)} constant(0.000244140625) - %closed_call.89 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1152), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.851 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1407, %closed_call.89), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1151 = f32[]{:T(128)} constant(1e-05) - %closed_call.88 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1151), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.864 = f32[4,128]{1,0:T(4,128)} add(%div.851, %closed_call.88), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %rsqrt.99 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} - %div.850 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.99, %add.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1150 = f32[]{:T(128)} constant(-0.5) - %closed_call.87 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1150), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.1728 = f32[4,128]{1,0:T(4,128)} multiply(%div.850, %closed_call.87), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1727 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1281, %mul.1728), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %constant.1149 = f32[]{:T(128)} constant(0.00048828125) - %mul.1729 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1149), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - ROOT %mul.1726 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1727, %mul.1729), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} -} - -%region_20.24 (dot_general.187: bf16[], dot_general.188: bf16[]) -> bf16[] { - %dot_general.187 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general"} - %dot_general.188 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general"} - ROOT %add.173 = bf16[]{:T(256)} add(%dot_general.187, %dot_general.188), metadata={op_name="add.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.94.clone.clone (param_0.1282: bf16[4,128,4096], param_1.1408: f32[4,128], param_2.1188: bf16[4,128,4096], param_3.855: bf16[4,128,4096], param_4.535: f32[4,128], param_5.436: bf16[4096]) -> (bf16[4096], bf16[4,128,4096]) { - %param_0.1282 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %param_2.1188 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %convert_element_type.1110 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_2.1188), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_1.1408 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %mul.1731 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1408), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1730 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1110, %mul.1731), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1109 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1730), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %multiply.271 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1282, %convert_element_type.1109), metadata={op_name="multiply.204"} - %constant.1153 = bf16[]{:T(256)} constant(0) - %reduce.179 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%multiply.271, %constant.1153), dimensions={0,1}, to_apply=%region_20.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_3.855 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %add_any.132.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_2.1181, %convolution.60.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %param_1.1410 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2103 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_1.1410), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2102 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%add_any.132.clone.3, %mul.2103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %convert_element_type.1123 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.2102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %mul.2101 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1124, %convert_element_type.1123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1140 = f32[]{:T(128)} constant(0) + %reduce.139 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2101, %constant.1140), dimensions={2}, to_apply=%region_14.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.189 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.139, %add_any.132.clone.3) +} + +%fused_computation.140.clone.clone (param_0.1288: f32[4,128], param_1.1411: f32[4,128]) -> f32[4,128] { + %param_0.1288 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1411 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.1144 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.89 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1144), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.851 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1411, %closed_call.89), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1143 = f32[]{:T(128)} constant(1e-05) + %closed_call.88 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1143), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.850 = f32[4,128]{1,0:T(4,128)} add(%div.851, %closed_call.88), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %rsqrt.99 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} + %div.850 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.99, %add.850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1142 = f32[]{:T(128)} constant(-0.5) + %closed_call.87 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1142), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2106 = f32[4,128]{1,0:T(4,128)} multiply(%div.850, %closed_call.87), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2105 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1288, %mul.2106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1141 = f32[]{:T(128)} constant(0.00048828125) + %mul.2107 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1141), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + ROOT %mul.2104 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.2105, %mul.2107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} +} + +%region_20.24 (reduce_sum.175: bf16[], reduce_sum.179: bf16[]) -> bf16[] { + %reduce_sum.175 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + %reduce_sum.179 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + ROOT %reduce_sum.180 = bf16[]{:T(256)} add(%reduce_sum.175, %reduce_sum.179), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.93.clone.clone (param_0.1289: bf16[4,128,4096], param_1.1412: f32[4,128], param_2.1182: bf16[4,128,4096], param_3.843: bf16[4,128,4096], param_4.529: f32[4,128], param_5.436: bf16[4096]) -> (bf16[4096], bf16[4,128,4096]) { + %param_2.1182 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %convert_element_type.1126 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_2.1182), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1412 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2110 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2109 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1126, %mul.2110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1125 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1289 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %mul.2108 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1125, %param_0.1289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1145 = bf16[]{:T(256)} constant(0) + %reduce.140 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%mul.2108, %constant.1145), dimensions={0,1}, to_apply=%region_20.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} + %param_3.843 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) %param_5.436 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(5) - %dot_general.286.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_5.436), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.263.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1282, %dot_general.286.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.753.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.263.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} - %mul.1142.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.753.clone.3, %mul.1731), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %param_4.535 = f32[4,128]{1,0:T(4,128)S(1)} parameter(4) - %mul.1151.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_4.535), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %mul.1141.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1110, %mul.1151.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %add_any.126.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1142.clone.3, %mul.1141.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - %convert_element_type.751.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.126.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} - %add_any.124.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_3.855, %convert_element_type.751.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - ROOT %tuple.190 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.179, %add_any.124.clone.3) -} - -%region_15.18 (dot_general.184: f32[], dot_general.185: f32[]) -> f32[] { - %dot_general.184 = f32[]{:T(128)} parameter(0), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} - %dot_general.185 = f32[]{:T(128)} parameter(1), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} - ROOT %add.169 = f32[]{:T(128)} add(%dot_general.184, %dot_general.185), metadata={op_name="add.31"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.25.clone.clone.clone.clone.clone.clone.clone (param_0.1283: bf16[4,32,128,4096], param_1.1409: s32[]) -> bf16[32,128,4096,1] { - %param_0.1283 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1409 = s32[]{:T(128)S(6)} parameter(1) - %constant.1154 = s32[]{:T(128)} constant(0) - %dynamic_slice.331 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1283, %param_1.1409, %constant.1154, %constant.1154, %constant.1154), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %mul.1399.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_5.436), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.1353.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1289, %mul.1399.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %convert_element_type.769.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1353.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %mul.1333.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.769.clone.3, %mul.2110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %param_4.529 = f32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %mul.1344.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_4.529), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %mul.1332.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1126, %mul.1344.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %add_any.126.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1333.clone.3, %mul.1332.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %convert_element_type.767.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.126.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %add_any.124.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_3.843, %convert_element_type.767.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + ROOT %tuple.190 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.140, %add_any.124.clone.3) +} + +%region_15.18 (dot_general.157: f32[], dot_general.158: f32[]) -> f32[] { + %dot_general.157 = f32[]{:T(128)} parameter(0), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} + %dot_general.158 = f32[]{:T(128)} parameter(1), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} + ROOT %add.157 = f32[]{:T(128)} add(%dot_general.157, %dot_general.158), metadata={op_name="add.31"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.clone.clone.clone.clone.clone.clone (param_0.1290: bf16[4,32,128,4096], param_1.1413: s32[]) -> bf16[32,128,4096,1] { + %param_0.1290 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1413 = s32[]{:T(128)S(6)} parameter(1) + %constant.1146 = s32[]{:T(128)} constant(0) + %dynamic_slice.331 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1290, %param_1.1413, %constant.1146, %constant.1146, %constant.1146), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.573 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.331), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.76.clone.clone.clone.clone.clone.clone (param_0.1284: bf16[4,128,4096]) -> bf16[4,128,4096,1] { - %param_0.1284 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.574 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%param_0.1284), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} +%fused_computation.76.clone.clone.clone.clone.clone.clone (param_0.1291: bf16[4,128,4096]) -> bf16[4,128,4096,1] { + %param_0.1291 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.574 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%param_0.1291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} } -%fused_computation.66.clone.clone (param_0.1285: bf16[4,32,128,128], param_1.1410: bf16[4,32,128,4096], param_2.1189: s32[], param_3.856: bf16[4,128,4096]) -> (f32[4,32,128], bf16[4,32,128,128]) { - %param_0.1285 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert.124 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%param_0.1285), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert" stack_frame_id=0} - %param_3.856 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %fusion.95.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_3.856), kind=kLoop, calls=%fused_computation.76.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - %param_1.1410 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1189 = s32[]{:T(128)S(6)} parameter(2) - %fusion.94.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.1410, %param_2.1189), kind=kLoop, calls=%fused_computation.25.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.66.clone.clone (param_0.1292: bf16[4,32,128,128], param_1.1414: bf16[4,32,128,4096], param_2.1183: s32[], param_3.844: bf16[4,128,4096]) -> (f32[4,32,128], bf16[4,32,128,128]) { + %param_0.1292 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert.87 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%param_0.1292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert" stack_frame_id=0} + %param_3.844 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %fusion.95.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_3.844), kind=kLoop, calls=%fused_computation.76.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %param_1.1414 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1183 = s32[]{:T(128)S(6)} parameter(2) + %fusion.94.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.1414, %param_2.1183), kind=kLoop, calls=%fused_computation.25.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.64.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.95.clone.3, %fusion.94.clone.3), window={size=1x32 pad=0_0x31_31 rhs_reversal=0x1}, dim_labels=0bf1_1oi0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - %constant.619.clone.3 = bf16[]{:T(256)} constant(0.25) - %div.442.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.619.clone.3), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} + %constant.611.clone.3 = bf16[]{:T(256)} constant(0.25) + %div.442.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.611.clone.3), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} %div.441.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convolution.64.clone.3, %div.442.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} %bitcast.209.clone.3 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%div.441.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} - %convert.123 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%bitcast.209.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert.1" stack_frame_id=0} - %multiply.272 = f32[4,32,128,128]{3,2,1,0:T(8,128)} multiply(%convert.124, %convert.123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/multiply" stack_frame_id=0} - %constant.1155 = f32[]{:T(128)} constant(0) - %dot_general.435 = f32[4,32,128]{2,1,0:T(8,128)S(1)} reduce(%multiply.272, %constant.1155), dimensions={3}, to_apply=%region_15.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general" stack_frame_id=0} - ROOT %tuple.191 = (f32[4,32,128]{2,1,0:T(8,128)S(1)}, bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)}) tuple(%dot_general.435, %bitcast.209.clone.3) + %convert.86 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%bitcast.209.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert.1" stack_frame_id=0} + %multiply.196 = f32[4,32,128,128]{3,2,1,0:T(8,128)} multiply(%convert.87, %convert.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/multiply" stack_frame_id=0} + %constant.1147 = f32[]{:T(128)} constant(0) + %dot_general.189 = f32[4,32,128]{2,1,0:T(8,128)S(1)} reduce(%multiply.196, %constant.1147), dimensions={3}, to_apply=%region_15.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general" stack_frame_id=0} + ROOT %tuple.191 = (f32[4,32,128]{2,1,0:T(8,128)S(1)}, bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)}) tuple(%dot_general.189, %bitcast.209.clone.3) } diff --git a/tests/utils/reference_hlo_qwen3_1.7b.txt b/tests/utils/reference_hlo_qwen3_1.7b.txt index f1ede66966..6bdc2b6141 100644 --- a/tests/utils/reference_hlo_qwen3_1.7b.txt +++ b/tests/utils/reference_hlo_qwen3_1.7b.txt @@ -14,1446 +14,1446 @@ StackFrames %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %reshape.444 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.461 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.444), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %gather.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.461), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,2048}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %transpose.460 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %reshape.443 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.460), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.554 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.261 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.554), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.261), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,2048}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.260 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.553 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.260), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} } %region_42.47.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} - ROOT %add.584 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.560 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %fused_computation.1 (param_0.3: bf16[151936,2048], param_1.5: s32[512], param_2.4: bf16[512,2048]) -> bf16[151936,2048] { %param_0.3 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) - %reshape.451 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.466 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.451), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %param_2.4 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} parameter(2) - %reshape.452 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - %transpose.467 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%reshape.452), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - ROOT %scatter.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.466, %transpose.467), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_42.47.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} + %reshape.561 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.266 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.561), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %param_2.4 = bf16[512,2048]{1,0:T(8,128)(2,1)} parameter(2) + %reshape.562 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + %transpose.267 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%reshape.562), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + ROOT %scatter.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.266, %transpose.267), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_42.47.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} } -%region_71.76 (reduce_sum.464: f32[], reduce_sum.465: f32[]) -> f32[] { - %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.464 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.464, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_71.76 (reduce_sum.569: f32[], reduce_sum.570: f32[]) -> f32[] { + %reduce_sum.570 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.569 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.571 = f32[]{:T(128)} add(%reduce_sum.569, %reduce_sum.570), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_56.61 (reduce_sum.386: f32[], reduce_sum.387: f32[]) -> f32[] { - %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.386 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.386, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_56.61 (reduce_sum.488: f32[], reduce_sum.492: f32[]) -> f32[] { + %reduce_sum.492 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.488 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.493 = f32[]{:T(128)} add(%reduce_sum.488, %reduce_sum.492), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.277 (param_0.1368: f32[151936,2048], param_1.1556: f32[], param_2.1314: f32[], param_3.918: f32[], param_4.556: f32[151936,2048], param_5.468: f32[], param_6.358: bf16[151936,2048], param_7.201: bf16[151936,2048,1], param_8.118: pred[], param_9.97: f32[151936,2048]) -> (f32[], f32[151936,2048], f32[151936,2048], f32[151936,2048], f32[]) { - %param_0.1368 = f32[151936,2048]{1,0:T(8,128)} parameter(0) - %param_3.918 = f32[]{:T(128)S(6)} parameter(3) - %mul.1926.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_3.918), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.288 (param_0.1409: f32[151936,2048], param_1.1583: f32[], param_2.1325: f32[], param_3.908: f32[], param_4.547: f32[151936,2048], param_5.481: f32[], param_6.356: bf16[151936,2048], param_7.200: bf16[151936,2048,1], param_8.118: pred[], param_9.97: f32[151936,2048]) -> (f32[], f32[151936,2048], f32[151936,2048], f32[151936,2048], f32[]) { + %param_0.1409 = f32[151936,2048]{1,0:T(8,128)} parameter(0) + %param_3.908 = f32[]{:T(128)S(6)} parameter(3) + %mul.2449.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_3.908), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.118 = pred[]{:T(512)S(6)} parameter(8) %select_n.268.clone.1 = pred[151936,2048]{1,0:T(8,128)(4,1)} broadcast(%param_8.118), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_7.201 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} parameter(7) - %bitcast.464.clone.1 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%param_7.201), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %convert_element_type.1409.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.464.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_6.358 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(6) - %convert_element_type.1408.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%param_6.358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %add_any.197.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1409.clone.1, %convert_element_type.1408.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} - %param_5.468 = f32[]{:T(128)} parameter(5) - %div.860.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_5.468), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.859.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add_any.197.clone.1, %div.860.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.267.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%select_n.268.clone.1, %add_any.197.clone.1, %div.859.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1092.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.844.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1092.clone.1), dimensions={}, metadata={op_name="broadcast.74"} - %mul.1932.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %broadcast.844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.200 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} parameter(7) + %bitcast.445.clone.1 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%param_7.200), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %convert_element_type.1433.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.445.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_6.356 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1432.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%param_6.356), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %add_any.188.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1433.clone.1, %convert_element_type.1432.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} + %param_5.481 = f32[]{:T(128)} parameter(5) + %div.860.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_5.481), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.859.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add_any.188.clone.1, %div.860.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.267.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%select_n.268.clone.1, %add_any.188.clone.1, %div.859.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1080.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.754.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1080.clone.1), dimensions={}, metadata={op_name="broadcast.74"} + %mul.2455.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %broadcast.754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_9.97 = f32[151936,2048]{1,0:T(8,128)} parameter(9) - %constant.1096.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1933.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1096.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1931.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_9.97, %mul.1933.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.941.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.1932.clone.1, %mul.1931.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1314 = f32[]{:T(128)S(6)} parameter(2) - %div.856.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_2.1314), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1084.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2456.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1084.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2454.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_9.97, %mul.2456.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.917.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.2455.clone.1, %mul.2454.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1325 = f32[]{:T(128)S(6)} parameter(2) + %div.856.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_2.1325), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.65.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %select_n.267.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1095.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1930.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1095.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1928.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %mul.1930.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.556 = f32[151936,2048]{1,0:T(8,128)} parameter(4) - %constant.1094.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1929.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1094.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1927.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_4.556, %mul.1929.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.940.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.1928.clone.1, %mul.1927.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1556 = f32[]{:T(128)S(6)} parameter(1) - %div.855.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_1.1556), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.854.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.940.clone.1, %div.855.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1083.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2453.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1083.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2451.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %mul.2453.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.547 = f32[151936,2048]{1,0:T(8,128)} parameter(4) + %constant.1082.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2452.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1082.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2450.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_4.547, %mul.2452.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.916.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.2451.clone.1, %mul.2450.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1583 = f32[]{:T(128)S(6)} parameter(1) + %div.855.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_1.1583), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.854.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.916.clone.1, %div.855.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.62.clone.1 = f32[151936,2048]{1,0:T(8,128)} sqrt(%div.854.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1093.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.939.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1093.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.938.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%sqrt.62.clone.1, %add.939.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.426.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%div.856.clone.1, %add.938.clone.1), metadata={op_name="multiply.61"} - %div.853.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.941.clone.1, %multiply.426.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1925.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_0.1368, %broadcast.844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.937.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%div.853.clone.1, %mul.1925.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1924.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%mul.1926.clone.1, %add.937.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.936.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%param_0.1368, %mul.1924.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.214 = f32[151936,2048]{1,0:T(8,128)} multiply(%add.936.clone.1, %add.936.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1200 = f32[]{:T(128)} constant(0) - %reduce.176 = f32[]{:T(128)} reduce(%square.214, %constant.1200), dimensions={0,1}, to_apply=%region_71.76, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.178.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.1200), dimensions={0,1}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.144 = (f32[]{:T(128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.176, %add.936.clone.1, %add.940.clone.1, %add.941.clone.1, %reduce.178.clone.1) -} - -%region_43.48 (reduce_sum.317: f32[], reduce_sum.318: f32[]) -> f32[] { - %reduce_sum.318 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.317 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.319 = f32[]{:T(128)} add(%reduce_sum.317, %reduce_sum.318), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.367.clone.clone (param_0.1355: f32[4,128], param_1.1549: bf16[4,128,2048], param_2.1290: bf16[2048]) -> bf16[4,128,2048] { - %param_2.1290 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.480 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1290), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1549 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1451 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1549), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1355 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2083 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1355), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.2082 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1451, %mul.2083), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1450 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2082), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.479 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.480, %convert_element_type.1450), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.289.clone.clone.clone (param_0.1356: bf16[4,128,151936], param_1.1550: s32[4,128], param_2.1291: f32[4,128], param_3.911: f32[4,128], param_4.546: bf16[4,128], param_5.446: f32[4,128]) -> bf16[4,128,151936] { - %param_5.446 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.2087 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.446), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.911 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.2086 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.911), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1356 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1454 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1356), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.546 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.94 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.546), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.93 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1454, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %exp.62 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.93), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.2085 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2086, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1291 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.966 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1291), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.965 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2085, %div.966), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1550 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.49 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1550), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %constant.1081.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.915.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1081.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.914.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%sqrt.62.clone.1, %add.915.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.287.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%div.856.clone.1, %add.914.clone.1), metadata={op_name="multiply.46"} + %div.853.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.917.clone.1, %multiply.287.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2448.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_0.1409, %broadcast.754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.913.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%div.853.clone.1, %mul.2448.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2447.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%mul.2449.clone.1, %add.913.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.912.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%param_0.1409, %mul.2447.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.214 = f32[151936,2048]{1,0:T(8,128)} multiply(%add.912.clone.1, %add.912.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1188 = f32[]{:T(128)} constant(0) + %reduce.106 = f32[]{:T(128)} reduce(%square.214, %constant.1188), dimensions={0,1}, to_apply=%region_71.76, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.108.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.1188), dimensions={0,1}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.106, %add.912.clone.1, %add.916.clone.1, %add.917.clone.1, %reduce.108.clone.1) +} + +%region_43.48 (reduce_sum.422: f32[], reduce_sum.423: f32[]) -> f32[] { + %reduce_sum.423 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.422 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.422, %reduce_sum.423), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.378.clone.clone (param_0.1396: f32[4,128], param_1.1576: bf16[4,128,2048], param_2.1301: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1576 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1475 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1576), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1396 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2627 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1396), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2626 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1475, %mul.2627), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1474 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2626), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1301 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2628 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1301), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2625 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1474, %mul.2628), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.300.clone.clone.clone (param_0.1397: bf16[4,128,151936], param_1.1577: s32[4,128], param_2.1302: f32[4,128], param_3.901: f32[4,128], param_4.537: bf16[4,128], param_5.459: f32[4,128]) -> bf16[4,128,151936] { + %param_5.459 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2632 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.459), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.901 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2631 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.901), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1397 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1478 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1397), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.537 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.92 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.537), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.91 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1478, %sub.92), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.62 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.91), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.2630 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2631, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1302 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.966 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1302), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.965 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2630, %div.966), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1577 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.49 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1577), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.48 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.47 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.49, %eq.48), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1453 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.92 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.965, %convert_element_type.1453), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.2084 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2087, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1452 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2084), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convert_element_type.1477 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.90 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.965, %convert_element_type.1477), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.2629 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2632, %sub.90), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1476 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2629), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} } -%fused_computation.281 (param_0.1381: bf16[151936,2048], param_1.1569: f32[4,128], param_2.1327: bf16[4,128,2048], param_3.931: bf16[2048], param_4.569: bf16[4,128,151936], param_5.481: s32[4,128], param_6.371: f32[4,128], param_7.214: f32[4,128], param_8.131: bf16[4,128], param_9.98: f32[4,128]) -> (f32[], bf16[151936,2048,1]) { - %param_4.569 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(4) - %param_5.481 = s32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %param_6.371 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) - %param_7.214 = f32[4,128]{1,0:T(4,128)S(1)} parameter(7) +%fused_computation.292 (param_0.1422: bf16[151936,2048], param_1.1596: f32[4,128], param_2.1338: bf16[4,128,2048], param_3.921: bf16[2048], param_4.560: bf16[4,128,151936], param_5.494: s32[4,128], param_6.369: f32[4,128], param_7.213: f32[4,128], param_8.131: bf16[4,128], param_9.98: f32[4,128]) -> (f32[], bf16[151936,2048,1]) { + %param_4.560 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(4) + %param_5.494 = s32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.369 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.213 = f32[4,128]{1,0:T(4,128)S(1)} parameter(7) %param_8.131 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(8) %param_9.98 = f32[4,128]{1,0:T(4,128)S(1)} parameter(9) - %multiply_convert_fusion.1.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_4.569, %param_5.481, %param_6.371, %param_7.214, %param_8.131, /*index=5*/%param_9.98), kind=kLoop, calls=%fused_computation.289.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_1.1569 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1327 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.931 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.269.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1569, %param_2.1327, %param_3.931), kind=kLoop, calls=%fused_computation.367.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convolution.86.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} convolution(%multiply_convert_fusion.1.clone.1, %fusion.269.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %bitcast.333 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %convert_element_type.1323 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.333), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_0.1381 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1322 = f32[151936,2048]{1,0:T(8,128)} convert(%param_0.1381), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %add_any.184 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1323, %convert_element_type.1322), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} - %square.215 = f32[151936,2048]{1,0:T(8,128)} multiply(%add_any.184, %add_any.184), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1213 = f32[]{:T(128)} constant(0) - %reduce.177 = f32[]{:T(128)} reduce(%square.215, %constant.1213), dimensions={0,1}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.166 = (f32[]{:T(128)}, bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.177, %convolution.86.clone.1) + %multiply_convert_fusion.2.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_4.560, %param_5.494, %param_6.369, %param_7.213, %param_8.131, /*index=5*/%param_9.98), kind=kLoop, calls=%fused_computation.300.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1596 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1338 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.921 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.279.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1596, %param_2.1338, %param_3.921), kind=kLoop, calls=%fused_computation.378.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convolution.86.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.279.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %bitcast.314 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %convert_element_type.1347 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.314), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_0.1422 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1346 = f32[151936,2048]{1,0:T(8,128)} convert(%param_0.1422), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %add_any.175 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1347, %convert_element_type.1346), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} + %square.215 = f32[151936,2048]{1,0:T(8,128)} multiply(%add_any.175, %add_any.175), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1201 = f32[]{:T(128)} constant(0) + %reduce.107 = f32[]{:T(128)} reduce(%square.215, %constant.1201), dimensions={0,1}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.167 = (f32[]{:T(128)}, bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.107, %convolution.86.clone.1) } -%region_57.62 (reduce_sum.389: f32[], reduce_sum.393: f32[]) -> f32[] { - %reduce_sum.393 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.389 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.394 = f32[]{:T(128)} add(%reduce_sum.389, %reduce_sum.393), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_57.62 (reduce_sum.494: f32[], reduce_sum.495: f32[]) -> f32[] { + %reduce_sum.495 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.494 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.499 = f32[]{:T(128)} add(%reduce_sum.494, %reduce_sum.495), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.288 (param_0.1392: bf16[4,128,151936], param_1.1577: f32[4,128], param_2.1330: s32[4,128], param_3.933: bf16[4,128]) -> f32[4,128] { - %param_2.1330 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %eq.30 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1330), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} +%fused_computation.299 (param_0.1433: bf16[4,128,151936], param_1.1604: f32[4,128], param_2.1341: s32[4,128], param_3.923: bf16[4,128]) -> f32[4,128] { + %param_2.1341 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.30 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1341), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.25 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.24 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.30, %eq.25), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %param_0.1392 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1340 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1392), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_3.933 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) - %sub.73 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.933), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.64 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1340, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %param_1.1577 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %sub.71 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1577), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_0.1433 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1364 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1433), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.923 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.923), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.64 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1364, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.1604 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1604), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %sub.60 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %constant.1225 = f32[]{:T(128)} constant(0) - %broadcast.769 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%constant.1225), dimensions={}, metadata={op_name="broadcast.109"} - %mul.1765 = f32[4,128,151936]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.769), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - ROOT %reduce.179 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1765, %constant.1225), dimensions={2}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1213 = f32[]{:T(128)} constant(0) + %broadcast.681 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%constant.1213), dimensions={}, metadata={op_name="broadcast.99"} + %mul.2269 = f32[4,128,151936]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.681), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.109 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2269, %constant.1213), dimensions={2}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_9.12 (reduce_sum.186: f32[], reduce_sum.190: f32[]) -> f32[] { - %reduce_sum.190 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.186 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.191 = f32[]{:T(128)} add(%reduce_sum.186, %reduce_sum.190), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_9.12 (reduce_sum.237: f32[], reduce_sum.241: f32[]) -> f32[] { + %reduce_sum.241 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.237 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.262 = f32[]{:T(128)} add(%reduce_sum.237, %reduce_sum.241), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.293 (param_0.1393: bf16[4,128,151936], param_1.1578: bf16[4,128]) -> f32[4,128] { - %param_0.1393 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1346 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1393), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_1.1578 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) - %sub.74 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1578), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.70 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1346, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} +%fused_computation.304 (param_0.1434: bf16[4,128,151936], param_1.1605: bf16[4,128]) -> f32[4,128] { + %param_0.1434 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1370 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1434), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1605 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1605), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.70 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1370, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.54 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %constant.1226 = f32[]{:T(128)} constant(0) - ROOT %reduce.180 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1226), dimensions={2}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1214 = f32[]{:T(128)} constant(0) + ROOT %reduce.110 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1214), dimensions={2}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_33.38 (reduce_sum.269: f32[], reduce_sum.270: f32[]) -> f32[] { - %reduce_sum.270 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.269 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.274 = f32[]{:T(128)} add(%reduce_sum.269, %reduce_sum.270), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_33.38 (reduce_sum.374: f32[], reduce_sum.375: f32[]) -> f32[] { + %reduce_sum.375 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.374 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.376 = f32[]{:T(128)} add(%reduce_sum.374, %reduce_sum.375), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.298 (param_0.1387: f32[4,6144,2048]) -> f32[] { - %param_0.1387 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(0) - %bitcast.347 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_0.1387), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.218 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%bitcast.347, %bitcast.347), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1219 = f32[]{:T(128)} constant(0) - ROOT %reduce.181 = f32[]{:T(128)} reduce(%square.218, %constant.1219), dimensions={0,1,2}, to_apply=%region_33.38, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.309 (param_0.1428: f32[4,6144,2048]) -> f32[] { + %param_0.1428 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(0) + %bitcast.328 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_0.1428), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.218 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%bitcast.328, %bitcast.328), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1207 = f32[]{:T(128)} constant(0) + ROOT %reduce.111 = f32[]{:T(128)} reduce(%square.218, %constant.1207), dimensions={0,1,2}, to_apply=%region_33.38, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_32.37 (reduce_sum.263: f32[], reduce_sum.267: f32[]) -> f32[] { - %reduce_sum.267 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.263 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.268 = f32[]{:T(128)} add(%reduce_sum.263, %reduce_sum.267), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_32.37 (reduce_sum.368: f32[], reduce_sum.369: f32[]) -> f32[] { + %reduce_sum.369 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.368 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.373 = f32[]{:T(128)} add(%reduce_sum.368, %reduce_sum.369), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_31.36 (reduce_sum.260: f32[], reduce_sum.261: f32[]) -> f32[] { - %reduce_sum.261 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.260 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.262 = f32[]{:T(128)} add(%reduce_sum.260, %reduce_sum.261), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_31.36 (reduce_sum.362: f32[], reduce_sum.366: f32[]) -> f32[] { + %reduce_sum.366 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.362 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.367 = f32[]{:T(128)} add(%reduce_sum.362, %reduce_sum.366), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.300 (param_0.1388: f32[4,2048,6144], param_1.1573: f32[4,2048,6144]) -> (f32[], f32[]) { - %param_0.1388 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(0) - %bitcast.351 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_0.1388), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.221 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.351, %bitcast.351), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1220 = f32[]{:T(128)} constant(0) - %reduce.182 = f32[]{:T(128)} reduce(%square.221, %constant.1220), dimensions={0,1,2}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1573 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(1) - %bitcast.355.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_1.1573), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.224.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.355.clone.1, %bitcast.355.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.183.clone.1 = f32[]{:T(128)} reduce(%square.224.clone.1, %constant.1220), dimensions={0,1,2}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.167 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.182, %reduce.183.clone.1) +%fused_computation.311 (param_0.1429: f32[4,2048,6144], param_1.1600: f32[4,2048,6144]) -> (f32[], f32[]) { + %param_0.1429 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(0) + %bitcast.332 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_0.1429), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.221 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.332, %bitcast.332), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1208 = f32[]{:T(128)} constant(0) + %reduce.112 = f32[]{:T(128)} reduce(%square.221, %constant.1208), dimensions={0,1,2}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1600 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(1) + %bitcast.336.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_1.1600), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.224.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.336.clone.1, %bitcast.336.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.113.clone.1 = f32[]{:T(128)} reduce(%square.224.clone.1, %constant.1208), dimensions={0,1,2}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.168 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.112, %reduce.113.clone.1) } -%fused_computation.303 (param_0.901: f32[6144,4,2048]) -> bf16[4,6144,2048] { - %param_0.901 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) - %copy.190 = bf16[6144,4,2048]{2,0,1:T(8,128)(2,1)} copy(%param_0.901), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} - ROOT %bitcast.356 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} bitcast(%copy.190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.314 (param_0.940: f32[6144,4,2048]) -> bf16[4,6144,2048] { + %param_0.940 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %copy.186 = bf16[6144,4,2048]{2,0,1:T(8,128)(2,1)} copy(%param_0.940), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} + ROOT %bitcast.337 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} bitcast(%copy.186), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.304 (param_0.903: f32[2048,4,6144]) -> bf16[4,2048,6144] { - %param_0.903 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %copy.191 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.903), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} - ROOT %bitcast.357 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.191), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.315 (param_0.942: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.942 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.187 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.942), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} + ROOT %bitcast.338 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.187), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.305 (param_0.905: f32[2048,4,6144]) -> bf16[4,2048,6144] { - %param_0.905 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %copy.192 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.905), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} - ROOT %bitcast.358 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.192), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.316 (param_0.944: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.944 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.188 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.944), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} + ROOT %bitcast.339 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.188), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_62.67 (reduce_sum.416: f32[], reduce_sum.417: f32[]) -> f32[] { - %reduce_sum.417 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.416 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.421 = f32[]{:T(128)} add(%reduce_sum.416, %reduce_sum.417), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_62.67 (reduce_sum.521: f32[], reduce_sum.522: f32[]) -> f32[] { + %reduce_sum.522 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.521 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.523 = f32[]{:T(128)} add(%reduce_sum.521, %reduce_sum.522), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_47.52 (reduce_sum.338: f32[], reduce_sum.339: f32[]) -> f32[] { - %reduce_sum.339 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.338 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.340 = f32[]{:T(128)} add(%reduce_sum.338, %reduce_sum.339), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_47.52 (reduce_sum.443: f32[], reduce_sum.444: f32[]) -> f32[] { + %reduce_sum.444 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.443 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.445 = f32[]{:T(128)} add(%reduce_sum.443, %reduce_sum.444), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.306 (param_0.1377: f32[6144,4,2048], param_1.1565: f32[], param_2.1323: f32[], param_3.927: f32[], param_4.565: f32[6144,4,2048], param_5.477: f32[], param_6.367: f32[4,6144,2048], param_7.210: pred[], param_8.127: f32[6144,4,2048]) -> (f32[], f32[6144,4,2048], f32[6144,4,2048], f32[6144,4,2048], f32[]) { - %param_0.1377 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) - %param_3.927 = f32[]{:T(128)S(6)} parameter(3) - %mul.1998.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_3.927), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.210 = pred[]{:T(512)S(6)} parameter(7) - %select_n.304.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.210), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.367 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(6) - %bitcast.482.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_6.367), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.477 = f32[]{:T(128)} parameter(5) - %div.932.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_5.477), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.931.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%bitcast.482.clone.1, %div.932.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.303.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%select_n.304.clone.1, %bitcast.482.clone.1, %div.931.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1146.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.886.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1146.clone.1), dimensions={}, metadata={op_name="broadcast.83"} - %mul.2004.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %broadcast.886.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.317 (param_0.1418: f32[6144,4,2048], param_1.1592: f32[], param_2.1334: f32[], param_3.917: f32[], param_4.556: f32[6144,4,2048], param_5.490: f32[], param_6.365: f32[4,6144,2048], param_7.209: pred[], param_8.127: f32[6144,4,2048]) -> (f32[], f32[6144,4,2048], f32[6144,4,2048], f32[6144,4,2048], f32[]) { + %param_0.1418 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %param_3.917 = f32[]{:T(128)S(6)} parameter(3) + %mul.2521.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_3.917), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.209 = pred[]{:T(512)S(6)} parameter(7) + %select_n.304.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.209), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.365 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(6) + %bitcast.463.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_6.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.490 = f32[]{:T(128)} parameter(5) + %div.932.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_5.490), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.931.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%bitcast.463.clone.1, %div.932.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.303.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%select_n.304.clone.1, %bitcast.463.clone.1, %div.931.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1134.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.796.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1134.clone.1), dimensions={}, metadata={op_name="broadcast.83"} + %mul.2527.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %broadcast.796.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.127 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(8) - %constant.1150.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.2005.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1150.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2003.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_8.127, %mul.2005.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.989.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2004.clone.1, %mul.2003.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1323 = f32[]{:T(128)S(6)} parameter(2) - %div.928.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_2.1323), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1138.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2528.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1138.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2526.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_8.127, %mul.2528.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.965.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2527.clone.1, %mul.2526.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1334 = f32[]{:T(128)S(6)} parameter(2) + %div.928.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_2.1334), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.74.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %select_n.303.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1149.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.2002.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1149.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2000.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%integer_pow.74.clone.1, %mul.2002.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.565 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(4) - %constant.1148.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.2001.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1148.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1999.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_4.565, %mul.2001.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.988.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2000.clone.1, %mul.1999.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1565 = f32[]{:T(128)S(6)} parameter(1) - %div.927.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_1.1565), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.926.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.988.clone.1, %div.927.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1137.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2525.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1137.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2523.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%integer_pow.74.clone.1, %mul.2525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.556 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(4) + %constant.1136.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2524.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1136.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2522.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_4.556, %mul.2524.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.964.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2523.clone.1, %mul.2522.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1592 = f32[]{:T(128)S(6)} parameter(1) + %div.927.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_1.1592), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.926.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.964.clone.1, %div.927.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.71.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} sqrt(%div.926.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1147.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.987.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1147.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.986.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%sqrt.71.clone.1, %add.987.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.435.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%div.928.clone.1, %add.986.clone.1), metadata={op_name="multiply.52"} - %div.925.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.989.clone.1, %multiply.435.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1997.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_0.1377, %broadcast.886.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.985.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%div.925.clone.1, %mul.1997.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1996.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%mul.1998.clone.1, %add.985.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.984.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%param_0.1377, %mul.1996.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.225 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%add.984.clone.1, %add.984.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1209 = f32[]{:T(128)} constant(0) - %reduce.184 = f32[]{:T(128)} reduce(%square.225, %constant.1209), dimensions={0,1,2}, to_apply=%region_62.67, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.187.clone.1 = f32[]{:T(128)} reduce(%integer_pow.74.clone.1, %constant.1209), dimensions={0,1,2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.145 = (f32[]{:T(128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.184, %add.984.clone.1, %add.988.clone.1, %add.989.clone.1, %reduce.187.clone.1) -} - -%region_61.66 (reduce_sum.410: f32[], reduce_sum.414: f32[]) -> f32[] { - %reduce_sum.414 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.410 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.415 = f32[]{:T(128)} add(%reduce_sum.410, %reduce_sum.414), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_46.51 (reduce_sum.332: f32[], reduce_sum.333: f32[]) -> f32[] { - %reduce_sum.333 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.332 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.337 = f32[]{:T(128)} add(%reduce_sum.332, %reduce_sum.333), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1135.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.963.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1135.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.962.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%sqrt.71.clone.1, %add.963.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.296.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%div.928.clone.1, %add.962.clone.1), metadata={op_name="multiply.37"} + %div.925.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.965.clone.1, %multiply.296.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2520.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_0.1418, %broadcast.796.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.961.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%div.925.clone.1, %mul.2520.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2519.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%mul.2521.clone.1, %add.961.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.960.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%param_0.1418, %mul.2519.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.225 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%add.960.clone.1, %add.960.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1197 = f32[]{:T(128)} constant(0) + %reduce.114 = f32[]{:T(128)} reduce(%square.225, %constant.1197), dimensions={0,1,2}, to_apply=%region_62.67, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.117.clone.1 = f32[]{:T(128)} reduce(%integer_pow.74.clone.1, %constant.1197), dimensions={0,1,2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.146 = (f32[]{:T(128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.114, %add.960.clone.1, %add.964.clone.1, %add.965.clone.1, %reduce.117.clone.1) +} + +%region_61.66 (reduce_sum.515: f32[], reduce_sum.516: f32[]) -> f32[] { + %reduce_sum.516 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.515 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.520 = f32[]{:T(128)} add(%reduce_sum.515, %reduce_sum.516), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_46.51 (reduce_sum.437: f32[], reduce_sum.438: f32[]) -> f32[] { + %reduce_sum.438 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.437 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.439 = f32[]{:T(128)} add(%reduce_sum.437, %reduce_sum.438), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.307 (param_0.1378: f32[2048,4,6144], param_1.1566: f32[], param_2.1324: f32[], param_3.928: f32[], param_4.566: f32[2048,4,6144], param_5.478: f32[], param_6.368: f32[4,2048,6144], param_7.211: pred[], param_8.128: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { - %param_0.1378 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %param_3.928 = f32[]{:T(128)S(6)} parameter(3) - %mul.2008.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.928), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.211 = pred[]{:T(512)S(6)} parameter(7) - %select_n.308.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.211), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.368 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) - %bitcast.484.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.368), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.478 = f32[]{:T(128)} parameter(5) - %div.940.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.478), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.939.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.484.clone.1, %div.940.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.307.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.308.clone.1, %bitcast.484.clone.1, %div.939.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1152.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.892.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1152.clone.1), dimensions={}, metadata={op_name="broadcast.85"} - %mul.2012.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %broadcast.892.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.318 (param_0.1419: f32[2048,4,6144], param_1.1593: f32[], param_2.1335: f32[], param_3.918: f32[], param_4.557: f32[2048,4,6144], param_5.491: f32[], param_6.366: f32[4,2048,6144], param_7.210: pred[], param_8.128: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1419 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.918 = f32[]{:T(128)S(6)} parameter(3) + %mul.2531.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.918), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.210 = pred[]{:T(512)S(6)} parameter(7) + %select_n.308.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.210), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.366 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.465.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.366), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.491 = f32[]{:T(128)} parameter(5) + %div.940.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.491), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.939.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.465.clone.1, %div.940.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.307.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.308.clone.1, %bitcast.465.clone.1, %div.939.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1140.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.802.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1140.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2535.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %broadcast.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.128 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) - %constant.1156.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.891.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1156.clone.1), dimensions={}, metadata={op_name="broadcast.84"} - %mul.2011.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.128, %broadcast.891.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.994.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2012.clone.1, %mul.2011.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1324 = f32[]{:T(128)S(6)} parameter(2) - %div.936.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1324), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1144.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.801.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1144.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2534.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.128, %broadcast.801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.970.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2535.clone.1, %mul.2534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1335 = f32[]{:T(128)S(6)} parameter(2) + %div.936.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1335), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.75.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %select_n.307.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1155.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.890.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1155.clone.1), dimensions={}, metadata={op_name="broadcast.73"} - %mul.2010.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.75.clone.1, %broadcast.890.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.566 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) - %constant.1154.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.889.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1154.clone.1), dimensions={}, metadata={op_name="broadcast.72"} - %mul.2009.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.566, %broadcast.889.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.993.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2010.clone.1, %mul.2009.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1566 = f32[]{:T(128)S(6)} parameter(1) - %div.935.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1566), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.934.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.993.clone.1, %div.935.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1143.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.800.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1143.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2533.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.75.clone.1, %broadcast.800.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.557 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1142.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.799.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1142.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2532.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.557, %broadcast.799.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.969.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2533.clone.1, %mul.2532.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1593 = f32[]{:T(128)S(6)} parameter(1) + %div.935.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1593), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.934.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.969.clone.1, %div.935.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.72.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.934.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1153.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.887.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1153.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %add.992.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.72.clone.1, %broadcast.887.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.436.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.936.clone.1, %add.992.clone.1), metadata={op_name="multiply.51"} - %div.933.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.994.clone.1, %multiply.436.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.2007.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1378, %broadcast.892.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.991.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.933.clone.1, %mul.2007.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.2006.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2008.clone.1, %add.991.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.990.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1378, %mul.2006.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.226 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.990.clone.1, %add.990.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1210 = f32[]{:T(128)} constant(0) - %reduce.185 = f32[]{:T(128)} reduce(%square.226, %constant.1210), dimensions={0,1,2}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.188.clone.1 = f32[]{:T(128)} reduce(%integer_pow.75.clone.1, %constant.1210), dimensions={0,1,2}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.146 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.185, %add.990.clone.1, %add.993.clone.1, %add.994.clone.1, %reduce.188.clone.1) -} - -%region_60.65 (reduce_sum.407: f32[], reduce_sum.408: f32[]) -> f32[] { - %reduce_sum.408 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.407 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.409 = f32[]{:T(128)} add(%reduce_sum.407, %reduce_sum.408), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_45.50 (reduce_sum.326: f32[], reduce_sum.330: f32[]) -> f32[] { - %reduce_sum.330 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.326 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.331 = f32[]{:T(128)} add(%reduce_sum.326, %reduce_sum.330), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1141.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.797.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1141.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.968.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.72.clone.1, %broadcast.797.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.297.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.936.clone.1, %add.968.clone.1), metadata={op_name="multiply.36"} + %div.933.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.970.clone.1, %multiply.297.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2530.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1419, %broadcast.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.967.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.933.clone.1, %mul.2530.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2529.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2531.clone.1, %add.967.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.966.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1419, %mul.2529.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.226 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.966.clone.1, %add.966.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1198 = f32[]{:T(128)} constant(0) + %reduce.115 = f32[]{:T(128)} reduce(%square.226, %constant.1198), dimensions={0,1,2}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.118.clone.1 = f32[]{:T(128)} reduce(%integer_pow.75.clone.1, %constant.1198), dimensions={0,1,2}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.147 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.115, %add.966.clone.1, %add.969.clone.1, %add.970.clone.1, %reduce.118.clone.1) +} + +%region_60.65 (reduce_sum.509: f32[], reduce_sum.513: f32[]) -> f32[] { + %reduce_sum.513 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.509 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.514 = f32[]{:T(128)} add(%reduce_sum.509, %reduce_sum.513), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.431: f32[], reduce_sum.432: f32[]) -> f32[] { + %reduce_sum.432 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.436 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.432), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.308 (param_0.1379: f32[2048,4,6144], param_1.1567: f32[], param_2.1325: f32[], param_3.929: f32[], param_4.567: f32[2048,4,6144], param_5.479: f32[], param_6.369: f32[4,2048,6144], param_7.212: pred[], param_8.129: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { - %param_0.1379 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %param_3.929 = f32[]{:T(128)S(6)} parameter(3) - %mul.2015.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.929), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.212 = pred[]{:T(512)S(6)} parameter(7) - %select_n.312.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.212), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.369 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) - %bitcast.486.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.369), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.479 = f32[]{:T(128)} parameter(5) - %div.948.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.479), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.947.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.486.clone.1, %div.948.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.311.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.312.clone.1, %bitcast.486.clone.1, %div.947.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1158.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.898.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1158.clone.1), dimensions={}, metadata={op_name="broadcast.85"} - %mul.2019.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %broadcast.898.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.319 (param_0.1420: f32[2048,4,6144], param_1.1594: f32[], param_2.1336: f32[], param_3.919: f32[], param_4.558: f32[2048,4,6144], param_5.492: f32[], param_6.367: f32[4,2048,6144], param_7.211: pred[], param_8.129: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1420 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.919 = f32[]{:T(128)S(6)} parameter(3) + %mul.2538.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.919), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.211 = pred[]{:T(512)S(6)} parameter(7) + %select_n.312.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.211), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.367 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.467.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.367), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.492 = f32[]{:T(128)} parameter(5) + %div.948.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.492), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.947.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.467.clone.1, %div.948.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.311.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.312.clone.1, %bitcast.467.clone.1, %div.947.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1146.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.808.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1146.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2542.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %broadcast.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.129 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) - %constant.1162.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.897.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1162.clone.1), dimensions={}, metadata={op_name="broadcast.84"} - %mul.2018.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.129, %broadcast.897.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.999.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2019.clone.1, %mul.2018.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1325 = f32[]{:T(128)S(6)} parameter(2) - %div.944.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1325), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1150.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.807.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1150.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2541.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.129, %broadcast.807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.975.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2542.clone.1, %mul.2541.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1336 = f32[]{:T(128)S(6)} parameter(2) + %div.944.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1336), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.76.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %select_n.311.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1161.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.896.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1161.clone.1), dimensions={}, metadata={op_name="broadcast.73"} - %mul.2017.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.76.clone.1, %broadcast.896.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.567 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) - %constant.1160.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.895.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1160.clone.1), dimensions={}, metadata={op_name="broadcast.72"} - %mul.2016.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.567, %broadcast.895.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.998.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2017.clone.1, %mul.2016.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1567 = f32[]{:T(128)S(6)} parameter(1) - %div.943.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1567), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.942.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.998.clone.1, %div.943.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1149.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.806.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1149.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2540.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.76.clone.1, %broadcast.806.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.558 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1148.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.805.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1148.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2539.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.558, %broadcast.805.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.974.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2540.clone.1, %mul.2539.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1594 = f32[]{:T(128)S(6)} parameter(1) + %div.943.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1594), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.942.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.974.clone.1, %div.943.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.73.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.942.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1159.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.893.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1159.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %add.997.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.73.clone.1, %broadcast.893.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.437.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.944.clone.1, %add.997.clone.1), metadata={op_name="multiply.50"} - %div.941.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.999.clone.1, %multiply.437.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.2014.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1379, %broadcast.898.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.996.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.941.clone.1, %mul.2014.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.2013.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2015.clone.1, %add.996.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.995.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1379, %mul.2013.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.227 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.995.clone.1, %add.995.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1211 = f32[]{:T(128)} constant(0) - %reduce.186 = f32[]{:T(128)} reduce(%square.227, %constant.1211), dimensions={0,1,2}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.189.clone.1 = f32[]{:T(128)} reduce(%integer_pow.76.clone.1, %constant.1211), dimensions={0,1,2}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.147 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.186, %add.995.clone.1, %add.998.clone.1, %add.999.clone.1, %reduce.189.clone.1) -} - -%region_39.44 (reduce_sum.302: f32[], reduce_sum.303: f32[]) -> f32[] { - %reduce_sum.303 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.302 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.304 = f32[]{:T(128)} add(%reduce_sum.302, %reduce_sum.303), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1147.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.803.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1147.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.973.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.73.clone.1, %broadcast.803.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.298.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.944.clone.1, %add.973.clone.1), metadata={op_name="multiply.35"} + %div.941.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.975.clone.1, %multiply.298.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2537.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1420, %broadcast.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.972.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.941.clone.1, %mul.2537.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2536.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2538.clone.1, %add.972.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.971.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1420, %mul.2536.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.227 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.971.clone.1, %add.971.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1199 = f32[]{:T(128)} constant(0) + %reduce.116 = f32[]{:T(128)} reduce(%square.227, %constant.1199), dimensions={0,1,2}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.119.clone.1 = f32[]{:T(128)} reduce(%integer_pow.76.clone.1, %constant.1199), dimensions={0,1,2}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.116, %add.971.clone.1, %add.974.clone.1, %add.975.clone.1, %reduce.119.clone.1) +} + +%region_39.44 (reduce_sum.404: f32[], reduce_sum.408: f32[]) -> f32[] { + %reduce_sum.408 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.404 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.409 = f32[]{:T(128)} add(%reduce_sum.404, %reduce_sum.408), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.324 (param_0.1382: f32[4,2048,16,128]) -> f32[] { - %param_0.1382 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.362 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1382), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.230 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%bitcast.362, %bitcast.362), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1214 = f32[]{:T(128)} constant(0) - ROOT %reduce.190 = f32[]{:T(128)} reduce(%square.230, %constant.1214), dimensions={0,1,2,3}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.335 (param_0.1423: f32[4,2048,16,128]) -> f32[] { + %param_0.1423 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.343 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.230 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%bitcast.343, %bitcast.343), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1202 = f32[]{:T(128)} constant(0) + ROOT %reduce.120 = f32[]{:T(128)} reduce(%square.230, %constant.1202), dimensions={0,1,2,3}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_38.43 (reduce_sum.296: f32[], reduce_sum.297: f32[]) -> f32[] { - %reduce_sum.297 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.296 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.298 = f32[]{:T(128)} add(%reduce_sum.296, %reduce_sum.297), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_38.43 (reduce_sum.401: f32[], reduce_sum.402: f32[]) -> f32[] { + %reduce_sum.402 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.401 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.403 = f32[]{:T(128)} add(%reduce_sum.401, %reduce_sum.402), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.326 (param_0.1383: f32[4,16,128,2048]) -> f32[] { - %param_0.1383 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.366 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_0.1383), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.233 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%bitcast.366, %bitcast.366), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1215 = f32[]{:T(128)} constant(0) - ROOT %reduce.191 = f32[]{:T(128)} reduce(%square.233, %constant.1215), dimensions={0,1,2,3}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.337 (param_0.1424: f32[4,16,128,2048]) -> f32[] { + %param_0.1424 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.347 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_0.1424), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.233 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%bitcast.347, %bitcast.347), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1203 = f32[]{:T(128)} constant(0) + ROOT %reduce.121 = f32[]{:T(128)} reduce(%square.233, %constant.1203), dimensions={0,1,2,3}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.327 (param_0.950: f32[16,4,128,2048]) -> bf16[4,16,128,2048] { - %param_0.950 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) - %copy.193 = bf16[16,4,128,2048]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.950), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} - ROOT %bitcast.367 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.193), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.338 (param_0.989: f32[16,4,128,2048]) -> bf16[4,16,128,2048] { + %param_0.989 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %copy.189 = bf16[16,4,128,2048]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.989), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} + ROOT %bitcast.348 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.189), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_68.73 (reduce_sum.449: f32[], reduce_sum.450: f32[]) -> f32[] { - %reduce_sum.450 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.449 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.451 = f32[]{:T(128)} add(%reduce_sum.449, %reduce_sum.450), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_68.73 (reduce_sum.551: f32[], reduce_sum.555: f32[]) -> f32[] { + %reduce_sum.555 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.551 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.556 = f32[]{:T(128)} add(%reduce_sum.551, %reduce_sum.555), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_53.58 (reduce_sum.368: f32[], reduce_sum.372: f32[]) -> f32[] { - %reduce_sum.372 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.368 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.373 = f32[]{:T(128)} add(%reduce_sum.368, %reduce_sum.372), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_53.58 (reduce_sum.473: f32[], reduce_sum.474: f32[]) -> f32[] { + %reduce_sum.474 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.473 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.478 = f32[]{:T(128)} add(%reduce_sum.473, %reduce_sum.474), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.328 (param_0.1371: f32[2048,4,16,128], param_1.1559: f32[], param_2.1317: f32[], param_3.921: f32[], param_4.559: f32[2048,4,16,128], param_5.471: f32[], param_6.361: f32[4,2048,16,128], param_7.204: pred[], param_8.121: f32[2048,4,16,128]) -> (f32[], f32[2048,4,16,128], f32[2048,4,16,128], f32[2048,4,16,128], f32[]) { - %param_0.1371 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.921 = f32[]{:T(128)S(6)} parameter(3) - %mul.1950.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_3.921), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.204 = pred[]{:T(512)S(6)} parameter(7) - %select_n.280.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.204), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.361 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.470.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_6.361), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.471 = f32[]{:T(128)} parameter(5) - %div.884.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_5.471), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.883.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%bitcast.470.clone.1, %div.884.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.279.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%select_n.280.clone.1, %bitcast.470.clone.1, %div.883.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1110.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.858.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1110.clone.1), dimensions={}, metadata={op_name="broadcast.75"} - %mul.1956.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %broadcast.858.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.339 (param_0.1412: f32[2048,4,16,128], param_1.1586: f32[], param_2.1328: f32[], param_3.911: f32[], param_4.550: f32[2048,4,16,128], param_5.484: f32[], param_6.359: f32[4,2048,16,128], param_7.203: pred[], param_8.121: f32[2048,4,16,128]) -> (f32[], f32[2048,4,16,128], f32[2048,4,16,128], f32[2048,4,16,128], f32[]) { + %param_0.1412 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.911 = f32[]{:T(128)S(6)} parameter(3) + %mul.2473.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_3.911), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.203 = pred[]{:T(512)S(6)} parameter(7) + %select_n.280.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.203), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.359 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.451.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_6.359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.484 = f32[]{:T(128)} parameter(5) + %div.884.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_5.484), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.883.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%bitcast.451.clone.1, %div.884.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.279.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%select_n.280.clone.1, %bitcast.451.clone.1, %div.883.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1098.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.768.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1098.clone.1), dimensions={}, metadata={op_name="broadcast.75"} + %mul.2479.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %broadcast.768.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.121 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.1114.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1957.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1114.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1955.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_8.121, %mul.1957.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.957.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.1956.clone.1, %mul.1955.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1317 = f32[]{:T(128)S(6)} parameter(2) - %div.880.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1317), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1102.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2480.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1102.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2478.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_8.121, %mul.2480.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.933.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.2479.clone.1, %mul.2478.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1328 = f32[]{:T(128)S(6)} parameter(2) + %div.880.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1328), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.68.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %select_n.279.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1113.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1954.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1113.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1952.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.68.clone.1, %mul.1954.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.559 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.1112.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1953.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1112.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1951.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_4.559, %mul.1953.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.956.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.1952.clone.1, %mul.1951.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1559 = f32[]{:T(128)S(6)} parameter(1) - %div.879.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1559), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.878.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.956.clone.1, %div.879.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1101.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2477.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1101.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2475.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.68.clone.1, %mul.2477.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.550 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1100.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2476.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1100.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2474.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_4.550, %mul.2476.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.932.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.2475.clone.1, %mul.2474.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1586 = f32[]{:T(128)S(6)} parameter(1) + %div.879.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1586), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.878.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.932.clone.1, %div.879.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.65.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} sqrt(%div.878.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1111.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.955.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1111.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.954.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%sqrt.65.clone.1, %add.955.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.429.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%div.880.clone.1, %add.954.clone.1), metadata={op_name="multiply.58"} - %div.877.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.957.clone.1, %multiply.429.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1949.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_0.1371, %broadcast.858.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.953.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%div.877.clone.1, %mul.1949.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1948.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%mul.1950.clone.1, %add.953.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.952.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%param_0.1371, %mul.1948.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.234 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%add.952.clone.1, %add.952.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1203 = f32[]{:T(128)} constant(0) - %reduce.192 = f32[]{:T(128)} reduce(%square.234, %constant.1203), dimensions={0,1,2,3}, to_apply=%region_68.73, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.194.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1203), dimensions={0,1,2,3}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.148 = (f32[]{:T(128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.192, %add.952.clone.1, %add.956.clone.1, %add.957.clone.1, %reduce.194.clone.1) + %constant.1099.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.931.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1099.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.930.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%sqrt.65.clone.1, %add.931.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.290.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%div.880.clone.1, %add.930.clone.1), metadata={op_name="multiply.43"} + %div.877.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.933.clone.1, %multiply.290.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2472.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_0.1412, %broadcast.768.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.929.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%div.877.clone.1, %mul.2472.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2471.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%mul.2473.clone.1, %add.929.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.928.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%param_0.1412, %mul.2471.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.234 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%add.928.clone.1, %add.928.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1191 = f32[]{:T(128)} constant(0) + %reduce.122 = f32[]{:T(128)} reduce(%square.234, %constant.1191), dimensions={0,1,2,3}, to_apply=%region_68.73, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.124.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1191), dimensions={0,1,2,3}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.122, %add.928.clone.1, %add.932.clone.1, %add.933.clone.1, %reduce.124.clone.1) +} + +%region_67.72 (reduce_sum.548: f32[], reduce_sum.549: f32[]) -> f32[] { + %reduce_sum.549 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.548 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.550 = f32[]{:T(128)} add(%reduce_sum.548, %reduce_sum.549), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_52.57 (reduce_sum.467: f32[], reduce_sum.471: f32[]) -> f32[] { + %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.467 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.467, %reduce_sum.471), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.340 (param_0.1413: f32[16,4,128,2048], param_1.1587: f32[], param_2.1329: f32[], param_3.912: f32[], param_4.551: f32[16,4,128,2048], param_5.485: f32[], param_6.360: f32[4,16,128,2048], param_7.204: pred[], param_8.122: f32[16,4,128,2048]) -> (f32[], f32[16,4,128,2048], f32[16,4,128,2048], f32[16,4,128,2048], f32[]) { + %param_0.1413 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %param_3.912 = f32[]{:T(128)S(6)} parameter(3) + %mul.2483.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_3.912), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.204 = pred[]{:T(512)S(6)} parameter(7) + %select_n.284.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.204), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.360 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.453.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_6.360), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.485 = f32[]{:T(128)} parameter(5) + %div.892.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_5.485), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.891.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%bitcast.453.clone.1, %div.892.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.283.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%select_n.284.clone.1, %bitcast.453.clone.1, %div.891.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1104.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.770.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1104.clone.1), dimensions={}, metadata={op_name="broadcast.76"} + %mul.2489.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %broadcast.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.122 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(8) + %constant.1108.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2490.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1108.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2488.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_8.122, %mul.2490.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.939.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.2489.clone.1, %mul.2488.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1329 = f32[]{:T(128)S(6)} parameter(2) + %div.888.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_2.1329), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.69.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %select_n.283.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1107.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2487.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1107.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2485.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%integer_pow.69.clone.1, %mul.2487.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.551 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(4) + %constant.1106.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2486.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1106.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2484.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_4.551, %mul.2486.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.938.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.2485.clone.1, %mul.2484.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1587 = f32[]{:T(128)S(6)} parameter(1) + %div.887.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_1.1587), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.886.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.938.clone.1, %div.887.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.66.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} sqrt(%div.886.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1105.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.937.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1105.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.936.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%sqrt.66.clone.1, %add.937.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.291.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%div.888.clone.1, %add.936.clone.1), metadata={op_name="multiply.42"} + %div.885.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.939.clone.1, %multiply.291.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2482.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_0.1413, %broadcast.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.935.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%div.885.clone.1, %mul.2482.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2481.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%mul.2483.clone.1, %add.935.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.934.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%param_0.1413, %mul.2481.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.235 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%add.934.clone.1, %add.934.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1192 = f32[]{:T(128)} constant(0) + %reduce.123 = f32[]{:T(128)} reduce(%square.235, %constant.1192), dimensions={0,1,2,3}, to_apply=%region_67.72, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.125.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1192), dimensions={0,1,2,3}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.150 = (f32[]{:T(128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.123, %add.934.clone.1, %add.938.clone.1, %add.939.clone.1, %reduce.125.clone.1) +} + +%region_41.46 (reduce_sum.416: f32[], reduce_sum.417: f32[]) -> f32[] { + %reduce_sum.417 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.416 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.418 = f32[]{:T(128)} add(%reduce_sum.416, %reduce_sum.417), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_67.72 (reduce_sum.443: f32[], reduce_sum.444: f32[]) -> f32[] { - %reduce_sum.444 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.443 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.445 = f32[]{:T(128)} add(%reduce_sum.443, %reduce_sum.444), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_36.41 (reduce_sum.389: f32[], reduce_sum.390: f32[]) -> f32[] { + %reduce_sum.390 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.389 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.394 = f32[]{:T(128)} add(%reduce_sum.389, %reduce_sum.390), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_52.57 (reduce_sum.365: f32[], reduce_sum.366: f32[]) -> f32[] { - %reduce_sum.366 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.365 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.367 = f32[]{:T(128)} add(%reduce_sum.365, %reduce_sum.366), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%fused_computation.352 (param_0.1426: f32[4,2048,8,128], param_1.1598: f32[4,2048,8,128]) -> (f32[], f32[]) { + %param_0.1426 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(0) + %bitcast.352 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1426), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.238 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.352, %bitcast.352), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1205 = f32[]{:T(128)} constant(0) + %reduce.126 = f32[]{:T(128)} reduce(%square.238, %constant.1205), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1598 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(1) + %bitcast.356.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1598), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.241.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.356.clone.1, %bitcast.356.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.127.clone.1 = f32[]{:T(128)} reduce(%square.241.clone.1, %constant.1205), dimensions={0,1,2,3}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.169 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.126, %reduce.127.clone.1) +} + +%fused_computation.355 (param_0.1021: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1021 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %copy.190 = bf16[2048,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1021), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} + ROOT %bitcast.357 = bf16[4,2048,8,128]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_70.75 (reduce_sum.563: f32[], reduce_sum.564: f32[]) -> f32[] { + %reduce_sum.564 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.563 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.565 = f32[]{:T(128)} add(%reduce_sum.563, %reduce_sum.564), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_55.60 (reduce_sum.485: f32[], reduce_sum.486: f32[]) -> f32[] { + %reduce_sum.486 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.485 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.487 = f32[]{:T(128)} add(%reduce_sum.485, %reduce_sum.486), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.356 (param_0.1410: f32[2048,4,8,128], param_1.1584: f32[], param_2.1326: f32[], param_3.909: f32[], param_4.548: f32[2048,4,8,128], param_5.482: f32[], param_6.357: f32[4,2048,8,128], param_7.201: pred[], param_8.119: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1410 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.909 = f32[]{:T(128)S(6)} parameter(3) + %mul.2459.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.909), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.201 = pred[]{:T(512)S(6)} parameter(7) + %select_n.272.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.201), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.357 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.447.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.357), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.482 = f32[]{:T(128)} parameter(5) + %div.868.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.482), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.447.clone.1, %div.868.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.271.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.272.clone.1, %bitcast.447.clone.1, %div.867.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1086.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.760.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1086.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.2463.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %broadcast.760.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.119 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1090.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.759.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1090.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.2462.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.119, %broadcast.759.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.922.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2463.clone.1, %mul.2462.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1326 = f32[]{:T(128)S(6)} parameter(2) + %div.864.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1326), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.66.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %select_n.271.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1089.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.758.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1089.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.2461.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.66.clone.1, %broadcast.758.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.548 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1088.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.757.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1088.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.2460.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.548, %broadcast.757.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.921.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2461.clone.1, %mul.2460.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1584 = f32[]{:T(128)S(6)} parameter(1) + %div.863.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1584), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.862.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.921.clone.1, %div.863.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.63.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.862.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1087.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.755.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1087.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.920.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.755.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.288.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.864.clone.1, %add.920.clone.1), metadata={op_name="multiply.45"} + %div.861.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.922.clone.1, %multiply.288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2458.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1410, %broadcast.760.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.919.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.861.clone.1, %mul.2458.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2457.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.2459.clone.1, %add.919.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.918.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1410, %mul.2457.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.242 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.918.clone.1, %add.918.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1189 = f32[]{:T(128)} constant(0) + %reduce.128 = f32[]{:T(128)} reduce(%square.242, %constant.1189), dimensions={0,1,2,3}, to_apply=%region_70.75, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.130.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.1189), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.151 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.128, %add.918.clone.1, %add.921.clone.1, %add.922.clone.1, %reduce.130.clone.1) +} + +%region_65.70 (reduce_sum.536: f32[], reduce_sum.537: f32[]) -> f32[] { + %reduce_sum.537 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.536 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.541 = f32[]{:T(128)} add(%reduce_sum.536, %reduce_sum.537), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_50.55 (reduce_sum.458: f32[], reduce_sum.459: f32[]) -> f32[] { + %reduce_sum.459 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.458 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.460 = f32[]{:T(128)} add(%reduce_sum.458, %reduce_sum.459), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.329 (param_0.1372: f32[16,4,128,2048], param_1.1560: f32[], param_2.1318: f32[], param_3.922: f32[], param_4.560: f32[16,4,128,2048], param_5.472: f32[], param_6.362: f32[4,16,128,2048], param_7.205: pred[], param_8.122: f32[16,4,128,2048]) -> (f32[], f32[16,4,128,2048], f32[16,4,128,2048], f32[16,4,128,2048], f32[]) { - %param_0.1372 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) - %param_3.922 = f32[]{:T(128)S(6)} parameter(3) - %mul.1960.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_3.922), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.205 = pred[]{:T(512)S(6)} parameter(7) - %select_n.284.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.205), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.362 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.472.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_6.362), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.472 = f32[]{:T(128)} parameter(5) - %div.892.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_5.472), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.891.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%bitcast.472.clone.1, %div.892.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.283.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%select_n.284.clone.1, %bitcast.472.clone.1, %div.891.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} +%fused_computation.357 (param_0.1415: f32[2048,4,8,128], param_1.1589: f32[], param_2.1331: f32[], param_3.914: f32[], param_4.553: f32[2048,4,8,128], param_5.487: f32[], param_6.362: f32[4,2048,8,128], param_7.206: pred[], param_8.124: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1415 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.914 = f32[]{:T(128)S(6)} parameter(3) + %mul.2500.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.914), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.206 = pred[]{:T(512)S(6)} parameter(7) + %select_n.292.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.206), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.362 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(6) + %bitcast.457.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.362), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.487 = f32[]{:T(128)} parameter(5) + %div.908.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.487), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.907.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.457.clone.1, %div.908.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.291.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.292.clone.1, %bitcast.457.clone.1, %div.907.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %constant.1116.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.860.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1116.clone.1), dimensions={}, metadata={op_name="broadcast.76"} - %mul.1966.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %broadcast.860.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_8.122 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(8) + %broadcast.782.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1116.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.2504.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %broadcast.782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.124 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) %constant.1120.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1967.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1120.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1965.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_8.122, %mul.1967.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.963.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.1966.clone.1, %mul.1965.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1318 = f32[]{:T(128)S(6)} parameter(2) - %div.888.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_2.1318), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.69.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %select_n.283.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %broadcast.781.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1120.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.2503.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.124, %broadcast.781.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.949.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2504.clone.1, %mul.2503.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1331 = f32[]{:T(128)S(6)} parameter(2) + %div.904.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1331), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.71.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %select_n.291.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} %constant.1119.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1964.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1119.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1962.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%integer_pow.69.clone.1, %mul.1964.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.560 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(4) + %broadcast.780.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1119.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.2502.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.71.clone.1, %broadcast.780.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.553 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) %constant.1118.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1963.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1118.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1961.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_4.560, %mul.1963.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.962.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.1962.clone.1, %mul.1961.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1560 = f32[]{:T(128)S(6)} parameter(1) - %div.887.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_1.1560), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.886.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.962.clone.1, %div.887.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.66.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} sqrt(%div.886.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %broadcast.779.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1118.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.2501.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.553, %broadcast.779.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.948.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2502.clone.1, %mul.2501.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1589 = f32[]{:T(128)S(6)} parameter(1) + %div.903.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1589), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.902.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.948.clone.1, %div.903.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.68.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.902.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} %constant.1117.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.961.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1117.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.960.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%sqrt.66.clone.1, %add.961.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.430.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%div.888.clone.1, %add.960.clone.1), metadata={op_name="multiply.57"} - %div.885.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.963.clone.1, %multiply.430.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1959.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_0.1372, %broadcast.860.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.959.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%div.885.clone.1, %mul.1959.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1958.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%mul.1960.clone.1, %add.959.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.958.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%param_0.1372, %mul.1958.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.235 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%add.958.clone.1, %add.958.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1204 = f32[]{:T(128)} constant(0) - %reduce.193 = f32[]{:T(128)} reduce(%square.235, %constant.1204), dimensions={0,1,2,3}, to_apply=%region_67.72, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.195.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1204), dimensions={0,1,2,3}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.149 = (f32[]{:T(128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.193, %add.958.clone.1, %add.962.clone.1, %add.963.clone.1, %reduce.195.clone.1) -} - -%region_41.46 (reduce_sum.311: f32[], reduce_sum.312: f32[]) -> f32[] { - %reduce_sum.312 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.311 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.316 = f32[]{:T(128)} add(%reduce_sum.311, %reduce_sum.312), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_36.41 (reduce_sum.284: f32[], reduce_sum.288: f32[]) -> f32[] { - %reduce_sum.288 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.284 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.289 = f32[]{:T(128)} add(%reduce_sum.284, %reduce_sum.288), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.341 (param_0.1385: f32[4,2048,8,128], param_1.1571: f32[4,2048,8,128]) -> (f32[], f32[]) { - %param_0.1385 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(0) - %bitcast.371 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.238 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.371, %bitcast.371), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1217 = f32[]{:T(128)} constant(0) - %reduce.196 = f32[]{:T(128)} reduce(%square.238, %constant.1217), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1571 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(1) - %bitcast.375.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1571), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.241.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.375.clone.1, %bitcast.375.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.197.clone.1 = f32[]{:T(128)} reduce(%square.241.clone.1, %constant.1217), dimensions={0,1,2,3}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.168 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.196, %reduce.197.clone.1) -} - -%fused_computation.344 (param_0.982: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { - %param_0.982 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) - %copy.194 = bf16[2048,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.982), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} - ROOT %bitcast.376 = bf16[4,2048,8,128]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.194), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} -} - -%region_70.75 (reduce_sum.458: f32[], reduce_sum.459: f32[]) -> f32[] { - %reduce_sum.459 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.458 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.463 = f32[]{:T(128)} add(%reduce_sum.458, %reduce_sum.459), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %broadcast.777.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1117.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.947.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.68.clone.1, %broadcast.777.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.293.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.904.clone.1, %add.947.clone.1), metadata={op_name="multiply.40"} + %div.901.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.949.clone.1, %multiply.293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2499.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1415, %broadcast.782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.946.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.901.clone.1, %mul.2499.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2498.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.2500.clone.1, %add.946.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.945.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1415, %mul.2498.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.243 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.945.clone.1, %add.945.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1194 = f32[]{:T(128)} constant(0) + %reduce.129 = f32[]{:T(128)} reduce(%square.243, %constant.1194), dimensions={0,1,2,3}, to_apply=%region_65.70, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.131.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1194), dimensions={0,1,2,3}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.152 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.129, %add.945.clone.1, %add.948.clone.1, %add.949.clone.1, %reduce.131.clone.1) +} + +%fused_computation.373 (param_0.1095: bf16[4,128,2048], param_1.1142: f32[4,128], param_2.842: f32[4,128], param_3.484: bf16[4,128,2048], param_4.283: bf16[2048]) -> bf16[4,128,2048] { + %param_3.484 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.283 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %mul.2385 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.283), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2359 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_3.484, %mul.2385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1387 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%mul.2359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.842 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.2356 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_2.842), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2347 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1387, %mul.2356), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1095 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1398 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1095), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1142 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2354 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_1.1142), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2353 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1398, %mul.2354), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %add_any.184 = f32[4,128,2048]{2,1,0:T(8,128)} add(%mul.2347, %mul.2353), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} + ROOT %convert_element_type.1385 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%add_any.184), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} +} + +%region_6.9 (reduce_sum.228: f32[], reduce_sum.229: f32[]) -> f32[] { + %reduce_sum.229 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.228 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.230 = f32[]{:T(128)} add(%reduce_sum.228, %reduce_sum.229), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.374 (param_0.1435: bf16[4,128,2048]) -> f32[4,128] { + %param_0.1435 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1389 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1435), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.246 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1389, %convert_element_type.1389), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} + %constant.1215 = f32[]{:T(128)} constant(0) + ROOT %reduce.132 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.246, %constant.1215), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} } -%region_55.60 (reduce_sum.380: f32[], reduce_sum.381: f32[]) -> f32[] { - %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.380 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.382 = f32[]{:T(128)} add(%reduce_sum.380, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_12.15 (reduce_sum.275: f32[], reduce_sum.276: f32[]) -> f32[] { + %reduce_sum.276 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.275 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.277 = f32[]{:T(128)} add(%reduce_sum.275, %reduce_sum.276), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.345 (param_0.1369: f32[2048,4,8,128], param_1.1557: f32[], param_2.1315: f32[], param_3.919: f32[], param_4.557: f32[2048,4,8,128], param_5.469: f32[], param_6.359: f32[4,2048,8,128], param_7.202: pred[], param_8.119: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { - %param_0.1369 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.919 = f32[]{:T(128)S(6)} parameter(3) - %mul.1936.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.919), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.202 = pred[]{:T(512)S(6)} parameter(7) - %select_n.272.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.202), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.359 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.466.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.469 = f32[]{:T(128)} parameter(5) - %div.868.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.469), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.466.clone.1, %div.868.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.271.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.272.clone.1, %bitcast.466.clone.1, %div.867.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1098.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.850.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1098.clone.1), dimensions={}, metadata={op_name="broadcast.80"} - %mul.1940.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %broadcast.850.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_8.119 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.1102.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.849.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1102.clone.1), dimensions={}, metadata={op_name="broadcast.79"} - %mul.1939.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.119, %broadcast.849.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.946.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1940.clone.1, %mul.1939.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1315 = f32[]{:T(128)S(6)} parameter(2) - %div.864.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1315), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.66.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %select_n.271.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1101.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.848.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1101.clone.1), dimensions={}, metadata={op_name="broadcast.69"} - %mul.1938.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.66.clone.1, %broadcast.848.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.557 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.1100.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.847.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1100.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1937.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.557, %broadcast.847.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.945.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1938.clone.1, %mul.1937.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1557 = f32[]{:T(128)S(6)} parameter(1) - %div.863.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1557), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.862.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.945.clone.1, %div.863.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.63.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.862.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1099.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.845.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1099.clone.1), dimensions={}, metadata={op_name="broadcast.63"} - %add.944.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.845.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.427.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.864.clone.1, %add.944.clone.1), metadata={op_name="multiply.60"} - %div.861.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.946.clone.1, %multiply.427.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1935.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1369, %broadcast.850.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.943.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.861.clone.1, %mul.1935.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1934.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1936.clone.1, %add.943.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.942.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1369, %mul.1934.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.242 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.942.clone.1, %add.942.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1201 = f32[]{:T(128)} constant(0) - %reduce.198 = f32[]{:T(128)} reduce(%square.242, %constant.1201), dimensions={0,1,2,3}, to_apply=%region_70.75, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.200.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.1201), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.150 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.198, %add.942.clone.1, %add.945.clone.1, %add.946.clone.1, %reduce.200.clone.1) +%fused_computation.376 (param_0.1430: bf16[4,128,2048], param_1.1601: bf16[4,128,2048], param_2.1339: bf16[2048]) -> f32[4,128] { + %param_0.1430 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1396 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1430), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1601 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1339 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2384 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1339), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2358 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1601, %mul.2384), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1395 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%mul.2358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %mul.2351 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1396, %convert_element_type.1395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1209 = f32[]{:T(128)} constant(0) + ROOT %reduce.133 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2351, %constant.1209), dimensions={2}, to_apply=%region_12.15, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} } -%region_65.70 (reduce_sum.431: f32[], reduce_sum.435: f32[]) -> f32[] { - %reduce_sum.435 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.436 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.435), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_10.13 (reduce_sum.263: bf16[], reduce_sum.264: bf16[]) -> bf16[] { + %reduce_sum.264 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.263 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.268 = bf16[]{:T(256)} add(%reduce_sum.263, %reduce_sum.264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_50.55 (reduce_sum.353: f32[], reduce_sum.354: f32[]) -> f32[] { - %reduce_sum.354 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.353 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.358 = f32[]{:T(128)} add(%reduce_sum.353, %reduce_sum.354), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%fused_computation.296.clone.clone (param_0.1392: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1392 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.505 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1392), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} } -%fused_computation.346 (param_0.1374: f32[2048,4,8,128], param_1.1562: f32[], param_2.1320: f32[], param_3.924: f32[], param_4.562: f32[2048,4,8,128], param_5.474: f32[], param_6.364: f32[4,2048,8,128], param_7.207: pred[], param_8.124: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { - %param_0.1374 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.924 = f32[]{:T(128)S(6)} parameter(3) - %mul.1977.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.924), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.207 = pred[]{:T(512)S(6)} parameter(7) - %select_n.292.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.207), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.364 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(6) - %bitcast.476.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.474 = f32[]{:T(128)} parameter(5) - %div.908.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.474), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.907.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.476.clone.1, %div.908.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.291.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.292.clone.1, %bitcast.476.clone.1, %div.907.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1128.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.872.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1128.clone.1), dimensions={}, metadata={op_name="broadcast.80"} - %mul.1981.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %broadcast.872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_8.124 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.1132.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.871.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1132.clone.1), dimensions={}, metadata={op_name="broadcast.79"} - %mul.1980.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.124, %broadcast.871.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.973.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1981.clone.1, %mul.1980.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1320 = f32[]{:T(128)S(6)} parameter(2) - %div.904.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1320), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.71.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %select_n.291.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1131.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.870.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1131.clone.1), dimensions={}, metadata={op_name="broadcast.69"} - %mul.1979.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.71.clone.1, %broadcast.870.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.562 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.1130.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.869.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1130.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1978.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.562, %broadcast.869.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.972.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1979.clone.1, %mul.1978.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1562 = f32[]{:T(128)S(6)} parameter(1) - %div.903.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1562), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.902.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.972.clone.1, %div.903.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.68.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.902.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1129.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1129.clone.1), dimensions={}, metadata={op_name="broadcast.63"} - %add.971.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.68.clone.1, %broadcast.867.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.432.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.904.clone.1, %add.971.clone.1), metadata={op_name="multiply.55"} - %div.901.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.973.clone.1, %multiply.432.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1976.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1374, %broadcast.872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.970.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.901.clone.1, %mul.1976.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1975.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1977.clone.1, %add.970.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.969.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1374, %mul.1975.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.243 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.969.clone.1, %add.969.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1206 = f32[]{:T(128)} constant(0) - %reduce.199 = f32[]{:T(128)} reduce(%square.243, %constant.1206), dimensions={0,1,2,3}, to_apply=%region_65.70, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.201.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1206), dimensions={0,1,2,3}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.151 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.199, %add.969.clone.1, %add.972.clone.1, %add.973.clone.1, %reduce.201.clone.1) -} - -%fused_computation.362 (param_0.1056: bf16[4,128,2048], param_1.1117: f32[4,128], param_2.830: f32[4,128], param_3.495: bf16[4,128,2048], param_4.296: bf16[2048]) -> bf16[4,128,2048] { - %param_3.495 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.296 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %dot_general.448 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.296), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.438 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_3.495, %dot_general.448), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.1363 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%dot_general.438), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_2.830 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %mul.1851 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_2.830), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1843 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1363, %mul.1851), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %param_0.1056 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1374 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1056), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.1117 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %mul.1850 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_1.1117), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1849 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1374, %mul.1850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %add_any.193 = f32[4,128,2048]{2,1,0:T(8,128)} add(%mul.1843, %mul.1849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} - ROOT %convert_element_type.1361 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%add_any.193), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} -} - -%region_7.10 (reduce_sum.171: f32[], reduce_sum.184: f32[]) -> f32[] { - %reduce_sum.184 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.171 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.185 = f32[]{:T(128)} add(%reduce_sum.171, %reduce_sum.184), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.363 (param_0.1394: bf16[4,128,2048]) -> f32[4,128] { - %param_0.1394 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1365 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1394), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.246 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1365, %convert_element_type.1365), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} - %constant.1227 = f32[]{:T(128)} constant(0) - ROOT %reduce.202 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.246, %constant.1227), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_12.15 (reduce_sum.198: f32[], reduce_sum.199: f32[]) -> f32[] { - %reduce_sum.199 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.198 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.200 = f32[]{:T(128)} add(%reduce_sum.198, %reduce_sum.199), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.365 (param_0.1389: bf16[4,128,2048], param_1.1574: bf16[4,128,2048], param_2.1328: bf16[2048]) -> f32[4,128] { - %param_0.1389 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1372 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1389), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.1574 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %param_2.1328 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.447 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1328), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.437 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1574, %dot_general.447), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.1371 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%dot_general.437), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %mul.1847 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1372, %convert_element_type.1371), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.1221 = f32[]{:T(128)} constant(0) - ROOT %reduce.203 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1847, %constant.1221), dimensions={2}, to_apply=%region_12.15, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_10.13 (dot_general.190: bf16[], dot_general.191: bf16[]) -> bf16[] { - %dot_general.191 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - %dot_general.190 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - ROOT %add.419 = bf16[]{:T(256)} add(%dot_general.190, %dot_general.191), metadata={op_name="add.82"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.285.clone.clone (param_0.1351: bf16[151936,2048]) -> bf16[151936,2048,1] { - %param_0.1351 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - ROOT %bitcast.528 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1351), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} -} - -%fused_computation.289.clone.1.clone.clone (param_0.1352: bf16[4,128,151936], param_1.1546: s32[4,128], param_2.1285: f32[4,128], param_3.906: f32[4,128], param_4.542: bf16[4,128], param_5.442: f32[4,128]) -> bf16[4,128,151936] { - %param_5.442 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.2075 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.442), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.906 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.2074 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.906), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1352 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1444 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1352), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.542 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.88 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.542), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.87 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1444, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %exp.60 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.87), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.2073 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2074, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1285 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.962 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1285), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.961 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2073, %div.962), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1546 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.43 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1546), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} +%fused_computation.300.clone.1.clone.clone (param_0.1393: bf16[4,128,151936], param_1.1573: s32[4,128], param_2.1296: f32[4,128], param_3.896: f32[4,128], param_4.533: bf16[4,128], param_5.455: f32[4,128]) -> bf16[4,128,151936] { + %param_5.455 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2616 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.455), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.896 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2615 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.896), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1393 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1468 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1393), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.533 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.86 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.533), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.85 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1468, %sub.86), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.60 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.85), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.2614 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2615, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1296 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.962 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1296), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.961 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2614, %div.962), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1573 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.43 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1573), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.42 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.41 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.43, %eq.42), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1443 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.86 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.961, %convert_element_type.1443), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.2072 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2075, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1442 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2072), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.366 (param_0.1350: f32[4,128], param_1.1545: bf16[4,128,2048], param_2.1286: bf16[151936,2048], param_3.907: bf16[4,128,151936], param_4.543: s32[4,128], param_5.443: f32[4,128], param_6.340: f32[4,128], param_7.199: bf16[4,128], param_8.116: f32[4,128]) -> (bf16[2048], bf16[4,128,2048]) { - %param_3.907 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(3) - %param_4.543 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) - %param_5.443 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %param_6.340 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) - %param_7.199 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) + %convert_element_type.1467 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.84 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.961, %convert_element_type.1467), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.2613 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2616, %sub.84), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1466 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2613), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.377 (param_0.1391: f32[4,128], param_1.1572: bf16[4,128,2048], param_2.1297: bf16[151936,2048], param_3.897: bf16[4,128,151936], param_4.534: s32[4,128], param_5.456: f32[4,128], param_6.342: f32[4,128], param_7.198: bf16[4,128], param_8.116: f32[4,128]) -> (bf16[2048], bf16[4,128,2048]) { + %param_1.1572 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1408 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1572), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1391 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2373 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1391), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2372 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1408, %mul.2373), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1407 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2372), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_3.897 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.534 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %param_5.456 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.342 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.198 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) %param_8.116 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) - %multiply_convert_fusion.2.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_3.907, %param_4.543, %param_5.443, %param_6.340, %param_7.199, /*index=5*/%param_8.116), kind=kLoop, calls=%fused_computation.289.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_2.1286 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(2) - %fusion.251.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1286), kind=kLoop, calls=%fused_computation.285.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %convolution.84.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.251.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %param_1.1545 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1384 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1545), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1350 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1862 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1350), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1861 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1384, %mul.1862), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1383 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.1861), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %multiply.420 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convolution.84.clone.1, %convert_element_type.1383), metadata={op_name="multiply.362"} - %constant.1050 = bf16[]{:T(256)} constant(0) - %reduce.204 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%multiply.420, %constant.1050), dimensions={0,1}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %tuple.165 = (bf16[2048]{0:T(1024)(128)(2,1)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.204, %convolution.84.clone.1) -} - -%fused_computation.374 (param_0.1088: f32[64], param_1.1150: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.1150 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %div.720 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1150), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - %param_0.1088 = f32[64]{0:T(128)S(1)} parameter(0) - %div.718 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1088), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %multiply_convert_fusion.3.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_3.897, %param_4.534, %param_5.456, %param_6.342, %param_7.198, /*index=5*/%param_8.116), kind=kLoop, calls=%fused_computation.300.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_2.1297 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(2) + %fusion.261.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1297), kind=kLoop, calls=%fused_computation.296.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %convolution.84.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.3.clone.1, %fusion.261.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %mul.2355 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1407, %convolution.84.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1051 = bf16[]{:T(256)} constant(0) + %reduce.134 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%mul.2355, %constant.1051), dimensions={0,1}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} + ROOT %tuple.166 = (bf16[2048]{0:T(1024)(128)(2,1)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.134, %convolution.84.clone.1) +} + +%fused_computation.385 (param_0.1129: f32[64], param_1.1177: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1177 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.720 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1177), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %param_0.1129 = f32[64]{0:T(128)S(1)} parameter(0) + %div.718 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1129), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %div.717 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.720, %div.718), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.717), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=0} - %convert_element_type.1392 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1416 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.717), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=0} - %convert_element_type.1391.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.158 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1392, %convert_element_type.1391.clone.1) + %convert_element_type.1415.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.159 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1416, %convert_element_type.1415.clone.1) } -%fused_computation.375 (param_0.1085: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.1085 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1042 = bf16[]{:T(256)} constant(-inf) - %pad.46 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1085, %constant.1042), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1085, %constant.1042), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.386 (param_0.1126: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1126 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1030 = bf16[]{:T(256)} constant(-inf) + %pad.46 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1126, %constant.1030), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1126, %constant.1030), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.42 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.46, %pad.45), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%fused_computation.376 (param_0.1087: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.1087 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1041 = bf16[]{:T(256)} constant(-inf) - %pad.48 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1087, %constant.1041), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.47 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1087, %constant.1041), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.387 (param_0.1128: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1128 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1029 = bf16[]{:T(256)} constant(-inf) + %pad.48 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1128, %constant.1029), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.47 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1128, %constant.1029), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.43 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.48, %pad.47), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%region_35.40 (reduce_sum.281: f32[], reduce_sum.282: f32[]) -> f32[] { - %reduce_sum.282 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.281 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.283 = f32[]{:T(128)} add(%reduce_sum.281, %reduce_sum.282), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_35.40 (reduce_sum.383: f32[], reduce_sum.387: f32[]) -> f32[] { + %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.383 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.383, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_34.39 (reduce_sum.275: f32[], reduce_sum.276: f32[]) -> f32[] { - %reduce_sum.276 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.275 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.277 = f32[]{:T(128)} add(%reduce_sum.275, %reduce_sum.276), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_34.39 (reduce_sum.380: f32[], reduce_sum.381: f32[]) -> f32[] { + %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.380 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.382 = f32[]{:T(128)} add(%reduce_sum.380, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.380 (param_0.1386: f32[4,2048], param_1.1572: f32[4,2048]) -> (f32[], f32[]) { - %param_0.1386 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(0) - %bitcast.404 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_0.1386), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.249 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.404, %bitcast.404), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1218 = f32[]{:T(128)} constant(0) - %reduce.205 = f32[]{:T(128)} reduce(%square.249, %constant.1218), dimensions={0,1}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1572 = f32[4,2048]{1,0:T(4,128)} parameter(1) - %bitcast.408.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_1.1572), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.252.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.408.clone.1, %bitcast.408.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.206.clone.1 = f32[]{:T(128)} reduce(%square.252.clone.1, %constant.1218), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.169 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.205, %reduce.206.clone.1) +%fused_computation.391 (param_0.1427: f32[4,2048], param_1.1599: f32[4,2048]) -> (f32[], f32[]) { + %param_0.1427 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.385 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_0.1427), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.249 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.385, %bitcast.385), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1206 = f32[]{:T(128)} constant(0) + %reduce.135 = f32[]{:T(128)} reduce(%square.249, %constant.1206), dimensions={0,1}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1599 = f32[4,2048]{1,0:T(4,128)} parameter(1) + %bitcast.389.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_1.1599), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.252.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.389.clone.1, %bitcast.389.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.136.clone.1 = f32[]{:T(128)} reduce(%square.252.clone.1, %constant.1206), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.170 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.135, %reduce.136.clone.1) } -%region_64.69 (reduce_sum.428: f32[], reduce_sum.429: f32[]) -> f32[] { - %reduce_sum.429 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.428 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.430 = f32[]{:T(128)} add(%reduce_sum.428, %reduce_sum.429), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_64.69 (reduce_sum.530: f32[], reduce_sum.534: f32[]) -> f32[] { + %reduce_sum.534 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.530 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.535 = f32[]{:T(128)} add(%reduce_sum.530, %reduce_sum.534), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_49.54 (reduce_sum.347: f32[], reduce_sum.351: f32[]) -> f32[] { - %reduce_sum.351 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.347 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.352 = f32[]{:T(128)} add(%reduce_sum.347, %reduce_sum.351), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_49.54 (reduce_sum.452: f32[], reduce_sum.453: f32[]) -> f32[] { + %reduce_sum.453 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.452 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.452, %reduce_sum.453), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.383 (param_0.1375: f32[2048,4], param_1.1563: f32[], param_2.1321: f32[], param_3.925: f32[], param_4.563: f32[2048,4], param_5.475: f32[], param_6.365: f32[4,2048], param_7.208: pred[], param_8.125: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { - %param_0.1375 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.925 = f32[]{:T(128)S(6)} parameter(3) - %mul.1984.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.925), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.208 = pred[]{:T(512)S(6)} parameter(7) - %select_n.296.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.208), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.365 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(6) - %bitcast.478.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.475 = f32[]{:T(128)} parameter(5) - %div.916.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.475), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.915.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.478.clone.1, %div.916.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.295.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.296.clone.1, %bitcast.478.clone.1, %div.915.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1134.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.878.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1134.clone.1), dimensions={}, metadata={op_name="broadcast.82"} - %mul.1988.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %broadcast.878.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.394 (param_0.1416: f32[2048,4], param_1.1590: f32[], param_2.1332: f32[], param_3.915: f32[], param_4.554: f32[2048,4], param_5.488: f32[], param_6.363: f32[4,2048], param_7.207: pred[], param_8.125: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1416 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.915 = f32[]{:T(128)S(6)} parameter(3) + %mul.2507.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.915), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.207 = pred[]{:T(512)S(6)} parameter(7) + %select_n.296.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.207), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.363 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.459.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.363), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.488 = f32[]{:T(128)} parameter(5) + %div.916.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.488), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.915.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.459.clone.1, %div.916.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.295.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.296.clone.1, %bitcast.459.clone.1, %div.915.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1122.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.788.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1122.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.2511.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %broadcast.788.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.125 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1138.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.877.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1138.clone.1), dimensions={}, metadata={op_name="broadcast.81"} - %mul.1987.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.125, %broadcast.877.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.978.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1988.clone.1, %mul.1987.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1321 = f32[]{:T(128)S(6)} parameter(2) - %div.912.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1321), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1126.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.787.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1126.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.2510.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.125, %broadcast.787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.954.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2511.clone.1, %mul.2510.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1332 = f32[]{:T(128)S(6)} parameter(2) + %div.912.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1332), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.72.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %select_n.295.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1137.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.876.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1137.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1986.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.72.clone.1, %broadcast.876.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.563 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1136.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.875.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1136.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1985.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.563, %broadcast.875.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.977.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1986.clone.1, %mul.1985.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1563 = f32[]{:T(128)S(6)} parameter(1) - %div.911.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1563), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.910.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.977.clone.1, %div.911.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1125.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.786.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1125.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.2509.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.72.clone.1, %broadcast.786.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.554 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1124.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.785.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1124.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.2508.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.554, %broadcast.785.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.953.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2509.clone.1, %mul.2508.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1590 = f32[]{:T(128)S(6)} parameter(1) + %div.911.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1590), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.910.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.953.clone.1, %div.911.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.69.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.910.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1135.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.873.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1135.clone.1), dimensions={}, metadata={op_name="broadcast.64"} - %add.976.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.69.clone.1, %broadcast.873.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.433.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.912.clone.1, %add.976.clone.1), metadata={op_name="multiply.54"} - %div.909.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.978.clone.1, %multiply.433.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1983.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1375, %broadcast.878.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.975.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.909.clone.1, %mul.1983.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1982.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.1984.clone.1, %add.975.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.974.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1375, %mul.1982.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.253 = f32[2048,4]{0,1:T(4,128)} multiply(%add.974.clone.1, %add.974.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1207 = f32[]{:T(128)} constant(0) - %reduce.207 = f32[]{:T(128)} reduce(%square.253, %constant.1207), dimensions={0,1}, to_apply=%region_64.69, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.209.clone.1 = f32[]{:T(128)} reduce(%integer_pow.72.clone.1, %constant.1207), dimensions={0,1}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.152 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.207, %add.974.clone.1, %add.977.clone.1, %add.978.clone.1, %reduce.209.clone.1) -} - -%region_63.68 (reduce_sum.422: f32[], reduce_sum.423: f32[]) -> f32[] { - %reduce_sum.423 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.422 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.422, %reduce_sum.423), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_48.53 (reduce_sum.344: f32[], reduce_sum.345: f32[]) -> f32[] { - %reduce_sum.345 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.344 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.346 = f32[]{:T(128)} add(%reduce_sum.344, %reduce_sum.345), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1123.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.783.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1123.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.952.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.69.clone.1, %broadcast.783.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.294.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.912.clone.1, %add.952.clone.1), metadata={op_name="multiply.39"} + %div.909.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.954.clone.1, %multiply.294.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2506.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1416, %broadcast.788.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.951.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.909.clone.1, %mul.2506.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2505.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.2507.clone.1, %add.951.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.950.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1416, %mul.2505.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.253 = f32[2048,4]{0,1:T(4,128)} multiply(%add.950.clone.1, %add.950.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1195 = f32[]{:T(128)} constant(0) + %reduce.137 = f32[]{:T(128)} reduce(%square.253, %constant.1195), dimensions={0,1}, to_apply=%region_64.69, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.139.clone.1 = f32[]{:T(128)} reduce(%integer_pow.72.clone.1, %constant.1195), dimensions={0,1}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.153 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.137, %add.950.clone.1, %add.953.clone.1, %add.954.clone.1, %reduce.139.clone.1) +} + +%region_63.68 (reduce_sum.527: f32[], reduce_sum.528: f32[]) -> f32[] { + %reduce_sum.528 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.527 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.529 = f32[]{:T(128)} add(%reduce_sum.527, %reduce_sum.528), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.446: f32[], reduce_sum.450: f32[]) -> f32[] { + %reduce_sum.450 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.446 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.451 = f32[]{:T(128)} add(%reduce_sum.446, %reduce_sum.450), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.384 (param_0.1376: f32[2048,4], param_1.1564: f32[], param_2.1322: f32[], param_3.926: f32[], param_4.564: f32[2048,4], param_5.476: f32[], param_6.366: f32[4,2048], param_7.209: pred[], param_8.126: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { - %param_0.1376 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.926 = f32[]{:T(128)S(6)} parameter(3) - %mul.1991.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.926), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.209 = pred[]{:T(512)S(6)} parameter(7) - %select_n.300.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.209), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.366 = f32[4,2048]{1,0:T(4,128)} parameter(6) - %bitcast.480.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.366), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.476 = f32[]{:T(128)} parameter(5) - %div.924.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.476), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.923.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.480.clone.1, %div.924.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.299.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.300.clone.1, %bitcast.480.clone.1, %div.923.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1140.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.884.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1140.clone.1), dimensions={}, metadata={op_name="broadcast.82"} - %mul.1995.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %broadcast.884.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.395 (param_0.1417: f32[2048,4], param_1.1591: f32[], param_2.1333: f32[], param_3.916: f32[], param_4.555: f32[2048,4], param_5.489: f32[], param_6.364: f32[4,2048], param_7.208: pred[], param_8.126: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1417 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.916 = f32[]{:T(128)S(6)} parameter(3) + %mul.2514.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.916), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.208 = pred[]{:T(512)S(6)} parameter(7) + %select_n.300.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.208), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.364 = f32[4,2048]{1,0:T(4,128)} parameter(6) + %bitcast.461.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.489 = f32[]{:T(128)} parameter(5) + %div.924.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.489), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.923.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.461.clone.1, %div.924.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.299.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.300.clone.1, %bitcast.461.clone.1, %div.923.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1128.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.794.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1128.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.2518.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %broadcast.794.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.126 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1144.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.883.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1144.clone.1), dimensions={}, metadata={op_name="broadcast.81"} - %mul.1994.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.126, %broadcast.883.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.983.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1995.clone.1, %mul.1994.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1322 = f32[]{:T(128)S(6)} parameter(2) - %div.920.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1322), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1132.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.793.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1132.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.2517.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.126, %broadcast.793.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.959.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2518.clone.1, %mul.2517.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1333 = f32[]{:T(128)S(6)} parameter(2) + %div.920.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1333), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.73.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %select_n.299.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1143.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.882.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1143.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1993.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.73.clone.1, %broadcast.882.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.564 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1142.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.881.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1142.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1992.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.564, %broadcast.881.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.982.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1993.clone.1, %mul.1992.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1564 = f32[]{:T(128)S(6)} parameter(1) - %div.919.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1564), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.918.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.982.clone.1, %div.919.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1131.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.792.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1131.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.2516.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.73.clone.1, %broadcast.792.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.555 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1130.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.791.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1130.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.2515.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.555, %broadcast.791.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.958.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2516.clone.1, %mul.2515.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1591 = f32[]{:T(128)S(6)} parameter(1) + %div.919.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1591), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.918.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.958.clone.1, %div.919.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.70.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.918.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1141.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.879.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1141.clone.1), dimensions={}, metadata={op_name="broadcast.64"} - %add.981.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.70.clone.1, %broadcast.879.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.434.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.920.clone.1, %add.981.clone.1), metadata={op_name="multiply.53"} - %div.917.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.983.clone.1, %multiply.434.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1990.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1376, %broadcast.884.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.980.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.917.clone.1, %mul.1990.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1989.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.1991.clone.1, %add.980.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.979.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1376, %mul.1989.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.254 = f32[2048,4]{0,1:T(4,128)} multiply(%add.979.clone.1, %add.979.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1208 = f32[]{:T(128)} constant(0) - %reduce.208 = f32[]{:T(128)} reduce(%square.254, %constant.1208), dimensions={0,1}, to_apply=%region_63.68, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.210.clone.1 = f32[]{:T(128)} reduce(%integer_pow.73.clone.1, %constant.1208), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.153 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.208, %add.979.clone.1, %add.982.clone.1, %add.983.clone.1, %reduce.210.clone.1) + %constant.1129.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.789.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1129.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.957.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.70.clone.1, %broadcast.789.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.295.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.920.clone.1, %add.957.clone.1), metadata={op_name="multiply.38"} + %div.917.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.959.clone.1, %multiply.295.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2513.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1417, %broadcast.794.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.956.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.917.clone.1, %mul.2513.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2512.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.2514.clone.1, %add.956.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.955.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1417, %mul.2512.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.254 = f32[2048,4]{0,1:T(4,128)} multiply(%add.955.clone.1, %add.955.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1196 = f32[]{:T(128)} constant(0) + %reduce.138 = f32[]{:T(128)} reduce(%square.254, %constant.1196), dimensions={0,1}, to_apply=%region_63.68, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.140.clone.1 = f32[]{:T(128)} reduce(%integer_pow.73.clone.1, %constant.1196), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.154 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.138, %add.955.clone.1, %add.958.clone.1, %add.959.clone.1, %reduce.140.clone.1) +} + +%region_11.14 (reduce_sum.269: f32[], reduce_sum.270: f32[]) -> f32[] { + %reduce_sum.270 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.269 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.271 = f32[]{:T(128)} add(%reduce_sum.269, %reduce_sum.270), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_11.14 (reduce_sum.192: f32[], reduce_sum.193: f32[]) -> f32[] { - %reduce_sum.193 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.192 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.197 = f32[]{:T(128)} add(%reduce_sum.192, %reduce_sum.193), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%fused_computation.406 (param_0.1431: bf16[2048]) -> f32[] { + %param_0.1431 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1420 = f32[2048]{0:T(1024)} convert(%param_0.1431), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.257 = f32[2048]{0:T(1024)} multiply(%convert_element_type.1420, %convert_element_type.1420), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1210 = f32[]{:T(128)} constant(0) + ROOT %reduce.141 = f32[]{:T(128)} reduce(%square.257, %constant.1210), dimensions={0}, to_apply=%region_11.14, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.395 (param_0.1390: bf16[2048]) -> f32[] { - %param_0.1390 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(0) - %convert_element_type.1396 = f32[2048]{0:T(1024)} convert(%param_0.1390), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.257 = f32[2048]{0:T(1024)} multiply(%convert_element_type.1396, %convert_element_type.1396), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1222 = f32[]{:T(128)} constant(0) - ROOT %reduce.211 = f32[]{:T(128)} reduce(%square.257, %constant.1222), dimensions={0}, to_apply=%region_11.14, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%region_59.64 (reduce_sum.506: f32[], reduce_sum.507: f32[]) -> f32[] { + %reduce_sum.507 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.506 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.508 = f32[]{:T(128)} add(%reduce_sum.506, %reduce_sum.507), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_59.64 (reduce_sum.401: f32[], reduce_sum.402: f32[]) -> f32[] { - %reduce_sum.402 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.401 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.403 = f32[]{:T(128)} add(%reduce_sum.401, %reduce_sum.402), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_44.49 (reduce_sum.425: f32[], reduce_sum.429: f32[]) -> f32[] { + %reduce_sum.429 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.425 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.430 = f32[]{:T(128)} add(%reduce_sum.425, %reduce_sum.429), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_44.49 (reduce_sum.323: f32[], reduce_sum.324: f32[]) -> f32[] { - %reduce_sum.324 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.323 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.325 = f32[]{:T(128)} add(%reduce_sum.323, %reduce_sum.324), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.396 (param_0.1380: f32[2048], param_1.1568: f32[], param_2.1326: f32[], param_3.930: f32[], param_4.568: f32[2048], param_5.480: f32[], param_6.370: bf16[2048], param_7.213: pred[], param_8.130: f32[2048]) -> (f32[], f32[2048], f32[2048], f32[2048], f32[]) { - %param_0.1380 = f32[2048]{0:T(1024)S(1)} parameter(0) - %param_3.930 = f32[]{:T(128)S(6)} parameter(3) - %mul.2022.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_3.930), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.213 = pred[]{:T(512)S(6)} parameter(7) - %select_n.316.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} broadcast(%param_7.213), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.370 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(6) - %convert_element_type.1411.clone.1 = f32[2048]{0:T(1024)} convert(%param_6.370), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_5.480 = f32[]{:T(128)} parameter(5) - %div.956.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_5.480), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.955.clone.1 = f32[2048]{0:T(1024)} divide(%convert_element_type.1411.clone.1, %div.956.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.315.clone.1 = f32[2048]{0:T(1024)} select(%select_n.316.clone.1, %convert_element_type.1411.clone.1, %div.955.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1164.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.900.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1164.clone.1), dimensions={}, metadata={op_name="broadcast.86"} - %mul.2028.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %broadcast.900.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.407 (param_0.1421: f32[2048], param_1.1595: f32[], param_2.1337: f32[], param_3.920: f32[], param_4.559: f32[2048], param_5.493: f32[], param_6.368: bf16[2048], param_7.212: pred[], param_8.130: f32[2048]) -> (f32[], f32[2048], f32[2048], f32[2048], f32[]) { + %param_0.1421 = f32[2048]{0:T(1024)S(1)} parameter(0) + %param_3.920 = f32[]{:T(128)S(6)} parameter(3) + %mul.2545.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_3.920), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.212 = pred[]{:T(512)S(6)} parameter(7) + %select_n.316.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} broadcast(%param_7.212), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.368 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1435.clone.1 = f32[2048]{0:T(1024)} convert(%param_6.368), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_5.493 = f32[]{:T(128)} parameter(5) + %div.956.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_5.493), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.955.clone.1 = f32[2048]{0:T(1024)} divide(%convert_element_type.1435.clone.1, %div.956.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.315.clone.1 = f32[2048]{0:T(1024)} select(%select_n.316.clone.1, %convert_element_type.1435.clone.1, %div.955.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1152.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.810.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1152.clone.1), dimensions={}, metadata={op_name="broadcast.86"} + %mul.2551.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %broadcast.810.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.130 = f32[2048]{0:T(1024)S(1)} parameter(8) - %constant.1168.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.2029.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1168.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2027.clone.1 = f32[2048]{0:T(1024)} multiply(%param_8.130, %mul.2029.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1005.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2028.clone.1, %mul.2027.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1326 = f32[]{:T(128)S(6)} parameter(2) - %div.952.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_2.1326), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1156.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2552.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1156.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2550.clone.1 = f32[2048]{0:T(1024)} multiply(%param_8.130, %mul.2552.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.981.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2551.clone.1, %mul.2550.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1337 = f32[]{:T(128)S(6)} parameter(2) + %div.952.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_2.1337), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.77.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %select_n.315.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1167.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.2026.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1167.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2024.clone.1 = f32[2048]{0:T(1024)} multiply(%integer_pow.77.clone.1, %mul.2026.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.568 = f32[2048]{0:T(1024)S(1)} parameter(4) - %constant.1166.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.2025.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1166.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2023.clone.1 = f32[2048]{0:T(1024)} multiply(%param_4.568, %mul.2025.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1004.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2024.clone.1, %mul.2023.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1568 = f32[]{:T(128)S(6)} parameter(1) - %div.951.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_1.1568), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.950.clone.1 = f32[2048]{0:T(1024)} divide(%add.1004.clone.1, %div.951.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1155.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2549.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1155.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2547.clone.1 = f32[2048]{0:T(1024)} multiply(%integer_pow.77.clone.1, %mul.2549.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.559 = f32[2048]{0:T(1024)S(1)} parameter(4) + %constant.1154.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2548.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1154.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2546.clone.1 = f32[2048]{0:T(1024)} multiply(%param_4.559, %mul.2548.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.980.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2547.clone.1, %mul.2546.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1595 = f32[]{:T(128)S(6)} parameter(1) + %div.951.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_1.1595), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.950.clone.1 = f32[2048]{0:T(1024)} divide(%add.980.clone.1, %div.951.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.74.clone.1 = f32[2048]{0:T(1024)} sqrt(%div.950.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1165.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.1003.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1165.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.1002.clone.1 = f32[2048]{0:T(1024)} add(%sqrt.74.clone.1, %add.1003.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.438.clone.1 = f32[2048]{0:T(1024)} multiply(%div.952.clone.1, %add.1002.clone.1), metadata={op_name="multiply.49"} - %div.949.clone.1 = f32[2048]{0:T(1024)} divide(%add.1005.clone.1, %multiply.438.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.2021.clone.1 = f32[2048]{0:T(1024)} multiply(%param_0.1380, %broadcast.900.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1001.clone.1 = f32[2048]{0:T(1024)} add(%div.949.clone.1, %mul.2021.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.2020.clone.1 = f32[2048]{0:T(1024)} multiply(%mul.2022.clone.1, %add.1001.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1000.clone.1 = f32[2048]{0:T(1024)S(1)} add(%param_0.1380, %mul.2020.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.258 = f32[2048]{0:T(1024)} multiply(%add.1000.clone.1, %add.1000.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1212 = f32[]{:T(128)} constant(0) - %reduce.212 = f32[]{:T(128)} reduce(%square.258, %constant.1212), dimensions={0}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.213.clone.1 = f32[]{:T(128)} reduce(%integer_pow.77.clone.1, %constant.1212), dimensions={0}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.156 = (f32[]{:T(128)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.212, %add.1000.clone.1, %add.1004.clone.1, %add.1005.clone.1, %reduce.213.clone.1) -} - -%fused_computation.402 (param_0.1150: s32[512]) -> s32[1024] { - %constant.972 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.815 = s32[1024]{0:T(1024)} broadcast(%constant.972), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %param_0.1150 = s32[512]{0:T(512)S(1)} parameter(0) - %constant.973 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %pad.49 = s32[1024]{0:T(1024)} pad(%param_0.1150, %constant.973), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %constant.971 = s32[] constant(151935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.814 = s32[1024]{0:T(1024)} broadcast(%constant.971), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.815, %pad.49, %broadcast.814), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} -} - -%fused_computation.405 (param_0.1149: s32[4,128]) -> s32[512] { - %param_0.1149 = s32[4,128]{1,0:T(4,128)} parameter(0) - %constant.1065 = s32[]{:T(128)} constant(0) - %broadcast.834 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1065), dimensions={}, metadata={op_name="broadcast.95"} - %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.1149, %broadcast.834), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} - %constant.1051 = s32[]{:T(128)} constant(151936) - %add.925 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1051), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %add.903 = s32[4,128]{1,0:T(4,128)} add(%param_0.1149, %add.925), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.903, %param_0.1149), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} - ROOT %bitcast.409 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} -} - -%region_40.45 (reduce_sum.305: f32[], reduce_sum.309: f32[]) -> f32[] { - %reduce_sum.309 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.305 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.310 = f32[]{:T(128)} add(%reduce_sum.305, %reduce_sum.309), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_37.42 (reduce_sum.290: f32[], reduce_sum.291: f32[]) -> f32[] { - %reduce_sum.291 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.290 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.295 = f32[]{:T(128)} add(%reduce_sum.290, %reduce_sum.291), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.407 (param_0.1384: f32[4,128], param_1.1570: f32[4,128]) -> (f32[], f32[]) { - %param_0.1384 = f32[4,128]{1,0:T(4,128)} parameter(0) - %bitcast.413 = f32[128,4]{0,1:T(4,128)} bitcast(%param_0.1384), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.261 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.413, %bitcast.413), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1216 = f32[]{:T(128)} constant(0) - %reduce.214 = f32[]{:T(128)} reduce(%square.261, %constant.1216), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1570 = f32[4,128]{1,0:T(4,128)} parameter(1) - %bitcast.417.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_1.1570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.264.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.417.clone.1, %bitcast.417.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.215.clone.1 = f32[]{:T(128)} reduce(%square.264.clone.1, %constant.1216), dimensions={0,1}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.170 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.214, %reduce.215.clone.1) -} - -%region_72.77 (reduce_sum.470: f32[], reduce_sum.471: f32[]) -> f32[] { - %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.470 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.470, %reduce_sum.471), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_58.63 (reduce_sum.395: f32[], reduce_sum.396: f32[]) -> f32[] { - %reduce_sum.396 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.395 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.400 = f32[]{:T(128)} add(%reduce_sum.395, %reduce_sum.396), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.410 (param_0.1391: bf16[4,128], param_1.1576: f32[4,128], param_2.1329: f32[4,128], param_3.932: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { - %param_3.932 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %constant.1170.clone.1 = s32[]{:T(128)} constant(0) - %broadcast.901.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1170.clone.1), dimensions={}, metadata={op_name="broadcast.95"} - %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.932, %broadcast.901.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} - %param_1.1576 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1576), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} - %param_0.1391 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) - %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1391), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - %add.927 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %square.269 = f32[4,128]{1,0:T(4,128)} multiply(%add.927, %add.927), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} - %constant.1224 = f32[]{:T(128)} constant(0) - %broadcast.831 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1224), dimensions={}, metadata={op_name="broadcast.99"} - %mul.1913 = f32[4,128]{1,0:T(4,128)} multiply(%square.269, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %mul.1893 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.1913, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.216 = f32[]{:T(128)} reduce(%mul.1893, %constant.1224), dimensions={0,1}, to_apply=%region_72.77, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %param_2.1329 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1329), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} - %add.904.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1913), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %mul.1894.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.904.clone.1, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.219.clone.1 = f32[]{:T(128)} reduce(%mul.1894.clone.1, %constant.1224), dimensions={0,1}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %mul.1911.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.927, %broadcast.831), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %constant.1068.clone.1 = f32[]{:T(128)} constant(1) - %add.922.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1068.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - %add.915.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1911.clone.1, %add.922.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.216, %reduce.219.clone.1, %ne.6.clone.1, %add.915.clone.1) -} - -%region_69.74 (reduce_sum.452: f32[], reduce_sum.456: f32[]) -> f32[] { - %reduce_sum.456 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.452 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.452, %reduce_sum.456), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1153.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.979.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1153.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.978.clone.1 = f32[2048]{0:T(1024)} add(%sqrt.74.clone.1, %add.979.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.299.clone.1 = f32[2048]{0:T(1024)} multiply(%div.952.clone.1, %add.978.clone.1), metadata={op_name="multiply.34"} + %div.949.clone.1 = f32[2048]{0:T(1024)} divide(%add.981.clone.1, %multiply.299.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2544.clone.1 = f32[2048]{0:T(1024)} multiply(%param_0.1421, %broadcast.810.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.977.clone.1 = f32[2048]{0:T(1024)} add(%div.949.clone.1, %mul.2544.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2543.clone.1 = f32[2048]{0:T(1024)} multiply(%mul.2545.clone.1, %add.977.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.976.clone.1 = f32[2048]{0:T(1024)S(1)} add(%param_0.1421, %mul.2543.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.258 = f32[2048]{0:T(1024)} multiply(%add.976.clone.1, %add.976.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1200 = f32[]{:T(128)} constant(0) + %reduce.142 = f32[]{:T(128)} reduce(%square.258, %constant.1200), dimensions={0}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.143.clone.1 = f32[]{:T(128)} reduce(%integer_pow.77.clone.1, %constant.1200), dimensions={0}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.142, %add.976.clone.1, %add.980.clone.1, %add.981.clone.1, %reduce.143.clone.1) +} + +%fused_computation.413 (param_0.1191: s32[512]) -> s32[1024] { + %constant.960 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.727 = s32[1024]{0:T(1024)} broadcast(%constant.960), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %param_0.1191 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.961 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %pad.49 = s32[1024]{0:T(1024)} pad(%param_0.1191, %constant.961), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %constant.959 = s32[] constant(151935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.726 = s32[1024]{0:T(1024)} broadcast(%constant.959), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.727, %pad.49, %broadcast.726), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%fused_computation.416 (param_0.1190: s32[4,128]) -> s32[512] { + %param_0.1190 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.1055 = s32[]{:T(128)} constant(0) + %broadcast.747 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1055), dimensions={}, metadata={op_name="broadcast.95"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.1190, %broadcast.747), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} + %constant.1052 = s32[]{:T(128)} constant(151936) + %add.901 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1052), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %add.879 = s32[4,128]{1,0:T(4,128)} add(%param_0.1190, %add.901), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.879, %param_0.1190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} + ROOT %bitcast.390 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} +} + +%region_40.45 (reduce_sum.410: f32[], reduce_sum.411: f32[]) -> f32[] { + %reduce_sum.411 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.410 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.415 = f32[]{:T(128)} add(%reduce_sum.410, %reduce_sum.411), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_54.59 (reduce_sum.374: f32[], reduce_sum.375: f32[]) -> f32[] { - %reduce_sum.375 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.374 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.379 = f32[]{:T(128)} add(%reduce_sum.374, %reduce_sum.375), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_37.42 (reduce_sum.395: f32[], reduce_sum.396: f32[]) -> f32[] { + %reduce_sum.396 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.395 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.397 = f32[]{:T(128)} add(%reduce_sum.395, %reduce_sum.396), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.411 (param_0.1370: f32[128,4], param_1.1558: f32[], param_2.1316: f32[], param_3.920: f32[], param_4.558: f32[128,4], param_5.470: f32[], param_6.360: f32[4,128], param_7.203: pred[], param_8.120: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { - %param_0.1370 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.920 = f32[]{:T(128)S(6)} parameter(3) - %mul.1943.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.920), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.203 = pred[]{:T(512)S(6)} parameter(7) - %select_n.276.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.203), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.360 = f32[4,128]{1,0:T(4,128)} parameter(6) - %bitcast.468.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.360), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.470 = f32[]{:T(128)} parameter(5) - %div.876.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.470), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.875.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.468.clone.1, %div.876.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.275.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.276.clone.1, %bitcast.468.clone.1, %div.875.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1104.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.856.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1104.clone.1), dimensions={}, metadata={op_name="broadcast.78"} - %mul.1947.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %broadcast.856.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.418 (param_0.1425: f32[4,128], param_1.1597: f32[4,128]) -> (f32[], f32[]) { + %param_0.1425 = f32[4,128]{1,0:T(4,128)} parameter(0) + %bitcast.394 = f32[128,4]{0,1:T(4,128)} bitcast(%param_0.1425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.261 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.394, %bitcast.394), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1204 = f32[]{:T(128)} constant(0) + %reduce.144 = f32[]{:T(128)} reduce(%square.261, %constant.1204), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1597 = f32[4,128]{1,0:T(4,128)} parameter(1) + %bitcast.398.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_1.1597), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.264.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.398.clone.1, %bitcast.398.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.145.clone.1 = f32[]{:T(128)} reduce(%square.264.clone.1, %constant.1204), dimensions={0,1}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.171 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.144, %reduce.145.clone.1) +} + +%region_72.77 (reduce_sum.572: f32[], reduce_sum.576: f32[]) -> f32[] { + %reduce_sum.576 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.572 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.577 = f32[]{:T(128)} add(%reduce_sum.572, %reduce_sum.576), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_58.63 (reduce_sum.500: f32[], reduce_sum.501: f32[]) -> f32[] { + %reduce_sum.501 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.500 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.502 = f32[]{:T(128)} add(%reduce_sum.500, %reduce_sum.501), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.421 (param_0.1432: bf16[4,128], param_1.1603: f32[4,128], param_2.1340: f32[4,128], param_3.922: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.922 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.1158.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.811.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1158.clone.1), dimensions={}, metadata={op_name="broadcast.95"} + %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.922, %broadcast.811.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} + %param_1.1603 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1603), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} + %param_0.1432 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1432), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + %add.903 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %square.269 = f32[4,128]{1,0:T(4,128)} multiply(%add.903, %add.903), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} + %constant.1212 = f32[]{:T(128)} constant(0) + %broadcast.741 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1212), dimensions={}, metadata={op_name="broadcast.50"} + %mul.2434 = f32[4,128]{1,0:T(4,128)} multiply(%square.269, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %mul.2414 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.2434, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.146 = f32[]{:T(128)} reduce(%mul.2414, %constant.1212), dimensions={0,1}, to_apply=%region_72.77, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %param_2.1340 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1340), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} + %add.880.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.2434), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %mul.2415.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.880.clone.1, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.149.clone.1 = f32[]{:T(128)} reduce(%mul.2415.clone.1, %constant.1212), dimensions={0,1}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %mul.2432.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.903, %broadcast.741), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %constant.1056.clone.1 = f32[]{:T(128)} constant(1) + %add.898.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1056.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + %add.891.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.2432.clone.1, %add.898.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + ROOT %tuple.158 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.146, %reduce.149.clone.1, %ne.6.clone.1, %add.891.clone.1) +} + +%region_69.74 (reduce_sum.557: f32[], reduce_sum.558: f32[]) -> f32[] { + %reduce_sum.558 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.557 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.562 = f32[]{:T(128)} add(%reduce_sum.557, %reduce_sum.558), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_54.59 (reduce_sum.479: f32[], reduce_sum.480: f32[]) -> f32[] { + %reduce_sum.480 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.479 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.481 = f32[]{:T(128)} add(%reduce_sum.479, %reduce_sum.480), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.422 (param_0.1411: f32[128,4], param_1.1585: f32[], param_2.1327: f32[], param_3.910: f32[], param_4.549: f32[128,4], param_5.483: f32[], param_6.358: f32[4,128], param_7.202: pred[], param_8.120: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1411 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.910 = f32[]{:T(128)S(6)} parameter(3) + %mul.2466.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.910), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.202 = pred[]{:T(512)S(6)} parameter(7) + %select_n.276.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.202), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.358 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.449.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.483 = f32[]{:T(128)} parameter(5) + %div.876.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.483), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.875.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.449.clone.1, %div.876.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.275.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.276.clone.1, %bitcast.449.clone.1, %div.875.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1092.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.766.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1092.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.2470.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %broadcast.766.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.120 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1108.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.855.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1108.clone.1), dimensions={}, metadata={op_name="broadcast.77"} - %mul.1946.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.120, %broadcast.855.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.951.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1947.clone.1, %mul.1946.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1316 = f32[]{:T(128)S(6)} parameter(2) - %div.872.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1316), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1096.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.765.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1096.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.2469.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.120, %broadcast.765.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.927.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2470.clone.1, %mul.2469.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1327 = f32[]{:T(128)S(6)} parameter(2) + %div.872.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1327), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.67.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %select_n.275.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1107.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.854.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1107.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1945.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.854.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.558 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1106.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.853.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1106.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1944.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.558, %broadcast.853.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.950.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1945.clone.1, %mul.1944.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1558 = f32[]{:T(128)S(6)} parameter(1) - %div.871.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1558), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.870.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.950.clone.1, %div.871.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1095.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.764.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1095.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.2468.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.764.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.549 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1094.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.763.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1094.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.2467.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.549, %broadcast.763.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.926.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2468.clone.1, %mul.2467.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1585 = f32[]{:T(128)S(6)} parameter(1) + %div.871.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1585), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.870.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.926.clone.1, %div.871.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.64.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.870.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1105.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.851.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1105.clone.1), dimensions={}, metadata={op_name="broadcast.62"} - %add.949.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.851.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.428.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.872.clone.1, %add.949.clone.1), metadata={op_name="multiply.59"} - %div.869.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.951.clone.1, %multiply.428.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1942.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1370, %broadcast.856.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.948.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.869.clone.1, %mul.1942.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1941.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.1943.clone.1, %add.948.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.947.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1370, %mul.1941.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.265 = f32[128,4]{0,1:T(4,128)} multiply(%add.947.clone.1, %add.947.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1202 = f32[]{:T(128)} constant(0) - %reduce.217 = f32[]{:T(128)} reduce(%square.265, %constant.1202), dimensions={0,1}, to_apply=%region_69.74, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.221.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1202), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.159 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.217, %add.947.clone.1, %add.950.clone.1, %add.951.clone.1, %reduce.221.clone.1) -} - -%region_66.71 (reduce_sum.437: f32[], reduce_sum.438: f32[]) -> f32[] { - %reduce_sum.438 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.437 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.442 = f32[]{:T(128)} add(%reduce_sum.437, %reduce_sum.438), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_51.56 (reduce_sum.359: f32[], reduce_sum.360: f32[]) -> f32[] { - %reduce_sum.360 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.359 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.361 = f32[]{:T(128)} add(%reduce_sum.359, %reduce_sum.360), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1093.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.761.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1093.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.925.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.761.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.289.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.872.clone.1, %add.925.clone.1), metadata={op_name="multiply.44"} + %div.869.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.927.clone.1, %multiply.289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2465.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1411, %broadcast.766.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.924.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.869.clone.1, %mul.2465.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2464.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.2466.clone.1, %add.924.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.923.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1411, %mul.2464.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.265 = f32[128,4]{0,1:T(4,128)} multiply(%add.923.clone.1, %add.923.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1190 = f32[]{:T(128)} constant(0) + %reduce.147 = f32[]{:T(128)} reduce(%square.265, %constant.1190), dimensions={0,1}, to_apply=%region_69.74, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.151.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1190), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.160 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.147, %add.923.clone.1, %add.926.clone.1, %add.927.clone.1, %reduce.151.clone.1) +} + +%region_66.71 (reduce_sum.542: f32[], reduce_sum.543: f32[]) -> f32[] { + %reduce_sum.543 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.542 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.544 = f32[]{:T(128)} add(%reduce_sum.542, %reduce_sum.543), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_51.56 (reduce_sum.464: f32[], reduce_sum.465: f32[]) -> f32[] { + %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.464 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.464, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.412 (param_0.1373: f32[128,4], param_1.1561: f32[], param_2.1319: f32[], param_3.923: f32[], param_4.561: f32[128,4], param_5.473: f32[], param_6.363: f32[4,128], param_7.206: pred[], param_8.123: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { - %param_0.1373 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.923 = f32[]{:T(128)S(6)} parameter(3) - %mul.1970.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.923), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.206 = pred[]{:T(512)S(6)} parameter(7) - %select_n.288.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.206), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.363 = f32[4,128]{1,0:T(4,128)} parameter(6) - %bitcast.474.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.363), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.473 = f32[]{:T(128)} parameter(5) - %div.900.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.473), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.899.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.474.clone.1, %div.900.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.287.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.288.clone.1, %bitcast.474.clone.1, %div.899.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1122.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.866.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1122.clone.1), dimensions={}, metadata={op_name="broadcast.78"} - %mul.1974.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %broadcast.866.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.423 (param_0.1414: f32[128,4], param_1.1588: f32[], param_2.1330: f32[], param_3.913: f32[], param_4.552: f32[128,4], param_5.486: f32[], param_6.361: f32[4,128], param_7.205: pred[], param_8.123: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1414 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.913 = f32[]{:T(128)S(6)} parameter(3) + %mul.2493.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.913), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.205 = pred[]{:T(512)S(6)} parameter(7) + %select_n.288.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.205), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.361 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.455.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.361), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.486 = f32[]{:T(128)} parameter(5) + %div.900.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.486), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.899.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.455.clone.1, %div.900.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.287.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.288.clone.1, %bitcast.455.clone.1, %div.899.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1110.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.776.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1110.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.2497.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %broadcast.776.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.123 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1126.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.865.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1126.clone.1), dimensions={}, metadata={op_name="broadcast.77"} - %mul.1973.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.123, %broadcast.865.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.968.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1974.clone.1, %mul.1973.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1319 = f32[]{:T(128)S(6)} parameter(2) - %div.896.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1319), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1114.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.775.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1114.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.2496.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.123, %broadcast.775.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.944.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2497.clone.1, %mul.2496.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1330 = f32[]{:T(128)S(6)} parameter(2) + %div.896.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1330), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.70.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %select_n.287.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1125.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.864.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1125.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1972.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.864.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.561 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1124.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.863.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1124.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1971.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.561, %broadcast.863.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.967.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1972.clone.1, %mul.1971.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1561 = f32[]{:T(128)S(6)} parameter(1) - %div.895.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1561), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.894.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.967.clone.1, %div.895.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1113.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.774.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1113.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.2495.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.774.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.552 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1112.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.773.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1112.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.2494.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.552, %broadcast.773.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.943.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2495.clone.1, %mul.2494.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1588 = f32[]{:T(128)S(6)} parameter(1) + %div.895.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1588), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.894.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.943.clone.1, %div.895.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.67.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.894.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1123.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.861.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1123.clone.1), dimensions={}, metadata={op_name="broadcast.62"} - %add.966.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.67.clone.1, %broadcast.861.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.431.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.896.clone.1, %add.966.clone.1), metadata={op_name="multiply.56"} - %div.893.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.968.clone.1, %multiply.431.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1969.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1373, %broadcast.866.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.965.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.893.clone.1, %mul.1969.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1968.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.1970.clone.1, %add.965.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.964.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1373, %mul.1968.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.266 = f32[128,4]{0,1:T(4,128)} multiply(%add.964.clone.1, %add.964.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1205 = f32[]{:T(128)} constant(0) - %reduce.218 = f32[]{:T(128)} reduce(%square.266, %constant.1205), dimensions={0,1}, to_apply=%region_66.71, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.222.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1205), dimensions={0,1}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.160 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.218, %add.964.clone.1, %add.967.clone.1, %add.968.clone.1, %reduce.222.clone.1) -} - -%fused_computation.421 (param_0.1201: f32[4,128], param_1.1323: f32[4,128]) -> f32[4,128] { - %param_0.1201 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %param_1.1323 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %constant.1045 = f32[]{:T(128)} constant(0.00048828125) - %broadcast.837 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1045), dimensions={}, metadata={op_name="broadcast.399"} - %div.767 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1323, %broadcast.837), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.1043 = f32[]{:T(128)} constant(1e-06) - %add.935 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1043), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.934 = f32[4,128]{1,0:T(4,128)} add(%div.767, %add.935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %rsqrt.168 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.934), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} - %div.754 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.168, %add.934), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.1040 = f32[]{:T(128)} constant(-0.5) - %mul.1919 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1040), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1910 = f32[4,128]{1,0:T(4,128)} multiply(%div.754, %mul.1919), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1909 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1201, %mul.1910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.1039 = f32[]{:T(128)} constant(0.0009765625) - %mul.1918 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1039), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - ROOT %mul.1908 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1909, %mul.1918), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} -} - -%region_0.1 (reduce_sum.137: s32[], reduce_sum.138: s32[]) -> s32[] { - %reduce_sum.138 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.137 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.139 = s32[]{:T(128)} add(%reduce_sum.137, %reduce_sum.138), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} -} - -%fused_computation.425 (param_0.1220: pred[4,128]) -> s32[] { - %param_0.1220 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %convert_element_type.1403 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1220), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} - %constant.1066 = s32[]{:T(128)} constant(0) - ROOT %reduce.220 = s32[]{:T(128)} reduce(%convert_element_type.1403, %constant.1066), dimensions={0,1}, to_apply=%region_0.1, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} -} - -%fused_computation.428 (param_0.1203: f32[4,128]) -> f32[4,128] { - %param_0.1203 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1111.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.771.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1111.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.942.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.67.clone.1, %broadcast.771.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.292.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.896.clone.1, %add.942.clone.1), metadata={op_name="multiply.41"} + %div.893.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.944.clone.1, %multiply.292.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2492.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1414, %broadcast.776.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.941.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.893.clone.1, %mul.2492.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2491.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.2493.clone.1, %add.941.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.940.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1414, %mul.2491.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.266 = f32[128,4]{0,1:T(4,128)} multiply(%add.940.clone.1, %add.940.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1193 = f32[]{:T(128)} constant(0) + %reduce.148 = f32[]{:T(128)} reduce(%square.266, %constant.1193), dimensions={0,1}, to_apply=%region_66.71, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.152.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1193), dimensions={0,1}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.161 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.148, %add.940.clone.1, %add.943.clone.1, %add.944.clone.1, %reduce.152.clone.1) +} + +%fused_computation.432 (param_0.1242: f32[4,128], param_1.1350: f32[4,128]) -> f32[4,128] { + %param_0.1242 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1350 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) %constant.1046 = f32[]{:T(128)} constant(0.00048828125) - %broadcast.829 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1046), dimensions={}, metadata={op_name="broadcast.399"} - %div.759 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1203, %broadcast.829), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %broadcast.749 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1046), dimensions={}, metadata={op_name="broadcast.362"} + %div.767 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1350, %broadcast.749), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} %constant.1044 = f32[]{:T(128)} constant(1e-06) - %add.924 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1044), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.921 = f32[4,128]{1,0:T(4,128)} add(%div.759, %add.924), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - ROOT %rsqrt.166 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.921), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} -} - -%fused_computation.429 (param_0.1204: pred[4,128], param_1.1575: f32[]) -> f32[4,128] { - %param_0.1204 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %param_1.1575 = f32[]{:T(128)S(6)} parameter(1) - %broadcast_in_dim.288 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1575), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} - %constant.1223 = f32[]{:T(128)} constant(0) - %broadcast.833 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1223), dimensions={}, metadata={op_name="broadcast.99"} - ROOT %mul.1920 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.1204, %broadcast_in_dim.288, %broadcast.833), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %add.911 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1044), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.910 = f32[4,128]{1,0:T(4,128)} add(%div.767, %add.911), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %rsqrt.168 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.910), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} + %div.754 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.168, %add.910), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.1028 = f32[]{:T(128)} constant(-0.5) + %mul.2440 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1028), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2431 = f32[4,128]{1,0:T(4,128)} multiply(%div.754, %mul.2440), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2430 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1242, %mul.2431), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1027 = f32[]{:T(128)} constant(0.0009765625) + %mul.2439 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1027), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2429 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.2430, %mul.2439), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%region_7.10 (reduce_sum.234: s32[], reduce_sum.235: s32[]) -> s32[] { + %reduce_sum.235 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.234 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.236 = s32[]{:T(128)} add(%reduce_sum.234, %reduce_sum.235), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.435 (param_0.1261: pred[4,128]) -> s32[] { + %param_0.1261 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1427 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1261), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} + %constant.1054 = s32[]{:T(128)} constant(0) + ROOT %reduce.150 = s32[]{:T(128)} reduce(%convert_element_type.1427, %constant.1054), dimensions={0,1}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%fused_computation.439 (param_0.1245: f32[4,128]) -> f32[4,128] { + %param_0.1245 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1047 = f32[]{:T(128)} constant(0.00048828125) + %broadcast.745 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1047), dimensions={}, metadata={op_name="broadcast.362"} + %div.759 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1245, %broadcast.745), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.1045 = f32[]{:T(128)} constant(1e-06) + %add.900 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1045), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.897 = f32[4,128]{1,0:T(4,128)} add(%div.759, %add.900), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + ROOT %rsqrt.166 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.897), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} +} + +%fused_computation.440 (param_0.1244: pred[4,128], param_1.1602: f32[]) -> f32[4,128] { + %param_0.1244 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.1602 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.309 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1602), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} + %constant.1211 = f32[]{:T(128)} constant(0) + %broadcast.743 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1211), dimensions={}, metadata={op_name="broadcast.50"} + ROOT %mul.2441 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.1244, %broadcast_in_dim.309, %broadcast.743), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} } -%fused_computation.431 () -> f32[64] { - %constant.1049 = f32[]{:T(128)} constant(1e+06) - %broadcast.840 = f32[64]{0:T(128)} broadcast(%constant.1049), dimensions={}, metadata={op_name="broadcast.390"} +%fused_computation.442 () -> f32[64] { + %constant.1050 = f32[]{:T(128)} constant(1e+06) + %broadcast.752 = f32[64]{0:T(128)} broadcast(%constant.1050), dimensions={}, metadata={op_name="broadcast.353"} %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=0} - %constant.1048 = s32[]{:T(128)} constant(2) - %broadcast.839 = s32[64]{0:T(128)} broadcast(%constant.1048), dimensions={}, metadata={op_name="broadcast.391"} - %mul.1921 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.839), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} - %convert_element_type.1404 = f32[64]{0:T(128)} convert(%mul.1921), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %constant.1047 = f32[]{:T(128)} constant(0.0078125) - %broadcast.838 = f32[64]{0:T(128)} broadcast(%constant.1047), dimensions={}, metadata={op_name="broadcast.392"} - %div.768 = f32[64]{0:T(128)} multiply(%convert_element_type.1404, %broadcast.838), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.840, %div.768), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} + %constant.1049 = s32[]{:T(128)} constant(2) + %broadcast.751 = s32[64]{0:T(128)} broadcast(%constant.1049), dimensions={}, metadata={op_name="broadcast.354"} + %mul.2442 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.751), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} + %convert_element_type.1428 = f32[64]{0:T(128)} convert(%mul.2442), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %constant.1048 = f32[]{:T(128)} constant(0.0078125) + %broadcast.750 = f32[64]{0:T(128)} broadcast(%constant.1048), dimensions={}, metadata={op_name="broadcast.355"} + %div.768 = f32[64]{0:T(128)} multiply(%convert_element_type.1428, %broadcast.750), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.752, %div.768), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} } -%fused_computation.432 (param_0.1218: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { - %param_0.1218 = s32[4,128]{1,0:T(4,128)} parameter(0) - %convert_element_type.1405 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1218), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %bitcast.418 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1405), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.162 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.418, %convert_element_type.1405) +%fused_computation.443 (param_0.1259: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1259 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1429 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1259), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %bitcast.399 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1429), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.163 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.399, %convert_element_type.1429) } -%fused_computation.435 (param_0.1360: f32[2048,4]) -> bf16[4,2048] { - %param_0.1360 = f32[2048,4]{0,1:T(4,128)} parameter(0) - %bitcast.531 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1360), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.145 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.531) +%fused_computation.446 (param_0.1400: f32[2048,4]) -> bf16[4,2048] { + %param_0.1400 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.507 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1400), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.79 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.507) } -%fused_computation.436 (param_0.1359: f32[2048,4]) -> bf16[4,2048] { - %param_0.1359 = f32[2048,4]{0,1:T(4,128)} parameter(0) - %bitcast.530 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1359), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.147 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.530) +%fused_computation.447 (param_0.1401: f32[2048,4]) -> bf16[4,2048] { + %param_0.1401 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.508 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1401), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.81 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.508) } -%fused_computation.437 (param_0.1361: f32[128,4]) -> bf16[4,128] { - %param_0.1361 = f32[128,4]{0,1:T(4,128)} parameter(0) - %bitcast.532 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1361), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.149 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.532) +%fused_computation.448 (param_0.1402: f32[128,4]) -> bf16[4,128] { + %param_0.1402 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.509 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1402), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.83 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.509) } -%fused_computation.438 (param_0.1362: f32[128,4]) -> bf16[4,128] { - %param_0.1362 = f32[128,4]{0,1:T(4,128)} parameter(0) - %bitcast.533 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1362), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.151 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.533) +%fused_computation.449 (param_0.1403: f32[128,4]) -> bf16[4,128] { + %param_0.1403 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.510 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1403), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.85 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.510) } %region_8.11 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { @@ -1462,539 +1462,539 @@ StackFrames ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.287.clone.clone (param_0.1346: bf16[151936,2048]) -> bf16[151936,2048,1] { - %param_0.1346 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - ROOT %bitcast.526 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1346), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} +%fused_computation.298.clone.clone (param_0.1387: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1387 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.503 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1387), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} } -%fused_computation.368.clone.clone (param_0.1347: f32[4,128], param_1.1542: bf16[4,128,2048], param_2.1281: bf16[2048]) -> bf16[4,128,2048] { - %param_2.1281 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.476 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1281), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1542 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1438 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1542), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1347 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2067 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1347), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.2066 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1438, %mul.2067), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1437 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2066), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.475 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.476, %convert_element_type.1437), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} +%fused_computation.379.clone.clone (param_0.1388: f32[4,128], param_1.1569: bf16[4,128,2048], param_2.1292: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1569 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1462 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1569), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1388 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2607 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1388), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2606 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1462, %mul.2607), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1461 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2606), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1292 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2608 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1292), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2605 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1461, %mul.2608), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} } -%fused_computation.439 (param_0.1363: bf16[151936,2048], param_1.1551: f32[4,128], param_2.1305: bf16[4,128,2048], param_3.913: bf16[2048]) -> (bf16[4,128], bf16[4,128,151936]) { - %param_1.1551 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1305 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.913 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.270.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1551, %param_2.1305, %param_3.913), kind=kLoop, calls=%fused_computation.368.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1363 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - %fusion.253.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1363), kind=kLoop, calls=%fused_computation.287.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %convolution.85.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convolution(%fusion.270.clone.1, %fusion.253.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %constant.1195 = bf16[]{:T(256)} constant(-inf) - %reduce.223 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.85.clone.1, %constant.1195), dimensions={2}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - ROOT %tuple.164 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,151936]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.223, %convolution.85.clone.1) +%fused_computation.450 (param_0.1404: bf16[151936,2048], param_1.1578: f32[4,128], param_2.1316: bf16[4,128,2048], param_3.903: bf16[2048]) -> (bf16[4,128], bf16[4,128,151936]) { + %param_1.1578 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1316 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.903 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.280.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1578, %param_2.1316, %param_3.903), kind=kLoop, calls=%fused_computation.379.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1404 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %fusion.263.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1404), kind=kLoop, calls=%fused_computation.298.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %convolution.85.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convolution(%fusion.280.clone.1, %fusion.263.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %constant.1183 = bf16[]{:T(256)} constant(-inf) + %reduce.153 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.85.clone.1, %constant.1183), dimensions={2}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + ROOT %tuple.165 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,151936]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.153, %convolution.85.clone.1) } -%fused_computation.440 (param_0.1358: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { - %param_0.1358 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) - %bitcast.529 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1358), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.153 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.529) +%fused_computation.451 (param_0.1399: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1399 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %bitcast.506 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1399), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.87 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.506) } -%convert_element_type.767.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { - %rhs.1 = bf16[] parameter(1) - %lhs.1 = bf16[] parameter(0) - ROOT %add.755 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%convert_element_type.785.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { + %rhs = bf16[] parameter(1) + %lhs = bf16[] parameter(0) + ROOT %add.730 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.155.clone.clone (param_0.1534: bf16[4,2048], param_1.1687: s32[]) -> bf16[2048] { - %param_0.1534 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1687 = s32[]{:T(128)S(6)} parameter(1) - %constant.1361 = s32[]{:T(128)} constant(0) - %dynamic_slice.388 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1534, %param_1.1687, %constant.1361), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1362 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.244 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.388, %constant.1362), dimensions={0}, to_apply=%convert_element_type.767.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.167.clone.clone (param_0.1572: bf16[4,2048], param_1.1711: s32[]) -> bf16[2048] { + %param_0.1572 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1711 = s32[]{:T(128)S(6)} parameter(1) + %constant.1348 = s32[]{:T(128)} constant(0) + %dynamic_slice.394 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1572, %param_1.1711, %constant.1348), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1349 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.174 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.394, %constant.1349), dimensions={0}, to_apply=%convert_element_type.785.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%region_14.16 (reduce_sum.204: f32[], reduce_sum.205: f32[]) -> f32[] { - %reduce_sum.205 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.204 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.206 = f32[]{:T(128)} add(%reduce_sum.204, %reduce_sum.205), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_14.16 (reduce_sum.278: f32[], reduce_sum.282: f32[]) -> f32[] { + %reduce_sum.282 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.278 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.283 = f32[]{:T(128)} add(%reduce_sum.278, %reduce_sum.282), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.58.clone.clone (param_0.1535: bf16[4,4,128,2048], param_1.1688: s32[]) -> f32[4,128] { - %param_0.1535 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1688 = s32[]{:T(128)S(6)} parameter(1) +%fused_computation.61.clone.clone (param_0.1573: bf16[4,4,128,2048], param_1.1712: s32[]) -> f32[4,128] { + %param_0.1573 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1712 = s32[]{:T(128)S(6)} parameter(1) + %constant.1350 = s32[]{:T(128)} constant(0) + %dynamic_slice.395 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1573, %param_1.1712, %constant.1350, %constant.1350, %constant.1350), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.602 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %convert_element_type.1585 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.602), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.280 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1585, %convert_element_type.1585), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1351 = f32[]{:T(128)} constant(0) + ROOT %reduce.175 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.280, %constant.1351), dimensions={2}, to_apply=%region_14.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} +} + +%fused_computation.190.clone.1.clone (param_0.1574: f32[4,128]) -> f32[4,128] { + %param_0.1574 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1353 = f32[]{:T(128)} constant(0.00048828125) + %closed_call.106 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1353), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.999 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1574, %closed_call.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1352 = f32[]{:T(128)} constant(1e-06) + %closed_call.105 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1352), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.1015 = f32[4,128]{1,0:T(4,128)} add(%div.999, %closed_call.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.181 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1015), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%region_15.17 (reduce_sum.284: f32[], reduce_sum.285: f32[]) -> f32[] { + %reduce_sum.285 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.284 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.289 = f32[]{:T(128)} add(%reduce_sum.284, %reduce_sum.285), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.1.clone.clone.clone.clone (param_0.1587: bf16[4,2048,16,128], param_1.1721: s32[]) -> bf16[2048,16,128,1] { + %param_0.1587 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1721 = s32[]{:T(128)S(6)} parameter(1) %constant.1363 = s32[]{:T(128)} constant(0) - %dynamic_slice.389 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1535, %param_1.1688, %constant.1363, %constant.1363, %constant.1363), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.633 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.389), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1564 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.633), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.280 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1564, %convert_element_type.1564), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1364 = f32[]{:T(128)} constant(0) - ROOT %reduce.245 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.280, %constant.1364), dimensions={2}, to_apply=%region_14.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} -} - -%fused_computation.179.clone.1.clone (param_0.1536: f32[4,128]) -> f32[4,128] { - %param_0.1536 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1366 = f32[]{:T(128)} constant(0.00048828125) - %closed_call.106 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1366), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.999 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1536, %closed_call.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1365 = f32[]{:T(128)} constant(1e-06) - %closed_call.105 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1365), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.1039 = f32[4,128]{1,0:T(4,128)} add(%div.999, %closed_call.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.181 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1039), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%region_15.17 (reduce_sum.207: f32[], reduce_sum.211: f32[]) -> f32[] { - %reduce_sum.211 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.207 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.212 = f32[]{:T(128)} add(%reduce_sum.207, %reduce_sum.211), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.25.clone.1.clone.clone.clone.clone (param_0.1550: bf16[4,2048,16,128], param_1.1698: s32[]) -> bf16[2048,16,128,1] { - %param_0.1550 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1698 = s32[]{:T(128)S(6)} parameter(1) - %constant.1377 = s32[]{:T(128)} constant(0) - %dynamic_slice.395 = bf16[1,2048,16,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1550, %param_1.1698, %constant.1377, %constant.1377, %constant.1377), dynamic_slice_sizes={1,2048,16,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.644 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.114.clone.clone.clone.clone (param_0.1551: f32[4,128], param_1.1699: bf16[4,4,128,2048], param_2.1405: s32[], param_3.982: bf16[2048]) -> bf16[4,128,2048,1] { - %param_3.982 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.571 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.982), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1699 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1405 = s32[]{:T(128)S(6)} parameter(2) - %constant.1378 = s32[]{:T(128)} constant(0) - %dynamic_slice.396 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1699, %param_2.1405, %constant.1378, %constant.1378, %constant.1378), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.646 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.396), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1575 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.646), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1551 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2256 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1551), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2255 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1575, %mul.2256), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1574 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2255), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.570 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.571, %convert_element_type.1574), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.645 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.61.clone.clone (param_0.1552: bf16[4,2048,16,128], param_1.1700: s32[], param_2.1406: f32[4,128], param_3.983: bf16[4,4,128,2048], param_4.604: bf16[2048]) -> (f32[4,128,16], bf16[4,128,16,128]) { - %param_2.1406 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.983 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1700 = s32[]{:T(128)S(6)} parameter(1) - %param_4.604 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.74.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1406, %param_3.983, %param_1.1700, %param_4.604), kind=kLoop, calls=%fused_computation.114.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1552 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.49.clone.3 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1552, %param_1.1700), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.44.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.74.clone.3, %fusion.49.clone.3), window={size=1x16 pad=0_0x15_15 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %convert_element_type.1576 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%convolution.44.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.282 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1576, %convert_element_type.1576), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1379 = f32[]{:T(128)} constant(0) - %reduce.247 = f32[4,128,16]{1,2,0:T(8,128)S(1)} reduce(%square.282, %constant.1379), dimensions={3}, to_apply=%region_15.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.208 = (f32[4,128,16]{1,2,0:T(8,128)S(1)}, bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.247, %convolution.44.clone.3) -} - -%fused_computation.151.clone.1.clone (param_0.1553: f32[4,128,16]) -> f32[4,128,16] { - %param_0.1553 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) - %constant.1380 = f32[]{:T(128)} constant(0.0078125) - %closed_call.108 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1380), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.1001 = f32[4,128,16]{1,2,0:T(8,128)} multiply(%param_0.1553, %closed_call.108), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1381 = f32[]{:T(128)} constant(1e-06) - %add.1044 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1381), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %add.1043 = f32[4,128,16]{1,2,0:T(8,128)} add(%div.1001, %add.1044), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.183 = f32[4,128,16]{1,2,0:T(8,128)S(1)} rsqrt(%add.1043), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%fused_computation.182.clone.clone (param_0.1549: bf16[4,128], param_1.1697: s32[]) -> bf16[128] { - %param_0.1549 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1697 = s32[]{:T(128)S(6)} parameter(1) - %constant.1376 = s32[]{:T(128)} constant(0) - %dynamic_slice.394 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1549, %param_1.1697, %constant.1376), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.643 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.394), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %dynamic_slice.400 = bf16[1,2048,16,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1587, %param_1.1721, %constant.1363, %constant.1363, %constant.1363), dynamic_slice_sizes={1,2048,16,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.610 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.2.clone.clone.clone.clone (param_0.1588: f32[4,128], param_1.1722: bf16[2048], param_2.1414: bf16[4,4,128,2048], param_3.972: s32[]) -> bf16[4,128,2048,1] { + %param_2.1414 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.972 = s32[]{:T(128)S(6)} parameter(3) + %constant.1364 = s32[]{:T(128)} constant(0) + %dynamic_slice.401 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1414, %param_3.972, %constant.1364, %constant.1364, %constant.1364), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1596 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.401), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1588 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2890 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1588), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2889 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1596, %mul.2890), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1595 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2889), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1722 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2891 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1722), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2888 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1595, %mul.2891), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.611 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.64.clone.clone (param_0.1589: bf16[4,2048,16,128], param_1.1723: s32[], param_2.1415: f32[4,128], param_3.973: bf16[2048], param_4.596: bf16[4,4,128,2048]) -> (f32[4,128,16], bf16[4,128,16,128]) { + %param_2.1415 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.973 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %param_4.596 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(4) + %param_1.1723 = s32[]{:T(128)S(6)} parameter(1) + %fusion.91.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1415, %param_3.973, %param_4.596, %param_1.1723), kind=kLoop, calls=%fused_computation.89.clone.2.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1589 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.49.clone.3 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1589, %param_1.1723), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.44.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.91.clone.3, %fusion.49.clone.3), window={size=1x16 pad=0_0x15_15 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %convert_element_type.1597 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%convolution.44.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.282 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1597, %convert_element_type.1597), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1365 = f32[]{:T(128)} constant(0) + %reduce.177 = f32[4,128,16]{1,2,0:T(8,128)S(1)} reduce(%square.282, %constant.1365), dimensions={3}, to_apply=%region_15.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.210 = (f32[4,128,16]{1,2,0:T(8,128)S(1)}, bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.177, %convolution.44.clone.3) +} + +%fused_computation.162.clone.1.clone (param_0.1590: f32[4,128,16]) -> f32[4,128,16] { + %param_0.1590 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1366 = f32[]{:T(128)} constant(0.0078125) + %closed_call.108 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1366), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1001 = f32[4,128,16]{1,2,0:T(8,128)} multiply(%param_0.1590, %closed_call.108), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1367 = f32[]{:T(128)} constant(1e-06) + %add.1020 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1367), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %add.1019 = f32[4,128,16]{1,2,0:T(8,128)} add(%div.1001, %add.1020), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.183 = f32[4,128,16]{1,2,0:T(8,128)S(1)} rsqrt(%add.1019), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.193.clone.clone (param_0.1591: bf16[4,128], param_1.1724: s32[]) -> bf16[128] { + %param_0.1591 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1724 = s32[]{:T(128)S(6)} parameter(1) + %constant.1368 = s32[]{:T(128)} constant(0) + %dynamic_slice.402 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1591, %param_1.1724, %constant.1368), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.612 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.402), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.121.clone.1.clone (param_0.1554: f32[4,128,16], param_1.1701: bf16[4,128,16,128], param_2.1407: bf16[128]) -> bf16[4,128,16,128] { - %param_2.1407 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) - %dot_general.573 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1407), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1701 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1578 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_1.1701), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1554 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) - %mul.2258 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1554), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2257 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1578, %mul.2258), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1577 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2257), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.572 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%dot_general.573, %convert_element_type.1577), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} +%fused_computation.118.clone.1.clone (param_0.1592: f32[4,128,16], param_1.1725: bf16[4,128,16,128], param_2.1416: bf16[128]) -> bf16[4,128,16,128] { + %param_1.1725 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1599 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_1.1725), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1592 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2894 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1592), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2893 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1599, %mul.2894), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1598 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2893), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1416 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %mul.2895 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1416), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2892 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%convert_element_type.1598, %mul.2895), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.90.clone.clone (param_0.1555: bf16[4,128,16,128]) -> (bf16[4,128,16,64], bf16[4,128,16,64]) { - %param_0.1555 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.160 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1555), slice={[0:4], [0:128], [0:16], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} +%fused_computation.93.clone.clone (param_0.1593: bf16[4,128,16,128]) -> (bf16[4,128,16,64], bf16[4,128,16,64]) { + %param_0.1593 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1593), slice={[0:4], [0:128], [0:16], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.129 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.161 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1555), slice={[0:4], [0:128], [0:16], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} - ROOT %tuple.209 = (bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) + %split.161 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1593), slice={[0:4], [0:128], [0:16], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.211 = (bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) } -%fused_computation.187.clone.clone () -> f32[64] { - %constant.1355 = f32[]{:T(128)} constant(1e+06) - %closed_call.104 = f32[64]{0:T(128)} broadcast(%constant.1355), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} +%fused_computation.198.clone.clone () -> f32[64] { + %constant.1343 = f32[]{:T(128)} constant(1e+06) + %closed_call.104 = f32[64]{0:T(128)} broadcast(%constant.1343), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=0} - %constant.1354 = s32[]{:T(128)} constant(2) - %closed_call.103 = s32[64]{0:T(128)} broadcast(%constant.1354), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.2242 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1562 = f32[64]{0:T(128)} convert(%mul.2242), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %constant.1356 = f32[]{:T(128)} constant(0.0078125) - %closed_call.102 = f32[64]{0:T(128)} broadcast(%constant.1356), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.995 = f32[64]{0:T(128)} multiply(%convert_element_type.1562, %closed_call.102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1342 = s32[]{:T(128)} constant(2) + %closed_call.103 = s32[64]{0:T(128)} broadcast(%constant.1342), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2867 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1583 = f32[64]{0:T(128)} convert(%mul.2867), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %constant.1344 = f32[]{:T(128)} constant(0.0078125) + %closed_call.102 = f32[64]{0:T(128)} broadcast(%constant.1344), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.995 = f32[64]{0:T(128)} multiply(%convert_element_type.1583, %closed_call.102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.104, %div.995), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=0} } -%fused_computation.143.clone.clone (param_0.1529: f32[64], param_1.1683: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.1683 = f32[4,128]{1,0:T(4,128)} parameter(1) - %div.998 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1683), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %param_0.1529 = f32[64]{0:T(128)S(1)} parameter(0) - %div.997 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1529), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} +%fused_computation.154.clone.clone (param_0.1569: f32[64], param_1.1709: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1709 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.998 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1709), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %param_0.1569 = f32[64]{0:T(128)S(1)} parameter(0) + %div.997 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1569), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %div.996 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.998, %div.997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.996), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=0} - %convert_element_type.1563 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1584 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.996), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=0} - %convert_element_type.1189.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.205 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1563, %convert_element_type.1189.clone.3) + %convert_element_type.1213.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.207 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1584, %convert_element_type.1213.clone.3) } -%fused_computation.146.clone.1.clone (param_0.1530: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1530 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1357 = bf16[]{:T(256)} constant(-inf) - %pad.69 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1530, %constant.1357), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.68 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1530, %constant.1357), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.157.clone.1.clone (param_0.1570: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1570 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1345 = bf16[]{:T(256)} constant(-inf) + %pad.69 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1570, %constant.1345), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.68 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1570, %constant.1345), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.53 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.69, %pad.68), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - ROOT %bitcast.630 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.53), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.601 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.53), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.145.clone.1.clone (param_0.1545: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1545 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1374 = bf16[]{:T(256)} constant(-inf) - %pad.71 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1545, %constant.1374), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.70 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1545, %constant.1374), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.156.clone.1.clone (param_0.1583: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1583 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1361 = bf16[]{:T(256)} constant(-inf) + %pad.71 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1583, %constant.1361), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.70 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1583, %constant.1361), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.54 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.71, %pad.70), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - ROOT %bitcast.641 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.54), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} -} - -%fused_computation.94.clone.clone (param_0.1556: bf16[4,128,16,64], param_1.1702: bf16[4,128,16,64], param_2.1408: bf16[4,128,128], param_3.984: bf16[4,128,128], param_4.605: f32[4,128,16], param_5.499: bf16[4,128,16,128], param_6.384: bf16[128]) -> bf16[4,16,128,128] { - %param_6.384 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(6) - %dot_general.575 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_6.384), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_5.499 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(5) - %convert_element_type.1580 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_5.499), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_4.605 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(4) - %mul.2265 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_4.605), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2264 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1580, %mul.2265), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1579 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.574 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.575, %convert_element_type.1579), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_3.984 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.2263 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.984), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2261 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.574, %mul.2263), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1702 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1382 = bf16[]{:T(256)} constant(-inf) - %pad.75 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1702, %constant.1382), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1556 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.74 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1556, %constant.1382), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + ROOT %bitcast.608 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.54), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.102.clone.clone (param_0.1594: bf16[4,128,16,64], param_1.1726: bf16[4,128,16,64], param_2.1417: bf16[4,128,128], param_3.974: bf16[4,128,128], param_4.597: bf16[128], param_5.512: f32[4,128,16], param_6.382: bf16[4,128,16,128]) -> bf16[4,16,128,128] { + %param_6.382 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(6) + %convert_element_type.1601 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_6.382), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_5.512 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(5) + %mul.2904 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_5.512), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2903 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1601, %mul.2904), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1600 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2903), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_4.597 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(4) + %mul.2902 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.597), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2901 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convert_element_type.1600, %mul.2902), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_3.974 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2900 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.974), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2898 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%mul.2901, %mul.2900), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1726 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1369 = bf16[]{:T(256)} constant(-inf) + %pad.75 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1726, %constant.1369), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1594 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.74 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1594, %constant.1369), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.56 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.75, %pad.74), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_2.1408 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %mul.2262 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1408), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2260 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.56, %mul.2262), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.1045 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2261, %mul.2260), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %constant.1383 = bf16[]{:T(256)} constant(0.08838) - %closed_call.109 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.1383), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.2259 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%add.1045, %closed_call.109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - ROOT %bitcast.647 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%mul.2259), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%region_16.18 (reduce_sum.213: f32[], reduce_sum.214: f32[]) -> f32[] { - %reduce_sum.214 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.213 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.218 = f32[]{:T(128)} add(%reduce_sum.213, %reduce_sum.214), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.69.clone.1.clone.clone.clone.clone (param_0.1541: bf16[4,2048,8,128], param_1.1692: s32[]) -> bf16[2048,8,128,1] { - %param_0.1541 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1692 = s32[]{:T(128)S(6)} parameter(1) - %constant.1369 = s32[]{:T(128)} constant(0) - %dynamic_slice.392 = bf16[1,2048,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1541, %param_1.1692, %constant.1369, %constant.1369, %constant.1369), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.638 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.392), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.113.clone.clone.clone.clone (param_0.1542: f32[4,128], param_1.1693: bf16[4,4,128,2048], param_2.1401: s32[], param_3.979: bf16[2048]) -> bf16[4,128,2048,1] { - %param_3.979 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.565 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.979), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1693 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1401 = s32[]{:T(128)S(6)} parameter(2) - %constant.1370 = s32[]{:T(128)} constant(0) - %dynamic_slice.393 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1693, %param_2.1401, %constant.1370, %constant.1370, %constant.1370), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.640 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.393), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1568 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.640), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1542 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2246 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1542), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2245 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1568, %mul.2246), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1567 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2245), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.564 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.565, %convert_element_type.1567), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.639 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.564), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.84.clone.clone (param_0.1543: bf16[4,2048,8,128], param_1.1694: s32[], param_2.1402: f32[4,128], param_3.980: bf16[4,4,128,2048], param_4.602: bf16[2048]) -> (f32[4,128,8], bf16[4,128,8,128]) { - %param_2.1402 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.980 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1694 = s32[]{:T(128)S(6)} parameter(1) - %param_4.602 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.73.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1402, %param_3.980, %param_1.1694, %param_4.602), kind=kLoop, calls=%fused_computation.113.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1543 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.87.clone.3 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1543, %param_1.1694), kind=kLoop, calls=%fused_computation.69.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.50.clone.3 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.73.clone.3, %fusion.87.clone.3), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %convert_element_type.1569 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%convolution.50.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.281 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1569, %convert_element_type.1569), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1371 = f32[]{:T(128)} constant(0) - %reduce.246 = f32[4,128,8]{1,2,0:T(8,128)S(1)} reduce(%square.281, %constant.1371), dimensions={3}, to_apply=%region_16.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.206 = (f32[4,128,8]{1,2,0:T(8,128)S(1)}, bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.246, %convolution.50.clone.3) -} - -%fused_computation.154.clone.1.clone (param_0.1544: f32[4,128,8]) -> f32[4,128,8] { - %param_0.1544 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) - %constant.1372 = f32[]{:T(128)} constant(0.0078125) - %closed_call.107 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1372), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.1000 = f32[4,128,8]{1,2,0:T(8,128)} multiply(%param_0.1544, %closed_call.107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1373 = f32[]{:T(128)} constant(1e-06) - %add.1041 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1373), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %add.1040 = f32[4,128,8]{1,2,0:T(8,128)} add(%div.1000, %add.1041), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.182 = f32[4,128,8]{1,2,0:T(8,128)S(1)} rsqrt(%add.1040), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%fused_computation.184.clone.clone (param_0.1528: bf16[4,128], param_1.1682: s32[]) -> bf16[128] { - %param_0.1528 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1682 = s32[]{:T(128)S(6)} parameter(1) - %constant.1353 = s32[]{:T(128)} constant(0) - %dynamic_slice.385 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1528, %param_1.1682, %constant.1353), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.629 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.139.clone.1.clone (param_0.1546: f32[4,128,8], param_1.1695: bf16[4,128,8,128], param_2.1403: bf16[128]) -> bf16[4,128,8,128] { - %param_2.1403 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) - %dot_general.567 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1403), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1695 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1571 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_1.1695), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1546 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) - %mul.2248 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1546), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2247 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1571, %mul.2248), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1570 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2247), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.566 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%dot_general.567, %convert_element_type.1570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.126.clone.clone (param_0.1547: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { - %param_0.1547 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1547), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %param_2.1417 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %mul.2899 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1417), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2897 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.56, %mul.2899), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.1021 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2898, %mul.2897), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %constant.1370 = bf16[]{:T(256)} constant(0.08838) + %closed_call.109 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.1370), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2896 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%add.1021, %closed_call.109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.613 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%mul.2896), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%region_16.18 (reduce_sum.290: f32[], reduce_sum.291: f32[]) -> f32[] { + %reduce_sum.291 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.290 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.292 = f32[]{:T(128)} add(%reduce_sum.290, %reduce_sum.291), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.70.clone.1.clone.clone.clone.clone (param_0.1579: bf16[4,2048,8,128], param_1.1716: s32[]) -> bf16[2048,8,128,1] { + %param_0.1579 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1716 = s32[]{:T(128)S(6)} parameter(1) + %constant.1356 = s32[]{:T(128)} constant(0) + %dynamic_slice.398 = bf16[1,2048,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1579, %param_1.1716, %constant.1356, %constant.1356, %constant.1356), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.606 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.398), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.1.clone.clone.clone.clone (param_0.1580: f32[4,128], param_1.1717: bf16[2048], param_2.1410: bf16[4,4,128,2048], param_3.969: s32[]) -> bf16[4,128,2048,1] { + %param_2.1410 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.969 = s32[]{:T(128)S(6)} parameter(3) + %constant.1357 = s32[]{:T(128)} constant(0) + %dynamic_slice.399 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1410, %param_3.969, %constant.1357, %constant.1357, %constant.1357), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1589 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1580 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2874 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1580), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2873 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1589, %mul.2874), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1588 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2873), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1717 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2875 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1717), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2872 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1588, %mul.2875), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.607 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2872), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.85.clone.clone (param_0.1581: bf16[4,2048,8,128], param_1.1718: s32[], param_2.1411: f32[4,128], param_3.970: bf16[2048], param_4.594: bf16[4,4,128,2048]) -> (f32[4,128,8], bf16[4,128,8,128]) { + %param_2.1411 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.970 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %param_4.594 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(4) + %param_1.1718 = s32[]{:T(128)S(6)} parameter(1) + %fusion.90.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1411, %param_3.970, %param_4.594, %param_1.1718), kind=kLoop, calls=%fused_computation.89.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1581 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.85.clone.3 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1581, %param_1.1718), kind=kLoop, calls=%fused_computation.70.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.56.clone.3 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.90.clone.3, %fusion.85.clone.3), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %convert_element_type.1590 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%convolution.56.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.281 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1590, %convert_element_type.1590), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1358 = f32[]{:T(128)} constant(0) + %reduce.176 = f32[4,128,8]{1,2,0:T(8,128)S(1)} reduce(%square.281, %constant.1358), dimensions={3}, to_apply=%region_16.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.208 = (f32[4,128,8]{1,2,0:T(8,128)S(1)}, bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.176, %convolution.56.clone.3) +} + +%fused_computation.165.clone.1.clone (param_0.1582: f32[4,128,8]) -> f32[4,128,8] { + %param_0.1582 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1359 = f32[]{:T(128)} constant(0.0078125) + %closed_call.107 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1359), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1000 = f32[4,128,8]{1,2,0:T(8,128)} multiply(%param_0.1582, %closed_call.107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1360 = f32[]{:T(128)} constant(1e-06) + %add.1017 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1360), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %add.1016 = f32[4,128,8]{1,2,0:T(8,128)} add(%div.1000, %add.1017), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.182 = f32[4,128,8]{1,2,0:T(8,128)S(1)} rsqrt(%add.1016), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.195.clone.clone (param_0.1568: bf16[4,128], param_1.1708: s32[]) -> bf16[128] { + %param_0.1568 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1708 = s32[]{:T(128)S(6)} parameter(1) + %constant.1341 = s32[]{:T(128)} constant(0) + %dynamic_slice.392 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1568, %param_1.1708, %constant.1341), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.600 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.392), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.144.clone.1.clone (param_0.1584: f32[4,128,8], param_1.1719: bf16[4,128,8,128], param_2.1412: bf16[128]) -> bf16[4,128,8,128] { + %param_1.1719 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1592 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_1.1719), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1584 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2878 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1584), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2877 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1592, %mul.2878), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1591 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2877), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1412 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %mul.2879 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1412), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2876 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%convert_element_type.1591, %mul.2879), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.125.clone.clone (param_0.1585: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1585 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1585), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1547), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} - ROOT %tuple.207 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) -} - -%fused_computation.129.clone.clone (param_0.1548: bf16[4,128,8,64], param_1.1696: bf16[4,128,8,64], param_2.1404: bf16[4,128,128], param_3.981: bf16[4,128,128], param_4.603: f32[4,128,8], param_5.498: bf16[4,128,8,128], param_6.383: bf16[128]) -> bf16[4,8,128,128] { - %param_6.383 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(6) - %dot_general.569 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_6.383), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_5.498 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(5) - %convert_element_type.1573 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_5.498), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_4.603 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(4) - %mul.2254 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_4.603), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2253 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1573, %mul.2254), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1572 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2253), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.568 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.569, %convert_element_type.1572), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_3.981 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.2252 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.981), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2250 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.568, %mul.2252), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1696 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1375 = bf16[]{:T(256)} constant(-inf) - %pad.73 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1696, %constant.1375), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1548 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.72 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1548, %constant.1375), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1585), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.209 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) +} + +%fused_computation.128.clone.clone (param_0.1586: bf16[4,128,8,64], param_1.1720: bf16[4,128,8,64], param_2.1413: bf16[4,128,128], param_3.971: bf16[4,128,128], param_4.595: bf16[128], param_5.511: f32[4,128,8], param_6.381: bf16[4,128,8,128]) -> bf16[4,8,128,128] { + %param_6.381 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(6) + %convert_element_type.1594 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_6.381), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_5.511 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(5) + %mul.2887 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_5.511), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2886 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1594, %mul.2887), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1593 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2886), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_4.595 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(4) + %mul.2885 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.595), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2884 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convert_element_type.1593, %mul.2885), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_3.971 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2883 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.971), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2881 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%mul.2884, %mul.2883), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1720 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1362 = bf16[]{:T(256)} constant(-inf) + %pad.73 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1720, %constant.1362), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1586 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.72 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1586, %constant.1362), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.55 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.73, %pad.72), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_2.1404 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %mul.2251 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1404), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2249 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.55, %mul.2251), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.1042 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2250, %mul.2249), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %bitcast.642 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.1042), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.169.clone.clone (param_0.1537: bf16[4,2048,8,128], param_1.1689: s32[]) -> bf16[1,2048,8,128] { - %param_0.1537 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) - %param_1.1689 = s32[]{:T(128)S(6)} parameter(1) - %constant.1367 = s32[]{:T(128)} constant(0) - ROOT %dynamic_slice.390 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1537, %param_1.1689, %constant.1367, %constant.1367, %constant.1367), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} -} - -%fused_computation.70.clone.1.clone.clone.clone.clone (param_0.1538: bf16[1,2048,8,128]) -> bf16[2048,8,128,1] { - %param_0.1538 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %copy.204 = bf16[1,2048,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1538), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} - ROOT %bitcast.634 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.204), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.111.clone.clone.clone.clone (param_0.1539: f32[4,128], param_1.1690: bf16[4,4,128,2048], param_2.1399: s32[], param_3.977: bf16[2048]) -> bf16[4,128,2048,1] { - %param_3.977 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.563 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.977), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1690 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1399 = s32[]{:T(128)S(6)} parameter(2) - %constant.1368 = s32[]{:T(128)} constant(0) - %dynamic_slice.391 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1690, %param_2.1399, %constant.1368, %constant.1368, %constant.1368), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.636 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.391), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1566 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.636), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1539 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2244 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1539), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2243 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1566, %mul.2244), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1565 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2243), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.562 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.563, %convert_element_type.1565), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.635 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.562), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.140.clone.clone (param_0.1540: bf16[1,2048,8,128], param_1.1691: f32[4,128], param_2.1400: bf16[4,4,128,2048], param_3.978: s32[], param_4.601: bf16[2048]) -> bf16[4,8,128,128] { - %param_1.1691 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1400 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) - %param_3.978 = s32[]{:T(128)S(6)} parameter(3) - %param_4.601 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.373 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1691, %param_2.1400, %param_3.978, %param_4.601), kind=kLoop, calls=%fused_computation.111.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1540 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %fusion.372 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1540), kind=kLoop, calls=%fused_computation.70.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.106 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.373, %fusion.372), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - ROOT %bitcast.637 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.188.clone.clone (param_0.1578: f32[4,16,128,128]) -> (f32[4,16,128], f32[4,16,128,1]) { - %param_0.1578 = f32[4,16,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) - %slice.11 = f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1578), slice={[0:4], [0:16], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} - %bitcast.660 = f32[4,16,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} - ROOT %tuple.213 = (f32[4,16,128]{2,1,0:T(8,128)S(1)}, f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)}) tuple(%bitcast.660, %slice.11) -} - -%region_17.20 (reduce_sum.219: f32[], reduce_sum.220: f32[]) -> f32[] { - %reduce_sum.220 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.219 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.221 = f32[]{:T(128)} add(%reduce_sum.219, %reduce_sum.220), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone (param_0.1557: bf16[4,16,128,2048], param_1.1703: s32[]) -> bf16[16,128,2048,1] { - %param_0.1557 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1703 = s32[]{:T(128)S(6)} parameter(1) - %constant.1384 = s32[]{:T(128)} constant(0) - %dynamic_slice.397 = bf16[1,16,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1557, %param_1.1703, %constant.1384, %constant.1384, %constant.1384), dynamic_slice_sizes={1,16,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.648 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.397), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.103.clone.clone.clone.clone.clone.clone (param_0.1558: bf16[4,16,128,128]) -> bf16[4,128,16,128] { - %param_0.1558 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.649 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1558), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.64.clone.clone (param_0.1559: bf16[4,16,128,2048], param_1.1704: s32[], param_2.1409: bf16[4,16,128,128], param_3.985: bf16[4,4,128,2048]) -> (f32[4,128], bf16[4,128,2048]) { - %param_3.985 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1704 = s32[]{:T(128)S(6)} parameter(1) - %constant.436.clone.1.clone.3 = s32[]{:T(128)} constant(0) - %dynamic_slice.242.clone.3 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.985, %param_1.1704, %constant.436.clone.1.clone.3, %constant.436.clone.1.clone.3, %constant.436.clone.1.clone.3), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.227.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.242.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %param_2.1409 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %fusion.96.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1409), kind=kLoop, calls=%fused_computation.103.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} - %param_0.1559 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %fusion.95.clone.3 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1559, %param_1.1704), kind=kLoop, calls=%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.62.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.96.clone.3, %fusion.95.clone.3), window={size=1x16}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %bitcast.203.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %add.768.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.227.clone.3, %bitcast.203.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %convert_element_type.1581 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%add.768.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.283 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1581, %convert_element_type.1581), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1385 = f32[]{:T(128)} constant(0) - %reduce.248 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.283, %constant.1385), dimensions={2}, to_apply=%region_17.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.210 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.248, %add.768.clone.3) -} - -%convert_element_type.763.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { - %rhs = bf16[] parameter(1) - %lhs = bf16[] parameter(0) - ROOT %add.754 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %param_2.1413 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %mul.2882 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1413), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2880 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.55, %mul.2882), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.1018 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2881, %mul.2880), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.609 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.1018), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.181.clone.clone (param_0.1575: bf16[4,2048,8,128], param_1.1713: s32[]) -> bf16[1,2048,8,128] { + %param_0.1575 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) + %param_1.1713 = s32[]{:T(128)S(6)} parameter(1) + %constant.1354 = s32[]{:T(128)} constant(0) + ROOT %dynamic_slice.396 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1575, %param_1.1713, %constant.1354, %constant.1354, %constant.1354), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +} + +%fused_computation.71.clone.1.clone.clone.clone.clone (param_0.1576: bf16[1,2048,8,128]) -> bf16[2048,8,128,1] { + %param_0.1576 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %copy.200 = bf16[1,2048,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1576), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} + ROOT %bitcast.603 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.200), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.clone.clone.clone.clone (param_0.1577: f32[4,128], param_1.1714: bf16[2048], param_2.1408: bf16[4,4,128,2048], param_3.967: s32[]) -> bf16[4,128,2048,1] { + %param_2.1408 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.967 = s32[]{:T(128)S(6)} parameter(3) + %constant.1355 = s32[]{:T(128)} constant(0) + %dynamic_slice.397 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1408, %param_3.967, %constant.1355, %constant.1355, %constant.1355), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1587 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.397), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1577 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2870 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1577), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2869 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1587, %mul.2870), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1586 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2869), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1714 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2871 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1714), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2868 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1586, %mul.2871), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.604 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2868), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.151.clone.clone (param_0.1578: bf16[1,2048,8,128], param_1.1715: f32[4,128], param_2.1409: bf16[2048], param_3.968: bf16[4,4,128,2048], param_4.593: s32[]) -> bf16[4,8,128,128] { + %param_1.1715 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1409 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %param_3.968 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.593 = s32[]{:T(128)S(6)} parameter(4) + %fusion.380 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1715, %param_2.1409, %param_3.968, %param_4.593), kind=kLoop, calls=%fused_computation.89.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1578 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %fusion.379 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1578), kind=kLoop, calls=%fused_computation.71.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.105 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.380, %fusion.379), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + ROOT %bitcast.605 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.199.clone.clone (param_0.1618: f32[4,16,128,128]) -> (f32[4,16,128], f32[4,16,128,1]) { + %param_0.1618 = f32[4,16,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) + %slice.11 = f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1618), slice={[0:4], [0:16], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} + %bitcast.626 = f32[4,16,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} + ROOT %tuple.216 = (f32[4,16,128]{2,1,0:T(8,128)S(1)}, f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)}) tuple(%bitcast.626, %slice.11) +} + +%region_17.20 (reduce_sum.296: f32[], reduce_sum.297: f32[]) -> f32[] { + %reduce_sum.297 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.296 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.298 = f32[]{:T(128)} add(%reduce_sum.296, %reduce_sum.297), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone (param_0.1595: bf16[4,16,128,2048], param_1.1727: s32[]) -> bf16[16,128,2048,1] { + %param_0.1595 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1727 = s32[]{:T(128)S(6)} parameter(1) + %constant.1371 = s32[]{:T(128)} constant(0) + %dynamic_slice.403 = bf16[1,16,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1595, %param_1.1727, %constant.1371, %constant.1371, %constant.1371), dynamic_slice_sizes={1,16,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.614 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.403), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.110.clone.clone.clone.clone.clone.clone (param_0.1596: bf16[4,16,128,128]) -> bf16[4,128,16,128] { + %param_0.1596 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.615 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1596), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.67.clone.clone (param_0.1597: bf16[4,16,128,2048], param_1.1728: s32[], param_2.1418: bf16[4,16,128,128], param_3.975: bf16[4,4,128,2048]) -> (f32[4,128], bf16[4,128,2048]) { + %param_3.975 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1728 = s32[]{:T(128)S(6)} parameter(1) + %constant.414.clone.1.clone.3 = s32[]{:T(128)} constant(0) + %dynamic_slice.265.clone.3 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.975, %param_1.1728, %constant.414.clone.1.clone.3, %constant.414.clone.1.clone.3, %constant.414.clone.1.clone.3), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.212.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.265.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %param_2.1418 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %fusion.100.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1418), kind=kLoop, calls=%fused_computation.110.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_0.1597 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %fusion.99.clone.3 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1597, %param_1.1728), kind=kLoop, calls=%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.62.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.100.clone.3, %fusion.99.clone.3), window={size=1x16}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %bitcast.204.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %add.744.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.212.clone.3, %bitcast.204.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %convert_element_type.1602 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%add.744.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.283 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1602, %convert_element_type.1602), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1372 = f32[]{:T(128)} constant(0) + %reduce.178 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.283, %constant.1372), dimensions={2}, to_apply=%region_17.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.212 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.178, %add.744.clone.3) +} + +%convert_element_type.808.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %rhs.1 = bf16[] parameter(1) + %lhs.1 = bf16[] parameter(0) + ROOT %add.731 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.156.clone.clone (param_0.1531: bf16[4,2048], param_1.1684: s32[]) -> bf16[2048] { - %param_0.1531 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1684 = s32[]{:T(128)S(6)} parameter(1) - %constant.1358 = s32[]{:T(128)} constant(0) - %dynamic_slice.386 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1531, %param_1.1684, %constant.1358), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1359 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.243 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.386, %constant.1359), dimensions={0}, to_apply=%convert_element_type.763.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.166.clone.clone (param_0.1571: bf16[4,2048], param_1.1710: s32[]) -> bf16[2048] { + %param_0.1571 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1710 = s32[]{:T(128)S(6)} parameter(1) + %constant.1346 = s32[]{:T(128)} constant(0) + %dynamic_slice.393 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1571, %param_1.1710, %constant.1346), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1347 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.173 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.393, %constant.1347), dimensions={0}, to_apply=%convert_element_type.808.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.13.clone.clone.clone (param_0.1532: bf16[4,6144,2048], param_1.1685: s32[]) -> bf16[6144,2048,1] { - %param_0.1532 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1685 = s32[]{:T(128)S(6)} parameter(1) - %constant.1360 = s32[]{:T(128)} constant(0) - %dynamic_slice.387 = bf16[1,6144,2048]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1532, %param_1.1685, %constant.1360, %constant.1360), dynamic_slice_sizes={1,6144,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.632 = bf16[6144,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.387), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.191.clone.1.clone (param_0.1598: f32[4,128]) -> f32[4,128] { + %param_0.1598 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1374 = f32[]{:T(128)} constant(0.00048828125) + %closed_call.111 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1374), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1002 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1598, %closed_call.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1373 = f32[]{:T(128)} constant(1e-06) + %closed_call.110 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1373), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.1022 = f32[4,128]{1,0:T(4,128)} add(%div.1002, %closed_call.110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.184 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1022), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.11.clone.1.clone.clone (param_0.1599: bf16[4,2048,6144], param_1.1729: s32[]) -> bf16[2048,6144,1] { + %param_0.1599 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1729 = s32[]{:T(128)S(6)} parameter(1) + %constant.1375 = s32[]{:T(128)} constant(0) + %dynamic_slice.404 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1599, %param_1.1729, %constant.1375, %constant.1375), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.616 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.404), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.116.clone.3.clone.clone (param_0.1600: f32[4,128], param_1.1730: bf16[4,128,2048], param_2.1419: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1730 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1604 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1730), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1600 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2907 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1600), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2906 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1604, %mul.2907), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1603 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2906), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1419 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2908 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1419), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2905 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1603, %mul.2908), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.23.clone.clone (param_0.1601: bf16[4,2048,6144], param_1.1731: s32[], param_2.1420: f32[4,128], param_3.976: bf16[4,128,2048], param_4.598: bf16[2048]) -> bf16[4,128,6144] { + %param_2.1420 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.976 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.598 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.382 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1420, %param_3.976, %param_4.598), kind=kLoop, calls=%fused_computation.116.clone.3.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1601 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1731 = s32[]{:T(128)S(6)} parameter(1) + %fusion.381 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1601, %param_1.1731), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.106 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.382, %fusion.381), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.14.clone.clone.clone (param_0.1604: bf16[4,2048,6144], param_1.1734: s32[]) -> bf16[2048,6144,1] { + %param_0.1604 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1734 = s32[]{:T(128)S(6)} parameter(1) + %constant.1377 = s32[]{:T(128)} constant(0) + %dynamic_slice.406 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1604, %param_1.1734, %constant.1377, %constant.1377), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.619 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.406), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.116.clone.2.clone.clone (param_0.1605: f32[4,128], param_1.1735: bf16[4,128,2048], param_2.1422: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1735 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1606 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1735), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1605 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2911 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1605), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2910 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1606, %mul.2911), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1605 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1422 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2912 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1422), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2909 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1605, %mul.2912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.20.clone.clone (param_0.1606: bf16[4,2048,6144], param_1.1736: s32[], param_2.1423: f32[4,128], param_3.977: bf16[4,128,2048], param_4.599: bf16[2048]) -> bf16[4,128,6144] { + %param_2.1423 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.977 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.599 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.386 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1423, %param_3.977, %param_4.599), kind=kLoop, calls=%fused_computation.116.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1606 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1736 = s32[]{:T(128)S(6)} parameter(1) + %fusion.385 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1606, %param_1.1736), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.108 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.386, %fusion.385), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.12.clone.clone.clone (param_0.1602: bf16[4,6144,2048], param_1.1732: s32[]) -> bf16[6144,2048,1] { + %param_0.1602 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1732 = s32[]{:T(128)S(6)} parameter(1) + %constant.1376 = s32[]{:T(128)} constant(0) + %dynamic_slice.405 = bf16[1,6144,2048]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1602, %param_1.1732, %constant.1376, %constant.1376), dynamic_slice_sizes={1,6144,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.618 = bf16[6144,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.405), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } %bitcast_fusion.1.clone.clone (bitcast_input.4: bf16[4,128,2048]) -> bf16[4,128,2048] { - %bitcast_input.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.631 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.4) -} - -%fused_computation.14.clone.clone (param_0.1533: bf16[4,128,2048], param_1.1686: bf16[4,6144,2048], param_2.1398: s32[]) -> bf16[6144,4,128] { - %param_1.1686 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1398 = s32[]{:T(128)S(6)} parameter(2) - %fusion.370 = bf16[6144,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1686, %param_2.1398), kind=kLoop, calls=%fused_computation.13.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1533 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %fusion.371 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1533), kind=kLoop, calls=%bitcast_fusion.1.clone.clone - ROOT %convolution.105 = bf16[6144,4,128]{0,2,1:T(8,128)(2,1)S(1)} convolution(%fusion.370, %fusion.371), window={size=4 pad=3_3 rhs_reversal=1}, dim_labels=bf0_0oi->b0f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} -} - -%fused_computation.180.clone.1.clone (param_0.1560: f32[4,128]) -> f32[4,128] { - %param_0.1560 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1387 = f32[]{:T(128)} constant(0.00048828125) - %closed_call.111 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1387), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.1002 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1560, %closed_call.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1386 = f32[]{:T(128)} constant(1e-06) - %closed_call.110 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1386), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.1046 = f32[4,128]{1,0:T(4,128)} add(%div.1002, %closed_call.110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.184 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1046), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%fused_computation.12.clone.1.clone.clone (param_0.1564: bf16[4,2048,6144], param_1.1708: s32[]) -> bf16[2048,6144,1] { - %param_0.1564 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1708 = s32[]{:T(128)S(6)} parameter(1) - %constant.1389 = s32[]{:T(128)} constant(0) - %dynamic_slice.399 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1564, %param_1.1708, %constant.1389, %constant.1389), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.651 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.119.clone.3.clone.clone (param_0.1565: f32[4,128], param_1.1709: bf16[4,128,2048], param_2.1412: bf16[2048]) -> bf16[4,128,2048] { - %param_2.1412 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.579 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1412), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1709 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1585 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1709), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1565 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2269 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1565), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2268 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1585, %mul.2269), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1584 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2268), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.578 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.579, %convert_element_type.1584), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.21.clone.clone (param_0.1566: bf16[4,2048,6144], param_1.1710: s32[], param_2.1413: f32[4,128], param_3.987: bf16[4,128,2048], param_4.607: bf16[2048]) -> bf16[4,128,6144] { - %param_2.1413 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.987 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.607 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.377 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1413, %param_3.987, %param_4.607), kind=kLoop, calls=%fused_computation.119.clone.3.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1566 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1710 = s32[]{:T(128)S(6)} parameter(1) - %fusion.376 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1566, %param_1.1710), kind=kLoop, calls=%fused_computation.12.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %convolution.108 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.377, %fusion.376), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} -} - -%fused_computation.11.clone.1.clone.clone (param_0.1568: bf16[4,2048,6144], param_1.1712: s32[]) -> bf16[2048,6144,1] { - %param_0.1568 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1712 = s32[]{:T(128)S(6)} parameter(1) - %constant.1391 = s32[]{:T(128)} constant(0) - %dynamic_slice.400 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1568, %param_1.1712, %constant.1391, %constant.1391), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.653 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.47.clone.1.clone.clone (param_0.1567: bf16[6144,4,128], param_1.1711: bf16[4,128,6144]) -> bf16[4,128,6144] { - %param_1.1711 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1390 = bf16[]{:T(256)} constant(1) - %jit_silu_.44 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1390), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} - %neg.130 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} negate(%param_1.1711), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} - %exp.69 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} exponential(%neg.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} - %add.1047 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} - %div.1003 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.1047), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} - %mul.2271 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1711, %div.1003), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %bitcast_input.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.617 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.4)