-
Notifications
You must be signed in to change notification settings - Fork 518
[ROCm] optimizations to reduce temp size prior to training starts #3922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,8 +88,11 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> | |
| scale = jax.device_put(scale, max_utils.device_space()) | ||
|
|
||
| scale = jnp.asarray(scale, self.dtype) | ||
| effective_scale = scale + self.scale_offset | ||
| return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding) | ||
| effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale | ||
| y = y * effective_scale | ||
| if out_sharding is not None: | ||
| y = jax.lax.with_sharding_constraint(y, out_sharding) | ||
| return y | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we intentionally coded in the way of instead of Since Could you try training llama-like model using |
||
|
|
||
|
|
||
| class GlobalRMSNorm(RMSNorm): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,6 +118,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes): | |
| rules=self.config.logical_axis_rules, | ||
| debug_sharding=self.config.debug_sharding, | ||
| extra_stack_level=1, | ||
| skip_trivial_specs=True, | ||
| ) | ||
|
|
||
| def _maybe_shard_with_name(self, inputs, sharding_name): | ||
|
|
@@ -139,7 +140,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): | |
| # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) | ||
| state_io_batch_idx = loop_iteration % self.microbatches_per_stage | ||
| state_io_slice = state_io[:, state_io_batch_idx] | ||
| shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why removing sharding hints? |
||
|
|
||
| if self.use_circ_storage: | ||
| # Setup potential input from circ_storage, which also has a rotating index for microbatch, | ||
|
|
@@ -154,7 +154,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): | |
| # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. | ||
| # from circ_storage). | ||
| first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) | ||
| first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical) | ||
|
|
||
| # Note that first_stage_in may correspond to bubble computation during the last few iterations. | ||
| # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are | ||
|
|
@@ -164,11 +163,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): | |
|
|
||
| def select_state_or_input(first_stage_in, shift): | ||
| # Selects input for stage 0, shift for other stages | ||
| return jnp.where( | ||
| jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0, | ||
| first_stage_in, | ||
| shift, | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why removing out_sharding? |
||
| return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) | ||
|
|
||
| # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) | ||
| stages_in = select_state_or_input(first_stage_in, shift) | ||
|
|
@@ -180,7 +175,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): | |
| non-circular""" | ||
| # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages | ||
| microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) | ||
| microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage"))) | ||
| microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches | ||
| repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches | ||
| return microbatch_ids, repeat_ids | ||
|
|
@@ -190,7 +184,7 @@ def get_pipeline_remat_policy(self): | |
| if self.config.remat_policy == "custom": | ||
| return self.remat_policy | ||
|
|
||
| save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") | ||
| save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input") | ||
| if self.remat_policy is not None: | ||
| remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) | ||
| else: | ||
|
|
@@ -247,16 +241,6 @@ def get_main_vmap_func_for_iterations(self): | |
| def func_to_vmap( | ||
| body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode | ||
| ): | ||
| weights = meta.remove_axis( | ||
| weights, | ||
| 0, | ||
| { | ||
| nn.PARTITION_NAME: "layers", | ||
| "sub_weight_split_dims_mapping": (None,), | ||
| "is_initializing": self.is_initializing(), | ||
| "x_times": self.num_stages, | ||
| }, | ||
| ) | ||
| return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) | ||
|
|
||
| vmap_func = nn.vmap( | ||
|
|
@@ -490,9 +474,7 @@ def vmap_gather(self, xs, ids, ids_dim): | |
| """ | ||
|
|
||
| def _gather_one(x, i): | ||
| idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) | ||
| replicated_sharding = NamedSharding(self.mesh, P()) | ||
| return x.at[idx].get(out_sharding=replicated_sharding) | ||
| return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) | ||
|
|
||
| ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) | ||
| outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) | ||
|
|
@@ -516,21 +498,16 @@ def get_new_loop_state(self, output, loop_state): | |
| loop_iteration = loop_state["loop_iteration"] | ||
| old_prev_outputs = loop_state["prev_outputs"] | ||
|
|
||
| @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) | ||
| def _rotate_right(arr): | ||
| # we use +1 for right shifting | ||
| stage_size = jax.lax.axis_size("stage") | ||
| perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] | ||
| arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) | ||
| return arr | ||
| # Use lax.slice to avoid generating a gather. | ||
| last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) | ||
| except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) | ||
| return jnp.concatenate([last, except_last], axis=0) | ||
|
|
||
| @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) | ||
| def _shift_right(arr): | ||
| stage_idx = jax.lax.axis_index("stage") | ||
| stage_size = jax.lax.axis_size("stage") | ||
| perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] | ||
| arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) | ||
| return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) | ||
| padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) | ||
| # Use lax.slice to guarantee the gradient is a pad. | ||
| return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) | ||
|
|
||
| # Shift either rotates or shifts depending on if the last stage immediately must send to first or not | ||
| # For non-circular pipelines, the last stage does not need to send to first | ||
|
|
@@ -574,29 +551,17 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): | |
| stream_buf_idx = loop_iteration % self.microbatches_per_stage | ||
| stream_slice = old_state_io[:, stream_buf_idx] | ||
|
|
||
| def _rotate_left(arr, stage_size): | ||
| # we use -1 for left shifting | ||
| perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] | ||
| return jax.lax.ppermute(arr, axis_name="stage", perm=perm) | ||
|
|
||
| def _shift_left(arr, stage_size, output): | ||
| stage_idx = jax.lax.axis_index("stage") | ||
| arr = _rotate_left(arr, stage_size) | ||
| return jnp.where(stage_idx == stage_size - 1, output, arr) | ||
|
|
||
| @jax.shard_map( | ||
| mesh=self.mesh, | ||
| in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), | ||
| out_specs=self.state_io_spec, | ||
| ) | ||
| def _update_state_io(state_in, stream_slice, output, stream_buf_idx): | ||
| def _update_state_io(state_in, stream_slice, output): | ||
| # Shift the current slice to the left, then fill the last stage with the final output. | ||
| stage_size = jax.lax.axis_size("stage") | ||
| stream_slice = _shift_left(stream_slice, stage_size, output) | ||
| padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) | ||
| stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) | ||
| stream_slice = jnp.where( | ||
| jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice | ||
| ) | ||
| stream_slice = jnp.expand_dims(stream_slice, 1) | ||
| return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) | ||
|
|
||
| new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) | ||
| new_state = _update_state_io(old_state_io, stream_slice, output) | ||
|
|
||
| new_loop_state = { | ||
| "state_io": new_state, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we make this change?