Skip to content

torchao >= 0.16.0 quantization not supported #13286

@zzlol63

Description

@zzlol63

Describe the bug

Below sample code (taken from https://huggingface.co/blog/lora-fast) does not work because torchao has renamed the APIs and mentions it as a breaking change in 0.15.0 (with deprecation warning) and above as per the release notes:
https://github.com/pytorch/ao/releases/tag/v0.15.0

Before:

from torchao.quantization import (
    float8_dynamic_activation_float8_weight,
    float8_static_activation_float8_weight,
    float8_weight_only,
    fpx_weight_only,
    gemlite_uintx_weight_only,
    int4_dynamic_activation_int4_weight,
    int4_weight_only,
    int8_dynamic_activation_int4_weight,
    int8_dynamic_activation_int8_weight,
    int8_weight_only,
    quantize_,
    uintx_weight_only,
)

After:

from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    Float8StaticActivationFloat8WeightConfig,
    Float8WeightOnlyConfig,
    FPXWeightOnlyConfig,
    GemliteUIntXWeightOnlyConfig,
    Int4DynamicActivationInt4WeightConfig,
    Int4WeightOnlyConfig,
    Int8DynamicActivationInt4WeightConfig,
    Int8DynamicActivationInt8WeightConfig,
    Int8WeightOnlyConfig,
    quantize_,
    UIntXWeightOnlyConfig,
)

In 0.16.0 its completely removed.

Reproduction

from diffusers import DiffusionPipeline, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
import torch

# quantize the Flux transformer with FP8
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    quantization_config=PipelineQuantizationConfig(
        quant_mapping={"transformer": TorchAoConfig("float8dq_e4m3_row")}
    )
).to("cuda")

# use torch.compile()
pipe.transformer.compile(fullgraph=True, mode="max-autotune")

# perform inference
pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 28,
    "max_sequence_length": 512,
}

# first time will be slower, subsequent runs will be faster
image = pipe(**pipe_kwargs).images[0]

Logs

File "C:\test\test.py", line 485, in main
    quant_mapping={"transformer": TorchAoConfig("float8dq_e4m3_row")}
  File "C:\Users\Home\anaconda3\envs\test\lib\site-packages\diffusers\quantizers\quantization_config.py", line 517, in __init__
    self.post_init()
  File "C:\Users\Home\anaconda3\envs\test\lib\site-packages\diffusers\quantizers\quantization_config.py", line 533, in post_init
    TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
  File "C:\Users\Home\anaconda3\envs\test\lib\site-packages\diffusers\quantizers\quantization_config.py", line 629, in _get_torchao_quant_type_to_method
    from torchao.quantization import (
ImportError: cannot import name 'float8_dynamic_activation_float8_weight' from 'torchao.quantization'

System Info

  • 🤗 Diffusers version: 0.37.0
  • Platform: Windows-10-10.0.26200-SP0
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.10.0+cu130 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.7.1
  • Transformers version: 5.3.0
  • Accelerate version: 1.10.1
  • PEFT version: 0.18.1
  • Bitsandbytes version: 0.49.2
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions