Skip to content

[NVFP4][MOE] Add unfused quantization fallback when input shape is not aligned #2747

Open
zhongbozhu wants to merge 2 commits intoNVIDIA:mainfrom
zhongbozhu:add_fallback_to_nvfp4_split_quantize
Open

[NVFP4][MOE] Add unfused quantization fallback when input shape is not aligned #2747
zhongbozhu wants to merge 2 commits intoNVIDIA:mainfrom
zhongbozhu:add_fallback_to_nvfp4_split_quantize

Conversation

@zhongbozhu
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu requested a review from ptrendx March 9, 2026 22:26
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 9, 2026

Greptile Summary

This PR adds a fallback to the unfused NVFP4 quantization path inside split_quantize when the input tensor's inner dimension is not a multiple of 128, preventing the fused grouped kernel from being invoked on unsupported shapes (common in MoE layers with non-standard hidden sizes). A global-once warning is emitted via std::once_flag / std::call_once, and a test case (M=1024, N=320) is added to validate correctness of the fallback.

Key changes:

  • cast.cpp: After bulk_allocate_nvfp4_tensors, checks input_shape.back() % 128 != 0 and overrides quantization_method to UNFUSED when true. A static std::once_flag ensures the warning fires at most once per process lifetime.
  • cast.cpp: contiguous_data_and_scale is now defensively initialised to false before std::tie, eliminating any theoretical UB if the function were to throw.
  • test_nvfp4_group_quantize.py: Adds (1024, 320) (N not divisible by 128) to the M/N parametrize list, exercising the new fallback across all existing edge-case and quantize-mode combinations.

Confidence Score: 4/5

  • Safe to merge; logic is correct and the new fallback is verified by an added test case covering the non-aligned shape.
  • The control-flow change is simple and correct: bulk allocation already handles non-128-aligned shapes (the returned tensors are valid), and the existing unfused path (multi_tensor_quantize_impl) was already exercised via the !contiguous_data_and_scale branch. The new code just adds a second gate into that same path. Thread safety is handled by std::once_flag. One point is held back because the single-global-once warning fires without exposing the actual offending dimension value in its message, which slightly reduces debuggability for users who encounter this fallback.
  • No files require special attention; both changed files are straightforward.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[split_quantize called] --> B{All NVFP4\nquantizers?}
    B -- No --> C[AllocationMethod::UNFUSED\nQuantizationMethod::UNFUSED]
    B -- Yes --> D[AllocationMethod::BULK_NVFP4\nQuantizationMethod::FUSED_NVFP4]
    D --> E[bulk_allocate_nvfp4_tensors]
    E --> F{input_shape.back\n% 128 != 0?}
    F -- Yes --> G[Emit once_flag warning\nQuantizationMethod = UNFUSED]
    F -- No --> H{contiguous_data\n_and_scale?}
    G --> H
    H -- No --> I[QuantizationMethod = UNFUSED]
    H -- Yes --> J[QuantizationMethod stays\nFUSED_NVFP4]
    I --> K[multi_tensor_quantize_impl\nunfused path]
    C --> K
    J --> L[split_quantize_nvfp4_impl\nfused kernel]
Loading

Last reviewed commit: e145d18

Comment on lines +1362 to +1367
NVTE_WARN(
"Unfused NVFP4 quantization fallback is triggered because the input tensor inner "
"dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. "
"NVFP4 might bring performance regressions for this input tensor shape.");
quantization_method = QuantizationMethod::UNFUSED;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially chatty warning on every call

NVTE_WARN is emitted unconditionally on every invocation of split_quantize when the inner dimension is not 128-aligned. In a MoE model this function can be called hundreds or thousands of times per forward pass, flooding the log. Consider using a static/once flag or a rate-limited warning so the message fires only on the first occurrence.

Suggested change
NVTE_WARN(
"Unfused NVFP4 quantization fallback is triggered because the input tensor inner "
"dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. "
"NVFP4 might bring performance regressions for this input tensor shape.");
quantization_method = QuantizationMethod::UNFUSED;
}
static std::once_flag _nvfp4_align_warn_flag;
std::call_once(_nvfp4_align_warn_flag, []() {
NVTE_WARN(
"Unfused NVFP4 quantization fallback is triggered because the input tensor inner "
"dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. "
"NVFP4 might bring performance regressions for this input tensor shape.");
});
quantization_method = QuantizationMethod::UNFUSED;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good idea, but if we just warn it once, people might just ignore it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. A better middle ground: warn once per unique shape instead of globally once.

This way:

  • Each problematic input shape gets flagged (so nothing is ignored)
  • Repeated calls with the same shape don't spam logs
  • Users can identify all non-aligned shapes in their workload

You could use a static std::unordered_set<size_t> to track shapes you've already warned about:

static std::unordered_set<size_t> warned_shapes;
size_t inner_dim = input_shape.back();
if (inner_dim % 128 != 0 && warned_shapes.insert(inner_dim).second) {
  NVTE_WARN("Unfused NVFP4 quantization fallback triggered for inner dim %zu (not multiple of 128)...", inner_dim);
}

This strikes a balance between visibility and log cleanliness.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to warn it once, can revert it if we do want to leave a lot of warnings in the log to remind people of the issue

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant