[OMNIML-4998] Per-expert weight quantization on TEGroupedLinear#1671
[OMNIML-4998] Per-expert weight quantization on TEGroupedLinear#1671hychiang-git wants to merge 14 commits into
Conversation
Adds a megatron_install_path field on SlurmConfig and a symmetric
parameter on slurm_factory, mirroring the existing modelopt_install_path
plumbing. build_slurm_executor and build_docker_executor now append a
{megatron_src}:{megatron_dst} bind-mount onto container_mounts so the
packaged modules/Megatron-LM/megatron source shadows the container's
pre-installed Megatron-LM at /usr/local/lib/python3.12/dist-packages/megatron.
This unblocks running modelopt-from-source against containers whose
bundled Megatron-LM is too old for newer modelopt features — concretely,
test_moe_sharded_state_dict against nemo:25.11.nemotron_3_nano, where
make_tp_sharded_tensor_for_checkpoint silently drops allow_shape_mismatch
and the resulting per-rank shape divergence deadlocks dist-checkpoint
metadata exchange. Path 3 from the OMNIML-5029 fix-paths analysis.
Tests:
- test_slurm_config.py: assert default megatron_install_path
- test_slurm_executor.py: assert megatron mount appears in container_mounts
- test_docker_execution.py: new test_megatron_mount for the Docker path
Full launcher suite: 63 passed, 2 skipped, 0 failed.
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…sharded_state_dict
Megatron-LM's MoELayer.forward asserts sequence_parallel=True whenever the
attn_tp_group has size > 1 (see megatron/core/transformer/moe/moe_layer.py).
test_moe_sharded_state_dict's TEGroupedMLP parameterization uses tp_size=2,
so the assertion fires before the dist-checkpoint logic under test even runs.
The SequentialMLP branch sidesteps it by dropping to tp_size=1.
The hybrid-Mamba model path in _gpt_model_provider already passes
sequence_parallel=(tp_size > 1) for exactly this reason; the plain-GPT
branch did not. This commit propagates the same convention:
- Add a sequence_parallel: bool = False kwarg to get_mcore_gpt_model
and forward it into TransformerConfig (replacing the previously
hardcoded sequence_parallel=False).
- In _gpt_model_provider's get_mcore_gpt_model call, pass
sequence_parallel=(tp_size > 1).
After this change the GroupedMLP variants (FP8_DEFAULT_CFG and
NVFP4_DEFAULT_CFG with moe_grouped_gemm=True) reach the actual
sharded_state_dict logic instead of erroring out at MoE forward.
Surfaced by the OMNIML-5029 cluster smoke run (job 496032) once the
launcher bind-mount let the test get past the Megatron-LM import phase.
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
… (axis=0)
Lift the per-tensor-only restriction on _QuantTEGroupedLinear's weight
quantizer so it can compute one amax per expert. The new path is gated on
axis=0 in the standard QuantizerAttributeConfig — default (axis=None)
keeps the legacy single-amax behavior bit-for-bit.
* _QuantTEGroupedLinear._is_per_expert_weight_quant() — small helper
that returns True when the existing weight_quantizer carries axis=0.
* iter_weights_for_calibration: per-expert mode yields one stacked
[num_gemms, out, in] tensor instead of N separate slices so the
quantizer's axis-0 reduction produces _amax of shape [N, 1, 1].
* te_grouped_quantized_linear_fn: per-expert forward stacks the N
expert weights, fake-quants once, and indexes back; broadcasting
handles the per-expert scale via _amax[:, None, None].
* _QuantMegatronTEGroupedLinear._process_quantizer_amax: accept either
a scalar (per-tensor, legacy) or a num_gemms-element tensor
(per-expert). Anything else raises with a clear message.
Closes the long-standing "support a unique quantizer for each gemm"
TODO in _QuantTEGroupedLinear._setup at the per-expert granularity.
Per-channel-within-expert and per-block per-expert (NVFP4) are
deliberately out of scope and tracked separately.
sharded_state_dict / _get_shard_axis_dict wiring for dist-checkpoint
round-trip of the new per-expert amax is the next commit (AC-2). GPU
smoke tests under tests/gpu_megatron/torch/quantization/plugins/ land
with that commit.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
… amax
Builds on the prior commit (per-expert axis=0 forward + calibration) by
plumbing per-expert weight_quantizer._amax through sharded_state_dict and
load_distributed_checkpoint so each EP rank reconstructs the same local
amax bit-for-bit after a save/restore.
The design avoids encoding EP sharding in the checkpoint metadata (which
make_sharded_tensors_for_checkpoint only knows about TP). Instead each
TEGroupedLinear all-gathers its local [num_gemms] amax across the
expert_model_parallel_group before save, so every rank stores an
identical [num_gemms_global] tensor that the dist-checkpoint framework
can safely replicate across DP/TP. On restore each rank narrows back to
its local slice before the base class's view_as() reshape.
Three pieces on _QuantMegatronTEGroupedLinear:
* _ep_group() — returns the EP group iff it's initialized AND has >1
rank; gives a clean None path for EP=1 and ad-hoc unit tests where
parallel_state isn't wired up.
* _process_quantizer_amax: per-expert branch all-gathers via
torch.distributed.all_gather then concatenates; per-tensor branch
unchanged.
* _load_from_state_dict: per-expert branch narrows the global tensor to
[ep_rank * num_gemms : (ep_rank+1) * num_gemms] before passing to
super(). _extra_state filter preserved.
* _get_shard_axis_dict: strips the TP-axis-0 marker that the
column-parallel parent would otherwise add for any non-None
weight_quantizer.axis — correct for per-channel-along-output, but
wrong for our [num_gemms_global] amax.
Smoke test: test_te_grouped_per_expert_sharded_state_dict under
tests/gpu_megatron/torch/quantization/plugins/. Builds a TEGroupedMLP
GPT with num_moe_experts=4, configures a TE-grouped-only per-expert
config (axis=0 weight quantizer, everything else disabled), runs through
_test_sharded_state_dict on dist_workers_size_4 with tp_size=1 ep_size=2.
Verifies the round-trip helper's bitwise state_dict comparison passes
with the new amax shape. Re-enabling INT4_BLOCKWISE_WEIGHT_ONLY_CFG on
TEGroupedMLP and the per-channel-within-expert / NVFP4 block scales
paths are deliberately out of scope and remain skipped.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…uring The previous version did the EP all-gather inside _process_quantizer_amax, which is called per-key inside _MegatronParallelLinear.sharded_state_dict's state_dict iteration. Megatron's dist-checkpoint save runs its own default-PG ALLGATHER metadata exchanges across the world group from inside that same iteration -- the two interleave at NCCL and one rank ends up stuck in our EP gather while the others have already advanced to the default-PG ALLGATHER, producing a 10-minute watchdog timeout. Move the gather to a single call at the top of sharded_state_dict, before super().sharded_state_dict starts traversing state_dict. The gathered [num_gemms_global] tensor is cached on a temporary attribute that _process_quantizer_amax consumes; the live weight_quantizer._amax buffer stays at its [num_gemms_local, 1, 1] shape so forward keeps broadcasting correctly against the stacked [num_gemms_local, out, in] expert weights. The cache is cleared in a finally block so it doesn't leak after save. Net effect: per-rank flow during save is now 1. EP all-gather (one collective per TEGroupedLinear layer, on EP group) 2. super().sharded_state_dict iterates state_dict (no collectives) 3. Megatron's dist-checkpoint save (default-PG collectives) The three phases no longer interleave, so the NCCL deadlock can't form. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…re layout) Earlier attempts at TP=1 EP=2 hit a 10-minute NCCL watchdog timeout on a default-PG ALLGATHER (SeqNum=5, NumelIn=1 NumelOut=4) regardless of weight_quantizer.axis. Confirmed by an explicit axis=None diagnostic run that reproduced the same deadlock signature, so the hang is in the MCore + TEGroupedMLP + dist-checkpoint stack at TP=1 EP=2 -- not in OMNIML-4998's per-expert plumbing. The TP=1 EP=2 layout has no existing modelopt test that exercises dist-checkpoint with grouped MoE: `test_moe_sharded_state_dict` runs TP=2 EP=2 for grouped_gemm=True, and `test_te_grouped_vs_sequential_quantize` runs TP=1 EP=2 but no dist-checkpoint. So we're not regressing anything; we're choosing a layout the surrounding code actually supports. Switching the AC-2 smoke to TP=2 EP=2 keeps the EP plumbing exercised (still 2 local experts per EP rank -> the EP all-gather in _gather_global_per_expert_amax still actually moves data) while staying on a Megatron/MCore layout that is known to work for grouped MoE dist-checkpoint save+restore. Followup: open an MCore-side issue for the TP=1 EP=2 hang and revisit this test config once it lands. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…exactly The TP=2 EP=2 run was still hitting the same SeqNum=5 ALLGATHER hang as TP=1 EP=2. But test_moe_sharded_state_dict already passes at TP=2 EP=2 with FP8_DEFAULT_CFG / NVFP4_DEFAULT_CFG -- so the test infrastructure works at this layout. Three things differed between our test and the working one: the fixture (dist_workers_size_4 vs dist_workers), hidden_size (32 vs 256), and the quant_cfg. Mirror the working test's fixture and hidden_size so the only varying thing in this run is our axis=0 quant_cfg. If it now passes, our config is fine and the prior hangs were from the size_4 fixture or the tiny hidden_size 32 (which is degenerate -- head_dim=8 with 4 attn heads). If it still hangs, the per-expert config itself is at fault and we narrow further. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds per-expert weight quantization support for Transformer Engine GroupedLinear modules within Megatron, enabling proper distributed checkpoint handling across expert-parallel ranks. It adjusts TP amax synchronization axes, implements EP-aware gather/load logic with amax reshaping, and extends launcher infrastructure to mount Megatron-LM in execution environments. ChangesTE GroupedLinear Per-Expert Quantization and Launcher Infrastructure
Sequence DiagramsequenceDiagram
participant Calib as max_calibrate
participant Module as TEGroupedLinear
participant Megatron as MegatronPlugin
participant SaveExec as sharded_state_dict
participant EPGather as _gather_global_per_expert_amax
participant Checkpoint as Checkpoint
participant LoadExec as _load_from_state_dict
participant TE as TEPlugin
Calib->>Module: detect per-expert (weight0)
Calib->>Calib: _weight_axes_for_sync returns axes including 0
Calib->>Module: sync_quantizer_amax_across_tp with adjusted axes
SaveExec->>EPGather: all-gather local weight._amax
EPGather->>Checkpoint: save global [num_gemms_global] tensor
SaveExec->>Checkpoint: cache gathered amax on module
LoadExec->>Checkpoint: load saved global [num_gemms_global] amax
LoadExec->>LoadExec: narrow to local [num_gemms] slice per EP rank
LoadExec->>TE: delegate to base loader
TE->>Module: _reshape_loaded_amax_to_buffer_shape hook
TE->>Module: restore amax to live buffer shape
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1671 +/- ##
==========================================
- Coverage 77.48% 75.56% -1.93%
==========================================
Files 489 511 +22
Lines 54415 60054 +5639
==========================================
+ Hits 42165 45379 +3214
- Misses 12250 14675 +2425
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
The dist-checkpoint LOAD path was building model_test's placeholder _amax from self.weight0 -- the un-stacked 2D [out, in] view of one expert's weight. With axis=0 that yields _amax of shape [out, 1] (e.g. [1024, 1] for fc1 at hidden=256 / tp=2 / ep=2), and _gather_global_per_expert_amax then fails on `amax.view(num_gemms)` because amax.numel() != num_gemms. The SAVE path was fine because mtq.quantize routes through iter_weights_for_calibration, which stacks the N expert weights into a [num_gemms, out, in] tensor before calling weight_quantizer once; axis=0 reduction then sizes _amax as [num_gemms, 1, 1] correctly. Mirror that stacking pattern in modelopt_post_restore so the restore path lands on the same placeholder shape. The per-tensor path stays on the legacy weight0 placeholder (already correct). Surfaced by the OMNIML-4998 AC-2 GPU smoke after the OMNIML-5029 launcher fix unblocked the cluster harness. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…af quantizer
Roots cause: the dist-checkpoint tensor and the _pytorch_state_metadata for
per-expert weight_quantizer._amax record different shapes. The tensor is saved
flat as [num_gemms_global] (the legacy _process_quantizer_amax flatten),
while the metadata records the live buffer shape [num_gemms, 1, 1] (axis=0
calibration with keepdims). On restore, the modelopt-state path registers
_amax as [num_gemms, 1, 1] first; the subsequent load_state_dict then tries
to copy the flat tensor into that buffer and trips the strict size check.
Parent-level approaches (TEGroupedLinear._load_from_state_dict mutating the
state_dict in place, or _process_quantizer_amax emitting a multi-dim save
shape) introduced rank-asymmetric collective ordering — either because the
mutation conditioned on per-rank buffer-registration timing or because the
save-shape change perturbed the ShardedTensor metadata that
determine_global_metadata exchanges via all_gather_object across ranks.
Fix at the leaf instead. _QuantTEGroupedLinear._setup now registers a
load_state_dict pre-hook on weight_quantizer that reshapes the loaded
_amax to the live buffer shape when shapes differ but numel matches. The
hook is:
- Leaf-level (runs at the TensorQuantizer module PyTorch actually iterates).
- Rank-local (touches only local state_dict and buffer shape).
- Side-effect-free (no collectives, no conditional code paths beyond the
presence-of-buffer check, which is symmetric across ranks once
set_extra_state has run on all ranks at the end of the first
load_state_dict pass).
Pairs with the modelopt_post_restore stacking commit (8fd0d53) so the live
buffer shape on a fresh restore is [num_gemms, 1, 1] — which the hook then
matches the loaded flat tensor to.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
For column-parallel TEGroupedLinear in per-expert mode (axis=0), max_calibrate
was NOT syncing _amax across the TP group. The existing axes_for_sync list
[None, -1] was designed for the classic case where axis=0 means "per-output-
channel" (TP-sharded; syncing would corrupt per-channel amax). For TEGrouped
in per-expert mode, axis=0 means "per-expert" (NOT TP-sharded under etp=1),
so it SHOULD be synced.
Symptom this fixes (caught by AC-2 GPU smoke on aws-cmh):
- Each TP rank within an EP group ran calibration on bit-identical-ish
expert weights (BF16/FP16 sums aren't bit-identical across ranks) and
produced slightly divergent per-expert amax.
- dist-checkpoint treats _amax as replicated (per `_get_shard_axis_dict`
override) and saves only one rank's view (effectively tp_rank=0).
- On restore, every rank loaded tp_rank=0's amax. model_ref (still
holding its per-rank-different amax) then mismatched model_test on
every TP rank except tp_rank=0, causing
`assert torch.allclose(model_ref._amax, model_test._amax)` to raise on
tp_rank=1 ranks. Failed ranks entered teardown's barrier; passing
ranks entered the test's terminal barrier; both stuck at op 27 on
different call sites, watchdog killed the job at 10-min mark.
Fix: detect TEGroupedLinear in per-expert mode (module has `weight0` attr
and `weight_quantizer.axis in (0, (0,))`) and extend axes_for_sync for the
column-parallel weight quantizer to [None, -1, 0]. The TEGroupedLinear
detection is by-attribute (weight0) rather than by-class to avoid a circular
import. Row-parallel already syncs axis=0, no change there.
Verified by `services/model-optimizer/test/omniml_4998_per_expert_amax_smoke.yaml`
on aws-cmh against nvcr.io/nvidia/nemo:25.11.nemotron_3_nano with OMNIML-5029's
megatron_install_path bind-mount: test_te_grouped_per_expert_sharded_state_dict
PASSED on all 4 ranks (tp=2 ep=2 etp=1 num_experts=4).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
|
Related to PR #1550 — same scope (per-expert weight quantization on TEGroupedLinear), alternative design:
Happy to discuss merging approaches. Cross-linking so reviewers on either PR have visibility into the other. |
|
/claude review |
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 360-373: The function _weight_axes_for_sync currently checks
quantizer.axis in (0, (0,)) but only appends the integer 0 to base_axes, causing
tuple-form axes like (0,) to be missed by downstream exact-match checks; update
_weight_axes_for_sync to normalize the quantizer.axis (e.g., treat (0,) and 0
equivalently) or append both representations (0 and (0,)) when weight_quantizer
is a TensorQuantizer with axis 0, so that axes_for_sync contains the same form
as quantizer.axis and TP amax sync occurs correctly for per-expert (weight0)
modules.
In `@modelopt/torch/quantization/plugins/megatron.py`:
- Around line 787-796: The cached per-expert amax override in
_process_quantizer_amax is too broad because the check `k.endswith("_amax")`
also matches keys like `weight_quantizer._global_amax`; change the condition so
the cached override only applies to the exact per-weight amax key (e.g., match
"weight_quantizer._amax" or the full qualifier used for per-expert weight amax)
— update the if that tests k and "weight_quantizer" to require the precise
suffix or full key "weight_quantizer._amax" before assigning cached.view(...) to
quantizer_state_dict[k], leaving all other _amax keys to use v.view(v.numel()).
In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py`:
- Around line 699-712: The second rule in TE_GROUPED_PER_EXPERT_CFG mistakenly
places "enable": True inside the nested "cfg" so the expert weight quantizers
remain disabled; update the rule matching "*experts.linear_fc*.weight_quantizer"
in TE_GROUPED_PER_EXPERT_CFG to set "enable": True at the top level of that rule
(alongside "quantizer_name" and "cfg") rather than inside "cfg", ensuring the
per-expert axis=0 quantization path in _QuantTEGroupedLinear is actually
enabled.
In `@tools/launcher/core.py`:
- Around line 317-319: The new parameter megatron_src_path was inserted before
experiment_title in run_jobs, breaking existing positional calls that pass
modelopt_src_path, experiment_title; restore the positional contract by moving
megatron_src_path after experiment_title in the run_jobs signature (or
alternatively make megatron_src_path keyword-only by adding a * before it), and
keep default None for modelopt_src_path and megatron_src_path; update any
callers if you choose the keyword-only approach.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 0cb44838-5056-48d6-9cf4-dfa316d3d963
📒 Files selected for processing (10)
modelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/plugins/megatron.pymodelopt/torch/quantization/plugins/transformer_engine.pytests/_test_utils/torch/megatron/models.pytests/gpu_megatron/torch/quantization/plugins/test_megatron.pytools/launcher/core.pytools/launcher/slurm_config.pytools/launcher/tests/test_docker_execution.pytools/launcher/tests/test_slurm_config.pytools/launcher/tests/test_slurm_executor.py
| def _weight_axes_for_sync(module, base_axes): | ||
| # For column-parallel TEGroupedLinear in per-expert mode (axis=0), the | ||
| # expert weights are NOT TP-sharded (etp=1 case) — axis=0 indexes experts, | ||
| # not output channels. Per-rank reductions still produce slightly | ||
| # divergent amax across TP (BF16/FP16 sums are not bit-identical across | ||
| # ranks even with the same input), and dist-checkpoint save treats | ||
| # _amax as replicated and captures only one rank's view. Without this | ||
| # sync, model_ref's per-rank-different amax mismatches the loaded | ||
| # model_test on every TP rank except the one whose value was saved. | ||
| if hasattr(module, "weight0"): | ||
| quantizer = getattr(module, "weight_quantizer", None) | ||
| if isinstance(quantizer, TensorQuantizer) and quantizer.axis in (0, (0,)): | ||
| return list(base_axes) + [0] | ||
| return base_axes |
There was a problem hiding this comment.
Include (0,) in the TP sync allowlist.
_weight_axes_for_sync() detects both 0 and (0,), but it only appends 0. The downstream check is an exact quantizer.axis in axes_for_sync, so tuple-normalized per-expert axes still skip TP amax sync and can diverge across ranks.
Proposed fix
if hasattr(module, "weight0"):
quantizer = getattr(module, "weight_quantizer", None)
if isinstance(quantizer, TensorQuantizer) and quantizer.axis in (0, (0,)):
- return list(base_axes) + [0]
+ return [*base_axes, 0, (0,)]
return base_axes🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/quantization/model_calib.py` around lines 360 - 373, The
function _weight_axes_for_sync currently checks quantizer.axis in (0, (0,)) but
only appends the integer 0 to base_axes, causing tuple-form axes like (0,) to be
missed by downstream exact-match checks; update _weight_axes_for_sync to
normalize the quantizer.axis (e.g., treat (0,) and 0 equivalently) or append
both representations (0 and (0,)) when weight_quantizer is a TensorQuantizer
with axis 0, so that axes_for_sync contains the same form as quantizer.axis and
TP amax sync occurs correctly for per-expert (weight0) modules.
| def _process_quantizer_amax(self, k, v, quantizer_state_dict): | ||
| assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" | ||
| quantizer_state_dict[k] = v.view(-1) | ||
| # Per-expert weight amax: emit the gathered global tensor cached by | ||
| # sharded_state_dict's pre-pass. No collective fires here, so this | ||
| # call is safe to run inside Megatron's save traversal. | ||
| # Per-tensor (or non-weight) amax: just reshape the local value. | ||
| cached = getattr(self, "_cached_global_per_expert_amax", None) | ||
| if cached is not None and "weight_quantizer" in k and k.endswith("_amax"): | ||
| quantizer_state_dict[k] = cached.view(cached.numel()) | ||
| return | ||
| quantizer_state_dict[k] = v.view(v.numel()) |
There was a problem hiding this comment.
Scope the cached override to weight_quantizer._amax only.
k.endswith("_amax") also matches weight_quantizer._global_amax. In per-expert NVFP4/static configs that would write the gathered per-expert vector into the scalar _global_amax slot, breaking checkpoint round-tripping for that quantizer state.
Proposed fix
- if cached is not None and "weight_quantizer" in k and k.endswith("_amax"):
+ if cached is not None and k == "weight_quantizer._amax":
quantizer_state_dict[k] = cached.view(cached.numel())
return🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/quantization/plugins/megatron.py` around lines 787 - 796, The
cached per-expert amax override in _process_quantizer_amax is too broad because
the check `k.endswith("_amax")` also matches keys like
`weight_quantizer._global_amax`; change the condition so the cached override
only applies to the exact per-weight amax key (e.g., match
"weight_quantizer._amax" or the full qualifier used for per-expert weight amax)
— update the if that tests k and "weight_quantizer" to require the precise
suffix or full key "weight_quantizer._amax" before assigning cached.view(...) to
quantizer_state_dict[k], leaving all other _amax keys to use v.view(v.numel()).
| TE_GROUPED_PER_EXPERT_CFG = { | ||
| "algorithm": "max", | ||
| "quant_cfg": [ | ||
| # Disable everything (attention, embeddings, non-MoE MLP) so the test | ||
| # isolates the per-expert path on TEGroupedMLP experts. | ||
| {"quantizer_name": "*", "enable": False}, | ||
| # Re-enable axis=0 on TEGrouped MoE experts only -- this triggers the | ||
| # per-expert path inside _QuantTEGroupedLinear. | ||
| { | ||
| "quantizer_name": "*experts.linear_fc*.weight_quantizer", | ||
| "cfg": {"num_bits": 8, "axis": 0, "enable": True}, | ||
| }, | ||
| ], | ||
| } |
There was a problem hiding this comment.
Re-enable the expert quantizers at the rule level.
This config disables everything first with {"quantizer_name": "*", "enable": False}, but the follow-up rule puts enable inside cfg. The rest of this file uses top-level enable for toggling, so this can leave the target expert weight quantizers disabled and make the new regression test pass without ever hitting the axis-0 path.
Proposed fix
{
"quantizer_name": "*experts.linear_fc*.weight_quantizer",
- "cfg": {"num_bits": 8, "axis": 0, "enable": True},
+ "enable": True,
+ "cfg": {"num_bits": 8, "axis": 0},
},🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/gpu_megatron/torch/quantization/plugins/test_megatron.py` around lines
699 - 712, The second rule in TE_GROUPED_PER_EXPERT_CFG mistakenly places
"enable": True inside the nested "cfg" so the expert weight quantizers remain
disabled; update the rule matching "*experts.linear_fc*.weight_quantizer" in
TE_GROUPED_PER_EXPERT_CFG to set "enable": True at the top level of that rule
(alongside "quantizer_name" and "cfg") rather than inside "cfg", ensuring the
per-expert axis=0 quantization path in _QuantTEGroupedLinear is actually
enabled.
| modelopt_src_path=None, | ||
| megatron_src_path=None, | ||
| experiment_title="cicd", |
There was a problem hiding this comment.
Signature insertion breaks existing positional call contract in run_jobs.
Adding megatron_src_path before experiment_title changes positional binding. The run_jobs call still passes ..., modelopt_src_path, experiment_title, so experiment_title is now consumed as megatron_src_path, producing an invalid Megatron mount source and ignoring the intended title.
💡 Proposed fix
def build_docker_executor(
hf_local,
slurm_config,
experiment_id,
job_dir,
task_name,
packager,
modelopt_src_path=None,
- megatron_src_path=None,
experiment_title="cicd",
+ megatron_src_path=None,
): if hf_local is not None:
executor = build_docker_executor(
hf_local,
task.slurm_config,
exp._id,
job_dir,
task_name,
packager,
- modelopt_src_path,
- experiment_title,
+ modelopt_src_path=modelopt_src_path,
+ experiment_title=experiment_title,
)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tools/launcher/core.py` around lines 317 - 319, The new parameter
megatron_src_path was inserted before experiment_title in run_jobs, breaking
existing positional calls that pass modelopt_src_path, experiment_title; restore
the positional contract by moving megatron_src_path after experiment_title in
the run_jobs signature (or alternatively make megatron_src_path keyword-only by
adding a * before it), and keep default None for modelopt_src_path and
megatron_src_path; update any callers if you choose the keyword-only approach.
… quantized weights in TEGroupedLinear forward The forward unpack in te_grouped_quantized_linear_fn used `q_stacked[i] for i in range(num_gemms)` to feed the N quantized expert weights into TE's grouped linear. Each select records a SelectBackward autograd node; backward then materializes an intermediate gradient buffer per select before scattering its slice into q_stacked.grad. At high N the intermediate-buffer allocations dominate backward latency by orders of magnitude. unbind(0) records a single UnbindBackward whose backward is a fused torch.stack -- one allocation, one kernel. Same numerical math (STE through the axis-0 fake-quant); only the autograd graph topology changes. Surfaced and measured under the OMNIML-5064 per-MoE-layer microbench comparison study; landing here for the parent OMNIML-5072 unification. Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
… axis-0 fake-quant on TEGroupedLinear Replaces the stack-then-quantize-then-unbind path in te_grouped_quantized_linear_fn with a custom torch.autograd.Function that loops over the N expert weights inside Function.forward, calling fake_quant_impl per expert against amax_vec[i]. Why: the previous path stacked the N expert weights into one [N, out, in] tensor each forward. That stack is intra-HBM memcopy (no NCCL involved) and scales with the per-GPU stacked weight bytes; it dominates per-step latency at low per-GPU N with large per-expert weights. The autograd.Function wrapper also collapses what would otherwise be N per-call Python dispatches into a single Function entry, so the call site avoids both the stack memcopy and the per-call Python overhead the legacy per-expert-module path (weight_quantizer_i) paid. Backward: clipping-aware STE per expert (dw = grad where |w| <= amax_vec[i], else 0), matching FakeTensorQuantFunction's _fake_tensor_quant_backward semantics. Single fused autograd node. Calibration: handled inline as per-expert max accumulation into the [N, 1, 1] _amax buffer (under no_grad). This bypasses modelopt's calibrator interface, which is built around a single stacked input. The path is coupled to max-calibration; _setup asserts the constraint up-front. Per- expert support for other calibrators (histogram, percentile) is a separate ticket. OMNIML-5072 AC5 design; surfaced under the OMNIML-5064 comparison study. Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…nt via modelopt's _fake_tensor_quant_backward OMNIML-5064 microbench surfaced a backward regression in 26307c7: the per- expert STE backward was written as a plain-Python loop of `torch.where( w.abs() <= amax, g, g.new_zeros(1))`, which dispatches three unfused eager- mode kernels per expert plus one fresh tensor allocation per iteration. Measured backward at Ultra mock-EP=16 was ~5x slower than the previous unbind-based path. Fix: delegate the per-expert STE math to modelopt's existing @torch.jit.script-decorated _fake_tensor_quant_backward. It fuses the abs/compare/where into a single kernel per expert and matches what FakeTensorQuantFunction's own backward uses. Same numerical output; collapses per-expert backward cost back to the expected single-fused- kernel level. Behaviour-equivalent change to 26307c7's backward; forward path is unchanged. Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…single-GPU Reverts: 4c164f9 [OMNIML-5072] quantization: fuse STE backward in _GroupedAxis0FakeQuant ... 26307c7 [OMNIML-5072] quantization: no-stack autograd.Function for per-expert ... Measurement under OMNIML-5064's Ultra mock-EP=16 microbench (production-shape per-GPU N=32 on a single B300): Cell fwd_us bwd_us step_us A (per-expert mods) 5,835 547 8,246 <- A baseline B unbind-only 6,380 1,308 9,549 <- d7ccf0a baseline B no-stack Function 6,060 6,155 14,073 <- AC5 regression AC5's forward improvement was real (~5%) and matched A's fwd within tolerance. But the per-expert backward kernels operate on N scattered HBM regions instead of one contiguous stacked tensor, and on B300's high HBM bandwidth the contiguous-tensor kernel access pattern dominates the (small) stack-allocation cost. Net: AC5's backward is ~4.7x slower than the unbind-only path's fused-on-stacked-tensor backward, and step time degrades by ~50%. The unbind fix (d7ccf0a) stays — that's still a strict improvement over the indexed-select pattern at high N. AC5's `_GroupedAxis0FakeQuant` is removed as a failed experiment; the design rationale for revisiting it (a fused per-expert CUDA/Triton kernel that achieves coalesced access across N scattered weights in a single launch) lives in OMNIML-5064 ac2-results.md. Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
What does this PR do?
Type of change: New feature
Per-expert (axis=0) weight quantization on
_QuantTEGroupedLinear, plus the cluster infrastructure to verify it. Consolidates three Jira tickets onto one branch:_QuantMegatronTEGroupedLinear._process_quantizer_amax's per-tensor restriction. Per-expert_amaxof shape[num_gemms, 1, 1]is calibrated viaiter_weights_for_calibration's stacked weight, broadcast through the existing axis-N machinery inTensorQuantizer.forward, and round-tripped viasharded_state_dictsave/restore. Per-tensor (axis=None) path remains bit-for-bit equivalent.megatron_install_pathfield onSlurmConfig/aws_cmh_slurm_factory. Bind-mounts the packaged Megatron-LM somake_tp_sharded_tensor_for_checkpoint(..., allow_shape_mismatch=True)resolves againstnemo:25.11.nemotron_3_nano(previously deadlocked the dist-checkpoint metadata exchange for 10 minutes).sequence_parallel=Truein the MoE+TPtest_moe_sharded_state_dictconfig so the TEGrouped path doesn't trip the post-OMNIML-5029 assertion.Usage
Testing
tests/gpu_megatron/torch/quantization/plugins/test_megatron.py::test_te_grouped_per_expert_sharded_state_dict(new) — per-expert axis=0 calibration + dist-checkpoint save/restore round-trip ontp=2 ep=2 etp=1 num_moe_experts=4. PASSED 4/4 ranks on aws-cmh (SLURM 511920, experimentcicd_1781150929, containernvcr.io/nvidia/nemo:25.11.nemotron_3_nano).Existing tests continue to pass:
test_moe_sharded_state_dict[*-True]with the OMNIML-5030sequence_parallelfix.axis=None) paths — bit-for-bit equivalent (no behavior change).Cluster yaml:
services/model-optimizer/test/omniml_4998_per_expert_amax_smoke.yamlon nmm-sandboxhungyuehc/omniml-5029 @ 24860c3.Before your PR is "Ready for review"
axis=None) path is bit-for-bit equivalent; per-expert is a new opt-in via existing axis-N machinery, not a behavior change for existing users.test_te_grouped_per_expert_sharded_state_dict./claude reviewbefore marking Ready-for-review.Additional Information
Related Jira:
Relationship to PR #1550 (
jennifchen/te_per_expert, WIP): PR #1550 explores the same goal with a different design — opt-in env-varMODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1to create N separateweight_quantizer_iinstances. This PR uses the single-quantizer + axis=0 design instead, which reuses modelopt's existing axis-N machinery without a new env-var or per-gemm state. The two designs serve overlapping but not identical needs:PR #1550 is currently WIP and CONFLICTING vs main. Open to discussion on whether to merge approaches; this PR is the focused common-case implementation.
Summary by CodeRabbit
New Features
Improvements
Tests