Skip to content

feat(zero2): add CPU offload support for Muon optimizer#7939

Closed
delock wants to merge 1 commit into
deepspeedai:masterfrom
delock:gma/muon_cpuoffload
Closed

feat(zero2): add CPU offload support for Muon optimizer#7939
delock wants to merge 1 commit into
deepspeedai:masterfrom
delock:gma/muon_cpuoffload

Conversation

@delock

@delock delock commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path:
Momentum is stored on CPU memory and Newton-Schultz algorithm happens on GPU.

This PR complete the piece in ZeRO2 and make CPU offload has same numerical behavior with non-CPU offload of ZeRO2.

Note this PR also contains code from PR #7953, this PR is intend to be merged after #7953.

@delock delock marked this pull request as draft March 31, 2026 07:02

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 54364fbe9a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread deepspeed/runtime/zero/stage_1_and_2.py Outdated
pad_tensor = torch.zeros(padded_size - self.bit16_groups_flat[i].numel(),
dtype=self.bit16_groups_flat[i].dtype,
device=self.bit16_groups_flat[i].device)
self.bit16_groups_flat[i] = torch.cat([self.bit16_groups_flat[i], pad_tensor])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Insert per-partition padding before Muon equal split

Appending a single padding block at the tail does not guarantee parameter-boundary partitioning: when an earlier partition is smaller than max_partition_size (e.g., sizes [4,5,1] for dp=3), get_data_parallel_partitions() still cuts at fixed max_partition_size offsets and splits a parameter across ranks. That breaks the new CPU-offload Muon path, which assumes unsplit parameters and writes a full update.view(-1) into a partition slice computed from grad_position, leading to shape mismatch or incorrect updates when source_offset != 0.

Useful? React with 👍 / 👎.

Comment thread deepspeed/runtime/zero/stage_1_and_2.py Outdated
if self._is_muon_param_group(i):
dp_size = dist.get_world_size(group=self.real_dp_process_group[i])
max_ps = self._get_muon_max_partition_size(self.round_robin_bit16_groups[i], dp_size, orig_group_numel)
padded_size = max_ps * dp_size

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep Muon partition size aligned for NCCL boundaries

max_partition_size is used directly to set padded_size, but it is not rounded to the existing NCCL start-alignment factor. If max_partition_size is odd with fp16/bf16 tensors, partition starts after rank 0 become 2-byte shifted and fail the existing 4-byte alignment assertion in the same initialization flow. This makes valid Muon configurations crash depending on parameter shapes.

Useful? React with 👍 / 👎.

@delock delock force-pushed the gma/muon_cpuoffload branch 2 times, most recently from d802f0e to c058864 Compare March 31, 2026 10:07
@delock delock force-pushed the gma/muon_cpuoffload branch 2 times, most recently from 0a4d98e to 456b565 Compare April 9, 2026 06:44
@delock delock changed the title [DRAFT] feat(zero2): add CPU offload support for Muon optimizer feat(zero2): add CPU offload support for Muon optimizer Apr 9, 2026
@delock delock marked this pull request as ready for review April 9, 2026 07:38

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 456b5650b1

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Q = Z.clone()
Q.diagonal().add_(a)
else:
Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Support batched tensors in Gram Newton-Schulz update

The new default ns_method="gram" regresses Muon's stated batched input support (grad.ndim >= 2): this path uses torch.addmm, which only accepts 2D inputs, so a Muon parameter with shape like (B, N, M) will now fail at runtime in _zeropower_via_gram_newtonschulz. Previously, the standard Newton-Schulz implementation used batched matmuls and handled these shapes, so this commit introduces a crash for valid prior inputs unless users manually switch to ns_method="standard".

Useful? React with 👍 / 👎.

Comment thread deepspeed/runtime/zero/stage_1_and_2.py
raise ValueError(f"All Muon parameter groups must have the same momentum (beta). "
f"Found {self.muon_beta} and {group_beta}.")
self.muon_beta = group_beta
self.muon_ns_method = param_group.get('ns_method', 'gram')

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve ns_method per Muon param group in ZeRO-3

ZeRO-3 stores ns_method in a single self.muon_ns_method while iterating all Muon param groups, so later groups overwrite earlier values. _muon_update_grads_in_place then applies that one method to every Muon subgroup, which silently ignores per-group configuration when users provide multiple Muon groups (a pattern already handled for momentum via explicit consistency checks).

Useful? React with 👍 / 👎.

@delock delock marked this pull request as draft April 17, 2026 06:45
@delock delock force-pushed the gma/muon_cpuoffload branch from 456b565 to a09e346 Compare May 22, 2026 07:33
Enable Muon optimizer with ZeRO Stage 2 CPU offload. The Newton-Schulz
orthogonalization always runs on GPU for performance (momentum is
temporarily moved to GPU), while the momentum buffer stays on CPU to
save GPU memory.

The _apply_muon_update_for_cpu_offload method intercepts the gradient
copy path in copy_grads_in_partition to apply muon_update before
writing to the CPU FP32 grad buffer. Cross-boundary parameters are
handled by processing the full gradient on each involved rank.

Includes cosimulation test verifying offload vs non-offload produce
consistent results.

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
@delock delock force-pushed the gma/muon_cpuoffload branch from a09e346 to 35a7f4b Compare May 22, 2026 08:08
@delock

delock commented May 22, 2026

Copy link
Copy Markdown
Collaborator Author

Close this PR since CPU offload in Z2 slows down performance, makes me think whether Z3 is better choice at this point. Revisit this if we see needs to combine Muon CPU offload with Z2.

@delock delock closed this May 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant