Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 19, 2025

Description

This PR addresses several issues related to CPU offloading performance and compatibility.

1. CPU Overhead Reduction

This PR reduces CPU overhead through multiple optimizations:

  • Skip processing for non-offloaded layers: Tensors are no longer processed when the layer is known to be non-offloaded (the case for non-manual synchronization when it is known which layers are offloaded in advance). Manual synchronization overhead may be addressed in future work.
  • Remove expensive checks in __torch_function__ hook: Previously costly validation checks have been eliminated.
  • Skip offloading small tensors: Small tensors are now excluded from offloading to avoid overhead.

2. Out of Memory Error with Fused Optimizer and DTensor

PyTorch introduced JAX-like DTensor, and some workloads use our fused optimizer with this tensor type. The previous implementation used .empty_like, which works correctly for standard tensors but does not respect sharding for DTensor—resulting in full tensors being created on each device. This has been fixed by switching to .empty with explicit shape specification.

3. Synchronization Issues When Offloading Small Tensors

For grouped tensors, allocation is performed in bulk, requiring an all-or-nothing offloading approach. This meant small tensors like scales were also offloaded, which caused issues with comm-gemm overlap when CUDA_DEVICE_MAX_CONNECTIONS=1 was set. In these cases, tensors were small enough that SMs were used for copying instead of copy engines, leading to synchronization problems.

Fixes:

  • Added a minimum tensor size threshold for offloading to mitigate this issue.
  • Added an option to disable bulk allocation for grouped tensor offloading (enabled automatically when offloading is active).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pre-commit-ci bot and others added 4 commits December 19, 2025 14:22
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Comment on lines 453 to 455
# Only offload tensors with at least 256k elements (~1MB for float32)
if t.numel() < 256 * 1024:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand, this is the reason we need to expose an option to disable bulk allocation in split_quantize? Bulk-allocated tensors hold on to memory untill all are deallocated, but this condition means that some small tensor might keep a large memory block alive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And we cannot offload small tensors, because it causes the synchronization of compute/communication operations when CUDA_DEVICE_MAX_CONNECTIONS=1 is set - which is needed by the comm/gemm overlap.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL and others added 6 commits January 8, 2026 13:50
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review January 9, 2026 15:57
@pggPL
Copy link
Collaborator Author

pggPL commented Jan 9, 2026

/te-ci pytorch

timmoon10
timmoon10 previously approved these changes Jan 9, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Summary

This PR addresses CPU offloading performance and compatibility issues through multiple coordinated changes:

Key Changes

  1. CPU Overhead Reduction: Added layer-level offloading skipping in DefaultOffloadSynchronizer.push_tensor() to avoid processing tensors when a layer won't be offloaded. Also conditionally guards mark_not_offload() calls.

  2. QuantizedTensor Offloading Support: Extended CPU offloading to handle QuantizedTensor types by decomposing them into component tensors, offloading each component recursively, and reconstructing them during reload.

  3. DTensor Compatibility: Changed from torch.empty(shape, device) to torch.empty_like(tensor, dtype) in FusedAdam to properly respect DTensor sharding annotations.

  4. Small Tensor Offloading Threshold: Added 256K element minimum threshold to prevent offloading of tiny tensors that would cause synchronization issues with CUDA_DEVICE_MAX_CONNECTIONS=1.

  5. Bulk Allocation Control: Added disable_bulk_allocation parameter to split_quantize() C++ function, enabled when CPU offloading is active to avoid grouping small tensors with large ones.

Files Modified

  • transformer_engine/pytorch/cpu_offload.py: Core offloading logic with QuantizedTensor support
  • transformer_engine/pytorch/optimizers/fused_adam.py: DTensor-aware state initialization
  • transformer_engine/pytorch/module/linear.py: Conditional mark_not_offload() guarding
  • transformer_engine/pytorch/module/grouped_linear.py: disable_bulk_allocation parameter passing
  • transformer_engine/pytorch/quantized_tensor.py: Removed CPU operation validation that blocked offloading
  • C++ files: Added disable_bulk_allocation parameter and logic
  • Tests: Updated tensor sizes to ensure components exceed 256K threshold

Issues Found

Critical Issue: The FusedAdam DTensor fix is incomplete. When a QuantizedTensor parameter wraps a DTensor, calling dequantize() creates a new plain tensor that loses DTensor sharding metadata. The fix should use the original parameter directly with .empty_like().

Type Annotation Issue: DefaultOffloadSynchronizer.push_tensor() return type annotation doesn't reflect actual return type (missing tuple[list, list]).

Behavioral Change: Default for retain_pinned_cpu_buffers changed from False to True, affecting memory usage patterns and performance characteristics. This change is not documented in the PR description.

Confidence Score: 2/5

  • This PR has a critical bug that breaks DTensor parameter handling in FusedAdam, and incomplete type annotations. The DTensor fix is fundamentally broken for QuantizedTensor parameters.
  • The PR contains one critical logic bug that makes the DTensor fix incomplete/incorrect. The FusedAdam change dequantizes QuantizedTensor parameters, which destroys DTensor sharding information that the empty_like() call is meant to preserve. Additionally, return type annotations are incomplete, and an undocumented behavioral default change (retain_pinned_cpu_buffers) could affect existing users. While the core CPU offloading improvements are sound, these issues need resolution before merging.
  • transformer_engine/pytorch/optimizers/fused_adam.py (critical DTensor bug), transformer_engine/pytorch/cpu_offload.py (type annotation and default value)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/optimizers/fused_adam.py 2/5 FusedAdam state initialization broken for QuantizedTensor parameters with DTensor sharding. Calling dequantize() loses DTensor metadata that should be preserved with empty_like().
transformer_engine/pytorch/cpu_offload.py 3/5 Multiple changes: QuantizedTensor offloading support added, CPU overhead reduced with layer skipping optimization, but return type annotation mismatch and behavioral default change. Default retain_pinned_cpu_buffers changed from False to True.
transformer_engine/pytorch/quantized_tensor.py 4/5 Removed CPU operation validation checks. This change is safe as it enables QuantizedTensor component tensors to be offloaded to CPU while preserving the original tensor type checks elsewhere.
transformer_engine/pytorch/module/grouped_linear.py 5/5 Disables bulk allocation in split_quantize when CPU offloading is active, addressing synchronization issues with small tensors and CUDA_DEVICE_MAX_CONNECTIONS=1.
transformer_engine/pytorch/csrc/extensions/cast.cpp 5/5 Correctly implements disable_bulk_allocation feature by wrapping bulk allocation logic in conditional, allowing fallback to unfused allocation when needed.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant CPU_Offload as CPU Offload System
    participant QuantTensor as QuantizedTensor
    participant OffloadState as OffloadableLayerState
    
    User->>Linear: Forward pass with QuantizedTensor
    Linear->>OffloadState: push_tensor(quantized_tensor)
    alt QuantizedTensor Case
        OffloadState->>QuantTensor: detach()
        QuantTensor->>QuantTensor: prepare_for_saving()
        Note over QuantTensor: Decompose into [data, transpose, scale_inv]
        loop For each component
            OffloadState->>OffloadState: push_tensor(component)
            OffloadState->>OffloadState: Check if numel() > 256K
        end
        OffloadState-->>Linear: Return (push_results, tensor_objs)
    else Regular Tensor Case
        OffloadState->>OffloadState: Check if offload needed
        OffloadState-->>Linear: Return tensor_id or tensor
    end
    
    Note over OffloadState: Backward pass
    Linear->>OffloadState: pop_tensor(tensor_id_or_tuple)
    alt QuantizedTensor Tuple Case
        OffloadState->>OffloadState: pop_tensor(each_component)
        OffloadState->>QuantTensor: restore_from_saved(components)
        OffloadState-->>Linear: Return reconstructed QuantizedTensor
    else Regular Case
        OffloadState-->>Linear: Return tensor
    end
Loading

Comment on lines 375 to +381
dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantization of QuantizedTensor parameters here breaks DTensor sharding preservation. When param is a QuantizedTensor wrapping a DTensor, calling dequantize() creates a new plain tensor that loses the DTensor sharding metadata. This defeats the purpose of using .empty_like() to preserve DTensor sharding.

The fix should use the original parameter directly without dequantization, since .empty_like() respects the sharding annotations of the input tensor regardless of whether it's quantized:

Suggested change
dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
data = torch.zeros_like(param, dtype=torch.int16)
...
data = torch.empty_like(param, dtype=dtype)

Alternatively, if dequantization is necessary for some reason, the sharding information from the original parameter should be explicitly preserved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WE have never supported quantized dtensors in the optimizer and we will not support it after my change.

Comment on lines 631 to 636
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation for DefaultOffloadSynchronizer.push_tensor() is incomplete and doesn't match the base class signature. This method can return a tuple (when processing QuantizedTensors), but the annotation only specifies int | torch.Tensor.

Suggested change
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)

num_layers: Optional[int] = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value for retain_pinned_cpu_buffers has been changed from False to True. While this may improve performance by reusing CPU buffers across iterations, this is a significant behavioral change that affects memory usage patterns and is not mentioned in the PR description. Consider documenting this change in the commit message and PR description, as it could impact existing users' performance characteristics.

If backward compatibility is important, consider keeping the default as False or provide a migration path for existing code.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Jan 12, 2026

/te-ci pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Overview

Greptile Summary

This PR implements several optimizations and bug fixes for CPU offloading performance and compatibility:

Key Changes

1. CPU Overhead Reduction

  • Adds minimum tensor size threshold (256K elements) to skip offloading small tensors that would cause overhead
  • Optimizes DefaultOffloadSynchronizer.push_tensor() to skip processing when layer is known to be non-offloaded
  • Removes expensive CPU operation validation checks from QuantizedTensor.__torch_function__ hook
  • Conditionally calls mark_not_offload() only when CPU offloading is enabled in linear.py

2. QuantizedTensor Support in CPU Offloading

  • Implements recursive decomposition/reconstruction of QuantizedTensors during offload/reload cycle
  • push_tensor() now decomposes QuantizedTensors into component tensors using prepare_for_saving()
  • pop_tensor() reconstructs QuantizedTensors from components using restore_from_saved()
  • Properly handles the tuple return type (list[push_results], list[tensor_objs]) for QuantizedTensors

3. DTensor OOM Fix in Fused Optimizer

  • Switches from .empty(shape) to .empty_like() to properly respect DTensor sharding
  • Adds QuantizedTensor handling by dequantizing before creating optimizer state buffers
  • Previous implementation created full tensors on each device instead of sharded tensors

4. Bulk Allocation Control for Grouped Tensors

  • Adds disable_bulk_allocation parameter to split_quantize() C++ function
  • Automatically disables bulk allocation when CPU offloading is active via grouped_linear.py
  • Prevents synchronization issues when small tensors use SMs instead of copy engines with CUDA_DEVICE_MAX_CONNECTIONS=1

Architecture

The changes maintain backward compatibility with all parameters defaulting to existing behavior. The 256K element threshold ensures only tensors large enough (~1MB for float32) are offloaded, avoiding overhead from processing small tensors like scaling factors.

Testing

Test tensors updated to 128x512 (from 64x256) to ensure both data and scaling factor tensors exceed the 256K threshold and are properly offloaded during testing.

Confidence Score: 4/5

  • This PR is safe to merge with minor documentation improvements recommended
  • The implementation is sound with well-thought-out optimizations. The QuantizedTensor decomposition logic is correctly implemented, the DTensor fix properly uses empty_like, and the bulk allocation control addresses the synchronization issue. All changes maintain backward compatibility with sensible defaults. Only minor style issue found (typo in test comment).
  • No files require special attention - all changes are well-implemented and tested

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/cpu_offload.py 4/5 Adds QuantizedTensor support and optimizations: decompose/reconstruct QuantizedTensors during offload, adds 256K element minimum threshold, optimizes push_tensor to skip processing for non-offloaded layers
transformer_engine/pytorch/optimizers/fused_adam.py 4/5 Fixes DTensor OOM by using empty_like instead of empty with explicit shape, adds QuantizedTensor support by dequantizing before creating optimizer state buffers
transformer_engine/pytorch/quantized_tensor.py 5/5 Removes CPU operation checks from torch_function hook, eliminating expensive validation overhead mentioned in PR description
transformer_engine/pytorch/module/linear.py 5/5 Optimizes mark_not_offload to only run when CPU offloading is enabled, reducing overhead when offloading is disabled
transformer_engine/pytorch/module/grouped_linear.py 5/5 Adds disable_bulk_allocation parameter to split_quantize call when CPU offloading is active to avoid small tensor offload issues with CUDA_DEVICE_MAX_CONNECTIONS=1
transformer_engine/pytorch/csrc/extensions/cast.cpp 5/5 Implements disable_bulk_allocation logic to conditionally skip bulk allocation methods based on parameter, allowing fine-grained control over tensor allocation strategy

Sequence Diagram

sequenceDiagram
    participant User as Training Loop
    participant Layer as TE Linear Layer
    participant Offload as CPU Offload System
    participant Optimizer as Fused Adam
    
    Note over User,Optimizer: Forward Pass with CPU Offloading
    User->>Layer: forward(input)
    Layer->>Layer: Check if offloading enabled
    alt CPU Offloading Enabled
        Layer->>Layer: Check tensor size >= 256K elements
        alt Tensor is QuantizedTensor
            Layer->>Offload: push_tensor(quantized_tensor)
            Offload->>Offload: decompose into components
            Offload->>Offload: push each component recursively
        else Regular Tensor
            Layer->>Offload: push_tensor(tensor)
            Offload->>Offload: store tensor index
        end
        Layer->>Layer: start_offload()
        Layer->>Offload: Async GPU→CPU copy on offload stream
    end
    Layer->>User: return output
    
    Note over User,Optimizer: Backward Pass
    User->>Layer: backward(grad_output)
    alt CPU Offloading Enabled
        Layer->>Offload: start_reload()
        Offload->>Offload: Async CPU→GPU copy on offload stream
        Layer->>Offload: pop_tensor(tensor_id)
        alt QuantizedTensor
            Offload->>Offload: pop each component recursively
            Offload->>Offload: reconstruct QuantizedTensor
        end
        Offload->>Layer: return reloaded tensor
    end
    Layer->>User: return gradients
    
    Note over User,Optimizer: Optimizer Step with DTensor Fix
    User->>Optimizer: step()
    loop For each parameter
        alt Parameter is QuantizedTensor
            Optimizer->>Optimizer: dequantize() first
        end
        alt Parameter is DTensor
            Optimizer->>Optimizer: Use empty_like (respects sharding)
        else Regular Tensor
            Optimizer->>Optimizer: Use empty_like
        end
        Optimizer->>Optimizer: Create/update optimizer state
    end
    Optimizer->>User: optimization complete
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile



class Utils:
# Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "Tensor big engough" should be "Tensor big enough"

Suggested change
# Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements,
# Tensor big enough that both data and scaling factor tensor are bigger than 256 * 1024 elements,

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Overview

Greptile Summary

This PR addresses CPU offloading performance and compatibility issues through multiple optimizations and bug fixes. The changes span 9 files with both performance improvements and functional fixes.

Key Changes

Performance Optimizations

  • Skip processing for non-offloaded layers: Added early return in DefaultOffloadSynchronizer.push_tensor() when the current layer is known to not be offloaded, reducing CPU overhead
  • Minimum tensor size threshold: Introduced 256K element threshold (≈1MB for float32) to avoid offloading small tensors, which can cause synchronization issues when CUDA_DEVICE_MAX_CONNECTIONS=1
  • Conditional mark_not_offload: Only calls mark_not_offload() when CPU offloading is actually enabled, avoiding unnecessary overhead from prepare_for_saving/restore_from_saved calls

Bug Fixes

  • DTensor OOM fix: Changed from torch.empty(param.shape, ...) to torch.empty_like() in fused optimizer to respect DTensor sharding
  • Grouped tensor offloading: Added disable_bulk_allocation option for grouped quantization to prevent small tensors (like scales) from being offloaded together with large tensors

QuantizedTensor Support

  • Added support for offloading QuantizedTensor by decomposing into component tensors during push_tensor() and reconstructing during pop_tensor()
  • Removed CPU operation validation checks from QuantizedTensor to enable CPU operations

Critical Issue Found

Data corruption bug in cpu_offload.py (lines 366-375): The code uses tensor.detach() to "make a copy" before calling prepare_for_saving(), but detach() only creates a shallow copy that shares the underlying storage with the original tensor. When prepare_for_saving() sets internal fields to None, it corrupts the original tensor still being used in the forward pass. This will cause crashes or incorrect results when QuantizedTensors are offloaded.

Recommendation: Replace detach() with clone() or restructure to avoid mutating tensors during push.

Other Concerns

  1. In-place modification during backward: pop_tensor() modifies saved tensor objects in-place, which may violate PyTorch autograd assumptions about tensor immutability
  2. Removed CPU safeguards: Validation checks for QuantizedTensor operations on CPU were removed without clear verification that all operations work correctly
  3. Type annotation inconsistency: Parent class OffloadSynchronizer.push_tensor() return type doesn't include tuple[list, list] which the child class now returns

Confidence Score: 1/5

  • This PR contains a critical data corruption bug that will cause crashes or incorrect results when offloading QuantizedTensors
  • Score of 1 reflects the critical bug in cpu_offload.py where detach() creates a shallow copy but prepare_for_saving() mutates the shared storage, corrupting the original tensor. This bug will manifest when QuantizedTensors are offloaded, leading to None pointer accesses in the forward pass. Additionally, the removal of CPU operation safeguards and in-place backward modifications raise concerns about correctness. The other changes (DTensor fix, performance optimizations) are sound, but the critical bug makes this PR unsafe to merge without fixes.
  • transformer_engine/pytorch/cpu_offload.py requires immediate attention for the shallow copy bug in push_tensor(). Also review quantized_tensor.py to ensure removed CPU checks don't cause silent failures.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/cpu_offload.py 1/5 Critical bug in QuantizedTensor handling: detach() creates shallow copy causing data corruption when prepare_for_saving() sets fields to None. Also adds optimization to skip processing for non-offloaded layers and 256K element threshold.
transformer_engine/pytorch/optimizers/fused_adam.py 4/5 Good fix for DTensor sharding issue by using empty_like instead of empty with explicit shape. Properly handles QuantizedTensor by dequantizing first. Change is straightforward and addresses the OOM issue described in PR.
transformer_engine/pytorch/quantized_tensor.py 2/5 Removes CPU operation validation check for QuantizedTensor. While this enables CPU offloading, it removes safeguards against unsupported operations. Need verification that all operations work correctly on CPU or that the risk is acceptable.
transformer_engine/pytorch/module/linear.py 5/5 Adds conditional to only call mark_not_offload when cpu_offloading is enabled, reducing unnecessary overhead. Clean optimization with no functional impact when offloading is disabled.
transformer_engine/pytorch/module/grouped_linear.py 4/5 Disables bulk allocation when CPU offloading is active to prevent small tensors from being offloaded together with large tensors, avoiding synchronization issues with CUDA_DEVICE_MAX_CONNECTIONS=1. Reasonable workaround.

Sequence Diagram

sequenceDiagram
    participant User as Training Loop
    participant Linear as Linear Module
    participant CPUOffload as CPU Offload Context
    participant LayerState as OffloadableLayerState
    participant GPU as GPU Memory
    participant CPU as CPU Memory (Pinned)

    Note over User,CPU: Forward Pass with CPU Offloading

    User->>Linear: forward(input)
    activate Linear
    
    Linear->>CPUOffload: __enter__() [saved_tensors_hooks]
    activate CPUOffload
    CPUOffload->>CPUOffload: fwd_step() - increment layer counter
    
    Linear->>Linear: Compute: output = input @ weight^T + bias
    
    Note over Linear,LayerState: Save activations for backward
    Linear->>LayerState: push_tensor(activation)
    alt QuantizedTensor
        LayerState->>LayerState: detach() [SHALLOW COPY - BUG!]
        LayerState->>LayerState: prepare_for_saving() [sets fields to None]
        LayerState->>LayerState: Recursively push components
        LayerState-->>Linear: Return (push_results, tensor_objs)
    else Regular Tensor
        alt Tensor size >= 256K elements
            LayerState->>LayerState: Mark for offloading
            LayerState-->>Linear: Return tensor_id (int)
        else Small Tensor
            LayerState-->>Linear: Return tensor itself
        end
    end
    
    Linear->>CPUOffload: __exit__()
    CPUOffload->>LayerState: start_offload()
    activate LayerState
    
    loop For each tensor in offload group
        LayerState->>GPU: Record event: tensor ready
        LayerState->>CPU: Async copy (GPU->CPU on offload_stream)
    end
    
    LayerState->>LayerState: Record finish_offload_event
    deactivate LayerState
    
    Note over LayerState: Offload happens asynchronously
    
    Linear-->>User: return output
    deactivate Linear
    
    CPUOffload->>LayerState: release_activation_forward_gpu_memory()
    LayerState->>LayerState: Wait for offload_stream
    LayerState->>GPU: Free GPU tensors
    
    deactivate CPUOffload

    Note over User,CPU: Backward Pass

    User->>Linear: backward(grad_output)
    activate Linear
    
    Linear->>CPUOffload: bwd_step()
    activate CPUOffload
    CPUOffload->>LayerState: start_reload()
    activate LayerState
    
    loop For each CPU tensor
        LayerState->>GPU: Allocate new GPU tensor
        LayerState->>GPU: Async copy (CPU->GPU on offload_stream)
        LayerState->>LayerState: Record reload_event
    end
    
    deactivate LayerState
    deactivate CPUOffload
    
    Note over Linear,LayerState: Retrieve saved tensors
    Linear->>LayerState: pop_tensor(tensor_id)
    LayerState->>LayerState: Wait for reload_event
    alt QuantizedTensor (tuple input)
        LayerState->>LayerState: Recursively pop components
        LayerState->>LayerState: restore_from_saved() [in-place modification]
        LayerState-->>Linear: Return reconstructed QuantizedTensor
    else Regular Tensor
        LayerState-->>Linear: Return reloaded tensor
    end
    
    Linear->>Linear: Compute gradients
    Linear-->>User: return grad_input
    deactivate Linear

Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +366 to +375
if isinstance(tensor, QuantizedTensor):
# Make a copy because prepare_for_saving modifies the object (sets fields to None)
tensor_copy = tensor.detach()
# Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass,
# so the generic prepare_for_saving would not call tensor.prepare_for_saving()
saved_tensors, tensor_obj = tensor_copy.prepare_for_saving()
push_results = [
self.push_tensor(t) if t is not None else None for t in saved_tensors
]
return (push_results, [tensor_obj])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical Bug: detach() creates shallow copy, causing data corruption

The comment says "Make a copy because prepare_for_saving modifies the object", but tensor.detach() only creates a shallow copy that shares the underlying storage with the original tensor.

Looking at how detach() is implemented (e.g., in Float8Tensor):

def detach(self) -> Float8Tensor:
    return Float8Tensor.make_like(self)

And make_like uses get_metadata() which returns references to the internal storage tensors:

def get_metadata(self) -> Dict[str, Any]:
    return {
        "rowwise_data": self._rowwise_data,  # Reference, not copy!
        ...
    }

Problem: When tensor_copy.prepare_for_saving() sets self._rowwise_data = None, it modifies the shared storage object, corrupting the original tensor that's still being used in the forward pass.

Impact: This will cause the forward pass to access None data, leading to crashes or incorrect results.

Solution: Either use clone() instead of detach() to create a deep copy, or restructure the code to avoid mutating the tensor during push_tensor.

Comment on lines +407 to +416
if isinstance(tensor_or_tensor_id, tuple):
push_results, tensor_objs = tensor_or_tensor_id
# Recursively pop each component
reloaded_tensors = [
self.pop_tensor(pr) if pr is not None else None for pr in push_results
]
# Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy
tensor_obj = tensor_objs[0]
tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_obj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential Issue: In-place modification during backward pass

The pop_tensor method for QuantizedTensor performs in-place restoration:

tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_obj

This modifies tensor_obj (which was created during push_tensor and stored in the tuple) in-place and returns it. However, PyTorch's autograd expects saved tensors to be immutable during backward.

Concern: If tensor_obj is the same object that was corrupted in push_tensor (due to the shallow copy bug), it may already be in an inconsistent state. Additionally, modifying saved objects in-place during backward could cause issues if PyTorch makes assumptions about tensor immutability.

Recommendation: Verify that this pattern is safe with PyTorch's autograd system, or consider creating a fresh tensor object rather than mutating the saved one.

Comment on lines +631 to +635
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type annotation in override

The push_tensor method in DefaultOffloadSynchronizer overrides the parent class method but has an inconsistent return type annotation.

Parent class (OffloadSynchronizer line 527):

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:

This override (line 631):

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:

The child class correctly returns the extended type (including tuple[list, list] for QuantizedTensor), but the parent class method should also have this return type to maintain type safety across polymorphic calls.

Recommendation: Update the parent class OffloadSynchronizer.push_tensor method signature to include tuple[list, list] in the return type annotation (line 527).

Comment on lines +376 to +381
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fix for DTensor and QuantizedTensor compatibility

This change correctly addresses the issue where torch.empty(param.shape, ...) doesn't respect DTensor sharding (would create full tensors on each device instead of sharded ones).

The fix:

  1. Dequantizes QuantizedTensor first to get a standard tensor
  2. Uses torch.empty_like() which properly handles DTensor sharding
  3. Then applies the dtype override

This ensures both DTensor sharding and QuantizedTensor are handled correctly.

Verification needed: Confirm that dequantize() for QuantizedTensor doesn't trigger expensive computation if the tensor is already on the correct device, or if it's acceptable overhead for this initialization path.

Comment on lines +57 to +61
# Tensor big engough that both data and scaling factor tensor are bigger than 256 * 1024 elements,
# so that they are offloaded to GPU.
tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
_B = 64
_S = 256
_B = 128
_S = 512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test configuration updated to match new offload threshold

The batch size and sequence length are increased from _B=64, _S=256 to _B=128, _S=512 to ensure test tensors exceed the new 256K element threshold for offloading (introduced in cpu_offload.py line 443).

Calculation:

  • Old: 64 × 256 = 16,384 elements (would NOT be offloaded with new threshold)
  • New: 128 × 512 = 65,536 elements base, but comment says "both data and scaling factor tensor are bigger than 256 * 1024 elements"

Concern: The math doesn't add up. With hidden_size=512 and num_heads=8, the actual tensor sizes in tests would be around 128×512×4×256 = 67,108,864 elements, which is well above 256K. However, the comment specifically mentions "scaling factor tensor" which is typically much smaller (e.g., per-block scales).

Recommendation: Verify that all relevant tensors (including scaling factors for quantized tensors) actually exceed 256K elements, or update the comment to be more precise about which tensors are being tested.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 145 to +148
if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disabling bulk allocation when CPU offloading is active

This change disables bulk tensor allocation for grouped quantization when cpu_offloading=True by passing disable_bulk_allocation=cpu_offloading to split_quantize.

Rationale (from PR description): With bulk allocation, all tensors in a group are offloaded together (all-or-nothing), which means small tensors like scales get offloaded even though they're below the 256K threshold. This causes issues when CUDA_DEVICE_MAX_CONNECTIONS=1 because small tensors use SMs for copying instead of copy engines, leading to synchronization problems.

This is a reasonable workaround, but note:

  1. Disabling bulk allocation may have performance implications for the quantization itself (not just offloading)
  2. The fix is automatic when offloading is enabled, which is good for usability
  3. Consider if there's a more targeted fix (e.g., filtering small tensors from offload groups) that wouldn't require disabling bulk allocation entirely

Recommendation: Document any performance trade-offs of disabling bulk allocation in the function's docstring or PR description.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (1)

transformer_engine/pytorch/quantized_tensor.py
Removed CPU support check - verify this is intentional

The previous code had a safeguard that prevented certain operations on QuantizedTensor when on CPU:

_quantized_tensor_cpu_supported_ops = (
    torch.ops.aten.empty_like.default,
    torch.ops.aten.copy_.default,
)

And checked:

assert func in _quantized_tensor_cpu_supported_ops, \
    f"QuantizedTensor on CPU does not support this operation: {func}"

This PR removes this check entirely, which means QuantizedTensor operations on CPU are now allowed without validation.

Questions:

  1. Is this intentional to enable CPU offloading?
  2. Were the previously unsupported operations actually tested on CPU?
  3. Could this lead to silent failures or incorrect behavior for operations that don't work correctly on CPU?

Recommendation: If CPU operations are now fully supported, add a comment explaining the change. If only specific operations are needed for offloading, consider keeping the safeguard but expanding the whitelist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants