Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 67 additions & 8 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import safetensors.torch
import torch

from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook

Expand All @@ -35,6 +35,41 @@
logger = get_logger(__name__) # pylint: disable=invalid-name


def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a TorchAO quantized tensor subclass."""
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor

return isinstance(tensor, TorchAOBaseTensor)


def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names


def _update_torchao_tensor_in_place(param: torch.Tensor, source: torch.Tensor) -> None:
"""Update internal tensor data of a TorchAO parameter in-place from source.
Must operate on the parameter/buffer object directly (not ``param.data``) because ``_make_wrapper_subclass``
returns a fresh wrapper from ``.data`` each time, so attribute mutations on ``.data`` are lost.
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))


def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)


# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
Expand Down Expand Up @@ -157,9 +192,16 @@ def _pinned_memory_tensors(self):
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean the to op is not implemented properly for torchao tensors?

if you have a minimal repro, we might be able to fix I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
from torchao.quantization import Int8WeightOnlyConfig, quantize_

linear = torch.nn.Linear(64, 64, dtype=torch.bfloat16)
quantize_(linear, Int8WeightOnlyConfig(version=2))
p = linear.weight

# Move a copy to CUDA and assign via .data
cpu_copy = p.data.cpu()
cuda_copy = cpu_copy.to("cuda")
p.data = cuda_copy

print(f"p.qdata.device = {p.qdata.device}")  # cpu
print(f"cuda_copy.qdata.device = {cuda_copy.qdata.device}")  # cuda:0

# Forward fails: input on cuda, weight internals still on cpu
linear.bias.data = linear.bias.data.to("cuda")
x = torch.randn(1, 64, device="cuda", dtype=torch.bfloat16)
linear(x)  # RuntimeError: mat2 is on cpu

moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_update_torchao_tensor_in_place(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)

def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
Expand Down Expand Up @@ -245,18 +287,35 @@ def _offload_to_memory(self):

for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_update_torchao_tensor_in_place(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_update_torchao_tensor_in_place(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_update_torchao_tensor_in_place(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses

if _is_torchao_tensor(param):
moved = param.data.to(self.offload_device, non_blocking=False)
_update_torchao_tensor_in_place(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.data.to(self.offload_device, non_blocking=False)
_update_torchao_tensor_in_place(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)

@torch.compiler.disable()
def onload_(self):
Expand Down
Loading