diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index c94b7d716..4381d54f7 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -39,6 +39,7 @@ def __init__( stop_multiplier: float = 4.0, quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + fp8_scale_sweep: bool = False, ): """Initialize MSE calibrator. @@ -46,13 +47,16 @@ def __init__( amax: Initial amax value (required). axis: Quantization axis. None means per-tensor quantization. step_size: Step size for amax search. The number of steps is computed as - ceil((stop_multiplier - start_multiplier) / step_size) + 1. + ceil((stop_multiplier - start_multiplier) / step_size) + 1. start_multiplier: Starting multiplier for amax search. stop_multiplier: Ending multiplier for amax search. quant_func: Function that quantizes input tensor given an amax value. - Should have signature: quant_func(x, amax) -> quantized_x. + Should have signature: quant_func(x, amax) -> quantized_x. error_func: Function to compute error between x and xq. - Default is F.mse_loss(x, xq, reduction='none'). + Default is F.mse_loss(x, xq, reduction='none'). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + instead of using multipliers. This is specifically for NVFP4 + per-block quantization where scales are stored in FP8 format. """ super().__init__(num_bits=None, axis=axis, unsigned=None) self._initial_amax = amax @@ -65,6 +69,13 @@ def __init__( self._error_func = error_func self._losses_sum = [None] * self._num_steps self._candidate_amaxs = [None] * self._num_steps + self._fp8_scale_sweep = fp8_scale_sweep + if fp8_scale_sweep: + # For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values + # (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN) + self._num_steps = 126 + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps self._amax = None @@ -83,14 +94,33 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - multipliers = torch.linspace( - self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device - ) + + if self._fp8_scale_sweep: + global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) + global_amax_expanded = global_amax * torch.ones_like(self._initial_amax) + + # Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn) + # Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32 + uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) + fp8_values = uint8_values.view(torch.float8_e4m3fn).float() + + # Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers + valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) + fp8_values_valid = fp8_values[valid_mask] + + candidates = fp8_values_valid / 448.0 + else: + candidates = torch.linspace( + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + ) # Get reduce axis for per-channel quantization reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) - for step, multiplier in enumerate(multipliers): - candidate_amax = self._initial_amax * multiplier + for step, candidate in enumerate(candidates): + if self._fp8_scale_sweep: + candidate_amax = global_amax_expanded * candidate + else: + candidate_amax = self._initial_amax * candidate xq = self._quant_func(x, candidate_amax) if self._error_func is not None: diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 2772e8138..9836648d0 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -387,30 +387,6 @@ "algorithm": "max", } -NVFP4_WEIGHT_ACT_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "mse", - "step_size": 0.25, - "start_multiplier": 0.25, - "stop_multiplier": 2.0, - }, -} - NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -1040,6 +1016,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig): reconstruction error of a tensor after uniform Q→DQ: s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations} + + When fp8_scale_sweep is enabled, step_size is ignored. """ method: Literal["mse"] = ModeloptField("mse") @@ -1066,6 +1044,14 @@ class MseCalibConfig(QuantizeAlgorithmConfig): description="Ending multiplier for amax search range (multiplies initial amax).", ) + fp8_scale_sweep: bool | None = ModeloptField( + default=False, + title="Enable FP8 scale sweep for NVFP4 per-block quantization.", + description="If True, sweep all 128 FP8 E4M3 scale values instead of using multipliers. " + "Only applies to NVFP4 weight quantization. When enabled, num_steps, step_size, " + "start_multiplier, and stop_multiplier are ignored.", + ) + distributed_sync: bool | None = ModeloptField( default=True, title="Whether to sync the amax across the distributed processes.", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index b8461a080..835b70fbf 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -197,6 +197,30 @@ def sync_quantizer_amax_across_tp( ) +def _mse_quant_func(x, amax, quantizer): + """Quantization function for MSE calibration.""" + original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None + quantizer._amax = amax + + with ( + enable_quant(quantizer), + disable_calib(quantizer), + enable_fake_quant(quantizer), + ): + if hasattr(quantizer, "_original_shape"): + x = quantizer._reset_to_original_shape(x) + xq = quantizer(x) + if hasattr(quantizer, "_block_reshape_size"): + xq = xq.reshape(quantizer._block_reshape_size) + + if original_amax is not None: + quantizer._amax = original_amax + else: + delattr(quantizer, "_amax") + + return xq + + @torch.no_grad() def mse_calibrate( model: nn.Module, @@ -205,6 +229,7 @@ def mse_calibrate( step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, + fp8_scale_sweep: bool = False, ): """Calibrate the model using MSE-based amax search. @@ -220,6 +245,10 @@ def mse_calibrate( step_size: Step size for amax search (default: 0.1). start_multiplier: Starting multiplier for amax search (default: 0.25). stop_multiplier: Ending multiplier for amax search (default: 4.0). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + for NVFP4 per-block quantization instead of using multipliers. + This is specifically designed for optimizing the FP8-quantized + per-block scales in NVFP4 format (default: False). See :class:`MseCalibConfig ` for details on the remaining arguments. @@ -238,27 +267,13 @@ def mse_calibrate( # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() - def quant_func(x, amax, quantizer=module): - original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None - quantizer._amax = amax - - with ( - enable_quant(quantizer), - disable_calib(quantizer), - enable_fake_quant(quantizer), - ): - if hasattr(quantizer, "_original_shape"): - x = quantizer._reset_to_original_shape(x) - xq = quantizer(x) - if hasattr(quantizer, "_block_reshape_size"): - xq = xq.reshape(quantizer._block_reshape_size) - - if original_amax is not None: - quantizer._amax = original_amax - else: - delattr(quantizer, "_amax") - - return xq + is_nvfp4_static = ( + fp8_scale_sweep + and module.is_static_block_quant + and module._num_bits == (2, 1) + and module._block_sizes is not None + and module._block_sizes.get("scale_bits") == (4, 3) + ) # Create MSE calibrator with quant_func module._calibrator = MseCalibrator( @@ -267,7 +282,8 @@ def quant_func(x, amax, quantizer=module): step_size=step_size, start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, - quant_func=quant_func, + quant_func=partial(_mse_quant_func, quantizer=module), + fp8_scale_sweep=is_nvfp4_static, ) # Identify weight quantizers by checking if they have corresponding weight parameters diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 7d3fa1251..6d580d4a7 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -753,11 +753,12 @@ def _fake_quantize(self, inputs): elif self._num_bits == (2, 1) and self.is_static_block_quant: outputs = static_blockwise_fp4_fake_quant( inputs, - amax / 6.0, + None, # scale None, # scale_fp8_quant_amax False, # skip_scale_quant inputs.dtype, # out_dtype self._pass_through_bwd, # pass_through_bwd + amax, # amax ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index b3e6edc3a..0a95d9916 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -574,21 +574,23 @@ def forward( skip_scale_quant, out_dtype, pass_through_bwd=False, + amax=None, ): """Forward method.""" - _save_for_backward_if_needed(ctx, pass_through_bwd, x, scale) + _save_for_backward_if_needed(ctx, pass_through_bwd, x, scale if scale is not None else amax) return triton_kernel.static_blockwise_fp4_fake_quant( x, scale, scale_fp8_quant_amax, skip_scale_quant, out_dtype, + amax, ) @staticmethod def backward(ctx, grad_outputs): """Implements straight through estimation with clipping.""" - return _fake_quant_backward_function(ctx, grad_outputs, num_args=6) + return _fake_quant_backward_function(ctx, grad_outputs, num_args=7) def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 6d97d984c..4735fb5ee 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -406,20 +406,32 @@ def static_blockwise_fp4_fake_quant_kernel( def static_blockwise_fp4_fake_quant( x: torch.Tensor, - scale: torch.Tensor, + scale: torch.Tensor | None = None, scale_fp8_quant_amax: torch.Tensor | None = None, skip_scale_quant: bool = False, out_dtype: torch.dtype | None = None, + amax: torch.Tensor | None = None, ): """Static blockwise FP4 fake quantization using Triton kernel. Args: x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. - scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. + scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. Mutually exclusive with amax. scale_fp8_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale. skip_scale_quant: If True, skip FP8 quantization of scale. out_dtype: Output dtype. Defaults to x.dtype if None. + amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. If provided, scale = amax / 6.0. + Mutually exclusive with scale. """ + if scale is None and amax is None: + raise ValueError("Either scale or amax must be provided") + if scale is not None and amax is not None: + raise ValueError("Cannot provide both scale and amax") + + if amax is not None: + scale = amax / 6.0 # FP4 max representable value is 6.0 + + assert scale is not None # Guaranteed by validation above assert x.ndim == 2 NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 9d82c1082..3d1de84d3 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -43,7 +43,30 @@ "enable": True, }, }, - "algorithm": "mse", + "algorithm": { + "method": "mse", + "step_size": 0.25, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} + +NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + }, + "algorithm": { + "method": "mse", + "fp8_scale_sweep": True, + }, } @@ -71,6 +94,7 @@ mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, NVFP4_WEIGHT_ACT_MSE_CFG, + NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, ], ) def test_quantize(model_cls, config): @@ -88,6 +112,7 @@ def test_quantize(model_cls, config): mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, NVFP4_WEIGHT_ACT_MSE_CFG, + NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index e9d4ef215..509e1b4e7 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -226,7 +226,7 @@ def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): scales = torch.full((num_blocks,), scale_value, device=inputs.device) quantized_outputs_triton = triton_kernel.static_blockwise_fp4_fake_quant( - inputs, scales, skip_scale_quant=skip_scale_quant + inputs, scale=scales, skip_scale_quant=skip_scale_quant ) # Only check exact values when skip_scale_quant=True @@ -264,8 +264,9 @@ def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): @pytest.mark.parametrize("block_size", [16, 32, 64]) @pytest.mark.parametrize("num_blocks", [4, 8, 16]) @pytest.mark.parametrize("use_explicit_amax", [False, True]) + @pytest.mark.parametrize("use_amax_param", [False, True]) def test_static_vs_dynamic_fp4_kernels( - self, set_torch_dtype, block_size, num_blocks, use_explicit_amax + self, set_torch_dtype, block_size, num_blocks, use_explicit_amax, use_amax_param ): """Test that static kernel with computed scales matches dynamic kernel behavior. @@ -273,6 +274,8 @@ def test_static_vs_dynamic_fp4_kernels( This test verifies that the static kernel with pre-computed scales (matching dynamic kernel's logic) produces the same results as the dynamic kernel. + Args: + use_amax_param: If True, use the amax parameter instead of scale parameter. """ torch.manual_seed(42) @@ -286,9 +289,17 @@ def test_static_vs_dynamic_fp4_kernels( else: scale_fp8_quant_amax = None - output_static = triton_kernel.static_blockwise_fp4_fake_quant( - x, scales, scale_fp8_quant_amax=scale_fp8_quant_amax, skip_scale_quant=False - ) + if use_amax_param: + output_static = triton_kernel.static_blockwise_fp4_fake_quant( + x, + amax=block_amax, + scale_fp8_quant_amax=scale_fp8_quant_amax, + skip_scale_quant=False, + ) + else: + output_static = triton_kernel.static_blockwise_fp4_fake_quant( + x, scale=scales, scale_fp8_quant_amax=scale_fp8_quant_amax, skip_scale_quant=False + ) output_dynamic = triton_kernel.fp4_fake_quant_block( x, global_amax=global_amax, @@ -298,8 +309,10 @@ def test_static_vs_dynamic_fp4_kernels( ) amax_mode = "explicit" if use_explicit_amax else "automatic" + param_mode = "amax" if use_amax_param else "scale" assert torch.allclose(output_static, output_dynamic, rtol=1e-3, atol=1e-5), ( - f"Static and dynamic kernels produced different outputs (scale_fp8_quant_amax={amax_mode}).\n" + f"Static and dynamic kernels produced different outputs " + f"(scale_fp8_quant_amax={amax_mode}, param={param_mode}).\n" f"Max abs diff: {(output_static - output_dynamic).abs().max()}\n" f"Mean abs diff: {(output_static - output_dynamic).abs().mean()}\n" f"Max relative diff: {((output_static - output_dynamic).abs() / (output_dynamic.abs() + 1e-8)).max()}" diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 5e5546512..efccec4c4 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -526,3 +526,70 @@ def quant_func(x, amax): assert a_best.numel() == 2 assert torch.all(torch.isfinite(a_best)) assert torch.all(a_best > 0) + + def test_fp8_scale_sweep_with_fixed_values_and_reset(self): + """Test FP8 scale sweep with fixed hand-written values and reset functionality.""" + x = torch.full((100,), 2.0, dtype=torch.float32) + x[0] = 20.0 + + initial_amax = torch.tensor(20.0) + + quant_cfg = QuantizerAttributeConfig(num_bits=(4, 3), axis=None, unsigned=False) + tq = TensorQuantizer(quant_attribute_cfg=quant_cfg, amax=initial_amax) + + def quant_func(x, amax): + original_amax = tq._amax.clone() if hasattr(tq, "_amax") else None + was_quant_enabled = tq._if_quant + was_calib_enabled = tq._if_calib + + tq._amax = amax + tq._if_quant = True + tq._if_calib = False + + with enable_fake_quant(tq): + xq = tq(x) + + if original_amax is not None: + tq._amax = original_amax + tq._if_quant = was_quant_enabled + tq._if_calib = was_calib_enabled + return xq + + cal = calib.MseCalibrator( + amax=initial_amax, + quant_func=quant_func, + fp8_scale_sweep=True, + ) + + assert cal._num_steps == 126 + + cal.collect(x) + + a_best = cal.compute_amax() + + assert torch.isfinite(a_best), "Optimal amax should be finite" + assert a_best > 0, "Optimal amax should be positive" + assert a_best <= initial_amax, "Optimal amax should not exceed initial amax" + + # FP8 scale sweep uses global_amax * fp8_multiplier where fp8_multiplier + # ranges from ~4.36e-06 to 1.0. For mostly 2.0 values with one 20.0 outlier, + # the optimal amax should be somewhere between these extremes + assert a_best >= initial_amax * 1e-6, "Optimal amax should not be unreasonably small" + + a_best_value = a_best.item() + + cal.reset() + + a_after_reset = cal.compute_amax() + assert a_after_reset is None, "After reset, compute_amax should return None" + + assert cal._num_steps == 126, "After reset, num_steps should still be 126" + + cal.collect(x) + a_best_after_reset = cal.compute_amax() + + assert torch.isfinite(a_best_after_reset), "Should be able to compute amax after reset" + assert a_best_after_reset > 0, "Amax after reset should be positive" + assert abs(a_best_after_reset.item() - a_best_value) < 1e-6, ( + "Amax after reset should match original value with same data" + )