Skip to content

Commit 082ba85

Browse files
committed
Fix bugs for bf16 model, spec_dec data flow, hanging for nano_v3_nvfp4
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 094ace3 commit 082ba85

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
BaseWeightMapper
2525
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
2626
from tensorrt_llm._torch.utils import ActivationType, relu2
27+
from tensorrt_llm.logger import logger
2728

2829
from ..attention_backend import AttentionMetadata
2930
from ..distributed import AllReduce
@@ -323,6 +324,17 @@ def __init__(
323324
self.is_nvfp4 = (model_config.quant_config is not None
324325
and model_config.quant_config.quant_mode is not None
325326
and model_config.quant_config.quant_mode.has_nvfp4())
327+
# The fused RMSNorm+NVFP4 CUDA kernel requires hidden_size to be
328+
# a supported tile size. Non-power-of-2 hidden sizes within tile
329+
# ranges may cause kernel hangs. Disable fused NVFP4 for such cases.
330+
# Supported tile sizes: 2048, 4096, 8192, 16384
331+
_SUPPORTED_NVFP4_HIDDEN_SIZES = {2048, 4096, 8192, 16384}
332+
if self.is_nvfp4 and config.hidden_size not in _SUPPORTED_NVFP4_HIDDEN_SIZES:
333+
logger.warning_once(
334+
f"Layer {layer_idx}: Disabling fused NVFP4 RMSNorm for hidden_size={config.hidden_size}. "
335+
f"Supported sizes: {_SUPPORTED_NVFP4_HIDDEN_SIZES}. Using non-fused path.",
336+
key=f"disable_nvfp4_rmsnorm_with_{config.hidden_size}")
337+
self.is_nvfp4 = False
326338

327339
self.norm = RMSNorm(
328340
hidden_size=config.hidden_size,
@@ -331,9 +343,9 @@ def __init__(
331343
# Enable fused NVFP4 quantization if possible.
332344
# It might be overridden in `_try_attach_nvfp4_scale` function.
333345
quantize_type="nvfp4" if self.is_nvfp4 else None,
334-
# Enable high precision output for MoE layer.
346+
# Enable high precision output for MoE layer (only with NVFP4).
335347
# It might be overridden in `_try_attach_nvfp4_scale` function.
336-
return_hp_output=layer_type == "E",
348+
return_hp_output=layer_type == "E" and self.is_nvfp4,
337349
)
338350

339351
if layer_type == "M":
@@ -424,7 +436,7 @@ def forward(
424436
if spec_metadata is not None and spec_metadata.is_layer_capture(
425437
self.layer_idx):
426438
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
427-
hidden_states, None)
439+
hidden_states, residual)
428440

429441
return hidden_states, residual
430442

0 commit comments

Comments
 (0)