-
Notifications
You must be signed in to change notification settings - Fork 240
add FP8 sweep option for static NVFP4 MSE #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -39,20 +39,24 @@ def __init__( | |||||||
| stop_multiplier: float = 4.0, | ||||||||
| quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, | ||||||||
| error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, | ||||||||
| fp8_scale_sweep: bool = False, | ||||||||
| ): | ||||||||
| """Initialize MSE calibrator. | ||||||||
|
|
||||||||
| Args: | ||||||||
| amax: Initial amax value (required). | ||||||||
| axis: Quantization axis. None means per-tensor quantization. | ||||||||
| step_size: Step size for amax search. The number of steps is computed as | ||||||||
| ceil((stop_multiplier - start_multiplier) / step_size) + 1. | ||||||||
| ceil((stop_multiplier - start_multiplier) / step_size) + 1. | ||||||||
| start_multiplier: Starting multiplier for amax search. | ||||||||
| stop_multiplier: Ending multiplier for amax search. | ||||||||
| quant_func: Function that quantizes input tensor given an amax value. | ||||||||
| Should have signature: quant_func(x, amax) -> quantized_x. | ||||||||
| Should have signature: quant_func(x, amax) -> quantized_x. | ||||||||
| error_func: Function to compute error between x and xq. | ||||||||
| Default is F.mse_loss(x, xq, reduction='none'). | ||||||||
| Default is F.mse_loss(x, xq, reduction='none'). | ||||||||
| fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values | ||||||||
| instead of using multipliers. This is specifically for NVFP4 | ||||||||
| per-block quantization where scales are stored in FP8 format. | ||||||||
| """ | ||||||||
| super().__init__(num_bits=None, axis=axis, unsigned=None) | ||||||||
| self._initial_amax = amax | ||||||||
|
|
@@ -65,6 +69,13 @@ def __init__( | |||||||
| self._error_func = error_func | ||||||||
| self._losses_sum = [None] * self._num_steps | ||||||||
| self._candidate_amaxs = [None] * self._num_steps | ||||||||
| self._fp8_scale_sweep = fp8_scale_sweep | ||||||||
| if fp8_scale_sweep: | ||||||||
| # For FP8 scale sweep, we always have exactly 126 valid FP8 E4M3 values | ||||||||
| # (128 total - 2 invalid: byte 0 = zero, byte 127 = NaN) | ||||||||
| self._num_steps = 126 | ||||||||
| self._losses_sum = [None] * self._num_steps | ||||||||
| self._candidate_amaxs = [None] * self._num_steps | ||||||||
|
|
||||||||
| self._amax = None | ||||||||
|
|
||||||||
|
|
@@ -83,14 +94,33 @@ def collect(self, x: torch.Tensor): | |||||||
| x = x.detach().to(dtype=torch.float32) | ||||||||
|
|
||||||||
| device = x.device | ||||||||
| multipliers = torch.linspace( | ||||||||
| self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device | ||||||||
| ) | ||||||||
|
|
||||||||
| if self._fp8_scale_sweep: | ||||||||
| global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) | ||||||||
| global_amax_expanded = global_amax * torch.ones_like(self._initial_amax) | ||||||||
|
|
||||||||
| # Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn) | ||||||||
| # Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32 | ||||||||
| uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) | ||||||||
| fp8_values = uint8_values.view(torch.float8_e4m3fn).float() | ||||||||
|
|
||||||||
| # Filter out invalid values (NaN, inf, and zero) which aren't useful as multipliers | ||||||||
| valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) | ||||||||
| fp8_values_valid = fp8_values[valid_mask] | ||||||||
|
|
||||||||
| candidates = fp8_values_valid / 448.0 | ||||||||
| else: | ||||||||
| candidates = torch.linspace( | ||||||||
| self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device | ||||||||
| ) | ||||||||
| # Get reduce axis for per-channel quantization | ||||||||
| reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) | ||||||||
|
|
||||||||
| for step, multiplier in enumerate(multipliers): | ||||||||
| candidate_amax = self._initial_amax * multiplier | ||||||||
| for step, candidate in enumerate(candidates): | ||||||||
| if self._fp8_scale_sweep: | ||||||||
Fridah-nv marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| candidate_amax = global_amax_expanded * candidate | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is the /6.0 ?
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am understanding that this is handled somewhere else. is that correct?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. found it -
should we support static_blockwise_fp4_fake_quant to provide either the scale or the amax? then we dont need to handle the amax/6.0 detail in tensorquantizer |
||||||||
| else: | ||||||||
| candidate_amax = self._initial_amax * candidate | ||||||||
| xq = self._quant_func(x, candidate_amax) | ||||||||
|
|
||||||||
| if self._error_func is not None: | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -387,30 +387,6 @@ | |||||
| "algorithm": "max", | ||||||
| } | ||||||
|
|
||||||
| NVFP4_WEIGHT_ACT_MSE_CFG = { | ||||||
| "quant_cfg": { | ||||||
| "*weight_quantizer": { | ||||||
| "num_bits": (2, 1), | ||||||
| "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, | ||||||
| "axis": None, | ||||||
| "enable": True, | ||||||
| }, | ||||||
| "*input_quantizer": { | ||||||
| "num_bits": (2, 1), | ||||||
| "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, | ||||||
| "axis": None, | ||||||
| "enable": True, | ||||||
| }, | ||||||
| **_default_disabled_quantizer_cfg, | ||||||
| }, | ||||||
| "algorithm": { | ||||||
| "method": "mse", | ||||||
| "step_size": 0.25, | ||||||
| "start_multiplier": 0.25, | ||||||
| "stop_multiplier": 2.0, | ||||||
| }, | ||||||
| } | ||||||
|
|
||||||
| NVFP4_AWQ_LITE_CFG = { | ||||||
| "quant_cfg": { | ||||||
| "*weight_quantizer": { | ||||||
|
|
@@ -1040,6 +1016,8 @@ class MseCalibConfig(QuantizeAlgorithmConfig): | |||||
| reconstruction error of a tensor after uniform Q→DQ: | ||||||
|
|
||||||
| s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations} | ||||||
|
|
||||||
| When fp8_scale_sweep is enabled, step_size is ignored. | ||||||
| """ | ||||||
|
|
||||||
| method: Literal["mse"] = ModeloptField("mse") | ||||||
|
|
@@ -1066,6 +1044,14 @@ class MseCalibConfig(QuantizeAlgorithmConfig): | |||||
| description="Ending multiplier for amax search range (multiplies initial amax).", | ||||||
| ) | ||||||
|
|
||||||
| fp8_scale_sweep: bool | None = ModeloptField( | ||||||
| default=False, | ||||||
| title="Enable FP8 scale sweep for NVFP4 per-block quantization.", | ||||||
| description="If True, sweep all 128 FP8 E4M3 scale values instead of using multipliers. " | ||||||
| "Only applies to NVFP4 weight quantization. When enabled, num_steps, step_size, " | ||||||
| "start_multiplier, and stop_multiplier are ignored.", | ||||||
| ) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| distributed_sync: bool | None = ModeloptField( | ||||||
| default=True, | ||||||
| title="Whether to sync the amax across the distributed processes.", | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
this should not be needed (since global_amax is a scalar)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is needed,
candidatein line 121 is also a scalar, makingcandidate_amax = global_amax_expanded * candidatea scalar, but we needcandidate_amaxto have same shape asinitial_amax