Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,24 @@ def __init__(
stop_multiplier: float = 4.0,
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
fp8_scale_sweep: bool = False,
):
"""Initialize MSE calibrator.

Args:
amax: Initial amax value (required).
axis: Quantization axis. None means per-tensor quantization.
step_size: Step size for amax search. The number of steps is computed as
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
ceil((stop_multiplier - start_multiplier) / step_size) + 1.
start_multiplier: Starting multiplier for amax search.
stop_multiplier: Ending multiplier for amax search.
quant_func: Function that quantizes input tensor given an amax value.
Should have signature: quant_func(x, amax) -> quantized_x.
Should have signature: quant_func(x, amax) -> quantized_x.
error_func: Function to compute error between x and xq.
Default is F.mse_loss(x, xq, reduction='none').
Default is F.mse_loss(x, xq, reduction='none').
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
instead of using multipliers. This is specifically for NVFP4
per-block quantization where scales are stored in FP8 format.
"""
super().__init__(num_bits=None, axis=axis, unsigned=None)
self._initial_amax = amax
Expand All @@ -65,6 +69,13 @@ def __init__(
self._error_func = error_func
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps
self._fp8_scale_sweep = fp8_scale_sweep
if fp8_scale_sweep:
# For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values
# (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN)
self._num_steps = 126
self._losses_sum = [None] * self._num_steps
self._candidate_amaxs = [None] * self._num_steps

self._amax = None

Expand All @@ -83,14 +94,33 @@ def collect(self, x: torch.Tensor):
x = x.detach().to(dtype=torch.float32)

device = x.device
multipliers = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)

if self._fp8_scale_sweep:
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
global_amax_expanded = global_amax * torch.ones_like(self._initial_amax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
this should not be needed (since global_amax is a scalar)

Suggested change
global_amax_expanded = global_amax * torch.ones_like(self._initial_amax)

Copy link
Contributor Author

@Fridah-nv Fridah-nv Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is needed, candidate in line 121 is also a scalar, making candidate_amax = global_amax_expanded * candidate a scalar, but we need candidate_amax to have same shape as initial_amax


# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()

# Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values_valid = fp8_values[valid_mask]

candidates = fp8_values_valid / 448.0
else:
candidates = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)
# Get reduce axis for per-channel quantization
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)

for step, multiplier in enumerate(multipliers):
candidate_amax = self._initial_amax * multiplier
for step, candidate in enumerate(candidates):
if self._fp8_scale_sweep:
candidate_amax = global_amax_expanded * candidate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the /6.0 ?
Should not this be

Suggested change
candidate_amax = global_amax_expanded * candidate
candidate_amax = (global_amax/6.0) * candidate_by_448

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am understanding that this is handled somewhere else. is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found it -

outputs = static_blockwise_fp4_fake_quant(

should we support static_blockwise_fp4_fake_quant to provide either the scale or the amax?

then we dont need to handle the amax/6.0 detail in tensorquantizer

else:
candidate_amax = self._initial_amax * candidate
xq = self._quant_func(x, candidate_amax)

if self._error_func is not None:
Expand Down
34 changes: 10 additions & 24 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,30 +387,6 @@
"algorithm": "max",
}

NVFP4_WEIGHT_ACT_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"step_size": 0.25,
"start_multiplier": 0.25,
"stop_multiplier": 2.0,
},
}

NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down Expand Up @@ -1040,6 +1016,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
reconstruction error of a tensor after uniform Q→DQ:

s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}

When fp8_scale_sweep is enabled, step_size is ignored.
"""

method: Literal["mse"] = ModeloptField("mse")
Expand All @@ -1066,6 +1044,14 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
description="Ending multiplier for amax search range (multiplies initial amax).",
)

fp8_scale_sweep: bool | None = ModeloptField(
default=False,
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
description="If True, sweep all 128 FP8 E4M3 scale values instead of using multipliers. "
"Only applies to NVFP4 weight quantization. When enabled, num_steps, step_size, "
"start_multiplier, and stop_multiplier are ignored.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)
This is ignore for all quantizations except NVFP4 weight quantization.)


distributed_sync: bool | None = ModeloptField(
default=True,
title="Whether to sync the amax across the distributed processes.",
Expand Down
60 changes: 38 additions & 22 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ def sync_quantizer_amax_across_tp(
)


def _mse_quant_func(x, amax, quantizer):
"""Quantization function for MSE calibration."""
original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
quantizer._amax = amax

with (
enable_quant(quantizer),
disable_calib(quantizer),
enable_fake_quant(quantizer),
):
if hasattr(quantizer, "_original_shape"):
x = quantizer._reset_to_original_shape(x)
xq = quantizer(x)
if hasattr(quantizer, "_block_reshape_size"):
xq = xq.reshape(quantizer._block_reshape_size)

if original_amax is not None:
quantizer._amax = original_amax
else:
delattr(quantizer, "_amax")

return xq


@torch.no_grad()
def mse_calibrate(
model: nn.Module,
Expand All @@ -205,6 +229,7 @@ def mse_calibrate(
step_size: float = 0.1,
start_multiplier: float = 0.25,
stop_multiplier: float = 4.0,
fp8_scale_sweep: bool = False,
):
"""Calibrate the model using MSE-based amax search.

Expand All @@ -220,6 +245,10 @@ def mse_calibrate(
step_size: Step size for amax search (default: 0.1).
start_multiplier: Starting multiplier for amax search (default: 0.25).
stop_multiplier: Ending multiplier for amax search (default: 4.0).
fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values
for NVFP4 per-block quantization instead of using multipliers.
This is specifically designed for optimizing the FP8-quantized
per-block scales in NVFP4 format (default: False).

See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
Expand All @@ -238,27 +267,13 @@ def mse_calibrate(
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()

def quant_func(x, amax, quantizer=module):
original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None
quantizer._amax = amax

with (
enable_quant(quantizer),
disable_calib(quantizer),
enable_fake_quant(quantizer),
):
if hasattr(quantizer, "_original_shape"):
x = quantizer._reset_to_original_shape(x)
xq = quantizer(x)
if hasattr(quantizer, "_block_reshape_size"):
xq = xq.reshape(quantizer._block_reshape_size)

if original_amax is not None:
quantizer._amax = original_amax
else:
delattr(quantizer, "_amax")

return xq
is_nvfp4_static = (
fp8_scale_sweep
and module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)

# Create MSE calibrator with quant_func
module._calibrator = MseCalibrator(
Expand All @@ -267,7 +282,8 @@ def quant_func(x, amax, quantizer=module):
step_size=step_size,
start_multiplier=start_multiplier,
stop_multiplier=stop_multiplier,
quant_func=quant_func,
quant_func=partial(_mse_quant_func, quantizer=module),
fp8_scale_sweep=is_nvfp4_static,
)

# Identify weight quantizers by checking if they have corresponding weight parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,12 @@ def _fake_quantize(self, inputs):
elif self._num_bits == (2, 1) and self.is_static_block_quant:
outputs = static_blockwise_fp4_fake_quant(
inputs,
amax / 6.0,
None, # scale
None, # scale_fp8_quant_amax
False, # skip_scale_quant
inputs.dtype, # out_dtype
self._pass_through_bwd, # pass_through_bwd
amax, # amax
)
elif isinstance(self._num_bits, tuple):
# Float-point quantization, e.g., FP8
Expand Down
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,21 +574,23 @@ def forward(
skip_scale_quant,
out_dtype,
pass_through_bwd=False,
amax=None,
):
"""Forward method."""
_save_for_backward_if_needed(ctx, pass_through_bwd, x, scale)
_save_for_backward_if_needed(ctx, pass_through_bwd, x, scale if scale is not None else amax)
return triton_kernel.static_blockwise_fp4_fake_quant(
x,
scale,
scale_fp8_quant_amax,
skip_scale_quant,
out_dtype,
amax,
)

@staticmethod
def backward(ctx, grad_outputs):
"""Implements straight through estimation with clipping."""
return _fake_quant_backward_function(ctx, grad_outputs, num_args=6)
return _fake_quant_backward_function(ctx, grad_outputs, num_args=7)


def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
Expand Down
16 changes: 14 additions & 2 deletions modelopt/torch/quantization/triton/fp4_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,20 +406,32 @@ def static_blockwise_fp4_fake_quant_kernel(

def static_blockwise_fp4_fake_quant(
x: torch.Tensor,
scale: torch.Tensor,
scale: torch.Tensor | None = None,
scale_fp8_quant_amax: torch.Tensor | None = None,
skip_scale_quant: bool = False,
out_dtype: torch.dtype | None = None,
amax: torch.Tensor | None = None,
):
"""Static blockwise FP4 fake quantization using Triton kernel.

Args:
x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA.
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA.
scale: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. Mutually exclusive with amax.
scale_fp8_quant_amax: Absolute max range for FP8 quantization of scale. If None, computed from scale.
skip_scale_quant: If True, skip FP8 quantization of scale.
out_dtype: Output dtype. Defaults to x.dtype if None.
amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] on CUDA. If provided, scale = amax / 6.0.
Mutually exclusive with scale.
"""
if scale is None and amax is None:
raise ValueError("Either scale or amax must be provided")
if scale is not None and amax is not None:
raise ValueError("Cannot provide both scale and amax")

if amax is not None:
scale = amax / 6.0 # FP4 max representable value is 6.0

assert scale is not None # Guaranteed by validation above
assert x.ndim == 2
NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape

Expand Down
27 changes: 26 additions & 1 deletion tests/gpu/torch/quantization/test_quantize_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,30 @@
"enable": True,
},
},
"algorithm": "mse",
"algorithm": {
"method": "mse",
"step_size": 0.25,
"start_multiplier": 0.25,
"stop_multiplier": 2.0,
},
}

NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
},
"algorithm": {
"method": "mse",
"fp8_scale_sweep": True,
},
}


Expand Down Expand Up @@ -71,6 +94,7 @@
mtq.NVFP4_KV_ROTATE_CFG,
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
NVFP4_WEIGHT_ACT_MSE_CFG,
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
],
)
def test_quantize(model_cls, config):
Expand All @@ -88,6 +112,7 @@ def test_quantize(model_cls, config):
mtq.NVFP4_KV_ROTATE_CFG,
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
NVFP4_WEIGHT_ACT_MSE_CFG,
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
]:
if get_cuda_ext_mx() is None:
pytest.skip("cuda_ext_mx is not available")
Expand Down
Loading
Loading