Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.

Args:
quant_type (`str` | AOBaseConfig):
quant_type (`str` | AOBaseConfig | None):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
Expand All @@ -469,9 +469,14 @@ class TorchAoConfig(QuantizationConfigMixin):
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
- An AOBaseConfig instance: for more advanced configuration options.
- `None`: when only `attention_backend` is used without weight quantization.
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
attention_backend (`str`, *optional*, default to `None`):
Low-precision attention backend to use. Currently supported: `"fp8_fa3"` (FP8 attention using Flash
Attention 3, requires Hopper GPU with SM90+). This is orthogonal to weight quantization — you can use
either or both. When used with `torch.compile`, RoPE fusion is automatically enabled.
kwargs (`dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
Expand All @@ -495,18 +500,28 @@ class TorchAoConfig(QuantizationConfigMixin):
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)

# FP8 attention only (no weight quantization)
quantization_config = TorchAoConfig(attention_backend="fp8_fa3")

# Combined: weight quantization + FP8 attention
quantization_config = TorchAoConfig("int8wo", attention_backend="fp8_fa3")
```
"""

_SUPPORTED_ATTENTION_BACKENDS = {"fp8_fa3"}

def __init__(
self,
quant_type: str | "AOBaseConfig", # noqa: F821
quant_type: str | "AOBaseConfig" | None = None, # noqa: F821
modules_to_not_convert: list[str] | None = None,
attention_backend: str | None = None,
**kwargs,
) -> None:
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
self.attention_backend = attention_backend

# When we load from serialized config, "quant_type_kwargs" will be the key
if "quant_type_kwargs" in kwargs:
Expand All @@ -517,6 +532,21 @@ def __init__(
self.post_init()

def post_init(self):
if self.quant_type is None and self.attention_backend is None:
raise ValueError(
"At least one of `quant_type` or `attention_backend` must be provided."
)

if self.attention_backend is not None and self.attention_backend not in self._SUPPORTED_ATTENTION_BACKENDS:
raise ValueError(
f"Unsupported attention_backend: {self.attention_backend!r}. "
f"Supported backends: {self._SUPPORTED_ATTENTION_BACKENDS}"
)

# Skip quant_type validation when only attention_backend is used
if self.quant_type is None:
return

if not isinstance(self.quant_type, str):
if is_torchao_version("<=", "0.9.0"):
raise ValueError(
Expand Down Expand Up @@ -570,6 +600,12 @@ def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()

if self.attention_backend is not None:
d["attention_backend"] = self.attention_backend

if self.quant_type is None:
return d

if isinstance(self.quant_type, str):
# Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
Expand Down Expand Up @@ -600,10 +636,11 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
if not is_torchao_version(">", "0.9.0"):
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
quant_type = config_dict.pop("quant_type", None)
attention_backend = config_dict.pop("attention_backend", None)

if isinstance(quant_type, str):
return cls(quant_type=quant_type, **config_dict)
if quant_type is None or isinstance(quant_type, str):
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)
# Check if we only have one key which is "default"
# In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, (
Expand All @@ -616,7 +653,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):

quant_type = config_from_dict(quant_type)

return cls(quant_type=quant_type, **config_dict)
return cls(quant_type=quant_type, attention_backend=attention_backend, **config_dict)

@classmethod
def _get_torchao_quant_type_to_method(cls):
Expand Down
119 changes: 116 additions & 3 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import re
import types
from fnmatch import fnmatch
from functools import partial
from typing import TYPE_CHECKING, Any

from packaging import version
Expand Down Expand Up @@ -198,8 +199,57 @@ def validate_environment(self, *args, **kwargs):
f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
)

attention_backend = getattr(self.quantization_config, "attention_backend", None)
if attention_backend is not None:
self._validate_attention_environment(attention_backend)

def _validate_attention_environment(self, attention_backend):
"""Validate that the environment supports the requested attention backend."""
# Check torchao.prototype.attention is importable
try:
importlib.import_module("torchao.prototype.attention")
except (ImportError, ModuleNotFoundError):
raise ImportError(
f"attention_backend={attention_backend!r} requires `torchao.prototype.attention`. "
"Please install a version of torchao that includes the prototype attention module."
)

# Check PyTorch >= 2.11.0
torch_version_parsed = version.parse(importlib.metadata.version("torch"))
if torch_version_parsed < version.parse("2.11.0"):
raise RuntimeError(
f"attention_backend={attention_backend!r} requires PyTorch >= 2.11.0, "
f"but the current version is {torch_version_parsed}."
)

# Check CUDA available with SM90+ (Hopper)
if not torch.cuda.is_available():
raise RuntimeError(
f"attention_backend={attention_backend!r} requires CUDA."
)
major, minor = torch.cuda.get_device_capability()
if major < 9:
raise RuntimeError(
f"attention_backend={attention_backend!r} requires Hopper GPU (SM90+), "
f"but the current device has SM{major}{minor}."
)

# Check FA3 availability
try:
importlib.import_module("flash_attn_interface")
except (ImportError, ModuleNotFoundError):
raise ImportError(
f"attention_backend={attention_backend!r} requires the flash-attn package with FA3 support. "
"Please install flash-attn with FA3 support."
)

def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if quant_type is None:
if torch_dtype is None:
torch_dtype = torch.bfloat16
return torch_dtype

if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
Expand All @@ -220,6 +270,9 @@ def update_torch_dtype(self, torch_dtype):

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
if quant_type is None:
return target_dtype

from accelerate.utils import CustomDtype

if isinstance(quant_type, str):
Expand Down Expand Up @@ -283,6 +336,9 @@ def check_if_quantized_param(
state_dict: dict[str, Any],
**kwargs,
) -> bool:
if self.quantization_config.quant_type is None:
return False

param_device = kwargs.pop("param_device", None)
# Check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
Expand Down Expand Up @@ -337,6 +393,9 @@ def get_cuda_warm_up_factor(self):
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
if self.quantization_config.quant_type is None:
return 4

# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
Expand Down Expand Up @@ -368,6 +427,13 @@ def _process_model_before_weight_loading(
keep_in_fp32_modules: list[str] = [],
**kwargs,
):
model.config.quantization_config = self.quantization_config

if self.quantization_config.quant_type is None:
# Attention-only mode: no weight quantization setup needed
self.modules_to_not_convert = []
return

self.modules_to_not_convert = self.quantization_config.modules_to_not_convert

if not isinstance(self.modules_to_not_convert, list):
Expand All @@ -386,11 +452,53 @@ def _process_model_before_weight_loading(
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]

model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model: "ModelMixin"):
attention_backend = getattr(self.quantization_config, "attention_backend", None)
if attention_backend is not None:
self._apply_low_precision_attention(model, attention_backend)
return model

def _apply_low_precision_attention(self, model, attention_backend):
"""Apply low-precision attention via forward hooks.

Uses forward pre/post hooks to monkey-patch F.scaled_dot_product_attention with
the FP8 custom op during model forward, and sets the torch.compile pre-grad
fusion pass for RoPE fusion.
"""
import torch._inductor.config as inductor_config
import torch.nn.functional as F
from torch.nn.attention import activate_flash_attention_impl, restore_flash_attention_impl

from torchao.prototype.attention.fp8_fa3.attention import _ops
from torchao.prototype.attention.shared_utils.fusion_utils import rope_sdpa_fusion_pass
from torchao.prototype.attention.shared_utils.wrapper import _make_causal_aware_sdpa

# Diffusion models don't use causal masks
sdpa_patch_fn = _make_causal_aware_sdpa(_ops.fp8_sdpa_op, strip_causal_mask=False)

# Set the torch.compile fusion pass for RoPE fusion
inductor_config.pre_grad_custom_pass = partial(
rope_sdpa_fusion_pass,
rope_sdpa_op=_ops.rope_sdpa_op,
fp8_sdpa_op=_ops.fp8_sdpa_op,
backend_name="FA3",
)

flash_impl_name = "FA3"

def _pre_hook(module, args, kwargs=None):
activate_flash_attention_impl(flash_impl_name)
module._original_sdpa = F.scaled_dot_product_attention
F.scaled_dot_product_attention = sdpa_patch_fn

def _post_hook(module, args, output, kwargs=None):
F.scaled_dot_product_attention = module._original_sdpa
del module._original_sdpa
restore_flash_attention_impl()

model.register_forward_pre_hook(_pre_hook, with_kwargs=True)
model.register_forward_hook(_post_hook, with_kwargs=True)

def is_serializable(self, safe_serialization=None):
# TODO(aryan): needs to be tested
if safe_serialization:
Expand All @@ -417,7 +525,12 @@ def is_serializable(self, safe_serialization=None):

@property
def is_trainable(self):
return self.quantization_config.quant_type.startswith("int8")
quant_type = self.quantization_config.quant_type
if quant_type is None:
return False
if isinstance(quant_type, str):
return quant_type.startswith("int8")
return False

@property
def is_compileable(self) -> bool:
Expand Down