diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 20f17e59db29..a8a5bde37d0f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1506,6 +1506,81 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): return torch.tensor(total_norm, device=self.device, dtype=torch.float) ############################################################################################ + def _apply_muon_update_for_cpu_offload(self, param): + """Apply muon_update for a parameter in the CPU offload path. + + For Muon parameters (use_muon=True), runs Newton-Schulz + orthogonalization on GPU (momentum is temporarily copied from + CPU to GPU) and writes only the partition slice back to the + CPU FP32 grad buffer. Cross-boundary parameters are + redundantly processed by each involved rank with the full + gradient, matching the non-offload path behavior in + get_flat_partition. + + Returns True if muon_update was applied (caller should skip + the normal copy for this param). + """ + if not getattr(param, 'use_muon', False): + return False + if 'muon' not in self.optimizer.__class__.__name__.lower(): + return False + + param_id = self.get_param_id(param) + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + grad_accum = self.get_param_gradient_attribute(param) + if grad_accum is None: + return False + + flatten_copy = self.optimizer.param_groups[i]['params'][0] + if "momentum_buffer" not in self.optimizer.state[flatten_copy]: + total_size = sum(p.numel() for p in self.params_in_partition[i]) + self.optimizer.state[flatten_copy]["momentum_buffer"] = torch.zeros(total_size, + dtype=torch.float32, + device=self.device) + + momentum_flat = self.optimizer.state[flatten_copy]["momentum_buffer"] + + muon_offset = 0 + for p in self.params_in_partition[i]: + if p is param: + break + muon_offset += p.numel() + + momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size()) + + beta = self.optimizer.param_groups[i].get('momentum', 0.95) + ns_method = self.optimizer.param_groups[i].get('ns_method', 'gram') + + # Run NS on GPU: keep grad on GPU, temporarily move momentum to GPU + gpu_device = grad_accum.device + grad_gpu = grad_accum.detach().clone().to(dtype=torch.float32) + momentum_gpu = momentum_cpu.to(device=gpu_device, dtype=torch.float32) + update = muon_update(grad_gpu.view(param.size()), momentum_gpu, beta=beta, ns_method=ns_method) + if self.check_grad_overflow and (update.isinf().any() or update.isnan().any()): + self.local_overflow = True + momentum_cpu.copy_(momentum_gpu.to(device='cpu')) + update_cpu = update.to(device='cpu') + del grad_gpu, momentum_gpu + + momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1) + + # Write only the partition slice of the update to CPU FP32 grad buffer + tensor_offset = 0 + actual_num_elements = param.numel() + if source_offset > 0: + tensor_offset = source_offset + actual_num_elements = param.numel() - tensor_offset + if actual_num_elements > num_elements: + actual_num_elements = num_elements + + dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, actual_num_elements) + update_slice = update_cpu.view(-1).narrow(0, tensor_offset, actual_num_elements) + dest_tensor.copy_(update_slice.to(self.master_weights_and_grads_dtype)) + + self.clear_grad_attribute(param) + return True + def copy_grads_in_partition(self, param): if self.cpu_offload: # Accumulate when there were prior backwards in this step (restore from @@ -1520,7 +1595,8 @@ def copy_grads_in_partition(self, param): self.update_offload_overflow_tracker_for_param_grad(param) - self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) + if not self._apply_muon_update_for_cpu_offload(param): + self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) return #print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}") diff --git a/tests/unit/ops/muon/test_muon_cpu_offload.py b/tests/unit/ops/muon/test_muon_cpu_offload.py new file mode 100644 index 000000000000..083de623d2f9 --- /dev/null +++ b/tests/unit/ops/muon/test_muon_cpu_offload.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +import pytest + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel +from deepspeed.accelerator import get_accelerator + +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported", allow_module_level=True) + + +@pytest.mark.parametrize('zero_stage', [2]) +class TestMuonCPUOffload(DistributedTest): + + def test_momentum_buffer_on_cpu(self, zero_stage): + """Verify Muon CPU offload creates momentum buffer on CPU. + + This is the key invariant: after a training step with CPU offload, + the Muon momentum buffer must reside on CPU (not GPU), confirming + that muon_update ran on CPU and no GPU memory is wasted. + """ + hidden_dim = 32 + batch_size = 8 + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + }, + }, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=5) + engine, optimizer, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() + + # Muon momentum buffer must exist and be on CPU. + # If muon_update was silently skipped, momentum_buffer would not be created. + flatten_copy = optimizer.optimizer.param_groups[0]['params'][0] + state = optimizer.optimizer.state[flatten_copy] + assert 'momentum_buffer' in state, ("momentum_buffer not found in optimizer state. " + "muon_update was not called in the CPU offload path.") + assert state['momentum_buffer'].device.type == 'cpu', ( + f"Momentum buffer is on {state['momentum_buffer'].device}, expected CPU") + + +@pytest.mark.parametrize('zero_stage', [2]) +class TestMuonCPUOffloadCosim(DistributedTest): + + def test_cosim_offload_vs_no_offload(self, zero_stage): + """Verify CPU offload produces results consistent with GPU path. + + With the same random seed, offload and non-offload should produce + close parameters. If muon_update is skipped or wrong in either path, + the results diverge significantly. + """ + hidden_dim = 32 + batch_size = 8 + + def train(offload): + torch.manual_seed(42) + config_dict = { + "train_batch_size": batch_size, + "optimizer": { + "type": "muon", + "params": { + "lr": 0.01 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "reduce_scatter": False, + }, + } + if offload: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + model = SimpleModel(hidden_dim=hidden_dim, nlayers=5) + engine, _, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + ) + + for _ in range(3): + x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half) + y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device) + loss = engine(x, y) + engine.backward(loss) + engine.step() + + return {n: p.clone().detach().float().cpu() for n, p in model.named_parameters()} + + params_offload = train(offload=True) + params_no_offload = train(offload=False) + + for name in params_offload: + p_off = params_offload[name] + p_no = params_no_offload[name] + # Both paths should produce the same NaN pattern + nan_mask = p_off.isnan() | p_no.isnan() + assert nan_mask.equal(p_off.isnan()), (f"{name}: NaN pattern differs between offload and non-offload. " + "muon_update produced different results.") + # On non-NaN elements, cosine similarity should be very high + valid = ~nan_mask + if valid.sum() > 0: + cos_sim = torch.nn.functional.cosine_similarity(p_off[valid].unsqueeze(0), + p_no[valid].unsqueeze(0)).item() + assert cos_sim > 0.99, (f"{name}: cosine similarity {cos_sim:.4f} between offload and " + f"non-offload is too low, indicating muon_update results diverge.")