2424 BaseWeightMapper
2525from tensorrt_llm ._torch .modules .mamba .mamba2_metadata import Mamba2Metadata
2626from tensorrt_llm ._torch .utils import ActivationType , relu2
27+ from tensorrt_llm .logger import logger
2728
2829from ..attention_backend import AttentionMetadata
2930from ..distributed import AllReduce
@@ -323,6 +324,17 @@ def __init__(
323324 self .is_nvfp4 = (model_config .quant_config is not None
324325 and model_config .quant_config .quant_mode is not None
325326 and model_config .quant_config .quant_mode .has_nvfp4 ())
327+ # The fused RMSNorm+NVFP4 CUDA kernel requires hidden_size to be
328+ # a supported tile size. Non-power-of-2 hidden sizes within tile
329+ # ranges may cause kernel hangs. Disable fused NVFP4 for such cases.
330+ # Supported tile sizes: 2048, 4096, 8192, 16384
331+ _SUPPORTED_NVFP4_HIDDEN_SIZES = {2048 , 4096 , 8192 , 16384 }
332+ if self .is_nvfp4 and config .hidden_size not in _SUPPORTED_NVFP4_HIDDEN_SIZES :
333+ logger .warning_once (
334+ f"Layer { layer_idx } : Disabling fused NVFP4 RMSNorm for hidden_size={ config .hidden_size } . "
335+ f"Supported sizes: { _SUPPORTED_NVFP4_HIDDEN_SIZES } . Using non-fused path." ,
336+ key = f"disable_nvfp4_rmsnorm_with_{ config .hidden_size } " )
337+ self .is_nvfp4 = False
326338
327339 self .norm = RMSNorm (
328340 hidden_size = config .hidden_size ,
@@ -331,9 +343,9 @@ def __init__(
331343 # Enable fused NVFP4 quantization if possible.
332344 # It might be overridden in `_try_attach_nvfp4_scale` function.
333345 quantize_type = "nvfp4" if self .is_nvfp4 else None ,
334- # Enable high precision output for MoE layer.
346+ # Enable high precision output for MoE layer (only with NVFP4) .
335347 # It might be overridden in `_try_attach_nvfp4_scale` function.
336- return_hp_output = layer_type == "E" ,
348+ return_hp_output = layer_type == "E" and self . is_nvfp4 ,
337349 )
338350
339351 if layer_type == "M" :
@@ -424,7 +436,7 @@ def forward(
424436 if spec_metadata is not None and spec_metadata .is_layer_capture (
425437 self .layer_idx ):
426438 spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
427- hidden_states , None )
439+ hidden_states , residual )
428440
429441 return hidden_states , residual
430442
0 commit comments