Skip to content

[WIP] [core] fix group offloading when using torchao#13276

Draft
sayakpaul wants to merge 1 commit intomainfrom
fix-torchao-groupoffloading
Draft

[WIP] [core] fix group offloading when using torchao#13276
sayakpaul wants to merge 1 commit intomainfrom
fix-torchao-groupoffloading

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

This PR is opened to discuss if these changes should rather be made to TorchAO or they should live in Diffusers.

The benefit of this is that many new releases benefit from quantization schemes robustly implemented and tested in TorchAO. But quantization alone rarely helps, we need offloading too. Many large models need group offloading (overlapping compute with data transfer).

Problem

Group offloading moves parameters between CPU and GPU by reassigning param.data:

param.data = source_tensor.to(device) 

This works for regular tensors but breaks for TorchAO quantized tensors.

TorchAO tensors are special instances that store their actual data in internal attributes (e.g., .qdata, .scale), not in the standard tensor storage. The .data assignment replaces the
outer wrapper storage but leaves these internal attributes on the original device, causing a device mismatch at compute time.

A further subtlety: accessing .data on a wrapper subclass parameter returns a new wrapper object each time, so mutating attributes on param.data doesn't persist either.

This PR

For TorchAO tensors, instead of reassigning data, we update the internal tensor attributes directly on the parameter object itself:

# Before (broken for TorchAO tensors)                                                                                                    
param.data = source_tensor.to(device)                                                                                                    
                                                                                                                                         
# After                                                                                                                                  
moved = source_tensor.to(device)                                                                                                         
if _is_torchao_tensor(param):                                                                                                            
    for attr in tensor_data_names:  # e.g. ["qdata", "scale"]                                                                            
        setattr(param, attr, getattr(moved, attr))                                                                                       
else:                                                                                                                                    
    param.data = moved     

Related issue: pytorch/ao#4088.

Happens with nightlies as well.

Code to test: https://gist.github.com/sayakpaul/929678132809874c5dbf9c5215460d33#file-check_torchao_offload_compile-py (run with --quantize, --group-offload; and potentially with --full-compile).

Nice results (with quantization + group offloading + full compile):

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:32<00:00,  8.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.18s/it]

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean the to op is not implemented properly for torchao tensors?

if you have a minimal repro, we might be able to fix I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
cpu_copy = p.data.cpu()
cuda_copy = cpu_copy.to("cuda")
p.data = cuda_copy

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants