-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[WIP] [core] fix group offloading when using torchao #13276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
sayakpaul
wants to merge
1
commit into
main
Choose a base branch
from
fix-torchao-groupoffloading
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+67
−8
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| import safetensors.torch | ||
| import torch | ||
|
|
||
| from ..utils import get_logger, is_accelerate_available | ||
| from ..utils import get_logger, is_accelerate_available, is_torchao_available | ||
| from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS | ||
| from .hooks import HookRegistry, ModelHook | ||
|
|
||
|
|
@@ -35,6 +35,41 @@ | |
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| def _is_torchao_tensor(tensor: torch.Tensor) -> bool: | ||
| """Check if a tensor is a TorchAO quantized tensor subclass.""" | ||
| if not is_torchao_available(): | ||
| return False | ||
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
| return isinstance(tensor, TorchAOBaseTensor) | ||
|
|
||
|
|
||
| def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: | ||
| """Get names of all internal tensor data attributes from a TorchAO tensor.""" | ||
| cls = type(tensor) | ||
| names = list(getattr(cls, "tensor_data_names", [])) | ||
| for attr_name in getattr(cls, "optional_tensor_data_names", []): | ||
| if getattr(tensor, attr_name, None) is not None: | ||
| names.append(attr_name) | ||
| return names | ||
|
|
||
|
|
||
| def _update_torchao_tensor_in_place(param: torch.Tensor, source: torch.Tensor) -> None: | ||
| """Update internal tensor data of a TorchAO parameter in-place from source. | ||
| Must operate on the parameter/buffer object directly (not ``param.data``) because ``_make_wrapper_subclass`` | ||
| returns a fresh wrapper from ``.data`` each time, so attribute mutations on ``.data`` are lost. | ||
| """ | ||
| for attr_name in _get_torchao_inner_tensor_names(source): | ||
| setattr(param, attr_name, getattr(source, attr_name)) | ||
|
|
||
|
|
||
| def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: | ||
| """Record stream for all internal tensors of a TorchAO parameter.""" | ||
| for attr_name in _get_torchao_inner_tensor_names(param): | ||
| getattr(param, attr_name).record_stream(stream) | ||
|
|
||
|
|
||
| # fmt: off | ||
| _GROUP_OFFLOADING = "group_offloading" | ||
| _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" | ||
|
|
@@ -157,9 +192,16 @@ def _pinned_memory_tensors(self): | |
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if _is_torchao_tensor(tensor): | ||
| _update_torchao_tensor_in_place(tensor, moved) | ||
| else: | ||
| tensor.data = moved | ||
| if self.record_stream: | ||
| tensor.data.record_stream(default_stream) | ||
| if _is_torchao_tensor(tensor): | ||
| _record_stream_torchao_tensor(tensor, default_stream) | ||
| else: | ||
| tensor.data.record_stream(default_stream) | ||
|
|
||
| def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): | ||
| for group_module in self.modules: | ||
|
|
@@ -245,18 +287,35 @@ def _offload_to_memory(self): | |
|
|
||
| for group_module in self.modules: | ||
| for param in group_module.parameters(): | ||
| param.data = self.cpu_param_dict[param] | ||
| if _is_torchao_tensor(param): | ||
| _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) | ||
| else: | ||
| param.data = self.cpu_param_dict[param] | ||
| for param in self.parameters: | ||
| param.data = self.cpu_param_dict[param] | ||
| if _is_torchao_tensor(param): | ||
| _update_torchao_tensor_in_place(param, self.cpu_param_dict[param]) | ||
| else: | ||
| param.data = self.cpu_param_dict[param] | ||
| for buffer in self.buffers: | ||
| buffer.data = self.cpu_param_dict[buffer] | ||
| if _is_torchao_tensor(buffer): | ||
| _update_torchao_tensor_in_place(buffer, self.cpu_param_dict[buffer]) | ||
| else: | ||
| buffer.data = self.cpu_param_dict[buffer] | ||
| else: | ||
| for group_module in self.modules: | ||
| group_module.to(self.offload_device, non_blocking=False) | ||
| for param in self.parameters: | ||
| param.data = param.data.to(self.offload_device, non_blocking=False) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses |
||
| if _is_torchao_tensor(param): | ||
| moved = param.data.to(self.offload_device, non_blocking=False) | ||
| _update_torchao_tensor_in_place(param, moved) | ||
| else: | ||
| param.data = param.data.to(self.offload_device, non_blocking=False) | ||
| for buffer in self.buffers: | ||
| buffer.data = buffer.data.to(self.offload_device, non_blocking=False) | ||
| if _is_torchao_tensor(buffer): | ||
| moved = buffer.data.to(self.offload_device, non_blocking=False) | ||
| _update_torchao_tensor_in_place(buffer, moved) | ||
| else: | ||
| buffer.data = buffer.data.to(self.offload_device, non_blocking=False) | ||
|
|
||
| @torch.compiler.disable() | ||
| def onload_(self): | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean the
toop is not implemented properly for torchao tensors?if you have a minimal repro, we might be able to fix I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.