From 125a7a81b94e8cc23ab268ca661645a04415c250 Mon Sep 17 00:00:00 2001 From: Howard Zhang Date: Wed, 18 Mar 2026 17:35:32 -0700 Subject: [PATCH] Add low precision attention API from torchao to TorchAoConfig --- .../quantizers/quantization_config.py | 49 +++++++- .../quantizers/torchao/torchao_quantizer.py | 119 +++++++++++++++++- 2 files changed, 159 insertions(+), 9 deletions(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 9a467e6b21ee..62cc810642f2 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -446,7 +446,7 @@ class TorchAoConfig(QuantizationConfigMixin): """This is a config class for torchao quantization/sparsity techniques. Args: - quant_type (`str` | AOBaseConfig): + quant_type (`str` | AOBaseConfig | None): The type of quantization we want to use, currently supporting: - **Integer quantization:** - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`, @@ -469,9 +469,14 @@ class TorchAoConfig(QuantizationConfigMixin): - Full function names: `uintx_weight_only` - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` - An AOBaseConfig instance: for more advanced configuration options. + - `None`: when only `attention_backend` is used without weight quantization. modules_to_not_convert (`list[str]`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some modules left in their original precision. + attention_backend (`str`, *optional*, default to `None`): + Low-precision attention backend to use. Currently supported: `"fp8_fa3"` (FP8 attention using Flash + Attention 3, requires Hopper GPU with SM90+). This is orthogonal to weight quantization — you can use + either or both. When used with `torch.compile`, RoPE fusion is automatically enabled. kwargs (`dict[str, Any]`, *optional*): The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and @@ -495,18 +500,28 @@ class TorchAoConfig(QuantizationConfigMixin): quantization_config=quantization_config, torch_dtype=torch.bfloat16, ) + + # FP8 attention only (no weight quantization) + quantization_config = TorchAoConfig(attention_backend="fp8_fa3") + + # Combined: weight quantization + FP8 attention + quantization_config = TorchAoConfig("int8wo", attention_backend="fp8_fa3") ``` """ + _SUPPORTED_ATTENTION_BACKENDS = {"fp8_fa3"} + def __init__( self, - quant_type: str | "AOBaseConfig", # noqa: F821 + quant_type: str | "AOBaseConfig" | None = None, # noqa: F821 modules_to_not_convert: list[str] | None = None, + attention_backend: str | None = None, **kwargs, ) -> None: self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert + self.attention_backend = attention_backend # When we load from serialized config, "quant_type_kwargs" will be the key if "quant_type_kwargs" in kwargs: @@ -517,6 +532,21 @@ def __init__( self.post_init() def post_init(self): + if self.quant_type is None and self.attention_backend is None: + raise ValueError( + "At least one of `quant_type` or `attention_backend` must be provided." + ) + + if self.attention_backend is not None and self.attention_backend not in self._SUPPORTED_ATTENTION_BACKENDS: + raise ValueError( + f"Unsupported attention_backend: {self.attention_backend!r}. " + f"Supported backends: {self._SUPPORTED_ATTENTION_BACKENDS}" + ) + + # Skip quant_type validation when only attention_backend is used + if self.quant_type is None: + return + if not isinstance(self.quant_type, str): if is_torchao_version("<=", "0.9.0"): raise ValueError( @@ -570,6 +600,12 @@ def to_dict(self): """Convert configuration to a dictionary.""" d = super().to_dict() + if self.attention_backend is not None: + d["attention_backend"] = self.attention_backend + + if self.quant_type is None: + return d + if isinstance(self.quant_type, str): # Handle layout serialization if present if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: @@ -600,10 +636,11 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): if not is_torchao_version(">", "0.9.0"): raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict") config_dict = config_dict.copy() - quant_type = config_dict.pop("quant_type") + quant_type = config_dict.pop("quant_type", None) + attention_backend = config_dict.pop("attention_backend", None) - if isinstance(quant_type, str): - return cls(quant_type=quant_type, **config_dict) + if quant_type is None or isinstance(quant_type, str): + return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict) # Check if we only have one key which is "default" # In the future we may update this assert len(quant_type) == 1 and "default" in quant_type, ( @@ -616,7 +653,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): quant_type = config_from_dict(quant_type) - return cls(quant_type=quant_type, **config_dict) + return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict) @classmethod def _get_torchao_quant_type_to_method(cls): diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 1679ed26a104..ccfffb711725 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -21,6 +21,7 @@ import re import types from fnmatch import fnmatch +from functools import partial from typing import TYPE_CHECKING, Any from packaging import version @@ -198,8 +199,57 @@ def validate_environment(self, *args, **kwargs): f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." ) + attention_backend = getattr(self.quantization_config, "attention_backend", None) + if attention_backend is not None: + self._validate_attention_environment(attention_backend) + + def _validate_attention_environment(self, attention_backend): + """Validate that the environment supports the requested attention backend.""" + # Check torchao.prototype.attention is importable + try: + importlib.import_module("torchao.prototype.attention") + except (ImportError, ModuleNotFoundError): + raise ImportError( + f"attention_backend={attention_backend!r} requires `torchao.prototype.attention`. " + "Please install a version of torchao that includes the prototype attention module." + ) + + # Check PyTorch >= 2.11.0 + torch_version_parsed = version.parse(importlib.metadata.version("torch")) + if torch_version_parsed < version.parse("2.11.0"): + raise RuntimeError( + f"attention_backend={attention_backend!r} requires PyTorch >= 2.11.0, " + f"but the current version is {torch_version_parsed}." + ) + + # Check CUDA available with SM90+ (Hopper) + if not torch.cuda.is_available(): + raise RuntimeError( + f"attention_backend={attention_backend!r} requires CUDA." + ) + major, minor = torch.cuda.get_device_capability() + if major < 9: + raise RuntimeError( + f"attention_backend={attention_backend!r} requires Hopper GPU (SM90+), " + f"but the current device has SM{major}{minor}." + ) + + # Check FA3 availability + try: + importlib.import_module("flash_attn_interface") + except (ImportError, ModuleNotFoundError): + raise ImportError( + f"attention_backend={attention_backend!r} requires the flash-attn package with FA3 support. " + "Please install flash-attn with FA3 support." + ) + def update_torch_dtype(self, torch_dtype): quant_type = self.quantization_config.quant_type + if quant_type is None: + if torch_dtype is None: + torch_dtype = torch.bfloat16 + return torch_dtype + if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")): if torch_dtype is not None and torch_dtype != torch.bfloat16: logger.warning( @@ -220,6 +270,9 @@ def update_torch_dtype(self, torch_dtype): def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": quant_type = self.quantization_config.quant_type + if quant_type is None: + return target_dtype + from accelerate.utils import CustomDtype if isinstance(quant_type, str): @@ -283,6 +336,9 @@ def check_if_quantized_param( state_dict: dict[str, Any], **kwargs, ) -> bool: + if self.quantization_config.quant_type is None: + return False + param_device = kwargs.pop("param_device", None) # Check if the param_name is not in self.modules_to_not_convert if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): @@ -337,6 +393,9 @@ def get_cuda_warm_up_factor(self): - Use a division factor of 8 for int4 weights - Use a division factor of 4 for int8 weights """ + if self.quantization_config.quant_type is None: + return 4 + # Original mapping for non-AOBaseConfig types # For the uint types, this is a best guess. Once these types become more used # we can look into their nuances. @@ -368,6 +427,13 @@ def _process_model_before_weight_loading( keep_in_fp32_modules: list[str] = [], **kwargs, ): + model.config.quantization_config = self.quantization_config + + if self.quantization_config.quant_type is None: + # Attention-only mode: no weight quantization setup needed + self.modules_to_not_convert = [] + return + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert if not isinstance(self.modules_to_not_convert, list): @@ -386,11 +452,53 @@ def _process_model_before_weight_loading( # and tied modules are usually kept in FP32. self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] - model.config.quantization_config = self.quantization_config - def _process_model_after_weight_loading(self, model: "ModelMixin"): + attention_backend = getattr(self.quantization_config, "attention_backend", None) + if attention_backend is not None: + self._apply_low_precision_attention(model, attention_backend) return model + def _apply_low_precision_attention(self, model, attention_backend): + """Apply low-precision attention via forward hooks. + + Uses forward pre/post hooks to monkey-patch F.scaled_dot_product_attention with + the FP8 custom op during model forward, and sets the torch.compile pre-grad + fusion pass for RoPE fusion. + """ + import torch._inductor.config as inductor_config + import torch.nn.functional as F + from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl + + from torchao.prototype.attention.fp8_fa3.attention import _ops + from torchao.prototype.attention.shared_utils.fusion_utils import rope_sdpa_fusion_pass + from torchao.prototype.attention.shared_utils.wrapper import _make_causal_aware_sdpa + + # Diffusion models don't use causal masks + sdpa_patch_fn = _make_causal_aware_sdpa(_ops.fp8_sdpa_op, strip_causal_mask=False) + + # Set the torch.compile fusion pass for RoPE fusion + inductor_config.pre_grad_custom_pass = partial( + rope_sdpa_fusion_pass, + rope_sdpa_op=_ops.rope_sdpa_op, + fp8_sdpa_op=_ops.fp8_sdpa_op, + backend_name="FA3", + ) + + flash_impl_name = "FA3" + + def _pre_hook(module, args, kwargs=None): + activate_flash_attention_impl(flash_impl_name) + module._original_sdpa = F.scaled_dot_product_attention + F.scaled_dot_product_attention = sdpa_patch_fn + + def _post_hook(module, args, output, kwargs=None): + F.scaled_dot_product_attention = module._original_sdpa + del module._original_sdpa + restore_flash_attention_impl() + + model.register_forward_pre_hook(_pre_hook, with_kwargs=True) + model.register_forward_hook(_post_hook, with_kwargs=True) + def is_serializable(self, safe_serialization=None): # TODO(aryan): needs to be tested if safe_serialization: @@ -417,7 +525,12 @@ def is_serializable(self, safe_serialization=None): @property def is_trainable(self): - return self.quantization_config.quant_type.startswith("int8") + quant_type = self.quantization_config.quant_type + if quant_type is None: + return False + if isinstance(quant_type, str): + return quant_type.startswith("int8") + return False @property def is_compileable(self) -> bool: