Skip to content

Commit d49eb2e

Browse files
remove str option for quantization config in torchao
1 parent 6761336 commit d49eb2e

4 files changed

Lines changed: 153 additions & 567 deletions

File tree

docs/source/en/quantization/torchao.md

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
2929
from torchao.quantization import Int8WeightOnlyConfig
3030

3131
pipeline_quant_config = PipelineQuantizationConfig(
32-
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
33-
)
34-
pipeline = DiffusionPipeline.from_pretrained(
35-
"black-forest-labs/FLUX.1-dev",
36-
quantization_config=pipeline_quant_config,
37-
torch_dtype=torch.bfloat16,
38-
device_map="cuda"
39-
)
40-
```
41-
42-
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
43-
44-
```py
45-
import torch
46-
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
47-
48-
pipeline_quant_config = PipelineQuantizationConfig(
49-
quant_mapping={"transformer": TorchAoConfig("int8wo")}
32+
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128))}
5033
)
5134
pipeline = DiffusionPipeline.from_pretrained(
5235
"black-forest-labs/FLUX.1-dev",
@@ -91,17 +74,6 @@ Weight-only quantization stores the model weights in a specific low-bit data typ
9174

9275
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
9376

94-
The quantization methods supported are as follows:
95-
96-
| **Category** | **Full Function Names** | **Shorthands** |
97-
|--------------|-------------------------|----------------|
98-
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
99-
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
100-
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
101-
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
102-
103-
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
104-
10577
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
10678

10779
## Serializing and Deserializing quantized models
@@ -111,8 +83,9 @@ To serialize a quantized model in a given dtype, first load the model with the d
11183
```python
11284
import torch
11385
from diffusers import AutoModel, TorchAoConfig
86+
from torchao.quantization import Int8WeightOnlyConfig
11487

115-
quantization_config = TorchAoConfig("int8wo")
88+
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
11689
transformer = AutoModel.from_pretrained(
11790
"black-forest-labs/Flux.1-Dev",
11891
subfolder="transformer",
@@ -137,18 +110,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
137110
image.save("output.png")
138111
```
139112

140-
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
113+
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
141114

142115
```python
143116
import torch
144117
from accelerate import init_empty_weights
145118
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
119+
from torchao.quantization import IntxWeightOnlyConfig
146120

147121
# Serialize the model
148122
transformer = AutoModel.from_pretrained(
149123
"black-forest-labs/Flux.1-Dev",
150124
subfolder="transformer",
151-
quantization_config=TorchAoConfig("uint4wo"),
125+
quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)),
152126
torch_dtype=torch.bfloat16,
153127
)
154128
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")

0 commit comments

Comments
 (0)