diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 9710d3a4b..c94b7d716 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -15,6 +15,7 @@ """Calibrator that returns the MSE amax of all collected tensors.""" +import math from collections.abc import Callable import torch @@ -33,7 +34,7 @@ def __init__( self, amax: torch.Tensor, axis: int | tuple | list | None = None, - num_steps: int = 10, + step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, @@ -44,7 +45,8 @@ def __init__( Args: amax: Initial amax value (required). axis: Quantization axis. None means per-tensor quantization. - num_steps: Number of amax candidates to try. + step_size: Step size for amax search. The number of steps is computed as + ceil((stop_multiplier - start_multiplier) / step_size) + 1. start_multiplier: Starting multiplier for amax search. stop_multiplier: Ending multiplier for amax search. quant_func: Function that quantizes input tensor given an amax value. @@ -54,13 +56,15 @@ def __init__( """ super().__init__(num_bits=None, axis=axis, unsigned=None) self._initial_amax = amax - self._num_steps = num_steps + self._step_size = step_size self._start_multiplier = start_multiplier self._stop_multiplier = stop_multiplier + self._num_steps = math.ceil((stop_multiplier - start_multiplier) / step_size) + 1 + self._quant_func = quant_func self._error_func = error_func - self._losses_sum = [None] * num_steps - self._candidate_amaxs = [None] * num_steps + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps self._amax = None @@ -82,7 +86,6 @@ def collect(self, x: torch.Tensor): multipliers = torch.linspace( self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device ) - # Get reduce axis for per-channel quantization reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) @@ -111,6 +114,18 @@ def reset(self): self._candidate_amaxs = [None] * self._num_steps self._amax = None + def clear(self): + """Clear all cached data to free GPU memory. + + Call this after compute_amax() and load_calib_amax() are done. + """ + self._losses_sum = [] + self._candidate_amaxs = [] + + if self._initial_amax is not None: + del self._initial_amax + self._initial_amax = None + @torch.no_grad() def compute_amax(self, verbose: bool = False): """Return the amax value that minimizes quantization error. diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 3a43113df..3c9f937a7 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -387,6 +387,29 @@ "algorithm": "max", } +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "step_size": 0.25, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} NVFP4_AWQ_LITE_CFG = { "quant_cfg": { @@ -734,7 +757,7 @@ def validate_num_bits(self): raise ValueError( "Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)." ) - elif num_bits != (4, 3) and ( + elif num_bits not in [(4, 3), (2, 1)] and ( block_sizes is None or block_sizes.get("type", None) != "dynamic" ): raise ValueError( @@ -992,11 +1015,12 @@ class MseCalibConfig(QuantizeAlgorithmConfig): method: Literal["mse"] = ModeloptField("mse") - num_steps: int | None = ModeloptField( - default=10, - ge=1, - title="Number of amax candidates to try.", - description="Number of amax candidates to search over for MSE minimization.", + step_size: float | None = ModeloptField( + default=0.1, + gt=0.0, + title="Step size for amax search.", + description="Step size between amax candidates. The number of candidates is computed as " + "ceil((stop_multiplier - start_multiplier) / step_size) + 1.", ) start_multiplier: float | None = ModeloptField( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c8e2b044c..69a6776b1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -190,7 +190,7 @@ def mse_calibrate( model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True, - num_steps: int = 10, + step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, ): @@ -205,7 +205,7 @@ def mse_calibrate( forward_loop: A callable which takes the model as argument and forwards calibration data through the model. distributed_sync: Whether to sync amax across distributed processes. - num_steps: Number of amax candidates to try (default: 10). + step_size: Step size for amax search (default: 0.1). start_multiplier: Starting multiplier for amax search (default: 0.25). stop_multiplier: Ending multiplier for amax search (default: 4.0). @@ -216,14 +216,12 @@ def mse_calibrate( max_calibrate(model, forward_loop, distributed_sync) # Step 2: Replace calibrators with MseCalibrator for enabled quantizers + # and identify weight quantizers + weight_quantizers = [] + seen_modules = set() + for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and not module._disabled: - # Static block quantization is not supported by MseCalibrator - if module.is_static_block_quant: - raise ValueError( - f"MSE calibration does not support static block quantization. " - f"Found static block quantization at {name}." - ) if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() @@ -237,7 +235,11 @@ def quant_func(x, amax, quantizer=module): disable_calib(quantizer), enable_fake_quant(quantizer), ): + if hasattr(quantizer, "_original_shape"): + x = quantizer._reset_to_original_shape(x) xq = quantizer(x) + if hasattr(quantizer, "_block_reshape_size"): + xq = xq.reshape(quantizer._block_reshape_size) if original_amax is not None: quantizer._amax = original_amax @@ -250,22 +252,63 @@ def quant_func(x, amax, quantizer=module): module._calibrator = MseCalibrator( amax=initial_amax, axis=module._calibrator._axis, - num_steps=num_steps, + step_size=step_size, start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, quant_func=quant_func, ) - # Step 3: Collect data with MSE calibrators + # Identify weight quantizers by checking if they have corresponding weight parameters + for name, parent_module in model.named_modules(): + if parent_module in seen_modules: + continue + for weight_name in weight_attr_names(parent_module): + weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer + weight_quantizer = getattr(parent_module, weight_quantizer_name, None) + if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: + if getattr(weight_quantizer, "_calibrator", None) is not None: + weight_quantizers.append((parent_module, weight_name, weight_quantizer)) + seen_modules.add(parent_module) + + # Step 3: Calibrate weight quantizers once with MSE calibration + # This ensures weights are only calibrated once, not during every forward pass + for parent_module, weight_name, weight_quantizer in weight_quantizers: + # Enable calibration mode for the weight quantizer + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + + with enable_weight_access_and_writeback(parent_module, model): + weight = getattr(parent_module, weight_name) + weight_quantizer(weight) + + # Step 4: Disable weight quantizers during forward loop + for _, _, weight_quantizer in weight_quantizers: + weight_quantizer.disable() + + # Step 5: Collect data with MSE calibrators for activation quantizers only enable_stats_collection(model) if forward_loop is None: - weight_only_quantize(model) + # If no forward loop, nothing else to do since weights are already calibrated + pass else: + # Run forward loop - only activation quantizers will collect data forward_loop(model) - # Step 4: Compute optimal amax and load it + # Step 6: Re-enable weight quantizers before finalizing calibration + # This ensures finish_stats_collection processes them correctly + for _, _, weight_quantizer in weight_quantizers: + weight_quantizer.enable() + + # Step 7: Compute optimal amax and load it for all quantizers (weights + activations) finish_stats_collection(model, method="mse") + # Step 8: Free GPU memory by clearing calibrator data + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer) and not module._disabled: + if hasattr(module, "_calibrator") and getattr(module, "_calibrator", None) is not None: + if hasattr(module._calibrator, "clear"): + module._calibrator.clear() + # TODO: Sync amax across distributed processes diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 1688c7fa7..6d0debe41 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -52,7 +52,12 @@ NVFP4QTensor, QTensorWrapper, ) -from ...tensor_quant import dynamic_block_quant, fake_tensor_quant, scaled_e4m3 +from ...tensor_quant import ( + dynamic_block_quant, + fake_tensor_quant, + scaled_e4m3, + static_blockwise_fp4_fake_quant, +) from ...utils import is_torch_export_mode from ..functional import normalized_hadamard_transform @@ -653,6 +658,15 @@ def _fake_quantize(self, inputs): getattr(self, "_onnx_quantizer_type", None), self._pass_through_bwd, ) + elif self._num_bits == (2, 1) and self.is_static_block_quant: + outputs = static_blockwise_fp4_fake_quant( + inputs, + amax / 6.0, + None, # scale_fp8_quant_amax + False, # skip_scale_quant + inputs.dtype, # out_dtype + self._pass_through_bwd, # pass_through_bwd + ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 5f69e3999..0c6f52470 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -571,6 +571,35 @@ def backward(ctx, grad_outputs): return _fake_quant_backward_function(ctx, grad_outputs, num_args=9) +class StaticBlockwiseFP4FakeQuantFunction(Function): + """Static blockwise FP4 fake quantization functional.""" + + @staticmethod + def forward( + ctx, + x, + scale, + scale_fp8_quant_amax, + skip_scale_quant, + out_dtype, + pass_through_bwd=False, + ): + """Forward method.""" + _save_for_backward_if_needed(ctx, pass_through_bwd, x, scale) + return triton_kernel.static_blockwise_fp4_fake_quant( + x, + scale, + scale_fp8_quant_amax, + skip_scale_quant, + out_dtype, + ) + + @staticmethod + def backward(ctx, grad_outputs): + """Implements straight through estimation with clipping.""" + return _fake_quant_backward_function(ctx, grad_outputs, num_args=6) + + def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): """Shared function body between TensorQuantFunction and FakeTensorQuantFunction.""" # Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning. @@ -615,3 +644,4 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): fake_tensor_quant = FakeTensorQuantFunction.apply scaled_e4m3 = ScaledE4M3Function.apply dynamic_block_quant = DynamicBlockQuantizationFunction.apply +static_blockwise_fp4_fake_quant = StaticBlockwiseFP4FakeQuantFunction.apply diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index f2f9bd077..6d97d984c 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -24,7 +24,7 @@ import triton import triton.language as tl -__all__ = ["fp4_fake_quant_block"] +__all__ = ["fp4_fake_quant_block", "static_blockwise_fp4_fake_quant"] _TORCH_TO_TL_DTYPE = { @@ -345,3 +345,115 @@ def fp4_dequantize( ) return output + + +@triton.jit +def static_blockwise_fp4_fake_quant_kernel( + x_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + y_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + scale_ptr, # [NUM_FP4_BLOCKS] + NUM_FP4_BLOCKS, + BLOCK_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + if pid >= NUM_FP4_BLOCKS: + return + + block_offset = pid * BLOCK_SIZE + idx = block_offset + tl.arange(0, BLOCK_SIZE) + + scale = tl.load(scale_ptr + pid).to(tl.float32) + + x = tl.load(x_ptr + idx).to(tl.float32) + + x_abs = tl.abs(x) + scale_safe = tl.where(scale >= 1e-5, scale, 1.0) + abs_scaled = x_abs / scale_safe + + # FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where( + abs_scaled < 3.5, + 3.0, + tl.where(abs_scaled <= 5.0, 4.0, 6.0), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * scale_safe + x_quant = tl.where(x >= 0, x_rescaled, -x_rescaled) + + tl.store(y_ptr + idx, x_quant.to(OUT_DTYPE)) + + +def static_blockwise_fp4_fake_quant( + x: torch.Tensor, + scale: torch.Tensor, + scale_fp8_quant_amax: torch.Tensor | None = None, + skip_scale_quant: bool = False, + out_dtype: torch.dtype | None = None, +): + """Static blockwise FP4 fake quantization using Triton kernel. + + Args: + x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. + scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. + scale_fp8_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale. + skip_scale_quant: If True, skip FP8 quantization of scale. + out_dtype: Output dtype. Defaults to x.dtype if None. + """ + assert x.ndim == 2 + NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape + + if out_dtype is None: + out_dtype = x.dtype + + if not skip_scale_quant: + from modelopt.torch.quantization.tensor_quant import scaled_e4m3_impl + from modelopt.torch.quantization.utils import reduce_amax + + if scale_fp8_quant_amax is None: + scale_fp8_quant_amax = reduce_amax( + scale, axis=None, keepdims=False, squeeze_scalar=True + ) + + scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax) + + x_flat = x.contiguous().view(-1) + y_flat = torch.empty_like(x_flat, dtype=out_dtype) + scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() + + tl_out_dtype = _torch_dtype_to_tl(out_dtype) + + grid = (NUM_FP4_BLOCKS,) + + # Ensure we're running on the correct CUDA device + with torch.cuda.device(x.device): + static_blockwise_fp4_fake_quant_kernel[grid]( + x_flat, + y_flat, + scale_flat, + NUM_FP4_BLOCKS, + BLOCK_SIZE, + OUT_DTYPE=tl_out_dtype, + ) + + return y_flat.view_as(x) diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 811e0be81..9d82c1082 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -28,6 +28,24 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.extensions import get_cuda_ext_mx +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + }, + "algorithm": "mse", +} + @pytest.mark.parametrize("model_cls", [SimpleLinear, SimpleConv, SimpleConvLinear]) @pytest.mark.parametrize( @@ -52,6 +70,7 @@ mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + NVFP4_WEIGHT_ACT_MSE_CFG, ], ) def test_quantize(model_cls, config): @@ -68,6 +87,7 @@ def test_quantize(model_cls, config): mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + NVFP4_WEIGHT_ACT_MSE_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index 83becee41..98b498bb6 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -221,3 +221,103 @@ def _test_fp4_kernel(test_in, test_out, skip_triton=False): test_in *= sign test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign _test_fp4_kernel(test_in, test_out) + + @pytest.mark.skipif(not triton_kernel.IS_AVAILABLE, reason="triton kernel is not available") + @pytest.mark.parametrize( + "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True + ) + @pytest.mark.parametrize("block_size", [8, 16, 32]) + @pytest.mark.parametrize("skip_scale_quant", [True, False]) + def test_static_blockwise_fp4(self, set_torch_dtype, block_size, skip_scale_quant): + # Test with e2m1 table values + sign = torch.randint(0, 2, (1, 8)).cuda() * 2 - 1 + + def _get_test_inputs_outputs(test_in, test_out, num_blocks=4): + return torch.concat((test_in,) * (block_size // 8), dim=-1).repeat( + num_blocks, 1 + ), torch.concat((test_out,) * (block_size // 8), dim=-1).repeat(num_blocks, 1) + + def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): + inputs, expected_outputs = _get_test_inputs_outputs(test_in, test_out) + num_blocks = inputs.shape[0] + scales = torch.full((num_blocks,), scale_value, device=inputs.device) + + quantized_outputs_triton = triton_kernel.static_blockwise_fp4_fake_quant( + inputs, scales, skip_scale_quant=skip_scale_quant + ) + + # Only check exact values when skip_scale_quant=True + # When scale quantization is enabled, the scale changes slightly, affecting outputs + if skip_scale_quant: + assert torch.allclose(quantized_outputs_triton, expected_outputs, atol=1e-6) + else: + assert quantized_outputs_triton.shape == expected_outputs.shape + + test_in = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + test_out = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + if skip_scale_quant: + # Test slightly below the e2m1 boundary values. + # Numbers should be quantized down to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] -= 0.1 + test_in *= sign + test_out = torch.tensor([[0.0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + # Test slightly above the e2m1 boundary values. + # Numbers should be quantized up to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] += 0.1 + test_in *= sign + test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + @pytest.mark.skipif(not triton_kernel.IS_AVAILABLE, reason="triton kernel is not available") + @pytest.mark.parametrize( + "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True + ) + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("num_blocks", [4, 8, 16]) + @pytest.mark.parametrize("use_explicit_amax", [False, True]) + def test_static_vs_dynamic_fp4_kernels( + self, set_torch_dtype, block_size, num_blocks, use_explicit_amax + ): + """Test that static kernel with computed scales matches dynamic kernel behavior. + + The dynamic kernel computes scales dynamically from block-wise max values with FP8 quantization. + This test verifies that the static kernel with pre-computed scales (matching dynamic kernel's logic) + produces the same results as the dynamic kernel. + + """ + torch.manual_seed(42) + + x = torch.randn(num_blocks, block_size, dtype=torch.float32).cuda() * 10 + block_amax = x.abs().max(dim=1, keepdim=False)[0] + global_amax = block_amax.max() + scales = block_amax / 6.0 + + if use_explicit_amax: + scale_fp8_quant_amax = global_amax / 6.0 + else: + scale_fp8_quant_amax = None + + output_static = triton_kernel.static_blockwise_fp4_fake_quant( + x, scales, scale_fp8_quant_amax=scale_fp8_quant_amax, skip_scale_quant=False + ) + output_dynamic = triton_kernel.fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + amax_mode = "explicit" if use_explicit_amax else "automatic" + assert torch.allclose(output_static, output_dynamic, rtol=1e-3, atol=1e-5), ( + f"Static and dynamic kernels produced different outputs (scale_fp8_quant_amax={amax_mode}).\n" + f"Max abs diff: {(output_static - output_dynamic).abs().max()}\n" + f"Mean abs diff: {(output_static - output_dynamic).abs().mean()}\n" + f"Max relative diff: {((output_static - output_dynamic).abs() / (output_dynamic.abs() + 1e-8)).max()}" + ) diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 26e7d52da..5e5546512 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -68,7 +68,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=20, + step_size=0.075, start_multiplier=0.1, stop_multiplier=1.5, quant_func=quant_func, @@ -115,7 +115,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=25, + step_size=0.045, start_multiplier=0.1, stop_multiplier=1.2, quant_func=quant_func, @@ -162,7 +162,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=50, + step_size=0.008, start_multiplier=0.8, stop_multiplier=1.2, quant_func=quant_func, @@ -214,7 +214,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=20, + step_size=0.075, start_multiplier=0.1, stop_multiplier=1.5, quant_func=quant_func, @@ -265,7 +265,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=15, + step_size=0.07, start_multiplier=0.5, stop_multiplier=1.5, quant_func=quant_func, @@ -307,7 +307,7 @@ def quant_func(x, amax): tq._if_calib = was_calib_enabled return xq - cal = calib.MseCalibrator(amax=initial_amax, num_steps=10, quant_func=quant_func) + cal = calib.MseCalibrator(amax=initial_amax, step_size=0.4, quant_func=quant_func) cal.collect(x) @@ -352,7 +352,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=10, + step_size=0.15, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, @@ -398,7 +398,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=15, + step_size=0.1, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, @@ -458,7 +458,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=20, + step_size=0.05, start_multiplier=0.5, stop_multiplier=1.5, quant_func=quant_func, @@ -511,7 +511,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=10, + step_size=0.15, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 85fa07fa4..f5cd52c84 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -45,6 +45,7 @@ "algorithm": "awq_lite", } +# Test configs for per channel MSE calibration INT8_MSE_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0},