[NVFP4][MOE] Add unfused quantization fallback when input shape is not aligned #2747
[NVFP4][MOE] Add unfused quantization fallback when input shape is not aligned #2747zhongbozhu wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds a fallback to the unfused NVFP4 quantization path inside Key changes:
Confidence Score: 4/5
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[split_quantize called] --> B{All NVFP4\nquantizers?}
B -- No --> C[AllocationMethod::UNFUSED\nQuantizationMethod::UNFUSED]
B -- Yes --> D[AllocationMethod::BULK_NVFP4\nQuantizationMethod::FUSED_NVFP4]
D --> E[bulk_allocate_nvfp4_tensors]
E --> F{input_shape.back\n% 128 != 0?}
F -- Yes --> G[Emit once_flag warning\nQuantizationMethod = UNFUSED]
F -- No --> H{contiguous_data\n_and_scale?}
G --> H
H -- No --> I[QuantizationMethod = UNFUSED]
H -- Yes --> J[QuantizationMethod stays\nFUSED_NVFP4]
I --> K[multi_tensor_quantize_impl\nunfused path]
C --> K
J --> L[split_quantize_nvfp4_impl\nfused kernel]
Last reviewed commit: e145d18 |
| NVTE_WARN( | ||
| "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " | ||
| "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " | ||
| "NVFP4 might bring performance regressions for this input tensor shape."); | ||
| quantization_method = QuantizationMethod::UNFUSED; | ||
| } |
There was a problem hiding this comment.
Potentially chatty warning on every call
NVTE_WARN is emitted unconditionally on every invocation of split_quantize when the inner dimension is not 128-aligned. In a MoE model this function can be called hundreds or thousands of times per forward pass, flooding the log. Consider using a static/once flag or a rate-limited warning so the message fires only on the first occurrence.
| NVTE_WARN( | |
| "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " | |
| "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " | |
| "NVFP4 might bring performance regressions for this input tensor shape."); | |
| quantization_method = QuantizationMethod::UNFUSED; | |
| } | |
| static std::once_flag _nvfp4_align_warn_flag; | |
| std::call_once(_nvfp4_align_warn_flag, []() { | |
| NVTE_WARN( | |
| "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " | |
| "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " | |
| "NVFP4 might bring performance regressions for this input tensor shape."); | |
| }); | |
| quantization_method = QuantizationMethod::UNFUSED; |
There was a problem hiding this comment.
Might be good idea, but if we just warn it once, people might just ignore it?
There was a problem hiding this comment.
Fair point. A better middle ground: warn once per unique shape instead of globally once.
This way:
- Each problematic input shape gets flagged (so nothing is ignored)
- Repeated calls with the same shape don't spam logs
- Users can identify all non-aligned shapes in their workload
You could use a static std::unordered_set<size_t> to track shapes you've already warned about:
static std::unordered_set<size_t> warned_shapes;
size_t inner_dim = input_shape.back();
if (inner_dim % 128 != 0 && warned_shapes.insert(inner_dim).second) {
NVTE_WARN("Unfused NVFP4 quantization fallback triggered for inner dim %zu (not multiple of 128)...", inner_dim);
}This strikes a balance between visibility and log cleanliness.
There was a problem hiding this comment.
changed to warn it once, can revert it if we do want to leave a lot of warnings in the log to remind people of the issue
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: