Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,81 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
return torch.tensor(total_norm, device=self.device, dtype=torch.float)

############################################################################################
def _apply_muon_update_for_cpu_offload(self, param):
"""Apply muon_update for a parameter in the CPU offload path.

For Muon parameters (use_muon=True), runs Newton-Schulz
orthogonalization on GPU (momentum is temporarily copied from
CPU to GPU) and writes only the partition slice back to the
CPU FP32 grad buffer. Cross-boundary parameters are
redundantly processed by each involved rank with the full
gradient, matching the non-offload path behavior in
get_flat_partition.

Returns True if muon_update was applied (caller should skip
the normal copy for this param).
"""
if not getattr(param, 'use_muon', False):
return False
if 'muon' not in self.optimizer.__class__.__name__.lower():
return False

param_id = self.get_param_id(param)
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]

grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is None:
return False

flatten_copy = self.optimizer.param_groups[i]['params'][0]
if "momentum_buffer" not in self.optimizer.state[flatten_copy]:
total_size = sum(p.numel() for p in self.params_in_partition[i])
self.optimizer.state[flatten_copy]["momentum_buffer"] = torch.zeros(total_size,
dtype=torch.float32,
device=self.device)

momentum_flat = self.optimizer.state[flatten_copy]["momentum_buffer"]

muon_offset = 0
for p in self.params_in_partition[i]:
if p is param:
break
muon_offset += p.numel()

momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size())

beta = self.optimizer.param_groups[i].get('momentum', 0.95)
ns_method = self.optimizer.param_groups[i].get('ns_method', 'gram')

# Run NS on GPU: keep grad on GPU, temporarily move momentum to GPU
gpu_device = grad_accum.device
grad_gpu = grad_accum.detach().clone().to(dtype=torch.float32)
momentum_gpu = momentum_cpu.to(device=gpu_device, dtype=torch.float32)
update = muon_update(grad_gpu.view(param.size()), momentum_gpu, beta=beta, ns_method=ns_method)
if self.check_grad_overflow and (update.isinf().any() or update.isnan().any()):
self.local_overflow = True
momentum_cpu.copy_(momentum_gpu.to(device='cpu'))
update_cpu = update.to(device='cpu')
del grad_gpu, momentum_gpu

momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1)

# Write only the partition slice of the update to CPU FP32 grad buffer
tensor_offset = 0
actual_num_elements = param.numel()
if source_offset > 0:
tensor_offset = source_offset
actual_num_elements = param.numel() - tensor_offset
if actual_num_elements > num_elements:
actual_num_elements = num_elements

dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, actual_num_elements)
update_slice = update_cpu.view(-1).narrow(0, tensor_offset, actual_num_elements)
dest_tensor.copy_(update_slice.to(self.master_weights_and_grads_dtype))

self.clear_grad_attribute(param)
return True

def copy_grads_in_partition(self, param):
if self.cpu_offload:
# Accumulate when there were prior backwards in this step (restore from
Expand All @@ -1520,7 +1595,8 @@ def copy_grads_in_partition(self, param):

self.update_offload_overflow_tracker_for_param_grad(param)

self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
if not self._apply_muon_update_for_cpu_offload(param):
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
Comment thread
delock marked this conversation as resolved.

return
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
Expand Down
144 changes: 144 additions & 0 deletions tests/unit/ops/muon/test_muon_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed
import torch
import pytest

from unit.common import DistributedTest
from unit.simple_model import SimpleModel
from deepspeed.accelerator import get_accelerator

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported", allow_module_level=True)


@pytest.mark.parametrize('zero_stage', [2])
class TestMuonCPUOffload(DistributedTest):

def test_momentum_buffer_on_cpu(self, zero_stage):
"""Verify Muon CPU offload creates momentum buffer on CPU.

This is the key invariant: after a training step with CPU offload,
the Muon momentum buffer must reside on CPU (not GPU), confirming
that muon_update ran on CPU and no GPU memory is wasted.
"""
hidden_dim = 32
batch_size = 8
config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": "muon",
"params": {
"lr": 0.01
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage,
"reduce_scatter": False,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
},
}

model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)

x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step()

# Muon momentum buffer must exist and be on CPU.
# If muon_update was silently skipped, momentum_buffer would not be created.
flatten_copy = optimizer.optimizer.param_groups[0]['params'][0]
state = optimizer.optimizer.state[flatten_copy]
assert 'momentum_buffer' in state, ("momentum_buffer not found in optimizer state. "
"muon_update was not called in the CPU offload path.")
assert state['momentum_buffer'].device.type == 'cpu', (
f"Momentum buffer is on {state['momentum_buffer'].device}, expected CPU")


@pytest.mark.parametrize('zero_stage', [2])
class TestMuonCPUOffloadCosim(DistributedTest):

def test_cosim_offload_vs_no_offload(self, zero_stage):
"""Verify CPU offload produces results consistent with GPU path.

With the same random seed, offload and non-offload should produce
close parameters. If muon_update is skipped or wrong in either path,
the results diverge significantly.
"""
hidden_dim = 32
batch_size = 8

def train(offload):
torch.manual_seed(42)
config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": "muon",
"params": {
"lr": 0.01
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage,
"reduce_scatter": False,
},
}
if offload:
config_dict["zero_optimization"]["offload_optimizer"] = {
"device": "cpu",
"pin_memory": True,
}

model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
engine, _, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)

for _ in range(3):
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step()

return {n: p.clone().detach().float().cpu() for n, p in model.named_parameters()}

params_offload = train(offload=True)
params_no_offload = train(offload=False)

for name in params_offload:
p_off = params_offload[name]
p_no = params_no_offload[name]
# Both paths should produce the same NaN pattern
nan_mask = p_off.isnan() | p_no.isnan()
assert nan_mask.equal(p_off.isnan()), (f"{name}: NaN pattern differs between offload and non-offload. "
"muon_update produced different results.")
# On non-NaN elements, cosine similarity should be very high
valid = ~nan_mask
if valid.sum() > 0:
cos_sim = torch.nn.functional.cosine_similarity(p_off[valid].unsqueeze(0),
p_no[valid].unsqueeze(0)).item()
assert cos_sim > 0.99, (f"{name}: cosine similarity {cos_sim:.4f} between offload and "
f"non-offload is too low, indicating muon_update results diverge.")
Loading