diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index de90c3006e8f..800d8a862ff3 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf from torchao.quantization import Int8WeightOnlyConfig pipeline_quant_config = PipelineQuantizationConfig( - quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))} -) -pipeline = DiffusionPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - quantization_config=pipeline_quant_config, - torch_dtype=torch.bfloat16, - device_map="cuda" -) -``` - -For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below. - -```py -import torch -from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig - -pipeline_quant_config = PipelineQuantizationConfig( - quant_mapping={"transformer": TorchAoConfig("int8wo")} + quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128))} ) pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", @@ -91,17 +74,6 @@ Weight-only quantization stores the model weights in a specific low-bit data typ Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly. -The quantization methods supported are as follows: - -| **Category** | **Full Function Names** | **Shorthands** | -|--------------|-------------------------|----------------| -| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` | -| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` | -| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` | -| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` | - -Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations. - Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available. ## Serializing and Deserializing quantized models @@ -111,8 +83,9 @@ To serialize a quantized model in a given dtype, first load the model with the d ```python import torch from diffusers import AutoModel, TorchAoConfig +from torchao.quantization import Int8WeightOnlyConfig -quantization_config = TorchAoConfig("int8wo") +quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) transformer = AutoModel.from_pretrained( "black-forest-labs/Flux.1-Dev", subfolder="transformer", @@ -137,18 +110,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] image.save("output.png") ``` -If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. +If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. ```python import torch from accelerate import init_empty_weights from diffusers import FluxPipeline, AutoModel, TorchAoConfig +from torchao.quantization import IntxWeightOnlyConfig # Serialize the model transformer = AutoModel.from_pretrained( "black-forest-labs/Flux.1-Dev", subfolder="transformer", - quantization_config=TorchAoConfig("uint4wo"), + quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)), torch_dtype=torch.bfloat16, ) transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB") diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 9a467e6b21ee..6f5e0c007294 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -23,20 +23,17 @@ from __future__ import annotations import copy -import dataclasses import importlib.metadata -import inspect import json import os import warnings -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from enum import Enum -from functools import partial from typing import Any, Callable from packaging import version -from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging +from ..utils import deprecate, is_torch_available, is_torchao_version, logging if is_torch_available(): @@ -53,16 +50,6 @@ class QuantizationMethod(str, Enum): MODELOPT = "modelopt" -if is_torchao_available(): - from torchao.quantization.quant_primitives import MappingType - - class TorchAoJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, MappingType): - return obj.name - return super().default(obj) - - @dataclass class QuantizationConfigMixin: """ @@ -446,49 +433,21 @@ class TorchAoConfig(QuantizationConfigMixin): """This is a config class for torchao quantization/sparsity techniques. Args: - quant_type (`str` | AOBaseConfig): - The type of quantization we want to use, currently supporting: - - **Integer quantization:** - - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, - `int8_weight_only`, `int8_dynamic_activation_int8_weight` - - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq` - - - **Floating point 8-bit quantization:** - - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`, - `float8_static_activation_float8_weight` - - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, - `float8_e4m3_tensor`, `float8_e4m3_row`, - - - **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0) - - Full function names: `fpx_weight_only` - - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number - of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must - be satisfied for a given shorthand notation. - - - **Unsigned Integer quantization:** - - Full function names: `uintx_weight_only` - - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` - - An AOBaseConfig instance: for more advanced configuration options. + quant_type (`AOBaseConfig`): + An `AOBaseConfig` subclass instance specifying the quantization type. See the [torchao + documentation](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) for + available config classes (e.g. `Int4WeightOnlyConfig`, `Int8WeightOnlyConfig`, `Float8WeightOnlyConfig`, + `Float8DynamicActivationFloat8WeightConfig`, etc.). modules_to_not_convert (`list[str]`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. - kwargs (`dict[str, Any]`, *optional*): - The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization - supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and - documentation of arguments can be found in - https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques Example: ```python from diffusers import FluxTransformer2DModel, TorchAoConfig - - # AOBaseConfig-based configuration from torchao.quantization import Int8WeightOnlyConfig quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) - - # String-based config - quantization_config = TorchAoConfig("int8wo") transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/Flux.1-Dev", subfolder="transformer", @@ -500,7 +459,7 @@ class TorchAoConfig(QuantizationConfigMixin): def __init__( self, - quant_type: str | "AOBaseConfig", # noqa: F821 + quant_type: "AOBaseConfig", # noqa: F821 modules_to_not_convert: list[str] | None = None, **kwargs, ) -> None: @@ -508,89 +467,28 @@ def __init__( self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert - # When we load from serialized config, "quant_type_kwargs" will be the key - if "quant_type_kwargs" in kwargs: - self.quant_type_kwargs = kwargs["quant_type_kwargs"] - else: - self.quant_type_kwargs = kwargs - self.post_init() def post_init(self): - if not isinstance(self.quant_type, str): - if is_torchao_version("<=", "0.9.0"): - raise ValueError( - f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. " - f"Upgrade to torchao > 0.9.0 to use AOBaseConfig." - ) - - from torchao.quantization.quant_api import AOBaseConfig + if is_torchao_version("<=", "0.9.0"): + raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.") - if not isinstance(self.quant_type, AOBaseConfig): - raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}") + from torchao.quantization.quant_api import AOBaseConfig - elif isinstance(self.quant_type, str): - TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() - - if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): - is_floatx_quant_type = self.quant_type.startswith("fp") - is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type - if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): - raise ValueError( - f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " - f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." - ) - elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"): - raise ValueError( - f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. " - f"Please downgrade to torchao <= 0.14.1 to use this quantization type." - ) - - raise ValueError( - f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " - f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." - ) - - method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type] - signature = inspect.signature(method) - all_kwargs = { - param.name - for param in signature.parameters.values() - if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] - } - unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) - - if len(unsupported_kwargs) > 0: - raise ValueError( - f'The quantization method "{self.quant_type}" does not support the following keyword arguments: ' - f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." - ) + if not isinstance(self.quant_type, AOBaseConfig): + raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}") def to_dict(self): """Convert configuration to a dictionary.""" d = super().to_dict() - if isinstance(self.quant_type, str): - # Handle layout serialization if present - if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: - if is_dataclass(d["quant_type_kwargs"]["layout"]): - d["quant_type_kwargs"]["layout"] = [ - d["quant_type_kwargs"]["layout"].__class__.__name__, - dataclasses.asdict(d["quant_type_kwargs"]["layout"]), - ] - if isinstance(d["quant_type_kwargs"]["layout"], list): - assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs" - assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string" - assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict" - else: - raise ValueError("layout must be a list") - else: - # Handle AOBaseConfig serialization - from torchao.core.config import config_to_dict + # Handle AOBaseConfig serialization + from torchao.core.config import config_to_dict - # For now we assume there is 1 config per Transformer, however in the future - # We may want to support a config per fqn. - d["quant_type"] = {"default": config_to_dict(self.quant_type)} + # For now we assume there is 1 config per Transformer, however in the future + # we may want to support a config per fqn. + # See: https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.quantize_.html + d["quant_type"] = {"default": config_to_dict(self.quant_type)} return d @@ -602,8 +500,6 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): config_dict = config_dict.copy() quant_type = config_dict.pop("quant_type") - if isinstance(quant_type, str): - return cls(quant_type=quant_type, **config_dict) # Check if we only have one key which is "default" # In the future we may update this assert len(quant_type) == 1 and "default" in quant_type, ( @@ -618,210 +514,13 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): return cls(quant_type=quant_type, **config_dict) - @classmethod - def _get_torchao_quant_type_to_method(cls): - r""" - Returns supported torchao quantization types with all commonly used notations. - """ - - if is_torchao_available(): - # TODO(aryan): Support sparsify - from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, - ) - - if is_torchao_version("<=", "0.14.1"): - from torchao.quantization import fpx_weight_only - # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers - from torchao.quantization.observer import PerRow, PerTensor - - def generate_float8dq_types(dtype: torch.dtype): - name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" - types = {} - - for granularity_cls in [PerTensor, PerRow]: - # Note: Activation and Weights cannot have different granularities - granularity_name = "tensor" if granularity_cls is PerTensor else "row" - types[f"float8dq_{name}_{granularity_name}"] = partial( - float8_dynamic_activation_float8_weight, - activation_dtype=dtype, - weight_dtype=dtype, - granularity=(granularity_cls(), granularity_cls()), - ) - - return types - - def generate_fpx_quantization_types(bits: int): - if is_torchao_version("<=", "0.14.1"): - types = {} - - for ebits in range(1, bits): - mbits = bits - ebits - 1 - types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) - - non_sign_bits = bits - 1 - default_ebits = (non_sign_bits + 1) // 2 - default_mbits = non_sign_bits - default_ebits - types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) - - return types - else: - raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0") - - INT4_QUANTIZATION_TYPES = { - # int4 weight + bfloat16/float16 activation - "int4wo": int4_weight_only, - "int4_weight_only": int4_weight_only, - # int4 weight + int8 activation - "int4dq": int8_dynamic_activation_int4_weight, - "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, - } - - INT8_QUANTIZATION_TYPES = { - # int8 weight + bfloat16/float16 activation - "int8wo": int8_weight_only, - "int8_weight_only": int8_weight_only, - # int8 weight + int8 activation - "int8dq": int8_dynamic_activation_int8_weight, - "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, - } - - # TODO(aryan): handle torch 2.2/2.3 - FLOATX_QUANTIZATION_TYPES = { - # float8_e5m2 weight + bfloat16/float16 activation - "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), - "float8_weight_only": float8_weight_only, - "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), - # float8_e4m3 weight + bfloat16/float16 activation - "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), - # float8_e5m2 weight + float8 activation (dynamic) - "float8dq": float8_dynamic_activation_float8_weight, - "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, - # ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out. - # However, changing activation_dtype=torch.float8_e4m3 might work here ===== - # "float8dq_e5m2": partial( - # float8_dynamic_activation_float8_weight, - # activation_dtype=torch.float8_e5m2, - # weight_dtype=torch.float8_e5m2, - # ), - # **generate_float8dq_types(torch.float8_e5m2), - # ===== ===== - # float8_e4m3 weight + float8 activation (dynamic) - "float8dq_e4m3": partial( - float8_dynamic_activation_float8_weight, - activation_dtype=torch.float8_e4m3fn, - weight_dtype=torch.float8_e4m3fn, - ), - **generate_float8dq_types(torch.float8_e4m3fn), - # float8 weight + float8 activation (static) - "float8_static_activation_float8_weight": float8_static_activation_float8_weight, - } - - if is_torchao_version("<=", "0.14.1"): - FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3)) - FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4)) - FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5)) - FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6)) - FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7)) - - UINTX_QUANTIZATION_DTYPES = { - "uintx_weight_only": uintx_weight_only, - "uint1wo": partial(uintx_weight_only, dtype=torch.uint1), - "uint2wo": partial(uintx_weight_only, dtype=torch.uint2), - "uint3wo": partial(uintx_weight_only, dtype=torch.uint3), - "uint4wo": partial(uintx_weight_only, dtype=torch.uint4), - "uint5wo": partial(uintx_weight_only, dtype=torch.uint5), - "uint6wo": partial(uintx_weight_only, dtype=torch.uint6), - "uint7wo": partial(uintx_weight_only, dtype=torch.uint7), - # "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported - } - - QUANTIZATION_TYPES = {} - QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) - QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) - QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) - - if cls._is_xpu_or_cuda_capability_atleast_8_9(): - QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) - - return QUANTIZATION_TYPES - else: - raise ValueError( - "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" - ) - - @staticmethod - def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: - if torch.cuda.is_available(): - major, minor = torch.cuda.get_device_capability() - if major == 8: - return minor >= 9 - return major >= 9 - elif torch.xpu.is_available(): - return True - else: - raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.") - def get_apply_tensor_subclass(self): """Create the appropriate quantization method based on configuration.""" - if not isinstance(self.quant_type, str): - return self.quant_type - else: - methods = self._get_torchao_quant_type_to_method() - quant_type_kwargs = self.quant_type_kwargs.copy() - if ( - not torch.cuda.is_available() - and is_torchao_available() - and self.quant_type == "int4_weight_only" - and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") - and quant_type_kwargs.get("layout", None) is None - ): - if torch.xpu.is_available(): - if version.parse(importlib.metadata.version("torchao")) >= version.parse( - "0.11.0" - ) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"): - from torchao.dtypes import Int4XPULayout - from torchao.quantization.quant_primitives import ZeroPointDomain - - quant_type_kwargs["layout"] = Int4XPULayout() - quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT - else: - raise ValueError( - "TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch." - ) - else: - from torchao.dtypes import Int4CPULayout - - quant_type_kwargs["layout"] = Int4CPULayout() - - return methods[self.quant_type](**quant_type_kwargs) + return self.quant_type def __repr__(self): - r""" - Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`: - - ``` - TorchAoConfig { - "modules_to_not_convert": null, - "quant_method": "torchao", - "quant_type": "uint4wo", - "quant_type_kwargs": { - "group_size": 32 - } - } - ``` - """ config_dict = self.to_dict() - return ( - f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n" - ) + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" @dataclass diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 1679ed26a104..da52716999d8 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -20,7 +20,6 @@ import importlib import re import types -from fnmatch import fnmatch from typing import TYPE_CHECKING, Any from packaging import version @@ -199,14 +198,6 @@ def validate_environment(self, *args, **kwargs): ) def update_torch_dtype(self, torch_dtype): - quant_type = self.quantization_config.quant_type - if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")): - if torch_dtype is not None and torch_dtype != torch.bfloat16: - logger.warning( - f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " - f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." - ) - if torch_dtype is None: # We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op logger.warning( @@ -219,45 +210,16 @@ def update_torch_dtype(self, torch_dtype): return torch_dtype def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": - quant_type = self.quantization_config.quant_type from accelerate.utils import CustomDtype - if isinstance(quant_type, str): - if quant_type.startswith("int8"): - # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8 - return torch.int8 - elif quant_type.startswith("int4"): - return CustomDtype.INT4 - elif quant_type == "uintx_weight_only": - return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) - elif quant_type.startswith("uint"): - return { - 1: torch.uint1, - 2: torch.uint2, - 3: torch.uint3, - 4: torch.uint4, - 5: torch.uint5, - 6: torch.uint6, - 7: torch.uint7, - }[int(quant_type[4])] - elif quant_type.startswith("float") or quant_type.startswith("fp"): - return torch.bfloat16 - - elif is_torchao_version(">", "0.9.0"): - from torchao.core.config import AOBaseConfig - - quant_type = self.quantization_config.quant_type - if isinstance(quant_type, AOBaseConfig): - # Extract size digit using fuzzy match on the class name - config_name = quant_type.__class__.__name__ - size_digit = fuzzy_match_size(config_name) - - # Map the extracted digit to appropriate dtype - if size_digit == "4": - return CustomDtype.INT4 - else: - # Default to int8 - return torch.int8 + quant_type = self.quantization_config.quant_type + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + if size_digit == "4": + return CustomDtype.INT4 + else: + return torch.int8 if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): return target_dtype @@ -337,29 +299,14 @@ def get_cuda_warm_up_factor(self): - Use a division factor of 8 for int4 weights - Use a division factor of 4 for int8 weights """ - # Original mapping for non-AOBaseConfig types - # For the uint types, this is a best guess. Once these types become more used - # we can look into their nuances. - if is_torchao_version(">", "0.9.0"): - from torchao.core.config import AOBaseConfig - - quant_type = self.quantization_config.quant_type - if isinstance(quant_type, AOBaseConfig): - # Extract size digit using fuzzy match on the class name - config_name = quant_type.__class__.__name__ - size_digit = fuzzy_match_size(config_name) - - if size_digit == "4": - return 8 - else: - return 4 - - map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4} quant_type = self.quantization_config.quant_type - for pattern, target_dtype in map_to_target_dtype.items(): - if fnmatch(quant_type, pattern): - return target_dtype - raise ValueError(f"Unsupported quant_type: {quant_type!r}") + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + if size_digit == "4": + return 8 + else: + return 4 def _process_model_before_weight_loading( self, @@ -415,9 +362,17 @@ def is_serializable(self, safe_serialization=None): return _is_torchao_serializable + _TRAINABLE_QUANTIZATION_CONFIGS = ( + "Int8WeightOnlyConfig", + "Int8DynamicActivationInt8WeightConfig", + "Int8StaticActivationInt8WeightConfig", + "Float8WeightOnlyConfig", + "Float8DynamicActivationFloat8WeightConfig", + ) + @property def is_trainable(self): - return self.quantization_config.quant_type.startswith("int8") + return self.quantization_config.quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS @property def is_compileable(self) -> bool: diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index a722eaece4d1..686c469e305c 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -55,6 +55,20 @@ enable_full_determinism() +def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: + if is_torch_available(): + import torch + + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + if major == 8: + return minor >= 9 + return major >= 9 + elif torch.xpu.is_available(): + return True + return False + + if is_torch_available(): import torch import torch.nn as nn @@ -64,12 +78,17 @@ if is_torchao_available(): from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization import ( + Float8WeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + ) from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor - from torchao.quantization.quant_primitives import MappingType from torchao.utils import get_model_size_in_bytes - if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"): - from torchao.quantization import Int8WeightOnlyConfig + if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"): + from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig @require_torch @@ -80,53 +99,30 @@ def test_to_dict(self): """ Makes sure the config format is properly set """ - quantization_config = TorchAoConfig("int4_weight_only") + quantization_config = TorchAoConfig(Int4WeightOnlyConfig()) torchao_orig_config = quantization_config.to_dict() - - for key in torchao_orig_config: - self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key]) + self.assertIn("quant_type", torchao_orig_config) + self.assertIn("quant_method", torchao_orig_config) def test_post_init_check(self): """ - Test kwargs validations in TorchAoConfig + Test that non-AOBaseConfig types are rejected """ - _ = TorchAoConfig("int4_weight_only") - with self.assertRaisesRegex(ValueError, "is not supported"): - _ = TorchAoConfig("uint8") + _ = TorchAoConfig(Int4WeightOnlyConfig()) + with self.assertRaises(TypeError): + _ = TorchAoConfig("int4_weight_only") - with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): - _ = TorchAoConfig("int4_weight_only", group_size1=32) + with self.assertRaises(TypeError): + _ = TorchAoConfig(42) def test_repr(self): """ Check that there is no error in the repr """ - quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) - expected_repr = """TorchAoConfig { - "modules_to_not_convert": [ - "conv" - ], - "quant_method": "torchao", - "quant_type": "int4_weight_only", - "quant_type_kwargs": { - "group_size": 8 - } - }""".replace(" ", "").replace("\n", "") - quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") - self.assertEqual(quantization_repr, expected_repr) - - quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC) - expected_repr = """TorchAoConfig { - "modules_to_not_convert": null, - "quant_method": "torchao", - "quant_type": "int4dq", - "quant_type_kwargs": { - "act_mapping_type": "SYMMETRIC", - "group_size": 64 - } - }""".replace(" ", "").replace("\n", "") - quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") - self.assertEqual(quantization_repr, expected_repr) + quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), modules_to_not_convert=["conv"]) + quantization_repr = repr(quantization_config) + self.assertIn("TorchAoConfig", quantization_repr) + self.assertIn("torchao", quantization_repr) # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @@ -234,79 +230,30 @@ def test_quantization(self): for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: # fmt: off QUANTIZATION_TYPES_TO_TEST = [ - ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), - ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), - ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), - ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + (Int4WeightOnlyConfig(), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), + (Int8DynamicActivationIntxWeightConfig(), np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), + (Int8WeightOnlyConfig(), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + (Int8DynamicActivationInt8WeightConfig(), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + (IntxWeightOnlyConfig(dtype=torch.uint4, group_size=16), np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), + (IntxWeightOnlyConfig(dtype=torch.uint7, group_size=16), np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ] - if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): + if _is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ - ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), - ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), - # ===== - # The following lead to an internal torch error: - # RuntimeError: mat2 shape (32x4 must be divisible by 16 - # Skip these for now; TODO(aryan): investigate later - # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== - # Cutlass fails to initialize for below - # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== + (Float8WeightOnlyConfig(weight_dtype=torch.float8_e5m2), np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), + (Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), ]) - if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): - QUANTIZATION_TYPES_TO_TEST.extend([ - ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ]) # fmt: on - for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: - quant_kwargs = {} - if quantization_name in ["uint4wo", "uint7wo"]: - # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here - quant_kwargs.update({"group_size": 16}) - quantization_config = TorchAoConfig( - quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs - ) + for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"]) self._test_quant_type(quantization_config, expected_slice, model_id) - @unittest.skip("Skipping floatx quantization tests") - def test_floatx_quantization(self): - for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: - if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): - if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): - quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) - self._test_quant_type( - quantization_config, - np.array( - [ - 0.4648, - 0.5195, - 0.5547, - 0.4180, - 0.4434, - 0.6445, - 0.4316, - 0.4531, - 0.5625, - ] - ), - model_id, - ) - else: - # Make sure the correct error is thrown - with self.assertRaisesRegex(ValueError, "Please downgrade"): - quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) - def test_int4wo_quant_bfloat16_conversion(self): """ Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. """ - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64)) quantized_model = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", @@ -361,7 +308,7 @@ def test_device_map(self): else: expected_slice = expected_slice_offload with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64)) quantized_model = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", @@ -385,7 +332,7 @@ def test_device_map(self): self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64)) quantized_model = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-sharded", subfolder="transformer", @@ -406,7 +353,7 @@ def test_device_map(self): self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) def test_modules_to_not_convert(self): - quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) + quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), modules_to_not_convert=["transformer_blocks.0"]) quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", @@ -422,7 +369,7 @@ def test_modules_to_not_convert(self): quantized_layer = quantized_model_with_not_convert.proj_out self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) - quantization_config = TorchAoConfig("int8_weight_only") + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) quantized_model = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", @@ -436,7 +383,7 @@ def test_modules_to_not_convert(self): self.assertTrue(size_quantized < size_quantized_with_not_convert) def test_training(self): - quantization_config = TorchAoConfig("int8_weight_only") + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) quantized_model = FluxTransformer2DModel.from_pretrained( "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", @@ -470,7 +417,7 @@ def test_training(self): def test_torch_compile(self): r"""Test that verifies if torch.compile works with torchao quantization.""" for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: - quantization_config = TorchAoConfig("int8_weight_only") + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) components = self.get_dummy_components(quantization_config, model_id=model_id) pipe = FluxPipeline(**components) pipe.to(device=torch_device) @@ -491,11 +438,15 @@ def test_memory_footprint(self): memory footprint of the converted model and the class type of the linear layers of the converted models """ for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: - transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"] + transformer_int4wo = self.get_dummy_components(TorchAoConfig(Int4WeightOnlyConfig()), model_id=model_id)[ + "transformer" + ] transformer_int4wo_gs32 = self.get_dummy_components( - TorchAoConfig("int4wo", group_size=32), model_id=model_id + TorchAoConfig(Int4WeightOnlyConfig(group_size=32)), model_id=model_id )["transformer"] - transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[ + "transformer" + ] transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64 @@ -553,20 +504,22 @@ def test_model_memory_usage(self): unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs) del transformer_bf16 - transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[ + "transformer" + ] transformer_int8wo.to(torch_device) quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs) assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio def test_wrong_config(self): - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): self.get_dummy_components(TorchAoConfig("int42")) def test_sequential_cpu_offload(self): r""" A test that checks if inference runs as expected when sequential cpu offloading is enabled. """ - quantization_config = TorchAoConfig("int8wo") + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) components = self.get_dummy_components(quantization_config) pipe = FluxPipeline(**components) pipe.enable_sequential_cpu_offload() @@ -595,8 +548,8 @@ def tearDown(self): gc.collect() backend_empty_cache(torch_device) - def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): - quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) + def get_dummy_model(self, quant_type, device=None): + quantization_config = TorchAoConfig(quant_type) quantized_model = FluxTransformer2DModel.from_pretrained( self.model_name, subfolder="transformer", @@ -632,8 +585,8 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0): "timestep": timestep, } - def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice): - quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device) + def _test_original_model_expected_slice(self, quant_type, expected_slice): + quantized_model = self.get_dummy_model(quant_type, torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() @@ -641,8 +594,8 @@ def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): - quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device) + def _check_serialization_expected_slice(self, quant_type, expected_slice, device): + quantized_model = self.get_dummy_model(quant_type, device) with tempfile.TemporaryDirectory() as tmp_dir: quantized_model.save_pretrained(tmp_dir, safe_serialization=False) @@ -662,40 +615,39 @@ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) def test_int_a8w8_accelerator(self): - quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) device = torch_device - self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) - self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + self._test_original_model_expected_slice(quant_type, expected_slice) + self._check_serialization_expected_slice(quant_type, expected_slice, device) def test_int_a16w8_accelerator(self): - quant_method, quant_method_kwargs = "int8_weight_only", {} + quant_type = Int8WeightOnlyConfig() expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) device = torch_device - self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) - self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + self._test_original_model_expected_slice(quant_type, expected_slice) + self._check_serialization_expected_slice(quant_type, expected_slice, device) def test_int_a8w8_cpu(self): - quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} + quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) device = "cpu" - self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) - self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + self._test_original_model_expected_slice(quant_type, expected_slice) + self._check_serialization_expected_slice(quant_type, expected_slice, device) def test_int_a16w8_cpu(self): - quant_method, quant_method_kwargs = "int8_weight_only", {} + quant_type = Int8WeightOnlyConfig() expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) device = "cpu" - self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) - self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + self._test_original_model_expected_slice(quant_type, expected_slice) + self._check_serialization_expected_slice(quant_type, expected_slice, device) - @require_torchao_version_greater_or_equal("0.9.0") def test_aobase_config(self): - quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {} + quant_type = Int8WeightOnlyConfig() expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) device = torch_device - self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) - self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) + self._test_original_model_expected_slice(quant_type, expected_slice) + self._check_serialization_expected_slice(quant_type, expected_slice, device) @require_torchao_version_greater_or_equal("0.14.0") @@ -817,29 +769,25 @@ def _test_quant_type(self, quantization_config, expected_slice): def test_quantization(self): # fmt: off QUANTIZATION_TYPES_TO_TEST = [ - ("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])), - ("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), + (Int8WeightOnlyConfig(), np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])), + (Int8DynamicActivationInt8WeightConfig(), np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), ] - if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): + if _is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ - ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), + (Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), ]) - if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): - QUANTIZATION_TYPES_TO_TEST.extend([ - ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), - ]) # fmt: on - for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: - quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) + for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"]) self._test_quant_type(quantization_config, expected_slice) gc.collect() backend_empty_cache(torch_device) backend_synchronize(torch_device) def test_serialization_int8wo(self): - quantization_config = TorchAoConfig("int8wo") + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) components = self.get_dummy_components(quantization_config) pipe = FluxPipeline(**components) pipe.enable_model_cpu_offload() @@ -876,7 +824,7 @@ def test_serialization_int8wo(self): def test_memory_footprint_int4wo(self): # The original checkpoints are in bf16 and about 24 GB expected_memory_in_gb = 6.0 - quantization_config = TorchAoConfig("int4wo") + quantization_config = TorchAoConfig(Int4WeightOnlyConfig()) cache_dir = None transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev", @@ -891,7 +839,7 @@ def test_memory_footprint_int4wo(self): def test_memory_footprint_int8wo(self): # The original checkpoints are in bf16 and about 24 GB expected_memory_in_gb = 12.0 - quantization_config = TorchAoConfig("int8wo") + quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) cache_dir = None transformer = FluxTransformer2DModel.from_pretrained( "black-forest-labs/FLUX.1-dev",