-
Notifications
You must be signed in to change notification settings - Fork 609
[PyTorch] Bunch of fixes for cpu offloading #2535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
791f183
6d2f43b
e28c581
056f06e
c8e0984
3f45dcd
04912bf
04b104d
3ddab74
14fbfae
5f7675c
7df9094
ccf54b9
c7b01a6
b52fa23
91928b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| from .quantized_tensor import ( | ||
| restore_from_saved, | ||
| prepare_for_saving, | ||
| QuantizedTensor, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -255,6 +256,8 @@ def start_offload(self): | |
| Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream. | ||
| Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded. | ||
| This event is recorded in the start_offload or push_tensor call. | ||
|
|
||
| Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor). | ||
| """ | ||
| self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"]) | ||
| self.state = "offload_started" | ||
|
|
@@ -275,19 +278,18 @@ def start_offload(self): | |
|
|
||
| with torch.cuda.stream(self.offload_stream): | ||
| if allocate_cpu_buffers: | ||
| # empty_like is defined also for QuantizedTensors | ||
| offloaded_tensor = torch.empty_like( | ||
| tensor, device=torch.device("cpu"), pin_memory=True | ||
| ) | ||
| self.cpu_tensor_group.tensor_list.append(offloaded_tensor) | ||
| else: | ||
| assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, ( | ||
| offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] | ||
| assert offloaded_tensor.shape == tensor.shape, ( | ||
| "CPU buffer shape does not match the offloaded tensor shape:" | ||
| f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} " | ||
| " Make sure that tensor shaped do not change between" | ||
| f" {offloaded_tensor.shape} != {tensor.shape} " | ||
| "Make sure that tensor shapes do not change between" | ||
| " iterations if retain_pinned_cpu_buffers is True." | ||
| ) | ||
| offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] | ||
| offloaded_tensor.copy_(tensor, non_blocking=True) | ||
|
|
||
| # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, | ||
|
|
@@ -318,6 +320,9 @@ def start_reload(self): | |
| """ | ||
| Start reloading of tensors. | ||
| It allocates new tensors on GPU and puts copy from CPU tasks on offload stream. | ||
|
|
||
| Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor | ||
| and reconstructed in pop_tensor). | ||
| """ | ||
| self._validate_state(func_name="start_reload", allowed_states=["offload_finished"]) | ||
| self.state = "reload_started" | ||
|
|
@@ -330,7 +335,6 @@ def start_reload(self): | |
| # cannot move tensors from pool of one stream to another without | ||
| # calling cudaFree and cudaMalloc again. | ||
|
|
||
| # empty_like is defined also for QuantizedTensors. | ||
| reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) | ||
| self.offload_stream.wait_stream(torch.cuda.current_stream()) | ||
|
|
||
|
|
@@ -347,16 +351,29 @@ def start_reload(self): | |
| self.bwd_gpu_tensor_group | ||
| ) | ||
|
|
||
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: | ||
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: | ||
| """ | ||
| It is called when a tensor is saved for backward pass. | ||
|
|
||
| If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group. | ||
| If tensor is not offloaded, returns the tensor itself. | ||
| For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple. | ||
| """ | ||
| self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) | ||
|
|
||
| if self._check_if_offload(tensor): | ||
| # For QuantizedTensor: decompose into component tensors, push each one recursively | ||
| if isinstance(tensor, QuantizedTensor): | ||
| # Make a copy because prepare_for_saving modifies the object (sets fields to None) | ||
| tensor_copy = tensor.detach() | ||
| # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, | ||
| # so the generic prepare_for_saving would not call tensor.prepare_for_saving() | ||
| saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() | ||
| push_results = [ | ||
| self.push_tensor(t) if t is not None else None for t in saved_tensors | ||
| ] | ||
| return (push_results, [tensor_obj]) | ||
|
Comment on lines
+366
to
+375
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical Bug: The comment says "Make a copy because prepare_for_saving modifies the object", but Looking at how def detach(self) -> Float8Tensor:
return Float8Tensor.make_like(self)And def get_metadata(self) -> Dict[str, Any]:
return {
"rowwise_data": self._rowwise_data, # Reference, not copy!
...
}Problem: When Impact: This will cause the forward pass to access None data, leading to crashes or incorrect results. Solution: Either use
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have test for that case, we want to have references. |
||
|
|
||
| self.fwd_gpu_tensor_group.tensor_list.append(tensor) | ||
| # The group is processed and offloaded at the end of the forward pass of current layer. | ||
| # To enable offloading of tensors faster we use self.offload_stream and record | ||
|
|
@@ -370,23 +387,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: | |
| return len(self.fwd_gpu_tensor_group.tensor_list) - 1 | ||
| return tensor | ||
|
|
||
| def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: | ||
| def pop_tensor( | ||
| self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list] | ||
| ) -> torch.Tensor: | ||
| """ | ||
| It is called when a tensor is used in backward pass. | ||
| Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish. | ||
| For QuantizedTensor (tuple input), reconstructs from component tensors. | ||
| """ | ||
| self._validate_state( | ||
| func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"] | ||
| ) | ||
|
|
||
| # 1. tensor not offloaded | ||
| # 1. tensor not offloaded (regular tensor returned as-is from push) | ||
| if isinstance(tensor_or_tensor_id, torch.Tensor): | ||
| return tensor_or_tensor_id | ||
| # 2. the layer was not offloaded at all | ||
|
|
||
| # 2. QuantizedTensor case: tuple of (push_results, tensor_objs) | ||
| if isinstance(tensor_or_tensor_id, tuple): | ||
| push_results, tensor_objs = tensor_or_tensor_id | ||
| # Recursively pop each component | ||
| reloaded_tensors = [ | ||
| self.pop_tensor(pr) if pr is not None else None for pr in push_results | ||
| ] | ||
| # Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy | ||
| tensor_obj = tensor_objs[0] | ||
| tensor_obj.restore_from_saved(reloaded_tensors) | ||
| return tensor_obj | ||
|
Comment on lines
+407
to
+416
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential Issue: In-place modification during backward pass The tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_objThis modifies Concern: If Recommendation: Verify that this pattern is safe with PyTorch's autograd system, or consider creating a fresh tensor object rather than mutating the saved one. |
||
|
|
||
| # 3. Regular tensor index case | ||
| if self.state == "not_offloaded": | ||
| return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] | ||
|
|
||
| # 3. the layer was offloaded | ||
| # 4. the layer was offloaded | ||
| assert self.state == "reload_started" | ||
| # wait for the tensor to be reloaded | ||
| torch.cuda.current_stream().wait_event( | ||
|
|
@@ -406,6 +439,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: | |
| """ | ||
| Check if tensor needs to be offloaded. | ||
| """ | ||
| # Only offload tensors with at least 256k elements (~1MB for float32) | ||
| if t.numel() < 256 * 1024: | ||
| return False | ||
|
|
||
| if ( | ||
| not isinstance(t, torch.nn.Parameter) | ||
| and not getattr(t, "_TE_do_not_offload", False) | ||
|
|
@@ -418,7 +455,6 @@ def _check_if_offload(self, t: torch.Tensor) -> bool: | |
| " this tensor will be skipped." | ||
| ) | ||
| return False | ||
|
|
||
| return True | ||
| return False | ||
|
|
||
|
|
@@ -488,11 +524,13 @@ def bwd_step(self, layer_num: int): | |
| self.previous_bwd_layer_id = layer_num | ||
| self.current_layer_id = layer_num | ||
|
|
||
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: | ||
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: | ||
| """Default push tensor method""" | ||
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) | ||
|
|
||
| def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: | ||
| def pop_tensor( | ||
| self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list] | ||
| ) -> torch.Tensor: | ||
| """Default pop tensor method""" | ||
| return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id) | ||
|
|
||
|
|
@@ -592,6 +630,12 @@ def bwd_step(self, layer_num: int): | |
| for layer in self.start_reload_map[layer_num]: | ||
| self.layer_states[layer].start_reload() | ||
|
|
||
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: | ||
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | ||
| if not self.offload_layer_map.get(self.num_of_fwds, False): | ||
| return tensor | ||
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) | ||
|
Comment on lines
+633
to
+637
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing return type annotation in override The Parent class ( def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:This override (line 631): def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:The child class correctly returns the extended type (including Recommendation: Update the parent class |
||
|
|
||
|
|
||
| class ManualOffloadSynchronizer(OffloadSynchronizer): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||||||||||||||||||||||
| from torch.distributed._tensor import DTensor | ||||||||||||||||||||||||||
| import transformer_engine_torch as tex | ||||||||||||||||||||||||||
| from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer | ||||||||||||||||||||||||||
| from transformer_engine.pytorch.quantized_tensor import QuantizedTensor | ||||||||||||||||||||||||||
| from .multi_tensor_apply import multi_tensor_applier | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -372,10 +373,12 @@ def _initialize_state( | |||||||||||||||||||||||||
| store_param_remainders (bool): Store only trailing remainder bits. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| dtype = self.name_to_dtype_map[state_name] | ||||||||||||||||||||||||||
| # Handle QuantizedTensor by dequantizing first | ||||||||||||||||||||||||||
| param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param | ||||||||||||||||||||||||||
| if store_param_remainders: | ||||||||||||||||||||||||||
| data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) | ||||||||||||||||||||||||||
| data = torch.zeros_like(param_for_empty, dtype=torch.int16) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | ||||||||||||||||||||||||||
| data = torch.empty_like(param_for_empty, dtype=dtype) | ||||||||||||||||||||||||||
|
Comment on lines
375
to
+381
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dequantization of QuantizedTensor parameters here breaks DTensor sharding preservation. When The fix should use the original parameter directly without dequantization, since
Suggested change
Alternatively, if dequantization is necessary for some reason, the sharding information from the original parameter should be explicitly preserved.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WE have never supported quantized dtensors in the optimizer and we will not support it after my change.
Comment on lines
+376
to
+381
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good fix for DTensor and QuantizedTensor compatibility This change correctly addresses the issue where The fix:
This ensures both DTensor sharding and QuantizedTensor are handled correctly. Verification needed: Confirm that |
||||||||||||||||||||||||||
| if zero_buffer: | ||||||||||||||||||||||||||
| data.zero_() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical Bug: QuantizedTensor.prepare_for_saving() may not exist
The code assumes that
QuantizedTensorhas aprepare_for_saving()method, but this method is only defined in the abstract base classQuantizedTensorStorage, not inQuantizedTensoritself.Looking at
quantized_tensor.py:QuantizedTensorStorage(line 72-76) declaresprepare_for_saving()as abstractQuantizedTensor(line 328+) inherits fromtorch.Tensor, NOT fromQuantizedTensorStorageThis will cause an AttributeError at runtime when a
QuantizedTensorinstance is offloaded, sinceQuantizedTensordoesn't have theprepare_for_saving()method.The comment on line 369-370 mentions that "QuantizedTensor is a torch.Tensor subclass", but this is misleading - the real issue is that
QuantizedTensordoesn't implement the required methods.Suggested fix: Either:
QuantizedTensoralso inherit fromQuantizedTensorStorageand implementprepare_for_saving(), ORprepare_for_saving()function which properly handles torch.Tensor instances