-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When combining torchao quantization (TorchAoConfig with Float8WeightOnlyConfig) and group offloading with use_stream=True, inference fails with a device mismatch error. The quantized weight remains on CPU while the input tensor is on CUDA.
Reproduction
import torch
from diffusers import QwenImageEditPlusPipeline # or any pipeline
from diffusers.hooks import apply_group_offloading
from diffusers import PipelineQuantizationConfig, TorchAoConfig
from torchao.quantization import Float8WeightOnlyConfig
model_path = "path/to/model"
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
pipe = QwenImageEditPlusPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
quantization_config=PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Float8WeightOnlyConfig())}
),
device_map="cpu"
)
# This will cause a RuntimeError during inference
pipe.transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True, # ← Required for performance, but triggers the bug
non_blocking=True
)
pipe.vae.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
non_blocking=True
)
apply_group_offloading(
pipe.text_encoder,
onload_device=onload_device,
offload_type="leaf_level",
use_stream=True,
non_blocking=True
)
# Triggers the error:
pipe(image, prompt, height=2048, width=2048, num_inference_steps=50).images[0]First, note that unlike official examples that load the pipeline directly to CUDA via device_map='cuda', I am constrained by VRAM size (which is the main intended use case for group offloading). The components (multi-billion parameter transformer + text encoder + VAE) are simply too large to fit into VRAM simultaneously. Therefore, the pipeline must be initialized on the CPU (device_map="cpu") before applying group offloading.
traceback:
File ".../diffusers/models/transformers/transformer_qwenimage.py", line 896, in forward
hidden_states = self.img_in(hidden_states)
File ".../torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
File ".../torchao/quantization/quantize_/workflows/float8/float8_tensor.py", line 448, in _float8_addmm_impl
out = torch.matmul(input_tensor, weight_tensor.dequantize())
RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu,
different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_mm)Logs
System Info
Environment
| Component | Version |
|---|---|
| diffusers | 0.37.0 |
| torch | 2.9.1+cu126 |
| torchao | 0.16.0 |
| accelerate | 1.3.0 |
| Python | 3.11 |
| CUDA | 12.6 |
Who can help?
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working