fix(zero): guard ds_grads_remaining#7904
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d997d509c8
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| @torch.compiler.disable | ||
| def _post_backward_module_hook(module, inputs): | ||
| module.ds_grads_remaining = 0 | ||
| _ensure_ds_grads_remaining(module) |
There was a problem hiding this comment.
Reset ds_grads_remaining at each forward-pre hook
In _post_backward_module_hook, replacing the per-forward reset with _ensure_ds_grads_remaining(module) keeps any stale ds_grads_remaining value from previous iterations instead of clearing it. This counter is decremented in PostBackwardFunctionModule.backward, and the nearby comment already notes some backward hooks may not trigger in edge cases; when that happens, a leftover positive count now carries into the next forward, so remaining can stop reaching 0 and post_sub_module_backward_function is skipped indefinitely, which can block parameter release and lead to persistent memory growth.
Useful? React with 👍 / 👎.
|
@ailuntz thanks for the PR. Given that large number of modifications this flag has had in the past, can you please explain the issue that triggered this PR? |
|
@ailuntz is this PR still needed? Thanks! |
Signed-off-by: ailuntz <ailuntz@ailuntzdeMac-mini.local>
d997d50 to
7380eb9
Compare
|
The PR is still needed on current The trigger for this fix is an edge case where a module reaches the backward hook path before This patch keeps the behavior the same for the normal path, but lazily initializes the counter to I rebased the branch onto the latest upstream |
|
@ailuntz thanks for explaining the possible need for this fix. However, it would be better to an actual repro of that scenario as you described earlier: I can't think of how the above could happen. The provided unit test does not show this either. An actual repro would be useful to exclude some other true root cause, or a better solution. For example, I am curious if #7929 which ensures that full initialization of ZeRO states before usage could be extended for this case. |
|
|
||
| def _run_after_backward_hook(*unused): | ||
| module.ds_grads_remaining = module.ds_grads_remaining - 1 | ||
| remaining = _ensure_ds_grads_remaining(module) - 1 |
There was a problem hiding this comment.
It seems that remaining could become negative here, right?
Signed-off-by: ailuntz <ailuntz@ailuntzdeMac-mini.local>
|
You're right that the previous version still did not provide a concrete standalone repro, and the broader read/write guards were more invasive than necessary. I narrowed the patch in the latest push:
That makes the state fully initialized before usage, which seems closer to the direction you pointed out around ensuring ZeRO state initialization before use, without masking other potential root causes by guarding every access. If you still want a concrete runtime repro before considering even this smaller hardening, I can keep digging there too. I mainly wanted to reduce the patch to the smallest defensible version first. |
|
Closing this for now because I do not have a stable reproduction or a confirmed root cause yet. I will reopen or send a fresh PR once I can provide a concrete repro. |
Summary
ds_grads_remainingto avoid unexpected errors in edge cases.Testing