Skip to content
Merged
11 changes: 9 additions & 2 deletions tests/pytorch/test_cpu_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@


class Utils:
# Tensor used for simulating long-running GPU work in long_job()
tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
_B = 64
_S = 256
# Test tensor dimensions: _B x _S x _D = 128 x 512 x 256 = 16,777,216 elements
# This exceeds the 256K element threshold for offloading (cpu_offload.py line 443).
# For quantized tensors, scale_inv tensors (~524K elements for block scaling) also exceed threshold.
_B = 128
_S = 512
_H = 4
_D = 256

Expand Down Expand Up @@ -395,6 +399,9 @@ def test_multiple_tensor_offload(self, recipe):
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
# Verify x1 is not corrupted after pushing (important for QuantizedTensor)
if recipe is not None:
x1.dequantize() # Should not raise - tensor should still be valid
offload_synchronizer.fwd_step()
# Only one copy of tensor on cpu is allocated.
assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)
Expand Down
72 changes: 58 additions & 14 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .quantized_tensor import (
restore_from_saved,
prepare_for_saving,
QuantizedTensor,
)


Expand Down Expand Up @@ -255,6 +256,8 @@ def start_offload(self):
Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream.
Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded.
This event is recorded in the start_offload or push_tensor call.

Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor).
"""
self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"])
self.state = "offload_started"
Expand All @@ -275,19 +278,18 @@ def start_offload(self):

with torch.cuda.stream(self.offload_stream):
if allocate_cpu_buffers:
# empty_like is defined also for QuantizedTensors
offloaded_tensor = torch.empty_like(
tensor, device=torch.device("cpu"), pin_memory=True
)
self.cpu_tensor_group.tensor_list.append(offloaded_tensor)
else:
assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, (
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
assert offloaded_tensor.shape == tensor.shape, (
"CPU buffer shape does not match the offloaded tensor shape:"
f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} "
" Make sure that tensor shaped do not change between"
f" {offloaded_tensor.shape} != {tensor.shape} "
"Make sure that tensor shapes do not change between"
" iterations if retain_pinned_cpu_buffers is True."
)
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
offloaded_tensor.copy_(tensor, non_blocking=True)

# aux is a dictionary that contains auxiliary data like information which tensors were deduplicated,
Expand Down Expand Up @@ -318,6 +320,9 @@ def start_reload(self):
"""
Start reloading of tensors.
It allocates new tensors on GPU and puts copy from CPU tasks on offload stream.

Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor
and reconstructed in pop_tensor).
"""
self._validate_state(func_name="start_reload", allowed_states=["offload_finished"])
self.state = "reload_started"
Expand All @@ -330,7 +335,6 @@ def start_reload(self):
# cannot move tensors from pool of one stream to another without
# calling cudaFree and cudaMalloc again.

# empty_like is defined also for QuantizedTensors.
reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda"))
self.offload_stream.wait_stream(torch.cuda.current_stream())

Expand All @@ -347,16 +351,29 @@ def start_reload(self):
self.bwd_gpu_tensor_group
)

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""
It is called when a tensor is saved for backward pass.

If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group.
If tensor is not offloaded, returns the tensor itself.
For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple.
"""
self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"])

if self._check_if_offload(tensor):
# For QuantizedTensor: decompose into component tensors, push each one recursively
if isinstance(tensor, QuantizedTensor):
# Make a copy because prepare_for_saving modifies the object (sets fields to None)
tensor_copy = tensor.detach()
# Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass,
# so the generic prepare_for_saving would not call tensor.prepare_for_saving()
saved_tensors, tensor_obj = tensor_copy.prepare_for_saving()
Comment on lines +366 to +371
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical Bug: QuantizedTensor.prepare_for_saving() may not exist

The code assumes that QuantizedTensor has a prepare_for_saving() method, but this method is only defined in the abstract base class QuantizedTensorStorage, not in QuantizedTensor itself.

Looking at quantized_tensor.py:

  • QuantizedTensorStorage (line 72-76) declares prepare_for_saving() as abstract
  • QuantizedTensor (line 328+) inherits from torch.Tensor, NOT from QuantizedTensorStorage

This will cause an AttributeError at runtime when a QuantizedTensor instance is offloaded, since QuantizedTensor doesn't have the prepare_for_saving() method.

The comment on line 369-370 mentions that "QuantizedTensor is a torch.Tensor subclass", but this is misleading - the real issue is that QuantizedTensor doesn't implement the required methods.

Suggested fix: Either:

  1. Have QuantizedTensor also inherit from QuantizedTensorStorage and implement prepare_for_saving(), OR
  2. Use the generic prepare_for_saving() function which properly handles torch.Tensor instances

push_results = [
self.push_tensor(t) if t is not None else None for t in saved_tensors
]
return (push_results, [tensor_obj])
Comment on lines +366 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical Bug: detach() creates shallow copy, causing data corruption

The comment says "Make a copy because prepare_for_saving modifies the object", but tensor.detach() only creates a shallow copy that shares the underlying storage with the original tensor.

Looking at how detach() is implemented (e.g., in Float8Tensor):

def detach(self) -> Float8Tensor:
    return Float8Tensor.make_like(self)

And make_like uses get_metadata() which returns references to the internal storage tensors:

def get_metadata(self) -> Dict[str, Any]:
    return {
        "rowwise_data": self._rowwise_data,  # Reference, not copy!
        ...
    }

Problem: When tensor_copy.prepare_for_saving() sets self._rowwise_data = None, it modifies the shared storage object, corrupting the original tensor that's still being used in the forward pass.

Impact: This will cause the forward pass to access None data, leading to crashes or incorrect results.

Solution: Either use clone() instead of detach() to create a deep copy, or restructure the code to avoid mutating the tensor during push_tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have test for that case, we want to have references.


self.fwd_gpu_tensor_group.tensor_list.append(tensor)
# The group is processed and offloaded at the end of the forward pass of current layer.
# To enable offloading of tensors faster we use self.offload_stream and record
Expand All @@ -370,23 +387,39 @@ def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
return len(self.fwd_gpu_tensor_group.tensor_list) - 1
return tensor

def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
def pop_tensor(
self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list]
) -> torch.Tensor:
"""
It is called when a tensor is used in backward pass.
Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish.
For QuantizedTensor (tuple input), reconstructs from component tensors.
"""
self._validate_state(
func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"]
)

# 1. tensor not offloaded
# 1. tensor not offloaded (regular tensor returned as-is from push)
if isinstance(tensor_or_tensor_id, torch.Tensor):
return tensor_or_tensor_id
# 2. the layer was not offloaded at all

# 2. QuantizedTensor case: tuple of (push_results, tensor_objs)
if isinstance(tensor_or_tensor_id, tuple):
push_results, tensor_objs = tensor_or_tensor_id
# Recursively pop each component
reloaded_tensors = [
self.pop_tensor(pr) if pr is not None else None for pr in push_results
]
# Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy
tensor_obj = tensor_objs[0]
tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_obj
Comment on lines +407 to +416
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential Issue: In-place modification during backward pass

The pop_tensor method for QuantizedTensor performs in-place restoration:

tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_obj

This modifies tensor_obj (which was created during push_tensor and stored in the tuple) in-place and returns it. However, PyTorch's autograd expects saved tensors to be immutable during backward.

Concern: If tensor_obj is the same object that was corrupted in push_tensor (due to the shallow copy bug), it may already be in an inconsistent state. Additionally, modifying saved objects in-place during backward could cause issues if PyTorch makes assumptions about tensor immutability.

Recommendation: Verify that this pattern is safe with PyTorch's autograd system, or consider creating a fresh tensor object rather than mutating the saved one.


# 3. Regular tensor index case
if self.state == "not_offloaded":
return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]

# 3. the layer was offloaded
# 4. the layer was offloaded
assert self.state == "reload_started"
# wait for the tensor to be reloaded
torch.cuda.current_stream().wait_event(
Expand All @@ -406,6 +439,10 @@ def _check_if_offload(self, t: torch.Tensor) -> bool:
"""
Check if tensor needs to be offloaded.
"""
# Only offload tensors with at least 256k elements (~1MB for float32)
if t.numel() < 256 * 1024:
return False

if (
not isinstance(t, torch.nn.Parameter)
and not getattr(t, "_TE_do_not_offload", False)
Expand All @@ -418,7 +455,6 @@ def _check_if_offload(self, t: torch.Tensor) -> bool:
" this tensor will be skipped."
)
return False

return True
return False

Expand Down Expand Up @@ -488,11 +524,13 @@ def bwd_step(self, layer_num: int):
self.previous_bwd_layer_id = layer_num
self.current_layer_id = layer_num

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""Default push tensor method"""
return self.layer_states[self.num_of_fwds].push_tensor(tensor)

def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor:
def pop_tensor(
self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list]
) -> torch.Tensor:
"""Default pop tensor method"""
return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id)

Expand Down Expand Up @@ -592,6 +630,12 @@ def bwd_step(self, layer_num: int):
for layer in self.start_reload_map[layer_num]:
self.layer_states[layer].start_reload()

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
Comment on lines +633 to +637
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type annotation in override

The push_tensor method in DefaultOffloadSynchronizer overrides the parent class method but has an inconsistent return type annotation.

Parent class (OffloadSynchronizer line 527):

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor:

This override (line 631):

def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:

The child class correctly returns the extended type (including tuple[list, list] for QuantizedTensor), but the parent class method should also have this return type to maintain type safety across polymorphic calls.

Recommendation: Update the parent class OffloadSynchronizer.push_tensor method signature to include tuple[list, list] in the return type annotation (line 527).



class ManualOffloadSynchronizer(OffloadSynchronizer):
"""
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list);
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation = false);

/***************************************************************************************************
* Bias gradient fusions
Expand Down
37 changes: 20 additions & 17 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list) {
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation) {
init_extension();

// Check number of tensors
Expand Down Expand Up @@ -1147,22 +1148,24 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 };
AllocationMethod allocation_method = AllocationMethod::UNFUSED;
QuantizationMethod quantization_method = QuantizationMethod::UNFUSED;
if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsMXFP8Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_MXFP8;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsNVFP4Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_NVFP4;
quantization_method = QuantizationMethod::FUSED_NVFP4;
if (!disable_bulk_allocation) {
if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsMXFP8Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_MXFP8;
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
[](const py::handle &quantizer) -> bool {
return detail::IsNVFP4Quantizers(quantizer.ptr());
})) {
allocation_method = AllocationMethod::BULK_NVFP4;
quantization_method = QuantizationMethod::FUSED_NVFP4;
}
}

// Allocate output tensors
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list"));
py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false);
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
Expand Down
7 changes: 6 additions & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ def forward(
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
# Disable bulk allocation when CPU offloading is active: offloading skips small
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def forward(
# weights if weights are externally touched outside this module
ctx.weight_object = weight

mark_not_offload(weight, weightmat, bias)
if cpu_offloading:
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.distributed._tensor import DTensor
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from .multi_tensor_apply import multi_tensor_applier


Expand Down Expand Up @@ -372,10 +373,12 @@ def _initialize_state(
store_param_remainders (bool): Store only trailing remainder bits.
"""
dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
Comment on lines 375 to +381
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantization of QuantizedTensor parameters here breaks DTensor sharding preservation. When param is a QuantizedTensor wrapping a DTensor, calling dequantize() creates a new plain tensor that loses the DTensor sharding metadata. This defeats the purpose of using .empty_like() to preserve DTensor sharding.

The fix should use the original parameter directly without dequantization, since .empty_like() respects the sharding annotations of the input tensor regardless of whether it's quantized:

Suggested change
dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param_for_empty, dtype=dtype)
data = torch.zeros_like(param, dtype=torch.int16)
...
data = torch.empty_like(param, dtype=dtype)

Alternatively, if dequantization is necessary for some reason, the sharding information from the original parameter should be explicitly preserved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WE have never supported quantized dtensors in the optimizer and we will not support it after my change.

Comment on lines +376 to +381
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fix for DTensor and QuantizedTensor compatibility

This change correctly addresses the issue where torch.empty(param.shape, ...) doesn't respect DTensor sharding (would create full tensors on each device instead of sharded ones).

The fix:

  1. Dequantizes QuantizedTensor first to get a standard tensor
  2. Uses torch.empty_like() which properly handles DTensor sharding
  3. Then applies the dtype override

This ensures both DTensor sharding and QuantizedTensor are handled correctly.

Verification needed: Confirm that dequantize() for QuantizedTensor doesn't trigger expensive computation if the tensor is already on the correct device, or if it's acceptable overhead for this initialization path.

if zero_buffer:
data.zero_()

Expand Down
14 changes: 0 additions & 14 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
_stride_from_shape,
)

_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)


class QuantizedTensorStorage:
r"""Base class for all TensorStorage classes.
Expand Down Expand Up @@ -539,15 +534,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg

args = tree_map(check_if_cpu, args)

# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)

Expand Down
Loading