From c025df72b712641230db856ab26f76cd6e93d062 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:10:29 +0000 Subject: [PATCH 01/10] tmp update for per block mse NVFP4 and INT4 Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 12 +- .../nn/modules/tensor_quantizer.py | 19 ++- .../torch/quantization/triton/fp4_kernel.py | 154 ++++++++++++++++++ 3 files changed, 174 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c8e2b044c..e22dc5d34 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -219,11 +219,11 @@ def mse_calibrate( for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and not module._disabled: # Static block quantization is not supported by MseCalibrator - if module.is_static_block_quant: - raise ValueError( - f"MSE calibration does not support static block quantization. " - f"Found static block quantization at {name}." - ) + # if module.is_static_block_quant: + # raise ValueError( + # f"MSE calibration does not support static block quantization. " + # f"Found static block quantization at {name}." + # ) if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() @@ -237,7 +237,9 @@ def quant_func(x, amax, quantizer=module): disable_calib(quantizer), enable_fake_quant(quantizer), ): + quantizer._keep_shape = True xq = quantizer(x) + quantizer._keep_shape = False if original_amax is not None: quantizer._amax = original_amax diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 1688c7fa7..59261a43a 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -128,6 +128,7 @@ def __init__( self._enable_pre_quant_scale = True self._dequantize = False self._input_dtype = None + self._keep_shape = False # Lazy initialize the bias calibrator for KV cache quantization self._bias_calibrator = None @@ -653,6 +654,12 @@ def _fake_quantize(self, inputs): getattr(self, "_onnx_quantizer_type", None), self._pass_through_bwd, ) + elif self._num_bits == (2, 1): + from modelopt.torch.quantization.triton.fp4_kernel import ( + launch_blockwise_fp4_fake_quant, + ) + + outputs = launch_blockwise_fp4_fake_quant(inputs, amax / 6.0, out_dtype=inputs.dtype) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 @@ -783,11 +790,11 @@ def _process_for_blockquant(self, inputs: torch.Tensor): if hasattr(self, "_padding"): inputs = F.pad(inputs, self._padding, "constant", 0) - if inputs.shape != self._original_shape: - raise ValueError( - f"Input shape has changed from {self._original_shape} to {inputs.shape}." - " Block-quantization requires a fixed input shape." - ) + # if inputs.shape != self._original_shape: + # print( + # f"Input shape has changed from {self._original_shape} to {inputs.shape}." + # " Block-quantization requires a fixed input shape." + # ) inputs = inputs.reshape(self._block_reshape_size) return inputs @@ -941,7 +948,7 @@ def forward(self, inputs): "This case should have been handled." ) - if self.is_static_block_quant: + if self.is_static_block_quant and not self._keep_shape: outputs = self._reset_to_original_shape(outputs) return outputs diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index f2f9bd077..61d66abe7 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -345,3 +345,157 @@ def fp4_dequantize( ) return output + + +@triton.jit +def blockwise_fp4_fake_quant_kernel( + x_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + y_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] + scale_ptr, # [NUM_FP4_BLOCKS] + NUM_FP4_BLOCKS, + BLOCK_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + if pid >= NUM_FP4_BLOCKS: + return + + block_offset = pid * BLOCK_SIZE + idx = block_offset + tl.arange(0, BLOCK_SIZE) + + scale = tl.load(scale_ptr + pid).to(tl.float32) + + x = tl.load(x_ptr + idx).to(tl.float32) + + x_abs = tl.abs(x) + scale_safe = tl.where(scale >= 1e-5, scale, 1.0) + abs_scaled = x_abs / scale_safe + + # FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 + q_val = tl.where( + abs_scaled <= 0.25, + 0.0, + tl.where( + abs_scaled < 0.75, + 0.5, + tl.where( + abs_scaled <= 1.25, + 1.0, + tl.where( + abs_scaled < 1.75, + 1.5, + tl.where( + abs_scaled <= 2.5, + 2.0, + tl.where( + abs_scaled < 3.5, + 3.0, + tl.where(abs_scaled <= 5.0, 4.0, 6.0), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * scale_safe + x_dequant = tl.where(x >= 0, x_rescaled, -x_rescaled) + + tl.store(y_ptr + idx, x_dequant.to(OUT_DTYPE)) + + +def launch_blockwise_fp4_fake_quant( + x: torch.Tensor, + scale: torch.Tensor, + out_dtype: torch.dtype = torch.float16, +): + """Launch Triton kernel for blockwise FP4 fake quantization. + + x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. + scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. + """ + assert x.ndim == 2 + NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape + + x_flat = x.contiguous().view(-1) + y_flat = torch.empty_like(x_flat, dtype=out_dtype) + scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() + + tl_out_dtype = _torch_dtype_to_tl(out_dtype) + + grid = (NUM_FP4_BLOCKS,) + + # Ensure we're running on the correct CUDA device + with torch.cuda.device(x.device): + blockwise_fp4_fake_quant_kernel[grid]( + x_flat, + y_flat, + scale_flat, + NUM_FP4_BLOCKS, + BLOCK_SIZE, + OUT_DTYPE=tl_out_dtype, + ) + + return y_flat.view_as(x) + + +def blockwise_fp4_fake_quant_reference( + x: torch.Tensor, + scale: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Reference implementation of blockwise FP4 fake quantization. + + x: [NUM_FP4_BLOCKS, BLOCK_SIZE]. + scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1]. + + Uses FP4 quantization levels: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0. + """ + assert x.ndim == 2 + num_blocks, block_size = x.shape + + if scale.ndim == 1: + scale = scale.view(num_blocks, 1) + assert scale.shape == (num_blocks, 1) + + x_f = x.to(torch.float32) + s_f = scale.to(torch.float32) + + s_f = torch.where(s_f >= 1e-5, s_f, torch.ones_like(s_f)) + + x_abs = torch.abs(x_f) + abs_scaled = x_abs / s_f + + q_val = torch.where( + abs_scaled <= 0.25, + torch.zeros_like(abs_scaled), + torch.where( + abs_scaled < 0.75, + torch.full_like(abs_scaled, 0.5), + torch.where( + abs_scaled <= 1.25, + torch.ones_like(abs_scaled), + torch.where( + abs_scaled < 1.75, + torch.full_like(abs_scaled, 1.5), + torch.where( + abs_scaled <= 2.5, + torch.full_like(abs_scaled, 2.0), + torch.where( + abs_scaled < 3.5, + torch.full_like(abs_scaled, 3.0), + torch.where( + abs_scaled <= 5.0, + torch.full_like(abs_scaled, 4.0), + torch.full_like(abs_scaled, 6.0), + ), + ), + ), + ), + ), + ), + ) + + x_rescaled = q_val * s_f + x_dequant = torch.where(x_f >= 0, x_rescaled, -x_rescaled) + return x_dequant.to(out_dtype) From 0bcba00c6c39abb0ef00077402b2eb88a7d4206d Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 5 Dec 2025 02:32:37 +0000 Subject: [PATCH 02/10] improvements: even steps for mse amax search;calibrate weight quant once; quant scale to FP8; rename static kernel Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/calib/mse.py | 13 ++++- modelopt/torch/quantization/model_calib.py | 57 ++++++++++++++++++- .../nn/modules/tensor_quantizer.py | 8 ++- .../torch/quantization/triton/fp4_kernel.py | 6 +- 4 files changed, 73 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 9710d3a4b..37f4540af 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -79,8 +79,17 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - multipliers = torch.linspace( - self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + # Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier + # to ensure balanced exploration on both sides of the original amax (1.0) + steps_first_half = self._num_steps // 2 + 1 # Include 1.0 + steps_second_half = self._num_steps - self._num_steps // 2 # For second range + multipliers = torch.cat( + [ + torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device), + torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[ + 1: + ], # Skip duplicate 1.0 + ] ) # Get reduce axis for per-channel quantization diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e22dc5d34..5ea7b9599 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -32,6 +32,7 @@ from .calib import MseCalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import QuantModule, SequentialQuantizer, TensorQuantizer +from .tensor_quant import scaled_e4m3_impl from .utils import ( disable_calib, enable_fake_quant, @@ -41,6 +42,7 @@ is_quantized_linear, is_quantized_row_parallel_linear, quantizer_attr_names, + reduce_amax, weight_attr_names, ) @@ -216,6 +218,10 @@ def mse_calibrate( max_calibrate(model, forward_loop, distributed_sync) # Step 2: Replace calibrators with MseCalibrator for enabled quantizers + # and identify weight quantizers + weight_quantizers = [] + seen_modules = set() + for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and not module._disabled: # Static block quantization is not supported by MseCalibrator @@ -241,6 +247,17 @@ def quant_func(x, amax, quantizer=module): xq = quantizer(x) quantizer._keep_shape = False + # FP8 quantization of NVFP4 static per-block scales + if ( + quantizer.is_static_block_quant + and quantizer._num_bits == (2, 1) + and quantizer._block_sizes.get("scale_bits") == (4, 3) + ): + weight_amax = reduce_amax( + x, axis=None, keepdims=False, squeeze_scalar=True + ) + quantizer._amax = scaled_e4m3_impl(amax, weight_amax) + if original_amax is not None: quantizer._amax = original_amax else: @@ -258,14 +275,48 @@ def quant_func(x, amax, quantizer=module): quant_func=quant_func, ) - # Step 3: Collect data with MSE calibrators + # Identify weight quantizers by checking if they have corresponding weight parameters + for name, parent_module in model.named_modules(): + if parent_module in seen_modules: + continue + for weight_name in weight_attr_names(parent_module): + weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer + weight_quantizer = getattr(parent_module, weight_quantizer_name, None) + if isinstance(weight_quantizer, TensorQuantizer) and not weight_quantizer._disabled: + if weight_quantizer._calibrator is not None: + weight_quantizers.append((parent_module, weight_name, weight_quantizer)) + seen_modules.add(parent_module) + + # Step 3: Calibrate weight quantizers once with MSE calibration + # This ensures weights are only calibrated once, not during every forward pass + for parent_module, weight_name, weight_quantizer in weight_quantizers: + # Enable calibration mode for the weight quantizer + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + + with enable_weight_access_and_writeback(parent_module, model): + weight = getattr(parent_module, weight_name) + weight_quantizer(weight) + + # Step 4: Disable weight quantizers during forward loop + for _, _, weight_quantizer in weight_quantizers: + weight_quantizer.disable() + + # Step 5: Collect data with MSE calibrators for activation quantizers only enable_stats_collection(model) if forward_loop is None: - weight_only_quantize(model) + # If no forward loop, nothing else to do since weights are already calibrated + pass else: + # Run forward loop - only activation quantizers will collect data forward_loop(model) - # Step 4: Compute optimal amax and load it + # Step 6: Re-enable weight quantizers before finalizing calibration + # This ensures finish_stats_collection processes them correctly + for _, _, weight_quantizer in weight_quantizers: + weight_quantizer.enable() + + # Step 7: Compute optimal amax and load it for all quantizers (weights + activations) finish_stats_collection(model, method="mse") # TODO: Sync amax across distributed processes diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 59261a43a..71695a634 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -654,12 +654,14 @@ def _fake_quantize(self, inputs): getattr(self, "_onnx_quantizer_type", None), self._pass_through_bwd, ) - elif self._num_bits == (2, 1): + elif self._num_bits == (2, 1) and self.is_static_block_quant: from modelopt.torch.quantization.triton.fp4_kernel import ( - launch_blockwise_fp4_fake_quant, + launch_static_blockwise_fp4_fake_quant, ) - outputs = launch_blockwise_fp4_fake_quant(inputs, amax / 6.0, out_dtype=inputs.dtype) + outputs = launch_static_blockwise_fp4_fake_quant( + inputs, amax / 6.0, out_dtype=inputs.dtype + ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 61d66abe7..6049882e4 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -348,7 +348,7 @@ def fp4_dequantize( @triton.jit -def blockwise_fp4_fake_quant_kernel( +def static_blockwise_fp4_fake_quant_kernel( x_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] y_ptr, # [NUM_FP4_BLOCKS * BLOCK_SIZE] scale_ptr, # [NUM_FP4_BLOCKS] @@ -404,7 +404,7 @@ def blockwise_fp4_fake_quant_kernel( tl.store(y_ptr + idx, x_dequant.to(OUT_DTYPE)) -def launch_blockwise_fp4_fake_quant( +def launch_static_blockwise_fp4_fake_quant( x: torch.Tensor, scale: torch.Tensor, out_dtype: torch.dtype = torch.float16, @@ -427,7 +427,7 @@ def launch_blockwise_fp4_fake_quant( # Ensure we're running on the correct CUDA device with torch.cuda.device(x.device): - blockwise_fp4_fake_quant_kernel[grid]( + static_blockwise_fp4_fake_quant_kernel[grid]( x_flat, y_flat, scale_flat, From bb23fe2952e7a3319a59b66187d11e456cbf5353 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Wed, 10 Dec 2025 01:12:20 +0000 Subject: [PATCH 03/10] tmp: config update for experiments Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/calib/mse.py | 14 +-- modelopt/torch/quantization/config.py | 122 +++++++++++++++++++-- modelopt/torch/quantization/model_calib.py | 2 +- 3 files changed, 118 insertions(+), 20 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 37f4540af..46cfaa670 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -79,18 +79,10 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - # Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier - # to ensure balanced exploration on both sides of the original amax (1.0) - steps_first_half = self._num_steps // 2 + 1 # Include 1.0 - steps_second_half = self._num_steps - self._num_steps // 2 # For second range - multipliers = torch.cat( - [ - torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device), - torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[ - 1: - ], # Skip duplicate 1.0 - ] + multipliers = torch.linspace( + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device ) + print(f"Multipliers: {multipliers}") # Get reduce axis for per-channel quantization reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 3a43113df..54a3b2652 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -387,6 +387,112 @@ "algorithm": "max", } +NVFP4_WEIGHT_MAX_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": "max", +} + +NVFP4_WEIGHT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "num_steps": 8, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} + +NVFP4_WEIGHT_MSE_4_6_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "num_steps": 8, + "start_multiplier": 0.375, + "stop_multiplier": 3.0, + }, +} + +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "num_steps": 8, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} + +NVFP4_WEIGHT_ACT_MSE_4_6_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "num_steps": 8, + "start_multiplier": 0.375, + "stop_multiplier": 3.0, + }, +} + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { @@ -720,7 +826,7 @@ def validate_num_bits(self): if not all(x > 0 for x in num_bits): raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") - block_sizes = self.block_sizes + # block_sizes = self.block_sizes if num_bits not in [ (4, 3), (5, 2), @@ -734,13 +840,13 @@ def validate_num_bits(self): raise ValueError( "Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)." ) - elif num_bits != (4, 3) and ( - block_sizes is None or block_sizes.get("type", None) != "dynamic" - ): - raise ValueError( - "Only blockwise dynamic quantization is supported with quantization " - "formats E{num_bis[0]}M{num_bits[1]}." - ) + # elif num_bits != (4, 3) and ( + # block_sizes is None or block_sizes.get("type", None) != "dynamic" + # ): + # raise ValueError( + # "Only blockwise dynamic quantization is supported with quantization " + # "formats E{num_bis[0]}M{num_bits[1]}." + # ) return self axis: int | tuple[int, ...] | None = ModeloptField( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5ea7b9599..e71e68373 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -256,7 +256,7 @@ def quant_func(x, amax, quantizer=module): weight_amax = reduce_amax( x, axis=None, keepdims=False, squeeze_scalar=True ) - quantizer._amax = scaled_e4m3_impl(amax, weight_amax) + quantizer._amax = scaled_e4m3_impl(amax / 6.0, weight_amax / 6.0) * 6.0 if original_amax is not None: quantizer._amax = original_amax From 6bcf535f8f704038b22c1c3b5f67c1e71c393aaa Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Thu, 18 Dec 2025 00:23:51 +0000 Subject: [PATCH 04/10] clear up cached data after calibration, remove tmp configs Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/calib/mse.py | 26 ++++- modelopt/torch/quantization/config.py | 115 --------------------- modelopt/torch/quantization/model_calib.py | 7 ++ 3 files changed, 30 insertions(+), 118 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 46cfaa670..4292f1ff4 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -79,10 +79,18 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - multipliers = torch.linspace( - self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device + # Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier + # to ensure balanced exploration on both sides of the original amax (1.0) + steps_first_half = self._num_steps // 2 + 1 # Include 1.0 + steps_second_half = self._num_steps - self._num_steps // 2 # For second range + multipliers = torch.cat( + [ + torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device), + torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[ + 1: + ], # Skip duplicate 1.0 + ] ) - print(f"Multipliers: {multipliers}") # Get reduce axis for per-channel quantization reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) @@ -112,6 +120,18 @@ def reset(self): self._candidate_amaxs = [None] * self._num_steps self._amax = None + def clear(self): + """Clear all cached data to free GPU memory. + + Call this after compute_amax() and load_calib_amax() are done. + """ + self._losses_sum = [] + self._candidate_amaxs = [] + + if self._initial_amax is not None: + del self._initial_amax + self._initial_amax = None + @torch.no_grad() def compute_amax(self, verbose: bool = False): """Return the amax value that minimizes quantization error. diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 54a3b2652..ad30d213e 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -387,113 +387,6 @@ "algorithm": "max", } -NVFP4_WEIGHT_MAX_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": "max", -} - -NVFP4_WEIGHT_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "mse", - "num_steps": 8, - "start_multiplier": 0.25, - "stop_multiplier": 2.0, - }, -} - -NVFP4_WEIGHT_MSE_4_6_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "enable": False, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "mse", - "num_steps": 8, - "start_multiplier": 0.375, - "stop_multiplier": 3.0, - }, -} - -NVFP4_WEIGHT_ACT_MSE_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "mse", - "num_steps": 8, - "start_multiplier": 0.25, - "stop_multiplier": 2.0, - }, -} - -NVFP4_WEIGHT_ACT_MSE_4_6_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - **_default_disabled_quantizer_cfg, - }, - "algorithm": { - "method": "mse", - "num_steps": 8, - "start_multiplier": 0.375, - "stop_multiplier": 3.0, - }, -} - - NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -826,7 +719,6 @@ def validate_num_bits(self): if not all(x > 0 for x in num_bits): raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") - # block_sizes = self.block_sizes if num_bits not in [ (4, 3), (5, 2), @@ -840,13 +732,6 @@ def validate_num_bits(self): raise ValueError( "Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)." ) - # elif num_bits != (4, 3) and ( - # block_sizes is None or block_sizes.get("type", None) != "dynamic" - # ): - # raise ValueError( - # "Only blockwise dynamic quantization is supported with quantization " - # "formats E{num_bis[0]}M{num_bits[1]}." - # ) return self axis: int | tuple[int, ...] | None = ModeloptField( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e71e68373..172107160 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -319,6 +319,13 @@ def quant_func(x, amax, quantizer=module): # Step 7: Compute optimal amax and load it for all quantizers (weights + activations) finish_stats_collection(model, method="mse") + # Step 8: Free GPU memory by clearing calibrator data + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer) and not module._disabled: + if hasattr(module, "_calibrator") and module._calibrator is not None: + if hasattr(module._calibrator, "clear"): + module._calibrator.clear() + # TODO: Sync amax across distributed processes From f19a829c1536c69b1c6408ac4de536c02e04e549 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:34:27 +0000 Subject: [PATCH 05/10] add unit tests, address reviewer comments Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 6 -- .../nn/modules/tensor_quantizer.py | 11 +-- .../torch/quantization/triton/fp4_kernel.py | 78 +++---------------- .../torch/quantization/test_quantize_cuda.py | 20 +++++ .../quantization/test_tensor_quant_cuda.py | 42 ++++++++++ .../torch/quantization/test_quantize_cpu.py | 1 + 6 files changed, 75 insertions(+), 83 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 172107160..c1d787956 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -224,12 +224,6 @@ def mse_calibrate( for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and not module._disabled: - # Static block quantization is not supported by MseCalibrator - # if module.is_static_block_quant: - # raise ValueError( - # f"MSE calibration does not support static block quantization. " - # f"Found static block quantization at {name}." - # ) if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 71695a634..ec6d44225 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -656,12 +656,10 @@ def _fake_quantize(self, inputs): ) elif self._num_bits == (2, 1) and self.is_static_block_quant: from modelopt.torch.quantization.triton.fp4_kernel import ( - launch_static_blockwise_fp4_fake_quant, + static_blockwise_fp4_fake_quant, ) - outputs = launch_static_blockwise_fp4_fake_quant( - inputs, amax / 6.0, out_dtype=inputs.dtype - ) + outputs = static_blockwise_fp4_fake_quant(inputs, amax / 6.0, out_dtype=inputs.dtype) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 @@ -792,11 +790,6 @@ def _process_for_blockquant(self, inputs: torch.Tensor): if hasattr(self, "_padding"): inputs = F.pad(inputs, self._padding, "constant", 0) - # if inputs.shape != self._original_shape: - # print( - # f"Input shape has changed from {self._original_shape} to {inputs.shape}." - # " Block-quantization requires a fixed input shape." - # ) inputs = inputs.reshape(self._block_reshape_size) return inputs diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 6049882e4..8977341f0 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -24,7 +24,7 @@ import triton import triton.language as tl -__all__ = ["fp4_fake_quant_block"] +__all__ = ["fp4_fake_quant_block", "static_blockwise_fp4_fake_quant"] _TORCH_TO_TL_DTYPE = { @@ -399,24 +399,28 @@ def static_blockwise_fp4_fake_quant_kernel( ) x_rescaled = q_val * scale_safe - x_dequant = tl.where(x >= 0, x_rescaled, -x_rescaled) + x_quant = tl.where(x >= 0, x_rescaled, -x_rescaled) - tl.store(y_ptr + idx, x_dequant.to(OUT_DTYPE)) + tl.store(y_ptr + idx, x_quant.to(OUT_DTYPE)) -def launch_static_blockwise_fp4_fake_quant( +def static_blockwise_fp4_fake_quant( x: torch.Tensor, scale: torch.Tensor, - out_dtype: torch.dtype = torch.float16, + out_dtype: torch.dtype | None = None, ): - """Launch Triton kernel for blockwise FP4 fake quantization. + """Static blockwise FP4 fake quantization using Triton kernel. x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. + out_dtype: Output dtype. Defaults to x.dtype if None. """ assert x.ndim == 2 NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape + if out_dtype is None: + out_dtype = x.dtype + x_flat = x.contiguous().view(-1) y_flat = torch.empty_like(x_flat, dtype=out_dtype) scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() @@ -437,65 +441,3 @@ def launch_static_blockwise_fp4_fake_quant( ) return y_flat.view_as(x) - - -def blockwise_fp4_fake_quant_reference( - x: torch.Tensor, - scale: torch.Tensor, - out_dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """Reference implementation of blockwise FP4 fake quantization. - - x: [NUM_FP4_BLOCKS, BLOCK_SIZE]. - scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1]. - - Uses FP4 quantization levels: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0. - """ - assert x.ndim == 2 - num_blocks, block_size = x.shape - - if scale.ndim == 1: - scale = scale.view(num_blocks, 1) - assert scale.shape == (num_blocks, 1) - - x_f = x.to(torch.float32) - s_f = scale.to(torch.float32) - - s_f = torch.where(s_f >= 1e-5, s_f, torch.ones_like(s_f)) - - x_abs = torch.abs(x_f) - abs_scaled = x_abs / s_f - - q_val = torch.where( - abs_scaled <= 0.25, - torch.zeros_like(abs_scaled), - torch.where( - abs_scaled < 0.75, - torch.full_like(abs_scaled, 0.5), - torch.where( - abs_scaled <= 1.25, - torch.ones_like(abs_scaled), - torch.where( - abs_scaled < 1.75, - torch.full_like(abs_scaled, 1.5), - torch.where( - abs_scaled <= 2.5, - torch.full_like(abs_scaled, 2.0), - torch.where( - abs_scaled < 3.5, - torch.full_like(abs_scaled, 3.0), - torch.where( - abs_scaled <= 5.0, - torch.full_like(abs_scaled, 4.0), - torch.full_like(abs_scaled, 6.0), - ), - ), - ), - ), - ), - ), - ) - - x_rescaled = q_val * s_f - x_dequant = torch.where(x_f >= 0, x_rescaled, -x_rescaled) - return x_dequant.to(out_dtype) diff --git a/tests/gpu/torch/quantization/test_quantize_cuda.py b/tests/gpu/torch/quantization/test_quantize_cuda.py index 811e0be81..9d82c1082 100644 --- a/tests/gpu/torch/quantization/test_quantize_cuda.py +++ b/tests/gpu/torch/quantization/test_quantize_cuda.py @@ -28,6 +28,24 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.extensions import get_cuda_ext_mx +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + }, + "algorithm": "mse", +} + @pytest.mark.parametrize("model_cls", [SimpleLinear, SimpleConv, SimpleConvLinear]) @pytest.mark.parametrize( @@ -52,6 +70,7 @@ mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + NVFP4_WEIGHT_ACT_MSE_CFG, ], ) def test_quantize(model_cls, config): @@ -68,6 +87,7 @@ def test_quantize(model_cls, config): mtq.MXINT8_DEFAULT_CFG, mtq.NVFP4_KV_ROTATE_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, + NVFP4_WEIGHT_ACT_MSE_CFG, ]: if get_cuda_ext_mx() is None: pytest.skip("cuda_ext_mx is not available") diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index 83becee41..af57e7bb5 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -221,3 +221,45 @@ def _test_fp4_kernel(test_in, test_out, skip_triton=False): test_in *= sign test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign _test_fp4_kernel(test_in, test_out) + + @pytest.mark.skipif(not triton_kernel.IS_AVAILABLE, reason="triton kernel is not available") + @pytest.mark.parametrize( + "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True + ) + @pytest.mark.parametrize("block_size", [8, 16, 32]) + def test_static_blockwise_fp4(self, set_torch_dtype, block_size): + # Test with e2m1 table values + sign = torch.randint(0, 2, (1, 8)).cuda() * 2 - 1 + + def _get_test_inputs_outputs(test_in, test_out, num_blocks=4): + return torch.concat((test_in,) * (block_size // 8), dim=-1).repeat( + num_blocks, 1 + ), torch.concat((test_out,) * (block_size // 8), dim=-1).repeat(num_blocks, 1) + + def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): + inputs, expected_outputs = _get_test_inputs_outputs(test_in, test_out) + num_blocks = inputs.shape[0] + scales = torch.full((num_blocks,), scale_value, device=inputs.device) + + quantized_outputs_triton = triton_kernel.static_blockwise_fp4_fake_quant(inputs, scales) + assert torch.allclose(quantized_outputs_triton, expected_outputs) + + test_in = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + test_out = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + # Test slightly below the e2m1 boundary values. + # Numbers should be quantized down to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] -= 0.1 + test_in *= sign + test_out = torch.tensor([[0.0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + # Test slightly above the e2m1 boundary values. + # Numbers should be quantized up to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] += 0.1 + test_in *= sign + test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 85fa07fa4..f5cd52c84 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -45,6 +45,7 @@ "algorithm": "awq_lite", } +# Test configs for per channel MSE calibration INT8_MSE_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": 8, "axis": 0}, From e1d7ff57cf7898ae1d8187738cae5d026e9aba44 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Wed, 7 Jan 2026 00:54:09 +0000 Subject: [PATCH 06/10] fix gpu unit test Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index ad30d213e..a8261fc06 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -719,6 +719,7 @@ def validate_num_bits(self): if not all(x > 0 for x in num_bits): raise ValueError("num_bits must be a positive integer or a tuple of positive integers.") + block_sizes = self.block_sizes if num_bits not in [ (4, 3), (5, 2), @@ -732,6 +733,13 @@ def validate_num_bits(self): raise ValueError( "Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)." ) + elif num_bits not in [(4, 3), (2, 1)] and ( + block_sizes is None or block_sizes.get("type", None) != "dynamic" + ): + raise ValueError( + "Only blockwise dynamic quantization is supported with quantization " + "formats E{num_bis[0]}M{num_bits[1]}." + ) return self axis: int | tuple[int, ...] | None = ModeloptField( From 56f31df4707e91d903e8c11d71c949cf74f1afe4 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 9 Jan 2026 23:51:38 +0000 Subject: [PATCH 07/10] use step_size instead of num_steps; move FP8 scale quant into FP4 kernel launch func Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/calib/mse.py | 28 +++++------ modelopt/torch/quantization/config.py | 35 ++++++++++++-- modelopt/torch/quantization/model_calib.py | 19 ++------ .../nn/modules/tensor_quantizer.py | 8 +++- .../torch/quantization/triton/fp4_kernel.py | 20 ++++++-- .../quantization/test_tensor_quant_cuda.py | 46 +++++++++++-------- 6 files changed, 96 insertions(+), 60 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 4292f1ff4..c94b7d716 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -15,6 +15,7 @@ """Calibrator that returns the MSE amax of all collected tensors.""" +import math from collections.abc import Callable import torch @@ -33,7 +34,7 @@ def __init__( self, amax: torch.Tensor, axis: int | tuple | list | None = None, - num_steps: int = 10, + step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, @@ -44,7 +45,8 @@ def __init__( Args: amax: Initial amax value (required). axis: Quantization axis. None means per-tensor quantization. - num_steps: Number of amax candidates to try. + step_size: Step size for amax search. The number of steps is computed as + ceil((stop_multiplier - start_multiplier) / step_size) + 1. start_multiplier: Starting multiplier for amax search. stop_multiplier: Ending multiplier for amax search. quant_func: Function that quantizes input tensor given an amax value. @@ -54,13 +56,15 @@ def __init__( """ super().__init__(num_bits=None, axis=axis, unsigned=None) self._initial_amax = amax - self._num_steps = num_steps + self._step_size = step_size self._start_multiplier = start_multiplier self._stop_multiplier = stop_multiplier + self._num_steps = math.ceil((stop_multiplier - start_multiplier) / step_size) + 1 + self._quant_func = quant_func self._error_func = error_func - self._losses_sum = [None] * num_steps - self._candidate_amaxs = [None] * num_steps + self._losses_sum = [None] * self._num_steps + self._candidate_amaxs = [None] * self._num_steps self._amax = None @@ -79,19 +83,9 @@ def collect(self, x: torch.Tensor): x = x.detach().to(dtype=torch.float32) device = x.device - # Split steps between _start_multiplier to 1.0 and 1.0 to _stop_multiplier - # to ensure balanced exploration on both sides of the original amax (1.0) - steps_first_half = self._num_steps // 2 + 1 # Include 1.0 - steps_second_half = self._num_steps - self._num_steps // 2 # For second range - multipliers = torch.cat( - [ - torch.linspace(self._start_multiplier, 1.0, steps=steps_first_half, device=device), - torch.linspace(1.0, self._stop_multiplier, steps=steps_second_half, device=device)[ - 1: - ], # Skip duplicate 1.0 - ] + multipliers = torch.linspace( + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device ) - # Get reduce axis for per-channel quantization reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index a8261fc06..3c9f937a7 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -387,6 +387,30 @@ "algorithm": "max", } +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "step_size": 0.25, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -991,11 +1015,12 @@ class MseCalibConfig(QuantizeAlgorithmConfig): method: Literal["mse"] = ModeloptField("mse") - num_steps: int | None = ModeloptField( - default=10, - ge=1, - title="Number of amax candidates to try.", - description="Number of amax candidates to search over for MSE minimization.", + step_size: float | None = ModeloptField( + default=0.1, + gt=0.0, + title="Step size for amax search.", + description="Step size between amax candidates. The number of candidates is computed as " + "ceil((stop_multiplier - start_multiplier) / step_size) + 1.", ) start_multiplier: float | None = ModeloptField( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c1d787956..09ee09925 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -32,7 +32,6 @@ from .calib import MseCalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import QuantModule, SequentialQuantizer, TensorQuantizer -from .tensor_quant import scaled_e4m3_impl from .utils import ( disable_calib, enable_fake_quant, @@ -42,7 +41,6 @@ is_quantized_linear, is_quantized_row_parallel_linear, quantizer_attr_names, - reduce_amax, weight_attr_names, ) @@ -192,7 +190,7 @@ def mse_calibrate( model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True, - num_steps: int = 10, + step_size: float = 0.1, start_multiplier: float = 0.25, stop_multiplier: float = 4.0, ): @@ -207,7 +205,7 @@ def mse_calibrate( forward_loop: A callable which takes the model as argument and forwards calibration data through the model. distributed_sync: Whether to sync amax across distributed processes. - num_steps: Number of amax candidates to try (default: 10). + step_size: Step size for amax search (default: 0.1). start_multiplier: Starting multiplier for amax search (default: 0.25). stop_multiplier: Ending multiplier for amax search (default: 4.0). @@ -241,17 +239,6 @@ def quant_func(x, amax, quantizer=module): xq = quantizer(x) quantizer._keep_shape = False - # FP8 quantization of NVFP4 static per-block scales - if ( - quantizer.is_static_block_quant - and quantizer._num_bits == (2, 1) - and quantizer._block_sizes.get("scale_bits") == (4, 3) - ): - weight_amax = reduce_amax( - x, axis=None, keepdims=False, squeeze_scalar=True - ) - quantizer._amax = scaled_e4m3_impl(amax / 6.0, weight_amax / 6.0) * 6.0 - if original_amax is not None: quantizer._amax = original_amax else: @@ -263,7 +250,7 @@ def quant_func(x, amax, quantizer=module): module._calibrator = MseCalibrator( amax=initial_amax, axis=module._calibrator._axis, - num_steps=num_steps, + step_size=step_size, start_multiplier=start_multiplier, stop_multiplier=stop_multiplier, quant_func=quant_func, diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index ec6d44225..3cdc6a220 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -659,7 +659,13 @@ def _fake_quantize(self, inputs): static_blockwise_fp4_fake_quant, ) - outputs = static_blockwise_fp4_fake_quant(inputs, amax / 6.0, out_dtype=inputs.dtype) + outputs = static_blockwise_fp4_fake_quant( + inputs, + amax / 6.0, + scale_quant_amax=None, + skip_scale_quant=False, + out_dtype=inputs.dtype, + ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 E, M = self._num_bits # noqa: N806 diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 8977341f0..6d83f2f48 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -407,13 +407,18 @@ def static_blockwise_fp4_fake_quant_kernel( def static_blockwise_fp4_fake_quant( x: torch.Tensor, scale: torch.Tensor, + scale_quant_amax: torch.Tensor | None = None, + skip_scale_quant: bool = False, out_dtype: torch.dtype | None = None, ): """Static blockwise FP4 fake quantization using Triton kernel. - x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. - scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. - out_dtype: Output dtype. Defaults to x.dtype if None. + Args: + x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. + scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. + scale_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale. + skip_scale_quant: If True, skip FP8 quantization of scale. + out_dtype: Output dtype. Defaults to x.dtype if None. """ assert x.ndim == 2 NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape @@ -421,6 +426,15 @@ def static_blockwise_fp4_fake_quant( if out_dtype is None: out_dtype = x.dtype + if not skip_scale_quant: + from modelopt.torch.quantization.tensor_quant import scaled_e4m3_impl + from modelopt.torch.quantization.utils import reduce_amax + + if scale_quant_amax is None: + scale_quant_amax = reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) + + scale = scaled_e4m3_impl(scale, scale_quant_amax) + x_flat = x.contiguous().view(-1) y_flat = torch.empty_like(x_flat, dtype=out_dtype) scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index af57e7bb5..cc66c93f1 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -227,7 +227,8 @@ def _test_fp4_kernel(test_in, test_out, skip_triton=False): "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True ) @pytest.mark.parametrize("block_size", [8, 16, 32]) - def test_static_blockwise_fp4(self, set_torch_dtype, block_size): + @pytest.mark.parametrize("skip_scale_quant", [True, False]) + def test_static_blockwise_fp4(self, set_torch_dtype, block_size, skip_scale_quant): # Test with e2m1 table values sign = torch.randint(0, 2, (1, 8)).cuda() * 2 - 1 @@ -241,25 +242,34 @@ def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): num_blocks = inputs.shape[0] scales = torch.full((num_blocks,), scale_value, device=inputs.device) - quantized_outputs_triton = triton_kernel.static_blockwise_fp4_fake_quant(inputs, scales) - assert torch.allclose(quantized_outputs_triton, expected_outputs) + quantized_outputs_triton = triton_kernel.static_blockwise_fp4_fake_quant( + inputs, scales, skip_scale_quant=skip_scale_quant + ) + + # Only check exact values when skip_scale_quant=True + # When scale quantization is enabled, the scale changes slightly, affecting outputs + if skip_scale_quant: + assert torch.allclose(quantized_outputs_triton, expected_outputs, atol=1e-6) + else: + assert quantized_outputs_triton.shape == expected_outputs.shape test_in = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign test_out = torch.tensor([[0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign _test_static_fp4_kernel(test_in, test_out) - # Test slightly below the e2m1 boundary values. - # Numbers should be quantized down to the corresponding e2m1 value. - test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() - test_in[:, :-1] -= 0.1 - test_in *= sign - test_out = torch.tensor([[0.0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign - _test_static_fp4_kernel(test_in, test_out) - - # Test slightly above the e2m1 boundary values. - # Numbers should be quantized up to the corresponding e2m1 value. - test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() - test_in[:, :-1] += 0.1 - test_in *= sign - test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign - _test_static_fp4_kernel(test_in, test_out) + if skip_scale_quant: + # Test slightly below the e2m1 boundary values. + # Numbers should be quantized down to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] -= 0.1 + test_in *= sign + test_out = torch.tensor([[0.0, 0.5, 1, 1.5, 2, 3, 4, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) + + # Test slightly above the e2m1 boundary values. + # Numbers should be quantized up to the corresponding e2m1 value. + test_in = torch.tensor([[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5, 6]]).cuda() + test_in[:, :-1] += 0.1 + test_in *= sign + test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign + _test_static_fp4_kernel(test_in, test_out) From 918e8e4aafcf57ce34ffc1226bf77bbdd45b4666 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Sat, 10 Jan 2026 00:11:20 +0000 Subject: [PATCH 08/10] minor: fix calibrator test Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 2 +- .../torch/quantization/test_mse_calibrator.py | 20 +++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 09ee09925..5b522920c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -264,7 +264,7 @@ def quant_func(x, amax, quantizer=module): weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer weight_quantizer = getattr(parent_module, weight_quantizer_name, None) if isinstance(weight_quantizer, TensorQuantizer) and not weight_quantizer._disabled: - if weight_quantizer._calibrator is not None: + if getattr(weight_quantizer, "_calibrator", None) is not None: weight_quantizers.append((parent_module, weight_name, weight_quantizer)) seen_modules.add(parent_module) diff --git a/tests/unit/torch/quantization/test_mse_calibrator.py b/tests/unit/torch/quantization/test_mse_calibrator.py index 26e7d52da..5e5546512 100644 --- a/tests/unit/torch/quantization/test_mse_calibrator.py +++ b/tests/unit/torch/quantization/test_mse_calibrator.py @@ -68,7 +68,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=20, + step_size=0.075, start_multiplier=0.1, stop_multiplier=1.5, quant_func=quant_func, @@ -115,7 +115,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=25, + step_size=0.045, start_multiplier=0.1, stop_multiplier=1.2, quant_func=quant_func, @@ -162,7 +162,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=50, + step_size=0.008, start_multiplier=0.8, stop_multiplier=1.2, quant_func=quant_func, @@ -214,7 +214,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=20, + step_size=0.075, start_multiplier=0.1, stop_multiplier=1.5, quant_func=quant_func, @@ -265,7 +265,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, - num_steps=15, + step_size=0.07, start_multiplier=0.5, stop_multiplier=1.5, quant_func=quant_func, @@ -307,7 +307,7 @@ def quant_func(x, amax): tq._if_calib = was_calib_enabled return xq - cal = calib.MseCalibrator(amax=initial_amax, num_steps=10, quant_func=quant_func) + cal = calib.MseCalibrator(amax=initial_amax, step_size=0.4, quant_func=quant_func) cal.collect(x) @@ -352,7 +352,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=10, + step_size=0.15, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, @@ -398,7 +398,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=15, + step_size=0.1, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, @@ -458,7 +458,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=20, + step_size=0.05, start_multiplier=0.5, stop_multiplier=1.5, quant_func=quant_func, @@ -511,7 +511,7 @@ def quant_func(x, amax): cal = calib.MseCalibrator( amax=initial_amax, axis=0, - num_steps=10, + step_size=0.15, start_multiplier=0.5, stop_multiplier=2.0, quant_func=quant_func, From 6ef3dcc88378cd49799b299d8c057fe8c0f8e6ea Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Tue, 13 Jan 2026 05:45:09 +0000 Subject: [PATCH 09/10] wrap FP4 kernel with torch autograd; move shape convert logic to mse_calibrate Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 10 ++++--- .../nn/modules/tensor_quantizer.py | 26 +++++++++------- modelopt/torch/quantization/tensor_quant.py | 30 +++++++++++++++++++ .../torch/quantization/triton/fp4_kernel.py | 10 +++---- 4 files changed, 57 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5b522920c..69a6776b1 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -235,9 +235,11 @@ def quant_func(x, amax, quantizer=module): disable_calib(quantizer), enable_fake_quant(quantizer), ): - quantizer._keep_shape = True + if hasattr(quantizer, "_original_shape"): + x = quantizer._reset_to_original_shape(x) xq = quantizer(x) - quantizer._keep_shape = False + if hasattr(quantizer, "_block_reshape_size"): + xq = xq.reshape(quantizer._block_reshape_size) if original_amax is not None: quantizer._amax = original_amax @@ -263,7 +265,7 @@ def quant_func(x, amax, quantizer=module): for weight_name in weight_attr_names(parent_module): weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer weight_quantizer = getattr(parent_module, weight_quantizer_name, None) - if isinstance(weight_quantizer, TensorQuantizer) and not weight_quantizer._disabled: + if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: if getattr(weight_quantizer, "_calibrator", None) is not None: weight_quantizers.append((parent_module, weight_name, weight_quantizer)) seen_modules.add(parent_module) @@ -303,7 +305,7 @@ def quant_func(x, amax, quantizer=module): # Step 8: Free GPU memory by clearing calibrator data for name, module in model.named_modules(): if isinstance(module, TensorQuantizer) and not module._disabled: - if hasattr(module, "_calibrator") and module._calibrator is not None: + if hasattr(module, "_calibrator") and getattr(module, "_calibrator", None) is not None: if hasattr(module._calibrator, "clear"): module._calibrator.clear() diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3cdc6a220..6d0debe41 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -52,7 +52,12 @@ NVFP4QTensor, QTensorWrapper, ) -from ...tensor_quant import dynamic_block_quant, fake_tensor_quant, scaled_e4m3 +from ...tensor_quant import ( + dynamic_block_quant, + fake_tensor_quant, + scaled_e4m3, + static_blockwise_fp4_fake_quant, +) from ...utils import is_torch_export_mode from ..functional import normalized_hadamard_transform @@ -128,7 +133,6 @@ def __init__( self._enable_pre_quant_scale = True self._dequantize = False self._input_dtype = None - self._keep_shape = False # Lazy initialize the bias calibrator for KV cache quantization self._bias_calibrator = None @@ -655,16 +659,13 @@ def _fake_quantize(self, inputs): self._pass_through_bwd, ) elif self._num_bits == (2, 1) and self.is_static_block_quant: - from modelopt.torch.quantization.triton.fp4_kernel import ( - static_blockwise_fp4_fake_quant, - ) - outputs = static_blockwise_fp4_fake_quant( inputs, amax / 6.0, - scale_quant_amax=None, - skip_scale_quant=False, - out_dtype=inputs.dtype, + None, # scale_fp8_quant_amax + False, # skip_scale_quant + inputs.dtype, # out_dtype + self._pass_through_bwd, # pass_through_bwd ) elif isinstance(self._num_bits, tuple): # Float-point quantization, e.g., FP8 @@ -796,6 +797,11 @@ def _process_for_blockquant(self, inputs: torch.Tensor): if hasattr(self, "_padding"): inputs = F.pad(inputs, self._padding, "constant", 0) + if inputs.shape != self._original_shape: + raise ValueError( + f"Input shape has changed from {self._original_shape} to {inputs.shape}." + " Block-quantization requires a fixed input shape." + ) inputs = inputs.reshape(self._block_reshape_size) return inputs @@ -949,7 +955,7 @@ def forward(self, inputs): "This case should have been handled." ) - if self.is_static_block_quant and not self._keep_shape: + if self.is_static_block_quant: outputs = self._reset_to_original_shape(outputs) return outputs diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 5f69e3999..0c6f52470 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -571,6 +571,35 @@ def backward(ctx, grad_outputs): return _fake_quant_backward_function(ctx, grad_outputs, num_args=9) +class StaticBlockwiseFP4FakeQuantFunction(Function): + """Static blockwise FP4 fake quantization functional.""" + + @staticmethod + def forward( + ctx, + x, + scale, + scale_fp8_quant_amax, + skip_scale_quant, + out_dtype, + pass_through_bwd=False, + ): + """Forward method.""" + _save_for_backward_if_needed(ctx, pass_through_bwd, x, scale) + return triton_kernel.static_blockwise_fp4_fake_quant( + x, + scale, + scale_fp8_quant_amax, + skip_scale_quant, + out_dtype, + ) + + @staticmethod + def backward(ctx, grad_outputs): + """Implements straight through estimation with clipping.""" + return _fake_quant_backward_function(ctx, grad_outputs, num_args=6) + + def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): """Shared function body between TensorQuantFunction and FakeTensorQuantFunction.""" # Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning. @@ -615,3 +644,4 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): fake_tensor_quant = FakeTensorQuantFunction.apply scaled_e4m3 = ScaledE4M3Function.apply dynamic_block_quant = DynamicBlockQuantizationFunction.apply +static_blockwise_fp4_fake_quant = StaticBlockwiseFP4FakeQuantFunction.apply diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 6d83f2f48..55ceedf3b 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -407,7 +407,7 @@ def static_blockwise_fp4_fake_quant_kernel( def static_blockwise_fp4_fake_quant( x: torch.Tensor, scale: torch.Tensor, - scale_quant_amax: torch.Tensor | None = None, + scale_fp8_quant_amax: torch.Tensor | None = None, skip_scale_quant: bool = False, out_dtype: torch.dtype | None = None, ): @@ -416,7 +416,7 @@ def static_blockwise_fp4_fake_quant( Args: x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. - scale_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale. + scale_fp8_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale. skip_scale_quant: If True, skip FP8 quantization of scale. out_dtype: Output dtype. Defaults to x.dtype if None. """ @@ -430,10 +430,10 @@ def static_blockwise_fp4_fake_quant( from modelopt.torch.quantization.tensor_quant import scaled_e4m3_impl from modelopt.torch.quantization.utils import reduce_amax - if scale_quant_amax is None: - scale_quant_amax = reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) + if scale_fp8_quant_amax is None: + scale_fp8_quant_amax = reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) - scale = scaled_e4m3_impl(scale, scale_quant_amax) + scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax) x_flat = x.contiguous().view(-1) y_flat = torch.empty_like(x_flat, dtype=out_dtype) From 76bd9b9b5875b2a3c4f9dd0358c7b7de7b38f430 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Tue, 13 Jan 2026 20:05:51 +0000 Subject: [PATCH 10/10] add unit test to compare between static and dynamic fp4 kernels Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- .../torch/quantization/triton/fp4_kernel.py | 4 +- .../quantization/test_tensor_quant_cuda.py | 48 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/quantization/triton/fp4_kernel.py index 55ceedf3b..6d97d984c 100644 --- a/modelopt/torch/quantization/triton/fp4_kernel.py +++ b/modelopt/torch/quantization/triton/fp4_kernel.py @@ -431,7 +431,9 @@ def static_blockwise_fp4_fake_quant( from modelopt.torch.quantization.utils import reduce_amax if scale_fp8_quant_amax is None: - scale_fp8_quant_amax = reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) + scale_fp8_quant_amax = reduce_amax( + scale, axis=None, keepdims=False, squeeze_scalar=True + ) scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax) diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index cc66c93f1..98b498bb6 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -273,3 +273,51 @@ def _test_static_fp4_kernel(test_in, test_out, scale_value=1.0): test_in *= sign test_out = torch.tensor([[0.5, 1, 1.5, 2, 3, 4, 6, 6]]).cuda() * sign _test_static_fp4_kernel(test_in, test_out) + + @pytest.mark.skipif(not triton_kernel.IS_AVAILABLE, reason="triton kernel is not available") + @pytest.mark.parametrize( + "set_torch_dtype", [torch.float, torch.float16, torch.bfloat16], indirect=True + ) + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("num_blocks", [4, 8, 16]) + @pytest.mark.parametrize("use_explicit_amax", [False, True]) + def test_static_vs_dynamic_fp4_kernels( + self, set_torch_dtype, block_size, num_blocks, use_explicit_amax + ): + """Test that static kernel with computed scales matches dynamic kernel behavior. + + The dynamic kernel computes scales dynamically from block-wise max values with FP8 quantization. + This test verifies that the static kernel with pre-computed scales (matching dynamic kernel's logic) + produces the same results as the dynamic kernel. + + """ + torch.manual_seed(42) + + x = torch.randn(num_blocks, block_size, dtype=torch.float32).cuda() * 10 + block_amax = x.abs().max(dim=1, keepdim=False)[0] + global_amax = block_amax.max() + scales = block_amax / 6.0 + + if use_explicit_amax: + scale_fp8_quant_amax = global_amax / 6.0 + else: + scale_fp8_quant_amax = None + + output_static = triton_kernel.static_blockwise_fp4_fake_quant( + x, scales, scale_fp8_quant_amax=scale_fp8_quant_amax, skip_scale_quant=False + ) + output_dynamic = triton_kernel.fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + amax_mode = "explicit" if use_explicit_amax else "automatic" + assert torch.allclose(output_static, output_dynamic, rtol=1e-3, atol=1e-5), ( + f"Static and dynamic kernels produced different outputs (scale_fp8_quant_amax={amax_mode}).\n" + f"Max abs diff: {(output_static - output_dynamic).abs().max()}\n" + f"Mean abs diff: {(output_static - output_dynamic).abs().mean()}\n" + f"Max relative diff: {((output_static - output_dynamic).abs() / (output_dynamic.abs() + 1e-8)).max()}" + )