-
Notifications
You must be signed in to change notification settings - Fork 603
[PyTorch] Bunch of fixes for cpu offloading #2535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d617edf to
6d2f43b
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
| # Only offload tensors with at least 256k elements (~1MB for float32) | ||
| if t.numel() < 256 * 1024: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand, this is the reason we need to expose an option to disable bulk allocation in split_quantize? Bulk-allocated tensors hold on to memory untill all are deallocated, but this condition means that some small tensor might keep a large memory block alive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. And we cannot offload small tensors, because it causes the synchronization of compute/communication operations when CUDA_DEVICE_MAX_CONNECTIONS=1 is set - which is needed by the comm/gemm overlap.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Summary
This PR addresses CPU offloading performance and compatibility issues through multiple coordinated changes:
Key Changes
-
CPU Overhead Reduction: Added layer-level offloading skipping in
DefaultOffloadSynchronizer.push_tensor()to avoid processing tensors when a layer won't be offloaded. Also conditionally guardsmark_not_offload()calls. -
QuantizedTensor Offloading Support: Extended CPU offloading to handle QuantizedTensor types by decomposing them into component tensors, offloading each component recursively, and reconstructing them during reload.
-
DTensor Compatibility: Changed from
torch.empty(shape, device)totorch.empty_like(tensor, dtype)in FusedAdam to properly respect DTensor sharding annotations. -
Small Tensor Offloading Threshold: Added 256K element minimum threshold to prevent offloading of tiny tensors that would cause synchronization issues with
CUDA_DEVICE_MAX_CONNECTIONS=1. -
Bulk Allocation Control: Added
disable_bulk_allocationparameter tosplit_quantize()C++ function, enabled when CPU offloading is active to avoid grouping small tensors with large ones.
Files Modified
transformer_engine/pytorch/cpu_offload.py: Core offloading logic with QuantizedTensor supporttransformer_engine/pytorch/optimizers/fused_adam.py: DTensor-aware state initializationtransformer_engine/pytorch/module/linear.py: Conditional mark_not_offload() guardingtransformer_engine/pytorch/module/grouped_linear.py: disable_bulk_allocation parameter passingtransformer_engine/pytorch/quantized_tensor.py: Removed CPU operation validation that blocked offloading- C++ files: Added disable_bulk_allocation parameter and logic
- Tests: Updated tensor sizes to ensure components exceed 256K threshold
Issues Found
Critical Issue: The FusedAdam DTensor fix is incomplete. When a QuantizedTensor parameter wraps a DTensor, calling dequantize() creates a new plain tensor that loses DTensor sharding metadata. The fix should use the original parameter directly with .empty_like().
Type Annotation Issue: DefaultOffloadSynchronizer.push_tensor() return type annotation doesn't reflect actual return type (missing tuple[list, list]).
Behavioral Change: Default for retain_pinned_cpu_buffers changed from False to True, affecting memory usage patterns and performance characteristics. This change is not documented in the PR description.
Confidence Score: 2/5
- This PR has a critical bug that breaks DTensor parameter handling in FusedAdam, and incomplete type annotations. The DTensor fix is fundamentally broken for QuantizedTensor parameters.
- The PR contains one critical logic bug that makes the DTensor fix incomplete/incorrect. The FusedAdam change dequantizes QuantizedTensor parameters, which destroys DTensor sharding information that the empty_like() call is meant to preserve. Additionally, return type annotations are incomplete, and an undocumented behavioral default change (retain_pinned_cpu_buffers) could affect existing users. While the core CPU offloading improvements are sound, these issues need resolution before merging.
- transformer_engine/pytorch/optimizers/fused_adam.py (critical DTensor bug), transformer_engine/pytorch/cpu_offload.py (type annotation and default value)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/optimizers/fused_adam.py | 2/5 | FusedAdam state initialization broken for QuantizedTensor parameters with DTensor sharding. Calling dequantize() loses DTensor metadata that should be preserved with empty_like(). |
| transformer_engine/pytorch/cpu_offload.py | 3/5 | Multiple changes: QuantizedTensor offloading support added, CPU overhead reduced with layer skipping optimization, but return type annotation mismatch and behavioral default change. Default retain_pinned_cpu_buffers changed from False to True. |
| transformer_engine/pytorch/quantized_tensor.py | 4/5 | Removed CPU operation validation checks. This change is safe as it enables QuantizedTensor component tensors to be offloaded to CPU while preserving the original tensor type checks elsewhere. |
| transformer_engine/pytorch/module/grouped_linear.py | 5/5 | Disables bulk allocation in split_quantize when CPU offloading is active, addressing synchronization issues with small tensors and CUDA_DEVICE_MAX_CONNECTIONS=1. |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 5/5 | Correctly implements disable_bulk_allocation feature by wrapping bulk allocation logic in conditional, allowing fallback to unfused allocation when needed. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant CPU_Offload as CPU Offload System
participant QuantTensor as QuantizedTensor
participant OffloadState as OffloadableLayerState
User->>Linear: Forward pass with QuantizedTensor
Linear->>OffloadState: push_tensor(quantized_tensor)
alt QuantizedTensor Case
OffloadState->>QuantTensor: detach()
QuantTensor->>QuantTensor: prepare_for_saving()
Note over QuantTensor: Decompose into [data, transpose, scale_inv]
loop For each component
OffloadState->>OffloadState: push_tensor(component)
OffloadState->>OffloadState: Check if numel() > 256K
end
OffloadState-->>Linear: Return (push_results, tensor_objs)
else Regular Tensor Case
OffloadState->>OffloadState: Check if offload needed
OffloadState-->>Linear: Return tensor_id or tensor
end
Note over OffloadState: Backward pass
Linear->>OffloadState: pop_tensor(tensor_id_or_tuple)
alt QuantizedTensor Tuple Case
OffloadState->>OffloadState: pop_tensor(each_component)
OffloadState->>QuantTensor: restore_from_saved(components)
OffloadState-->>Linear: Return reconstructed QuantizedTensor
else Regular Case
OffloadState-->>Linear: Return tensor
end
| dtype = self.name_to_dtype_map[state_name] | ||
| # Handle QuantizedTensor by dequantizing first | ||
| param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param | ||
| if store_param_remainders: | ||
| data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) | ||
| data = torch.zeros_like(param_for_empty, dtype=torch.int16) | ||
| else: | ||
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | ||
| data = torch.empty_like(param_for_empty, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dequantization of QuantizedTensor parameters here breaks DTensor sharding preservation. When param is a QuantizedTensor wrapping a DTensor, calling dequantize() creates a new plain tensor that loses the DTensor sharding metadata. This defeats the purpose of using .empty_like() to preserve DTensor sharding.
The fix should use the original parameter directly without dequantization, since .empty_like() respects the sharding annotations of the input tensor regardless of whether it's quantized:
| dtype = self.name_to_dtype_map[state_name] | |
| # Handle QuantizedTensor by dequantizing first | |
| param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param | |
| if store_param_remainders: | |
| data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) | |
| data = torch.zeros_like(param_for_empty, dtype=torch.int16) | |
| else: | |
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | |
| data = torch.empty_like(param_for_empty, dtype=dtype) | |
| data = torch.zeros_like(param, dtype=torch.int16) | |
| ... | |
| data = torch.empty_like(param, dtype=dtype) |
Alternatively, if dequantization is necessary for some reason, the sharding information from the original parameter should be explicitly preserved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WE have never supported quantized dtensors in the optimizer and we will not support it after my change.
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: | ||
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | ||
| if not self.offload_layer_map.get(self.num_of_fwds, False): | ||
| return tensor | ||
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type annotation for DefaultOffloadSynchronizer.push_tensor() is incomplete and doesn't match the base class signature. This method can return a tuple (when processing QuantizedTensors), but the annotation only specifies int | torch.Tensor.
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: | |
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | |
| if not self.offload_layer_map.get(self.num_of_fwds, False): | |
| return tensor | |
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) | |
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: | |
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | |
| if not self.offload_layer_map.get(self.num_of_fwds, False): | |
| return tensor | |
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) |
| num_layers: Optional[int] = 1, | ||
| model_layers: int = 1, | ||
| offload_activations: bool = True, | ||
| offload_weights: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for retain_pinned_cpu_buffers has been changed from False to True. While this may improve performance by reusing CPU buffers across iterations, this is a significant behavioral change that affects memory usage patterns and is not mentioned in the PR description. Consider documenting this change in the commit message and PR description, as it could impact existing users' performance characteristics.
If backward compatibility is important, consider keeping the default as False or provide a migration path for existing code.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci pytorch |
Greptile OverviewGreptile SummaryThis PR implements several optimizations and bug fixes for CPU offloading performance and compatibility: Key Changes1. CPU Overhead Reduction
2. QuantizedTensor Support in CPU Offloading
3. DTensor OOM Fix in Fused Optimizer
4. Bulk Allocation Control for Grouped Tensors
ArchitectureThe changes maintain backward compatibility with all parameters defaulting to existing behavior. The 256K element threshold ensures only tensors large enough (~1MB for float32) are offloaded, avoiding overhead from processing small tensors like scaling factors. TestingTest tensors updated to 128x512 (from 64x256) to ensure both data and scaling factor tensors exceed the 256K threshold and are properly offloaded during testing. Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as Training Loop
participant Layer as TE Linear Layer
participant Offload as CPU Offload System
participant Optimizer as Fused Adam
Note over User,Optimizer: Forward Pass with CPU Offloading
User->>Layer: forward(input)
Layer->>Layer: Check if offloading enabled
alt CPU Offloading Enabled
Layer->>Layer: Check tensor size >= 256K elements
alt Tensor is QuantizedTensor
Layer->>Offload: push_tensor(quantized_tensor)
Offload->>Offload: decompose into components
Offload->>Offload: push each component recursively
else Regular Tensor
Layer->>Offload: push_tensor(tensor)
Offload->>Offload: store tensor index
end
Layer->>Layer: start_offload()
Layer->>Offload: Async GPU→CPU copy on offload stream
end
Layer->>User: return output
Note over User,Optimizer: Backward Pass
User->>Layer: backward(grad_output)
alt CPU Offloading Enabled
Layer->>Offload: start_reload()
Offload->>Offload: Async CPU→GPU copy on offload stream
Layer->>Offload: pop_tensor(tensor_id)
alt QuantizedTensor
Offload->>Offload: pop each component recursively
Offload->>Offload: reconstruct QuantizedTensor
end
Offload->>Layer: return reloaded tensor
end
Layer->>User: return gradients
Note over User,Optimizer: Optimizer Step with DTensor Fix
User->>Optimizer: step()
loop For each parameter
alt Parameter is QuantizedTensor
Optimizer->>Optimizer: dequantize() first
end
alt Parameter is DTensor
Optimizer->>Optimizer: Use empty_like (respects sharding)
else Regular Tensor
Optimizer->>Optimizer: Use empty_like
end
Optimizer->>Optimizer: Create/update optimizer state
end
Optimizer->>User: optimization complete
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
|
|
||
|
|
||
| class Utils: | ||
| # Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in comment: "Tensor big engough" should be "Tensor big enough"
| # Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements, | |
| # Tensor big enough that both data and scaling factor tensor are bigger than 256 * 1024 elements, |
Greptile OverviewGreptile SummaryThis PR addresses CPU offloading performance and compatibility issues through multiple optimizations and bug fixes. The changes span 9 files with both performance improvements and functional fixes. Key ChangesPerformance Optimizations
Bug Fixes
QuantizedTensor Support
Critical Issue FoundData corruption bug in cpu_offload.py (lines 366-375): The code uses Recommendation: Replace Other Concerns
Confidence Score: 1/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as Training Loop
participant Linear as Linear Module
participant CPUOffload as CPU Offload Context
participant LayerState as OffloadableLayerState
participant GPU as GPU Memory
participant CPU as CPU Memory (Pinned)
Note over User,CPU: Forward Pass with CPU Offloading
User->>Linear: forward(input)
activate Linear
Linear->>CPUOffload: __enter__() [saved_tensors_hooks]
activate CPUOffload
CPUOffload->>CPUOffload: fwd_step() - increment layer counter
Linear->>Linear: Compute: output = input @ weight^T + bias
Note over Linear,LayerState: Save activations for backward
Linear->>LayerState: push_tensor(activation)
alt QuantizedTensor
LayerState->>LayerState: detach() [SHALLOW COPY - BUG!]
LayerState->>LayerState: prepare_for_saving() [sets fields to None]
LayerState->>LayerState: Recursively push components
LayerState-->>Linear: Return (push_results, tensor_objs)
else Regular Tensor
alt Tensor size >= 256K elements
LayerState->>LayerState: Mark for offloading
LayerState-->>Linear: Return tensor_id (int)
else Small Tensor
LayerState-->>Linear: Return tensor itself
end
end
Linear->>CPUOffload: __exit__()
CPUOffload->>LayerState: start_offload()
activate LayerState
loop For each tensor in offload group
LayerState->>GPU: Record event: tensor ready
LayerState->>CPU: Async copy (GPU->CPU on offload_stream)
end
LayerState->>LayerState: Record finish_offload_event
deactivate LayerState
Note over LayerState: Offload happens asynchronously
Linear-->>User: return output
deactivate Linear
CPUOffload->>LayerState: release_activation_forward_gpu_memory()
LayerState->>LayerState: Wait for offload_stream
LayerState->>GPU: Free GPU tensors
deactivate CPUOffload
Note over User,CPU: Backward Pass
User->>Linear: backward(grad_output)
activate Linear
Linear->>CPUOffload: bwd_step()
activate CPUOffload
CPUOffload->>LayerState: start_reload()
activate LayerState
loop For each CPU tensor
LayerState->>GPU: Allocate new GPU tensor
LayerState->>GPU: Async copy (CPU->GPU on offload_stream)
LayerState->>LayerState: Record reload_event
end
deactivate LayerState
deactivate CPUOffload
Note over Linear,LayerState: Retrieve saved tensors
Linear->>LayerState: pop_tensor(tensor_id)
LayerState->>LayerState: Wait for reload_event
alt QuantizedTensor (tuple input)
LayerState->>LayerState: Recursively pop components
LayerState->>LayerState: restore_from_saved() [in-place modification]
LayerState-->>Linear: Return reconstructed QuantizedTensor
else Regular Tensor
LayerState-->>Linear: Return reloaded tensor
end
Linear->>Linear: Compute gradients
Linear-->>User: return grad_input
deactivate Linear
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 7 comments
| if isinstance(tensor, QuantizedTensor): | ||
| # Make a copy because prepare_for_saving modifies the object (sets fields to None) | ||
| tensor_copy = tensor.detach() | ||
| # Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass, | ||
| # so the generic prepare_for_saving would not call tensor.prepare_for_saving() | ||
| saved_tensors, tensor_obj = tensor_copy.prepare_for_saving() | ||
| push_results = [ | ||
| self.push_tensor(t) if t is not None else None for t in saved_tensors | ||
| ] | ||
| return (push_results, [tensor_obj]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical Bug: detach() creates shallow copy, causing data corruption
The comment says "Make a copy because prepare_for_saving modifies the object", but tensor.detach() only creates a shallow copy that shares the underlying storage with the original tensor.
Looking at how detach() is implemented (e.g., in Float8Tensor):
def detach(self) -> Float8Tensor:
return Float8Tensor.make_like(self)And make_like uses get_metadata() which returns references to the internal storage tensors:
def get_metadata(self) -> Dict[str, Any]:
return {
"rowwise_data": self._rowwise_data, # Reference, not copy!
...
}Problem: When tensor_copy.prepare_for_saving() sets self._rowwise_data = None, it modifies the shared storage object, corrupting the original tensor that's still being used in the forward pass.
Impact: This will cause the forward pass to access None data, leading to crashes or incorrect results.
Solution: Either use clone() instead of detach() to create a deep copy, or restructure the code to avoid mutating the tensor during push_tensor.
| if isinstance(tensor_or_tensor_id, tuple): | ||
| push_results, tensor_objs = tensor_or_tensor_id | ||
| # Recursively pop each component | ||
| reloaded_tensors = [ | ||
| self.pop_tensor(pr) if pr is not None else None for pr in push_results | ||
| ] | ||
| # Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy | ||
| tensor_obj = tensor_objs[0] | ||
| tensor_obj.restore_from_saved(reloaded_tensors) | ||
| return tensor_obj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential Issue: In-place modification during backward pass
The pop_tensor method for QuantizedTensor performs in-place restoration:
tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_objThis modifies tensor_obj (which was created during push_tensor and stored in the tuple) in-place and returns it. However, PyTorch's autograd expects saved tensors to be immutable during backward.
Concern: If tensor_obj is the same object that was corrupted in push_tensor (due to the shallow copy bug), it may already be in an inconsistent state. Additionally, modifying saved objects in-place during backward could cause issues if PyTorch makes assumptions about tensor immutability.
Recommendation: Verify that this pattern is safe with PyTorch's autograd system, or consider creating a fresh tensor object rather than mutating the saved one.
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: | ||
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | ||
| if not self.offload_layer_map.get(self.num_of_fwds, False): | ||
| return tensor | ||
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return type annotation in override
The push_tensor method in DefaultOffloadSynchronizer overrides the parent class method but has an inconsistent return type annotation.
Parent class (OffloadSynchronizer line 527):
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:This override (line 631):
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:The child class correctly returns the extended type (including tuple[list, list] for QuantizedTensor), but the parent class method should also have this return type to maintain type safety across polymorphic calls.
Recommendation: Update the parent class OffloadSynchronizer.push_tensor method signature to include tuple[list, list] in the return type annotation (line 527).
| # Handle QuantizedTensor by dequantizing first | ||
| param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param | ||
| if store_param_remainders: | ||
| data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) | ||
| data = torch.zeros_like(param_for_empty, dtype=torch.int16) | ||
| else: | ||
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | ||
| data = torch.empty_like(param_for_empty, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good fix for DTensor and QuantizedTensor compatibility
This change correctly addresses the issue where torch.empty(param.shape, ...) doesn't respect DTensor sharding (would create full tensors on each device instead of sharded ones).
The fix:
- Dequantizes QuantizedTensor first to get a standard tensor
- Uses
torch.empty_like()which properly handles DTensor sharding - Then applies the dtype override
This ensures both DTensor sharding and QuantizedTensor are handled correctly.
Verification needed: Confirm that dequantize() for QuantizedTensor doesn't trigger expensive computation if the tensor is already on the correct device, or if it's acceptable overhead for this initialization path.
| # Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements, | ||
| # so that they are offloaded to GPU. | ||
| tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16) | ||
| _B = 64 | ||
| _S = 256 | ||
| _B = 128 | ||
| _S = 512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test configuration updated to match new offload threshold
The batch size and sequence length are increased from _B=64, _S=256 to _B=128, _S=512 to ensure test tensors exceed the new 256K element threshold for offloading (introduced in cpu_offload.py line 443).
Calculation:
- Old: 64 × 256 = 16,384 elements (would NOT be offloaded with new threshold)
- New: 128 × 512 = 65,536 elements base, but comment says "both data and scaling factor tensor are bigger than 256 * 1024 elements"
Concern: The math doesn't add up. With hidden_size=512 and num_heads=8, the actual tensor sizes in tests would be around 128×512×4×256 = 67,108,864 elements, which is well above 256K. However, the comment specifically mentions "scaling factor tensor" which is typically much smaller (e.g., per-block scales).
Recommendation: Verify that all relevant tensors (including scaling factors for quantized tensors) actually exceed 256K elements, or update the comment to be more precise about which tensors are being tested.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if fp8 and not debug: | ||
| inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) | ||
| inputmats = tex.split_quantize( | ||
| inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disabling bulk allocation when CPU offloading is active
This change disables bulk tensor allocation for grouped quantization when cpu_offloading=True by passing disable_bulk_allocation=cpu_offloading to split_quantize.
Rationale (from PR description): With bulk allocation, all tensors in a group are offloaded together (all-or-nothing), which means small tensors like scales get offloaded even though they're below the 256K threshold. This causes issues when CUDA_DEVICE_MAX_CONNECTIONS=1 because small tensors use SMs for copying instead of copy engines, leading to synchronization problems.
This is a reasonable workaround, but note:
- Disabling bulk allocation may have performance implications for the quantization itself (not just offloading)
- The fix is automatic when offloading is enabled, which is good for usability
- Consider if there's a more targeted fix (e.g., filtering small tensors from offload groups) that wouldn't require disabling bulk allocation entirely
Recommendation: Document any performance trade-offs of disabling bulk allocation in the function's docstring or PR description.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Additional Comments (1)
The previous code had a safeguard that prevented certain operations on QuantizedTensor when on CPU: _quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)And checked: assert func in _quantized_tensor_cpu_supported_ops, \
f"QuantizedTensor on CPU does not support this operation: {func}"This PR removes this check entirely, which means QuantizedTensor operations on CPU are now allowed without validation. Questions:
Recommendation: If CPU operations are now fully supported, add a comment explaining the change. If only specific operations are needed for offloading, consider keeping the safeguard but expanding the whitelist. |
Description
This PR addresses several issues related to CPU offloading performance and compatibility.
1. CPU Overhead Reduction
This PR reduces CPU overhead through multiple optimizations:
__torch_function__hook: Previously costly validation checks have been eliminated.2. Out of Memory Error with Fused Optimizer and DTensor
PyTorch introduced JAX-like DTensor, and some workloads use our fused optimizer with this tensor type. The previous implementation used
.empty_like, which works correctly for standard tensors but does not respect sharding for DTensor—resulting in full tensors being created on each device. This has been fixed by switching to.emptywith explicit shape specification.3. Synchronization Issues When Offloading Small Tensors
For grouped tensors, allocation is performed in bulk, requiring an all-or-nothing offloading approach. This meant small tensors like scales were also offloaded, which caused issues with comm-gemm overlap when
CUDA_DEVICE_MAX_CONNECTIONS=1was set. In these cases, tensors were small enough that SMs were used for copying instead of copy engines, leading to synchronization problems.Fixes:
Type of change
Checklist: