-
Notifications
You must be signed in to change notification settings - Fork 4.9k
feat(zero2): add CPU offload support for Muon optimizer #7939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import deepspeed | ||
| import torch | ||
| import pytest | ||
|
|
||
| from unit.common import DistributedTest | ||
| from unit.simple_model import SimpleModel | ||
| from deepspeed.accelerator import get_accelerator | ||
|
|
||
| if torch.half not in get_accelerator().supported_dtypes(): | ||
| pytest.skip(f"fp16 not supported", allow_module_level=True) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('zero_stage', [2]) | ||
| class TestMuonCPUOffload(DistributedTest): | ||
|
|
||
| def test_momentum_buffer_on_cpu(self, zero_stage): | ||
| """Verify Muon CPU offload creates momentum buffer on CPU. | ||
|
|
||
| This is the key invariant: after a training step with CPU offload, | ||
| the Muon momentum buffer must reside on CPU (not GPU), confirming | ||
| that muon_update ran on CPU and no GPU memory is wasted. | ||
| """ | ||
| hidden_dim = 32 | ||
| batch_size = 8 | ||
| config_dict = { | ||
| "train_batch_size": batch_size, | ||
| "optimizer": { | ||
| "type": "muon", | ||
| "params": { | ||
| "lr": 0.01 | ||
| } | ||
| }, | ||
| "fp16": { | ||
| "enabled": True | ||
| }, | ||
| "zero_optimization": { | ||
| "stage": zero_stage, | ||
| "reduce_scatter": False, | ||
| "offload_optimizer": { | ||
| "device": "cpu", | ||
| "pin_memory": True, | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| model = SimpleModel(hidden_dim=hidden_dim, nlayers=5) | ||
| engine, optimizer, _, _ = deepspeed.initialize( | ||
| config=config_dict, | ||
| model=model, | ||
| model_parameters=model.parameters(), | ||
| dist_init_required=False, | ||
| ) | ||
|
|
||
| x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) | ||
| y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) | ||
| loss = engine(x, y) | ||
| engine.backward(loss) | ||
| engine.step() | ||
|
|
||
| # Muon momentum buffer must exist and be on CPU. | ||
| # If muon_update was silently skipped, momentum_buffer would not be created. | ||
| flatten_copy = optimizer.optimizer.param_groups[0]['params'][0] | ||
| state = optimizer.optimizer.state[flatten_copy] | ||
| assert 'momentum_buffer' in state, ("momentum_buffer not found in optimizer state. " | ||
| "muon_update was not called in the CPU offload path.") | ||
| assert state['momentum_buffer'].device.type == 'cpu', ( | ||
| f"Momentum buffer is on {state['momentum_buffer'].device}, expected CPU") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('zero_stage', [2]) | ||
| class TestMuonCPUOffloadCosim(DistributedTest): | ||
|
|
||
| def test_cosim_offload_vs_no_offload(self, zero_stage): | ||
| """Verify CPU offload produces results consistent with GPU path. | ||
|
|
||
| With the same random seed, offload and non-offload should produce | ||
| close parameters. If muon_update is skipped or wrong in either path, | ||
| the results diverge significantly. | ||
| """ | ||
| hidden_dim = 32 | ||
| batch_size = 8 | ||
|
|
||
| def train(offload): | ||
| torch.manual_seed(42) | ||
| config_dict = { | ||
| "train_batch_size": batch_size, | ||
| "optimizer": { | ||
| "type": "muon", | ||
| "params": { | ||
| "lr": 0.01 | ||
| } | ||
| }, | ||
| "fp16": { | ||
| "enabled": True | ||
| }, | ||
| "zero_optimization": { | ||
| "stage": zero_stage, | ||
| "reduce_scatter": False, | ||
| }, | ||
| } | ||
| if offload: | ||
| config_dict["zero_optimization"]["offload_optimizer"] = { | ||
| "device": "cpu", | ||
| "pin_memory": True, | ||
| } | ||
|
|
||
| model = SimpleModel(hidden_dim=hidden_dim, nlayers=5) | ||
| engine, _, _, _ = deepspeed.initialize( | ||
| config=config_dict, | ||
| model=model, | ||
| model_parameters=model.parameters(), | ||
| dist_init_required=False, | ||
| ) | ||
|
|
||
| for _ in range(3): | ||
| x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) | ||
| y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) | ||
| loss = engine(x, y) | ||
| engine.backward(loss) | ||
| engine.step() | ||
|
|
||
| return {n: p.clone().detach().float().cpu() for n, p in model.named_parameters()} | ||
|
|
||
| params_offload = train(offload=True) | ||
| params_no_offload = train(offload=False) | ||
|
|
||
| for name in params_offload: | ||
| p_off = params_offload[name] | ||
| p_no = params_no_offload[name] | ||
| # Both paths should produce the same NaN pattern | ||
| nan_mask = p_off.isnan() | p_no.isnan() | ||
| assert nan_mask.equal(p_off.isnan()), (f"{name}: NaN pattern differs between offload and non-offload. " | ||
| "muon_update produced different results.") | ||
| # On non-NaN elements, cosine similarity should be very high | ||
| valid = ~nan_mask | ||
| if valid.sum() > 0: | ||
| cos_sim = torch.nn.functional.cosine_similarity(p_off[valid].unsqueeze(0), | ||
| p_no[valid].unsqueeze(0)).item() | ||
| assert cos_sim > 0.99, (f"{name}: cosine similarity {cos_sim:.4f} between offload and " | ||
| f"non-offload is too low, indicating muon_update results diverge.") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.