refactor(pt): full refactor of HybridMuon optimizer#5275
refactor(pt): full refactor of HybridMuon optimizer#5275OutisLi wants to merge 2 commits intodeepmodeling:masterfrom
Conversation
📝 WalkthroughWalkthroughReworks HybridMuonOptimizer routing to a name-aware, mode-driven system (muon_mode: "2d"/"flat"/"slice"), adds Magma-lite damping for Muon updates, introduces batched Newton–Schulz orthogonalization for slice-mode, adjusts optimizer construction and registration, and expands tests for routing and Magma behavior. Changes
Sequence DiagramsequenceDiagram
actor User
participant Optimizer as HybridMuonOptimizer
participant Router as Routing Logic
participant Shape as Shape Analysis
participant Magma as Magma Scaler
participant Step as Optimizer Step
User->>Optimizer: step(closure)
Optimizer->>Router: evaluate param name -> route (Adam/AdamW/Muon)
Router-->>Optimizer: route decision
alt Muon route
Optimizer->>Shape: compute effective shape & matrix view (muon_mode)
Shape-->>Optimizer: matrix/view dims
alt magma_muon enabled
Optimizer->>Magma: compute damping scales (per-bucket/per-param EMA)
Magma-->>Optimizer: damping scales
end
Optimizer->>Step: apply Newton–Schulz orth (batched if slice)
Step-->>Optimizer: muon-updated params/state
else Adam/AdamW route
Optimizer->>Step: apply Adam/AdamW update
Step-->>Optimizer: adam-updated params/state
end
Optimizer-->>User: return loss (from closure)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)
3758-3768: Validatemuon_modevalues at argcheck time.
muon_modeis free-formstrhere, so typos pass schema normalization and fail later during optimizer construction. Consider constraining accepted values to{"2d", "flat", "slice"}in this layer for earlier, clearer errors.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 3758 - 3768, The muon_mode argument in the argcheck schema (the "muon_mode" param definition) is currently an unconstrained str which lets typos slip through; change the schema to restrict allowed values to the set {"2d", "flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so validation fails early with a clear message referencing muon_mode when an invalid value is provided.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 343-345: Define specific exception classes (e.g.,
InvalidTensorShapeError(ValueError) and InvalidMuonModeError(ValueError)) near
the top of the module with the full explanatory messages as their default
docstring/message, then replace the three inline multi-line ValueError raises
with simple raises of those classes: replace the shape check in
batched_newton_schulz (the current raise ValueError(... "Batched Newton-Schulz
expects a 3D tensor...")) with raise InvalidTensorShapeError, and replace both
muon_mode validation raises (the f-string multi-line and the single-line check)
with raise InvalidMuonModeError; run ruff check/format before committing.
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The assertion comparing optimizer state uses exact float
equality which is fragile on CUDA; update the torch.allclose call for
model1.adam_scale vs model2.adam_scale to use a small nonzero tolerance (e.g.
atol=1e-6 and/or rtol=1e-6) instead of atol=0.0, rtol=0.0 so the test checks
near-equality while remaining stable; locate the comparison around
model1.adam_scale in the test_hybrid_muon.py and replace the zero tolerances
with a tight positive tolerance.
---
Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3758-3768: The muon_mode argument in the argcheck schema (the
"muon_mode" param definition) is currently an unconstrained str which lets typos
slip through; change the schema to restrict allowed values to the set {"2d",
"flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so
validation fails early with a clear message referencing muon_mode when an
invalid value is provided.
ℹ️ Review info
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
| raise ValueError( | ||
| "Batched Newton-Schulz expects a 3D tensor with shape (B, m, n)." | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and get line count
wc -l deepmd/pt/optimizer/hybrid_muon.pyRepository: deepmodeling/deepmd-kit
Length of output: 106
🏁 Script executed:
# Extract the specific lines mentioned in the review comment
sed -n '343,345p; 477,479p; 604p' deepmd/pt/optimizer/hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 428
🏁 Script executed:
# Run ruff check to detect TRY003 violations in the file
ruff check deepmd/pt/optimizer/hybrid_muon.py --select TRY003Repository: deepmodeling/deepmd-kit
Length of output: 1997
🏁 Script executed:
# Get more context around the mentioned lines
echo "=== Lines 343-345 ===" && sed -n '341,347p' deepmd/pt/optimizer/hybrid_muon.py | cat -n
echo "=== Lines 477-479 ===" && sed -n '475,481p' deepmd/pt/optimizer/hybrid_muon.py | cat -n
echo "=== Lines 604 ===" && sed -n '602,606p' deepmd/pt/optimizer/hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 925
Resolve TRY003 linting warnings for ValueError messages.
Three instances violate Ruff's TRY003 rule (avoid specifying long messages outside the exception class):
- Lines 343-345: Multi-line error message for tensor shape validation
- Lines 477-479: Multi-line error message with f-string for muon_mode validation
- Line 604: Single-line error message for muon_mode validation
Consolidate messages into exception class definitions or use dedicated exception types to keep the file lint-clean. Per coding guidelines, run ruff check . and ruff format . before committing.
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 343-345: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 343 - 345, Define specific
exception classes (e.g., InvalidTensorShapeError(ValueError) and
InvalidMuonModeError(ValueError)) near the top of the module with the full
explanatory messages as their default docstring/message, then replace the three
inline multi-line ValueError raises with simple raises of those classes: replace
the shape check in batched_newton_schulz (the current raise ValueError(...
"Batched Newton-Schulz expects a 3D tensor...")) with raise
InvalidTensorShapeError, and replace both muon_mode validation raises (the
f-string multi-line and the single-line check) with raise InvalidMuonModeError;
run ruff check/format before committing.
| self.assertFalse(torch.allclose(model1.weight, model2.weight)) | ||
| self.assertTrue( | ||
| torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0) | ||
| ) |
There was a problem hiding this comment.
Avoid exact float equality for Adam-path invariance assertions.
Using atol=0.0, rtol=0.0 can be flaky on CUDA due tiny nondeterministic differences. A tight tolerance keeps intent while improving stability.
💡 Suggested test tweak
- self.assertTrue(
- torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0)
- )
+ self.assertTrue(
+ torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-6)
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@source/tests/pt/test_hybrid_muon.py` around lines 358 - 361, The assertion
comparing optimizer state uses exact float equality which is fragile on CUDA;
update the torch.allclose call for model1.adam_scale vs model2.adam_scale to use
a small nonzero tolerance (e.g. atol=1e-6 and/or rtol=1e-6) instead of atol=0.0,
rtol=0.0 so the test checks near-equality while remaining stable; locate the
comparison around model1.adam_scale in the test_hybrid_muon.py and replace the
zero tolerances with a tight positive tolerance.
There was a problem hiding this comment.
Pull request overview
This PR refactors the PyTorch HybridMuonOptimizer to use name-based routing, adds a new muon_mode routing scheme (including per-slice Muon for higher-rank tensors), and introduces optional “Magma-lite” damping applied only on the Muon update path. It also updates training/config plumbing and expands tests to cover the new routing and damping behavior.
Changes:
- Replace
muon_2d_only/min_2d_dimrouting withmuon_mode(2d/flat/slice) and parameter-name-based routing rules. - Add
magma_muonoption implementing per-block momentum/gradient alignment scoring and damping on Muon updates. - Update training arg schema + trainer optimizer construction; expand unit tests for slice-mode routing and Magma damping.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
deepmd/pt/optimizer/hybrid_muon.py |
Implements muon_mode routing, name-based Adam/AdamW routing, batched NS for slice mode, and Magma-lite damping. |
deepmd/pt/train/training.py |
Wires new optimizer args (muon_mode, magma_muon) and passes named parameters for name-based routing. |
deepmd/utils/argcheck.py |
Updates the training config schema/docs for HybridMuon to use muon_mode and adds magma_muon. |
source/tests/pt/test_hybrid_muon.py |
Removes outdated tests and adds new coverage for slice routing, 2d routing behavior, and Magma damping state/range. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
deepmd/pt/train/training.py
Outdated
| muon_2d_only=bool(self.opt_param["muon_2d_only"]), | ||
| min_2d_dim=int(self.opt_param["min_2d_dim"]), | ||
| muon_mode=str(self.opt_param["muon_mode"]), | ||
| named_parameters=tuple(self.wrapper.named_parameters()), |
There was a problem hiding this comment.
named_parameters=tuple(self.wrapper.named_parameters()) eagerly materializes all (name, param) pairs, which can be expensive in memory/time for large models. Since HybridMuonOptimizer only needs to iterate once to build an id->name map, pass self.wrapper.named_parameters() directly (or another lazy iterable) instead of converting to a tuple.
| named_parameters=tuple(self.wrapper.named_parameters()), | |
| named_parameters=self.wrapper.named_parameters(), |
| - Parameters are routed by effective shape (singleton dimensions removed). | ||
| - ``muon_mode="2d"``: | ||
| - effective rank 2 parameters use Muon. | ||
| - effective rank >2 parameters use Adam. |
There was a problem hiding this comment.
In the HybridMuonOptimizer docstring, muon_mode="2d" currently says effective-rank >2 parameters use plain Adam, but _build_param_routing() routes these to the decoupled-decay AdamW-style path (adam_decay). Please update the docstring to match the actual behavior (Adam + decoupled weight decay for non-matrix shapes in 2d mode).
| - effective rank >2 parameters use Adam. | |
| - effective rank >2 parameters use Adam with decoupled weight decay | |
| (AdamW-style) fallback. |
| .. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz. | ||
| https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin) | ||
| .. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates. | ||
| arXiv:2602.15322, 2025. |
There was a problem hiding this comment.
The reference arXiv:2602.15322, 2025 is internally inconsistent: arXiv IDs starting with 2602 correspond to Feb 2026. Please correct the year in the citation (or adjust the identifier) so the reference is accurate.
| arXiv:2602.15322, 2025. | |
| arXiv:2602.15322, 2026. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5275 +/- ##
==========================================
+ Coverage 82.32% 82.33% +0.01%
==========================================
Files 768 768
Lines 77097 77179 +82
Branches 3659 3659
==========================================
+ Hits 63468 63549 +81
+ Misses 12459 12458 -1
- Partials 1170 1172 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Implement block-wise momentum-gradient alignment with EMA smoothing and soft scaling [0.1, 1.0] on Muon updates (magma_muon option) - Fix AdamW weight decay to use adam_lr instead of base lr - Wire magma_muon through training config and argcheck - Clean up redundant optimizer tests
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
source/tests/pt/test_hybrid_muon.py (1)
358-361:⚠️ Potential issue | 🟡 MinorAvoid exact float equality for the Adam-path invariance assertion.
atol=0.0, rtol=0.0is still brittle on CUDA. A very small tolerance keeps the intent while avoiding flaky failures.💡 Suggested test tweak
self.assertTrue( - torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0) + torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-6) )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_hybrid_muon.py` around lines 358 - 361, The test currently asserts exact equality on model1.adam_scale vs model2.adam_scale using atol=0.0, rtol=0.0 which is brittle on CUDA; update the assertion in test_hybrid_muon.py to allow a tiny tolerance (e.g. atol=1e-6 or rtol=1e-6) when comparing torch.allclose(model1.adam_scale, model2.adam_scale) so the Adam-path invariance intent remains but avoids flaky failures on GPU.
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)
2950-2972: Update the HybridMuon option docs to match the new name-based routes.The help text here still describes
adam_beta1/adam_beta2as 1D-only andweight_decayas Muon-only, but the optimizer now applies those settings to explicitadam_/adamw_routes too. Right now the generated config docs under-document the new behavior for higher-rank parameters.💡 Suggested doc fix
Argument( "adam_beta1", float, optional=True, default=0.9, doc=doc_only_pt_supported - + "Adam beta1 coefficient for 1D parameters (biases, norms).", + + "Adam beta1 coefficient for Adam-routed parameters " + "(1D params and explicit `adam_` / `adamw_` routes).", ), Argument( "adam_beta2", float, optional=True, default=0.95, doc=doc_only_pt_supported - + "Adam beta2 coefficient for 1D parameters (biases, norms).", + + "Adam beta2 coefficient for Adam-routed parameters " + "(1D params and explicit `adam_` / `adamw_` routes).", ), Argument( "weight_decay", float, optional=True, default=0.001, doc=doc_only_pt_supported - + "Weight decay coefficient. Applied only to Muon-routed parameters", + + "Weight decay coefficient. Applied to Muon-routed parameters " + "and `adamw_`-routed parameters.", ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 2950 - 2972, The docs for the Argument entries adam_beta1, adam_beta2, and weight_decay are stale: update their doc strings so they no longer claim the settings apply only to 1D parameters or only to Muon-routed params; instead state that these values are applied to explicit name-based routes (e.g., parameters routed by prefixes like "adam_" and "adamw_") as well as the prior special cases. Modify the doc text concatenated with doc_only_pt_supported in the Argument(...) calls for "adam_beta1", "adam_beta2", and "weight_decay" to mention name-based routes (adam_/adamw_) and that the optimizer also applies these settings to higher-rank parameters when routed by name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 1027-1031: The paired loops using zip over
adam_no_decay_exp_avgs/adam_no_decay_grads_fp32 and
adam_no_decay_exp_avg_sqs/grad_sq should use strict=True to ensure lengths
remain aligned; update the two zip(...) calls in the function that computes
exponential moving averages (the blocks that call ea.lerp_(...) and
eas.lerp_(...)) to zip(..., strict=True) so Ruff B905 is satisfied and any
length drift raises immediately.
---
Duplicate comments:
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The test currently asserts exact equality on
model1.adam_scale vs model2.adam_scale using atol=0.0, rtol=0.0 which is brittle
on CUDA; update the assertion in test_hybrid_muon.py to allow a tiny tolerance
(e.g. atol=1e-6 or rtol=1e-6) when comparing torch.allclose(model1.adam_scale,
model2.adam_scale) so the Adam-path invariance intent remains but avoids flaky
failures on GPU.
---
Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2950-2972: The docs for the Argument entries adam_beta1,
adam_beta2, and weight_decay are stale: update their doc strings so they no
longer claim the settings apply only to 1D parameters or only to Muon-routed
params; instead state that these values are applied to explicit name-based
routes (e.g., parameters routed by prefixes like "adam_" and "adamw_") as well
as the prior special cases. Modify the doc text concatenated with
doc_only_pt_supported in the Argument(...) calls for "adam_beta1", "adam_beta2",
and "weight_decay" to mention name-based routes (adam_/adamw_) and that the
optimizer also applies these settings to higher-rank parameters when routed by
name.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5f87feaa-db3d-471d-b82f-b99922b5aab4
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
| for ea, g in zip(adam_no_decay_exp_avgs, adam_no_decay_grads_fp32): | ||
| ea.lerp_(g, 1 - adam_betas[0]) | ||
| grad_sq = [g * g for g in adam_grads_fp32] | ||
| for eas, gsq in zip(adam_exp_avg_sqs, grad_sq): | ||
| grad_sq = [g * g for g in adam_no_decay_grads_fp32] | ||
| for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq): | ||
| eas.lerp_(gsq, 1 - adam_betas[1]) |
There was a problem hiding this comment.
Add strict=True to these paired zip() loops.
These lists are meant to stay perfectly aligned; plain zip() can silently truncate if they ever drift, and Ruff already flags B905 on both blocks.
🔧 Suggested fix
- for ea, g in zip(adam_no_decay_exp_avgs, adam_no_decay_grads_fp32):
+ for ea, g in zip(
+ adam_no_decay_exp_avgs, adam_no_decay_grads_fp32, strict=True
+ ):
ea.lerp_(g, 1 - adam_betas[0])
grad_sq = [g * g for g in adam_no_decay_grads_fp32]
- for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq):
+ for eas, gsq in zip(
+ adam_no_decay_exp_avg_sqs, grad_sq, strict=True
+ ):
eas.lerp_(gsq, 1 - adam_betas[1])
...
- for ea, g in zip(adam_decay_exp_avgs, adam_decay_grads_fp32):
+ for ea, g in zip(
+ adam_decay_exp_avgs, adam_decay_grads_fp32, strict=True
+ ):
ea.lerp_(g, 1 - adam_betas[0])
grad_sq = [g * g for g in adam_decay_grads_fp32]
- for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq):
+ for eas, gsq in zip(
+ adam_decay_exp_avg_sqs, grad_sq, strict=True
+ ):
eas.lerp_(gsq, 1 - adam_betas[1])As per coding guidelines, **/*.py: Always run ruff check . and ruff format . before committing changes or CI will fail.
Also applies to: 1086-1090
🧰 Tools
🪛 Ruff (0.15.4)
[warning] 1027-1027: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
[warning] 1030-1030: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 1027 - 1031, The paired
loops using zip over adam_no_decay_exp_avgs/adam_no_decay_grads_fp32 and
adam_no_decay_exp_avg_sqs/grad_sq should use strict=True to ensure lengths
remain aligned; update the two zip(...) calls in the function that computes
exponential moving averages (the blocks that call ea.lerp_(...) and
eas.lerp_(...)) to zip(..., strict=True) so Ruff B905 is satisfied and any
length drift raises immediately.
Summary by CodeRabbit
New Features
Documentation
Tests