Skip to content

fix(superoffload) preserve multi-group updates with shared cpu buffers (#7905)#7906

Open
xylian86 wants to merge 15 commits intomasterfrom
xinyu/issue7905
Open

fix(superoffload) preserve multi-group updates with shared cpu buffers (#7905)#7906
xylian86 wants to merge 15 commits intomasterfrom
xinyu/issue7905

Conversation

@xylian86
Copy link
Copy Markdown
Collaborator

Fix issue #7905

  • Preserve optimizer param-group metadata across ZeRO-3 subgroup splitting so SuperOffload handles multiple optimizer groups correctly.
  • Switch the CPU worker path to shared CPU parameter and gradient buffers, removing the need to send updated parameters back through the result queue.
  • Make the GPU-to-CPU gradient copy asynchronous and submit CPU optimizer work only after the copy is ready.

The figures below compare per-iteration time and GPU memory usage against the non-offload. The second figure presents a correctness check of the updated version.
image

image

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f35fcc8790

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

if len(bucket.params) == 0:
self._cur_bucket_index = i
if getattr(param, "ds_grad_is_ready", True):
if getattr(param, "ds_grad_is_ready", True) and param.grad is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this approach works, I am curious why execution gets this far for params without grads. We have a mechanism for excluding frozen params in Z1/2 and Z3. Could it be that we are incorrectly adding backward hooks to frozen parameters?

Looking further, I see that for Z1/Z2 we exclude frozen parameters from backward hooks. It seems we don't have similar exclusion in Z3/SuperOffload. Can you please double check in this logic?

@xylian86 xylian86 requested a review from loadams as a code owner March 21, 2026 19:38
@xylian86 xylian86 marked this pull request as draft March 21, 2026 19:40
@xylian86 xylian86 marked this pull request as ready for review March 22, 2026 15:07
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7da1d33ca9

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +215 to +218
if self.clip_grad:
self._step_with_clipping(scaled_global_grad_norm, timer_names)
else:
self._step_without_clipping(scaled_global_grad_norm, timer_names)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate the synchronous path on actual clipping

DeepSpeed's default gradient_clipping is 1.0 (deepspeed/runtime/constants.py:251), so branching on self.clip_grad here sends the default SuperOffload configuration into _step_with_clipping() on every step, even when check_clip_grads(...) would be false. Because partition_grads() now only launches async CPUAdam work when not self.clip_grad, this change disables the new backward/CPU overlap for most users and turns the patch into a throughput regression unless they explicitly set clipping to 0.0.

Useful? React with 👍 / 👎.

Comment on lines +49 to +55
optimizer_configs.append({
"lr": pg["lr"],
"betas": pg["betas"],
"eps": pg["eps"],
"weight_decay": pg["weight_decay"],
"amsgrad": pg["amsgrad"],
})
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve bias_correction when copying optimizer groups

The worker-side DeepSpeedCPUAdam still reads group['bias_correction'] during step_subgroup() (deepspeed/ops/adam/cpu_adam.py:199-202), but the new per-group config copy only forwards lr, betas, eps, weight_decay, and amsgrad. If a SuperOffload run has multiple optimizer groups with different bias_correction settings, every mirrored worker group silently falls back to the constructor default instead of the original group's value, so those subgroups will use the wrong Adam update rule.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] SuperOffload Fails when having more than one optimizer group

2 participants