feat(zero2): add CPU offload support for Muon optimizer#7939
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 54364fbe9a
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| pad_tensor = torch.zeros(padded_size - self.bit16_groups_flat[i].numel(), | ||
| dtype=self.bit16_groups_flat[i].dtype, | ||
| device=self.bit16_groups_flat[i].device) | ||
| self.bit16_groups_flat[i] = torch.cat([self.bit16_groups_flat[i], pad_tensor]) |
There was a problem hiding this comment.
Insert per-partition padding before Muon equal split
Appending a single padding block at the tail does not guarantee parameter-boundary partitioning: when an earlier partition is smaller than max_partition_size (e.g., sizes [4,5,1] for dp=3), get_data_parallel_partitions() still cuts at fixed max_partition_size offsets and splits a parameter across ranks. That breaks the new CPU-offload Muon path, which assumes unsplit parameters and writes a full update.view(-1) into a partition slice computed from grad_position, leading to shape mismatch or incorrect updates when source_offset != 0.
Useful? React with 👍 / 👎.
| if self._is_muon_param_group(i): | ||
| dp_size = dist.get_world_size(group=self.real_dp_process_group[i]) | ||
| max_ps = self._get_muon_max_partition_size(self.round_robin_bit16_groups[i], dp_size, orig_group_numel) | ||
| padded_size = max_ps * dp_size |
There was a problem hiding this comment.
Keep Muon partition size aligned for NCCL boundaries
max_partition_size is used directly to set padded_size, but it is not rounded to the existing NCCL start-alignment factor. If max_partition_size is odd with fp16/bf16 tensors, partition starts after rank 0 become 2-byte shifted and fail the existing 4-byte alignment assertion in the same initialization flow. This makes valid Muon configurations crash depending on parameter shapes.
Useful? React with 👍 / 👎.
d802f0e to
c058864
Compare
0a4d98e to
456b565
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 456b5650b1
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| Q = Z.clone() | ||
| Q.diagonal().add_(a) | ||
| else: | ||
| Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0) |
There was a problem hiding this comment.
Support batched tensors in Gram Newton-Schulz update
The new default ns_method="gram" regresses Muon's stated batched input support (grad.ndim >= 2): this path uses torch.addmm, which only accepts 2D inputs, so a Muon parameter with shape like (B, N, M) will now fail at runtime in _zeropower_via_gram_newtonschulz. Previously, the standard Newton-Schulz implementation used batched matmuls and handled these shapes, so this commit introduces a crash for valid prior inputs unless users manually switch to ns_method="standard".
Useful? React with 👍 / 👎.
| raise ValueError(f"All Muon parameter groups must have the same momentum (beta). " | ||
| f"Found {self.muon_beta} and {group_beta}.") | ||
| self.muon_beta = group_beta | ||
| self.muon_ns_method = param_group.get('ns_method', 'gram') |
There was a problem hiding this comment.
Preserve ns_method per Muon param group in ZeRO-3
ZeRO-3 stores ns_method in a single self.muon_ns_method while iterating all Muon param groups, so later groups overwrite earlier values. _muon_update_grads_in_place then applies that one method to every Muon subgroup, which silently ignores per-group configuration when users provide multiple Muon groups (a pattern already handled for momentum via explicit consistency checks).
Useful? React with 👍 / 👎.
456b565 to
a09e346
Compare
Enable Muon optimizer with ZeRO Stage 2 CPU offload. The Newton-Schulz orthogonalization always runs on GPU for performance (momentum is temporarily moved to GPU), while the momentum buffer stays on CPU to save GPU memory. The _apply_muon_update_for_cpu_offload method intercepts the gradient copy path in copy_grads_in_partition to apply muon_update before writing to the CPU FP32 grad buffer. Cross-boundary parameters are handled by processing the full gradient on each involved rank. Includes cosimulation test verifying offload vs non-offload produce consistent results. Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
a09e346 to
35a7f4b
Compare
|
Close this PR since CPU offload in Z2 slows down performance, makes me think whether Z3 is better choice at this point. Revisit this if we see needs to combine Muon CPU offload with Z2. |
Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path:
Momentum is stored on CPU memory and Newton-Schultz algorithm happens on GPU.
This PR complete the piece in ZeRO2 and make CPU offload has same numerical behavior with non-CPU offload of ZeRO2.
Note this PR also contains code from PR #7953, this PR is intend to be merged after #7953.