[OMNIML-4944] peft: add lora_dtype field to PEFTAttributeConfig#1646
Draft
hychiang-git wants to merge 9 commits into
Draft
[OMNIML-4944] peft: add lora_dtype field to PEFTAttributeConfig#1646hychiang-git wants to merge 9 commits into
hychiang-git wants to merge 9 commits into
Conversation
Adds an optional dtype string ('bf16' | 'fp16' | 'fp32', with long-form
aliases) that lets a LoRA adapter pin its factor dtype independent of
the wrapped layer.
Motivation: for low-bit MoE QAD, the base layer's storage dtype can be
fake-quantized (e.g. uint4/int4 stand-ins) while the LoRA sidecar must
stay BF16. Today _register_adapter_with_device copies the parent's dtype
onto LoRA factors, which is wrong in that regime. lora_dtype gives
plugins an explicit override knob; None (default) preserves today's
inherit-from-parent behavior.
Validator normalizes long forms to canonical short forms and rejects
unsupported strings.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Contributor
|
f6687c0 to
3d3dd3e
Compare
Adds modelopt/torch/peft/lora/plugins/megatron_moe.py: - _StackedLoRAFactor: tiny nn.Module carrier holding one stacked Parameter of shape [num_experts, ...] for per-expert factors. - _LoRATEGroupedBase: LoRAModule subclass with update_layer_lora() that allocates A:[E,r,in] / B:[E,out,r] honoring attr_config.lora_dtype, and a forward() that splits the input by tokens_per_expert and dispatches each expert's factors before summing into the base output. - _LoRATEGroupedColumnParallelLinear / _LoRATEGroupedRowParallelLinear: thin registration markers binding the base class to TE's ColumnParallel / RowParallel grouped linears (the production path when moe_grouped_gemm=True, as in Nemotron-3 Hybrid). forward() bypasses LoRAModule.forward via super(LoRAModule, self) because LoRAModule's adapter loop calls lora_a(x) directly -- which would invoke _StackedLoRAFactor.forward and raise. The stacked layout requires per-expert dispatch keyed by tokens_per_expert. Wires the new module into modelopt/torch/peft/lora/plugins/__init__.py under the existing megatron import_plugin guard. Plumbing only -- sharded_state_dict, SVDQuant init, and quant+LoRA combined registration land in follow-up commits. Smoke test follows in the next commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu_megatron/torch/peft/plugins/test_megatron_moe_lora.py exercising the plugin landed in 3e8e868 on a tiny 2-layer MoE GPT (num_moe_experts=4, moe_grouped_gemm=True, transformer_impl="transformer_engine" -- the TEGroupedMLP production path). Four assertions: - Registration + zero-init identity: at least one _LoRATEGrouped*ParallelLinear exists after update_model; per-expert stacked factor shapes match [E, rank, in] / [E, out, rank]; zeros on B means lora_output == base_output; disable/enable adapters round-trips. - Random init perturbs output: Kaiming on both A and B changes the model output. - lora_dtype pin: lora_dtype="bf16" produces bfloat16 LoRA factor tensors, independent of the base layer's dtype. - Gradient flow: with default freeze_base_model=True, LoRA params receive non-zero gradients and base params receive none. Module-level pytest.mark.skipif(not HAVE_TE_GROUPED) lets the file run as a no-op when the container's Transformer Engine is older than 1.9.0.dev0 -- expected to flag a container bump if all four tests come back SKIPPED. Test not yet runtime-verified; pending SLURM submission via nmm-sandbox. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The smoke test's _test_gradient_flow was triggering a hard Megatron
assertion at moe_layer.py:616:
if self.training and self.attn_tp_group.size() > 1 \
and not self.config.sequence_parallel:
raise ValueError("During training, performance may degrade if MoE
and tensor parallelism are enabled without also
enabling sequence parallelism.")
The MoE test model uses moe_grouped_gemm=True and the dist_workers fixture
defaults to tp_size=4 here; the _test_utils.torch.megatron.models factory
hardcodes sequence_parallel=False and does not surface it via config_kwargs,
so the only knob we can flip is self.training.
PyTorch autograd is mode-independent: loss.backward() in eval mode still
populates .grad on every parameter that participated in the forward, so the
test's actual claim -- LoRA factors receive non-zero gradients while base
weights remain frozen by default -- holds without entering train mode.
Smoke test run on aws-cmh (job 475427) confirms 4/4 green:
test_registration_and_zero_init_identity PASSED
test_random_init_perturbs_output PASSED
test_lora_dtype_pin PASSED
test_gradient_flow PASSED <- this commit
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a `lora_init_method` enum (Literal["kaiming_zeros", "svdquant"]) to PEFTAttributeConfig and wires the "svdquant" branch into the TE-grouped plugin's update_layer_lora. When chosen, per-expert factors are seeded from a rank-r SVD of the quantization residual W_e - quant(W_e) using the existing model_calib.svd helper: lora_a.weight[e] <- vt ([rank, in_features]) lora_b.weight[e] <- us ([out_features, rank]) so that B_e @ A_e == us @ vt approximates the residual at init time. Falls back to zero-init on both factors (with a warning) when no enabled weight_quantizer is attached -- mtq.quantize() must run before mtpeft.update_model() for SVDQuant init to be meaningful. Plugin also registers the LoRA grouped class against the quantized TE-grouped base classes (_MegatronTEGroupedColumnParallelLinear / _MegatronTEGroupedRowParallelLinear) so the quantize -> LoRA flow finds them on the grouped path. Mirrors peft/lora/plugins/megatron.py: 244-250 for the non-grouped variant. Import order in peft/lora/plugins/__init__.py is flipped so megatron_moe runs before megatron. The dynamic-module registry's class resolution (modelopt/torch/opt/dynamic.py:_get_registered_nn_class) iterates registrations in insertion order and picks the first class whose `forward` identity-matches the target's forward. Both the TE-grouped and non-grouped quant classes inherit `forward` from _QuantFunctionalMixin, so they'd otherwise tie and the earlier-registered one (non-grouped) would win -- causing _LoRAMegatronColumnParallelLinear.update_layer_lora to be invoked on a TE-grouped target, where `self.input_size` doesn't exist. With megatron_moe first, the grouped registration wins for grouped targets; non-grouped targets are unaffected because their issubclass check still excludes TE-grouped classes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two assertions to tests/gpu_megatron/torch/peft/plugins/ test_megatron_moe_lora.py, plus the per-tensor INT8 quant config they use: - test_svdquant_init_recovers_residual: runs mtq.quantize then mtpeft.update_model(svdquant), and verifies that for each TE-grouped expert e, B_e @ A_e is a non-trivial rank-r approximation of the quantization residual W_e - quant(W_e) (reconstruction error bounded by residual norm; factors not all zero). - test_svdquant_init_no_quantizer_falls_back: skips mtq.quantize and asserts svdquant correctly degenerates to zero-init on both factors. SVDQUANT_LORA_CFG restricts LoRA application to *experts*linear_fc* patterns. Matches the ticket's "adapters placed one per up_proj / down_proj across all MoE expert layers" scope, and sidesteps a pre-existing modelopt bug where LoRA on TE plain linears (attention, in the quantize -> LoRA flow) crashes on a missing `self.input_size` attribute. Worth a follow-up ticket for the non-grouped TE+quant+LoRA case, but out of scope here. INT8_PER_TENSOR_QUANT_CFG mirrors NVFP4_DEFAULT_CONFIG's list-of-dicts shape from tests/.../test_megatron_peft.py and stays per-tensor to satisfy the TEGroupedLinear quantization restriction (quantization/plugins/megatron.py:696). All 6 smoke assertions pass on aws-cmh (SLURM job 476191, 4:02). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu_megatron/torch/peft/plugins/test_megatron_moe_lora.py:: test_quantize_then_lora_svdquant_state_dict_roundtrip. After mtq.quantize -> mtpeft.update_model(svdquant), snapshots per-expert LoRA factors, torch.save the state_dict to disk, zeroes the live factors in place, then torch.load + load_state_dict(strict=True) and asserts bitwise equality with the snapshot. This is the Phase 2 sanity scaffold for OMNIML-4944: it validates the "LoRA-on-quant-checkpoint" plumbing claim end-to-end at small scale, without needing Megatron's sharded_state_dict (deferred to PR #3+). Each dist_workers rank saves to its own tempdir to avoid collisions. All 7 smoke assertions pass on aws-cmh (SLURM job 479919, 3:07). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Megatron dist-checkpoint integration for the per-expert stacked LoRA
factors on _LoRATEGroupedBase. Three additions:
- _factor_tp_axis(factor_name) on each subclass returns the TP-sharded
axis of the 2-D per-expert slice (or None for replicated):
column-parallel: lora_a None, lora_b axis 0 (out_features dim)
row-parallel: lora_a axis 1 (in_features dim), lora_b None
- sharded_state_dict() splits each stacked [E_local, ...] factor into
per-global-expert entries keyed weight{global_idx}, with the EP dim
encoded as a prepended sharded axis and TP via make_*_sharded_*.
Mirrors TEGroupedLinear._sharded_state_dict_grouped (transformer_engine.py:1981).
- _load_from_state_dict() stacks the per-expert weight{i} entries back
into a single stacked weight that matches the carrier's state_dict
shape. Short-circuits when the stacked key is already present (plain
torch.save/load path) so it doesn't pay the EP lookup cost there.
_ep_rank_size_tp_group() helper resolves (ep_rank, ep_size, tp_group)
preferring self._pg_collection / self._tp_group when available, falling
back to parallel_state.get_expert_model_parallel_* when not (the small
test factory builds linears via a path that doesn't reliably set
_pg_collection on every per-rank instance).
Unblocks real-Nemotron-3 checkpoint testing on dist-checkpoint format
(jennifchen's W4A16 dirs) and prepares the training-loop integration
that follows.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… factors Adds test_quantize_then_lora_svdquant_dist_checkpoint_roundtrip mirroring test_megatron_peft.py::test_mcore_quantize_then_lora_save_restore but for the TE-grouped MoE LoRA path. Builds a quantized + LoRA model_ref, saves via save_distributed_checkpoint + save_sharded_modelopt_state, restores into a fresh model_test via restore_sharded_modelopt_state + load_distributed_checkpoint, and asserts the post-LoRA forward output matches model_ref's within tolerance. Exercises the new _LoRATEGroupedBase.sharded_state_dict + _load_from_state_dict hooks end-to-end on Megatron's actual dist-checkpoint format -- the same format jennifchen's Nemotron-3-Nano-30B-A3B W4A16 checkpoints use. All 8 smoke assertions pass on aws-cmh. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds an optional dtype string ('bf16' | 'fp16' | 'fp32', with long-form aliases) that lets a LoRA adapter pin its factor dtype independent of the wrapped layer.
Motivation: for low-bit MoE QAD, the base layer's storage dtype can be fake-quantized (e.g. uint4/int4 stand-ins) while the LoRA sidecar must stay BF16. Today _register_adapter_with_device copies the parent's dtype onto LoRA factors, which is wrong in that regime. lora_dtype gives plugins an explicit override knob; None (default) preserves today's inherit-from-parent behavior.
Validator normalizes long forms to canonical short forms and rejects unsupported strings.
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information