Skip to content
27 changes: 21 additions & 6 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Calibrator that returns the MSE amax of all collected tensors."""

import math
from collections.abc import Callable

import torch
Expand All @@ -33,7 +34,7 @@ def __init__(
self,
amax: torch.Tensor,
axis: int | tuple | list | None = None,
num_steps: int = 10,
step_size: float = 0.1,
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
Expand All @@ -44,7 +45,8 @@ def __init__(
Args:
amax: Initial amax value (required).
axis: Quantization axis. None means per-tensor quantization.
num_steps: Number of amax candidates to try.
step_size: Step size for amax search. The number of steps is computed as
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
start_multiplier: Starting multiplier for amax search.
stop_multiplier: Ending multiplier for amax search.
quant_func: Function that quantizes input tensor given an amax value.
Expand All @@ -54,13 +56,15 @@ def __init__(
"""
super().__init__(num_bits=None, axis=axis, unsigned=None)
self._initial_amax = amax
self._num_steps = num_steps
self._step_size = step_size
self._start_multiplier = start_multiplier
self._stop_multiplier = stop_multiplier
self._num_steps = math.ceil((stop_multiplier - start_multiplier) / step_size) + 1

self._quant_func = quant_func
self._error_func = error_func
self._losses_sum = [None] * num_steps
self._candidate_amaxs = [None] * num_steps
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps

self._amax = None

Expand All @@ -82,7 +86,6 @@ def collect(self, x: torch.Tensor):
multipliers = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)

# Get reduce axis for per-channel quantization
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)

Expand Down Expand Up @@ -111,6 +114,18 @@ def reset(self):
self._candidate_amaxs = [None] * self._num_steps
self._amax = None

def clear(self):
"""Clear all cached data to free GPU memory.

Call this after compute_amax() and load_calib_amax() are done.
"""
self._losses_sum = []
self._candidate_amaxs = []

if self._initial_amax is not None:
del self._initial_amax
self._initial_amax = None

@torch.no_grad()
def compute_amax(self, verbose: bool = False):
"""Return the amax value that minimizes quantization error.
Expand Down
36 changes: 30 additions & 6 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,29 @@
"algorithm": "max",
}

NVFP4_WEIGHT_ACT_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"step_size": 0.25,
"start_multiplier": 0.25,
"stop_multiplier": 2.0,
},
}

NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
Expand Down Expand Up @@ -734,7 +757,7 @@ def validate_num_bits(self):
raise ValueError(
"Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)."
)
elif num_bits != (4, 3) and (
elif num_bits not in [(4, 3), (2, 1)] and (
block_sizes is None or block_sizes.get("type", None) != "dynamic"
):
raise ValueError(
Expand Down Expand Up @@ -992,11 +1015,12 @@ class MseCalibConfig(QuantizeAlgorithmConfig):

method: Literal["mse"] = ModeloptField("mse")

num_steps: int | None = ModeloptField(
default=10,
ge=1,
title="Number of amax candidates to try.",
description="Number of amax candidates to search over for MSE minimization.",
step_size: float | None = ModeloptField(
default=0.1,
gt=0.0,
title="Step size for amax search.",
description="Step size between amax candidates. The number of candidates is computed as "
"ceil((stop_multiplier - start_multiplier) / step_size) + 1.",
)

start_multiplier: float | None = ModeloptField(
Expand Down
67 changes: 55 additions & 12 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def mse_calibrate(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
num_steps: int = 10,
step_size: float = 0.1,
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
):
Expand All @@ -205,7 +205,7 @@ def mse_calibrate(
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
distributed_sync: Whether to sync amax across distributed processes.
num_steps: Number of amax candidates to try (default: 10).
step_size: Step size for amax search (default: 0.1).
start_multiplier: Starting multiplier for amax search (default: 0.25).
stop_multiplier: Ending multiplier for amax search (default: 4.0).

Expand All @@ -216,14 +216,12 @@ def mse_calibrate(
max_calibrate(model, forward_loop, distributed_sync)

# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
seen_modules = set()

for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
# Static block quantization is not supported by MseCalibrator
if module.is_static_block_quant:
raise ValueError(
f"MSE calibration does not support static block quantization. "
f"Found static block quantization at {name}."
)
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
Expand All @@ -237,7 +235,11 @@ def quant_func(x, amax, quantizer=module):
disable_calib(quantizer),
enable_fake_quant(quantizer),
):
if hasattr(quantizer, "_original_shape"):
x = quantizer._reset_to_original_shape(x)
xq = quantizer(x)
if hasattr(quantizer, "_block_reshape_size"):
xq = xq.reshape(quantizer._block_reshape_size)

if original_amax is not None:
quantizer._amax = original_amax
Expand All @@ -250,22 +252,63 @@ def quant_func(x, amax, quantizer=module):
module._calibrator = MseCalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
num_steps=num_steps,
step_size=step_size,
start_multiplier=start_multiplier,
stop_multiplier=stop_multiplier,
quant_func=quant_func,
)

# Step 3: Collect data with MSE calibrators
# Identify weight quantizers by checking if they have corresponding weight parameters
for name, parent_module in model.named_modules():
if parent_module in seen_modules:
continue
for weight_name in weight_attr_names(parent_module):
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
seen_modules.add(parent_module)

# Step 3: Calibrate weight quantizers once with MSE calibration
# This ensures weights are only calibrated once, not during every forward pass
for parent_module, weight_name, weight_quantizer in weight_quantizers:
# Enable calibration mode for the weight quantizer
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()

with enable_weight_access_and_writeback(parent_module, model):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)

# Step 4: Disable weight quantizers during forward loop
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.disable()

# Step 5: Collect data with MSE calibrators for activation quantizers only
enable_stats_collection(model)
if forward_loop is None:
weight_only_quantize(model)
# If no forward loop, nothing else to do since weights are already calibrated
pass
else:
# Run forward loop - only activation quantizers will collect data
forward_loop(model)

# Step 4: Compute optimal amax and load it
# Step 6: Re-enable weight quantizers before finalizing calibration
# This ensures finish_stats_collection processes them correctly
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.enable()

# Step 7: Compute optimal amax and load it for all quantizers (weights + activations)
finish_stats_collection(model, method="mse")

# Step 8: Free GPU memory by clearing calibrator data
for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
if hasattr(module, "_calibrator") and getattr(module, "_calibrator", None) is not None:
if hasattr(module._calibrator, "clear"):
module._calibrator.clear()

# TODO: Sync amax across distributed processes


Expand Down
16 changes: 15 additions & 1 deletion modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@
NVFP4QTensor,
QTensorWrapper,
)
from ...tensor_quant import dynamic_block_quant, fake_tensor_quant, scaled_e4m3
from ...tensor_quant import (
dynamic_block_quant,
fake_tensor_quant,
scaled_e4m3,
static_blockwise_fp4_fake_quant,
)
from ...utils import is_torch_export_mode
from ..functional import normalized_hadamard_transform

Expand Down Expand Up @@ -653,6 +658,15 @@ def _fake_quantize(self, inputs):
getattr(self, "_onnx_quantizer_type", None),
self._pass_through_bwd,
)
elif self._num_bits == (2, 1) and self.is_static_block_quant:
outputs = static_blockwise_fp4_fake_quant(
inputs,
amax / 6.0,
None, # scale_fp8_quant_amax
False, # skip_scale_quant
inputs.dtype, # out_dtype
self._pass_through_bwd, # pass_through_bwd
)
elif isinstance(self._num_bits, tuple):
# Float-point quantization, e.g., FP8
E, M = self._num_bits # noqa: N806
Expand Down
30 changes: 30 additions & 0 deletions modelopt/torch/quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,35 @@ def backward(ctx, grad_outputs):
return _fake_quant_backward_function(ctx, grad_outputs, num_args=9)


class StaticBlockwiseFP4FakeQuantFunction(Function):
"""Static blockwise FP4 fake quantization functional."""

@staticmethod
def forward(
ctx,
x,
scale,
scale_fp8_quant_amax,
skip_scale_quant,
out_dtype,
pass_through_bwd=False,
):
"""Forward method."""
_save_for_backward_if_needed(ctx, pass_through_bwd, x, scale)
return triton_kernel.static_blockwise_fp4_fake_quant(
x,
scale,
scale_fp8_quant_amax,
skip_scale_quant,
out_dtype,
)

@staticmethod
def backward(ctx, grad_outputs):
"""Implements straight through estimation with clipping."""
return _fake_quant_backward_function(ctx, grad_outputs, num_args=6)


def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
"""Shared function body between TensorQuantFunction and FakeTensorQuantFunction."""
# Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning.
Expand Down Expand Up @@ -615,3 +644,4 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
fake_tensor_quant = FakeTensorQuantFunction.apply
scaled_e4m3 = ScaledE4M3Function.apply
dynamic_block_quant = DynamicBlockQuantizationFunction.apply
static_blockwise_fp4_fake_quant = StaticBlockwiseFP4FakeQuantFunction.apply
Loading