-
Notifications
You must be signed in to change notification settings - Fork 233
[draft] bug for MoE distributed parallelism #752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: realAsma <akuriparambi@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #752 +/- ##
==========================================
- Coverage 74.65% 74.63% -0.03%
==========================================
Files 192 192
Lines 18969 18984 +15
==========================================
+ Hits 14162 14169 +7
- Misses 4807 4815 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| sync_quantizer_amax_across_dp_ep( | ||
| child, module.parallel_state, get_module_device(module) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please test if all MoE quantizers have amax after this line (locally)?
if `experts` in name and "weight_quantizer` in name:
assert child.amax is not None
| if synced_amax is not None: | ||
| # Move to target device | ||
| if target_device is not None: | ||
| synced_amax = synced_amax.to(target_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to add
synced_amax = synced_amax.clone().detach()
otherwise the sharding metadata of global_offset=(0, 0) on all ranks will be kept during save checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I am hoping you could take over the PR and address this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added below
Signed-off-by: jenchen13 <jennifchen@nvidia.com>
| # Iterative max handles both scalar and tensor amax values | ||
| result = valid_amaxs[0] | ||
| for amax in valid_amaxs[1:]: | ||
| result = torch.maximum(result, amax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if this line is comparing a scalar vs a tensor? how does it determine the max?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see https://docs.pytorch.org/docs/stable/generated/torch.maximum.html
it simply performs element wise maximum -> the shape does not matter as long as both are pytorch tensors (including scalar tensors)
| "supported by the current distributed backend. This warning can be ignored" | ||
| "if happening during modelopt restore." | ||
| ) | ||
| def sync_amax_across_distributed_group( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the current sync_amax_across_distributed_group moves the amax to cpu -> this is to accommodate the case were some amaxs are None and some are tensors. However this happens typically only for MoEs.
so can we do the old method of sync for non MoEs:
dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=parallel_group.group)
and the sync as object via CPU only for MoEs?
What does this PR do?
Type of change: ?
Overview: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Additional Information