diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 891ac28455af..04be39056656 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,7 +22,7 @@ import safetensors.torch import torch -from ..utils import get_logger, is_accelerate_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -35,6 +35,41 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +def _is_torchao_tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a TorchAO quantized tensor subclass.""" + if not is_torchao_available(): + return False + from torchao.utils import TorchAOBaseTensor + + return isinstance(tensor, TorchAOBaseTensor) + + +def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: + """Get names of all internal tensor data attributes from a TorchAO tensor.""" + cls = type(tensor) + names = list(getattr(cls, "tensor_data_names", [])) + for attr_name in getattr(cls, "optional_tensor_data_names", []): + if getattr(tensor, attr_name, None) is not None: + names.append(attr_name) + return names + + +def _update_torchao_tensor_in_place(param: torch.Tensor, source: torch.Tensor) -> None: + """Update internal tensor data of a TorchAO parameter in-place from source. + + Must operate on the parameter/buffer object directly (not ``param.data``) because ``_make_wrapper_subclass`` + returns a fresh wrapper from ``.data`` each time, so attribute mutations on ``.data`` are lost. + """ + for attr_name in _get_torchao_inner_tensor_names(source): + setattr(param, attr_name, getattr(source, attr_name)) + + +def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: + """Record stream for all internal tensors of a TorchAO parameter.""" + for attr_name in _get_torchao_inner_tensor_names(param): + getattr(param, attr_name).record_stream(stream) + + # fmt: off _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" @@ -157,9 +192,16 @@ def _pinned_memory_tensors(self): pinned_dict = None def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): - tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if _is_torchao_tensor(tensor): + _update_torchao_tensor_in_place(tensor, moved) + else: + tensor.data = moved if self.record_stream: - tensor.data.record_stream(default_stream) + if _is_torchao_tensor(tensor): + _record_stream_torchao_tensor(tensor, default_stream) + else: + tensor.data.record_stream(default_stream) def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): for group_module in self.modules: @@ -245,18 +287,35 @@ def _offload_to_memory(self): for group_module in self.modules: for param in group_module.parameters(): - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for param in self.parameters: - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for buffer in self.buffers: - buffer.data = self.cpu_param_dict[buffer] + if _is_torchao_tensor(buffer): + _update_torchao_tensor_in_place(buffer, self.cpu_param_dict[buffer]) + else: + buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(param): + moved = param.data.to(self.offload_device, non_blocking=False) + _update_torchao_tensor_in_place(param, moved) + else: + param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(buffer): + moved = buffer.data.to(self.offload_device, non_blocking=False) + _update_torchao_tensor_in_place(buffer, moved) + else: + buffer.data = buffer.data.to(self.offload_device, non_blocking=False) @torch.compiler.disable() def onload_(self):