Skip to content

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974

Open
vthumbe1503 wants to merge 16 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix
Open

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
vthumbe1503 wants to merge 16 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 11, 2026

Description

Fixes DCP Sync checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes NVFP4 allgather + dequant numerical errors for fsdp2. Turns out this was due to us not setting the fsdp group as the amax reduction group in the quantizer

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Untyped_storage implementation needed for FSDP2 + DCP

    • untyped_storage is now defined for the base QuantizedTensor to return empty storage. Untyped_storage refers to the backing storage that we use to create all the internal tensors. Since we use make_wrapper_subclass to create TE QuantizedTensors, we use dont have any backing storage associated with the tensor. data_ptr on our Custom QuantizedTensor also returns 0.
    • The main issue is that FSDP2 maintains sharded param tensor for checkpointing. It does so by calling view(-1) on our Quantized sharded model parameters. We return back a dequantized 1D tensor in TE. So, the sharded tensor that FSDP2 maintains for checkpointing is BF16 and Quantized sharded param is our custom FP8 tensor. It evaluates untyped_storage(BF16 sharded tensor reloaded from disk) == untyped_storage(Quantized sharded parameter) to see if the same_tensor. With us returning empty storage now, this would never be equal to sharded tensor's untyped storage.
  • DCP Aync/Sync Checkpoint loading

    • For Sync cases previously we were going through the route of dequantization to BF16 before saving to disk, which happened through the to_new_empty function
    • For both syn/async, dequantizing is not ideal. And so we now have .cpu() and .to() implemented for QuantizedTensor which dont go through dequantization and rather just copy inner tensors of QuantizedTensor to cpu if needed in blocking/non-blocking way.
  • NVFP4 Allgather Correctness issues

    • Allgather with FSDP2 was very far away from fp32 allgather for the same values. This was due to us not setting the amax reduction group in the quantizer.
  • TE_DType Serialization issues with DCP Checkpointing

    • DCP uses torch.load(weights_only=True), whose Unpickler rejects every GLOBAL reference that isn't in add_safe_globals — and getattr is intentionally not allow-listed.
    • So we override the default enum reduction in pybind:
default:      (getattr, (tex.DType, "kFloat8E4M3"))   # needs getattr + tex.DType allow-listed
pybind override: (tex.DType, (int_value,))            # only needs tex.DType allow-listed

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes DCP (Distributed Checkpoint) sync and async checkpoint loading for quantized tensor types (MXFP8, NVFP4, Float8Blockwise) under FSDP2 with QuantizedModelInit, and corrects a numerical error in NVFP4 all-gather under FSDP2.

  • untyped_storage + _to_copy overrides on QuantizedTensor: FSDP2's checkpoint staging calls untyped_storage() for identity checks and uses _to_copy to move tensors to CPU; both now work correctly — the former returns an empty storage (avoiding false-positive same-tensor matches), the latter moves all inner buffers (data, scales, amax) without dequantizing.
  • Module-level reconstruct functions + pybind DType override: Each tensor type's __reduce_ex__ now references a module-level function instead of a bound classmethod, removing the need for getattr in PyTorch's weights_only=True safe-globals list; tex.DType is overridden to serialize as (tex.DType, (int,)) so only the class itself need be allow-listed.
  • NVFP4 amax_reduction_group fix: NVFP4Quantizer is now included in the FSDP2 DTensor path that sets the shard process group as the amax reduction group, fixing the all-gather numerical errors reported for FP4 training.

Confidence Score: 5/5

The changes are well-scoped bug fixes for DCP checkpoint loading; the _to_copy path, untyped_storage override, reconstruct functions, and amax_reduction_group extension are all targeted and low-risk.

All changes address concrete, previously-failing DCP checkpointing paths. The current safe-globals list contains only concrete classes and module-level functions. The only findings are documentation clarity issues that do not affect runtime behavior.

The backward-compat _make_in_reduce_ex classmethods in float8_tensor.py, mxfp8_tensor.py, nvfp4_tensor.py, and float8_blockwise_tensor.py carry comments that overstate their utility for weights_only=True loading.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds untyped_storage() returning empty storage and a _to_copy dispatch handler that moves all inner tensor buffers to the target device while preserving the QuantizedTensor subclass. Also fixes make_like to propagate device.
transformer_engine/pytorch/init.py Adds add_safe_globals registration for all QuantizedTensor types and their module-level reconstruct functions; getattr is correctly not included after the previous security fix.
transformer_engine/pytorch/module/base.py Extends the amax_reduction_group assignment to include NVFP4Quantizer alongside Float8CurrentScalingQuantizer, fixing NVFP4 allgather numerical errors under FSDP2.
transformer_engine/pytorch/tensor/float8_tensor.py Adds reduce_ex that always serializes raw FP8 buffers (no CPU dequantization fallback), moves _make_in_reduce_ex to a module-level function for DCP weights_only=True compat, and propagates device to all tensor construction sites.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Same CPU-bounce pattern as mxfp8_tensor_storage; has a truncated inline comment on the guard condition.
transformer_engine/common/util/pybind_helper.h Adds reduce and reduce_ex overrides to the DType pybind11 enum so it serializes as (tex.DType, (int,)) rather than (getattr, (tex.DType, name)), eliminating the need for getattr in safe globals.

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant QT as QuantizedTensor
    participant Storage as TensorStorage
    participant DCP
    participant tex

    Note over FSDP2,DCP: DCP Async Checkpoint Staging
    FSDP2->>QT: untyped_storage()
    QT-->>FSDP2: UntypedStorage(0 bytes)

    FSDP2->>QT: ".to(device=cpu) / .cpu()"
    QT->>QT: __torch_dispatch__(_to_copy)
    QT->>Storage: get_metadata()
    Storage-->>QT: rowwise_data, scale_inv, amax, ...
    QT->>QT: move each tensor buffer to CPU
    QT-->>FSDP2: CPU-resident QuantizedTensor (subclass preserved)

    DCP->>QT: __reduce_ex__(protocol)
    QT-->>DCP: (_make_X_tensor_in_reduce_ex, (data, scale_inv, fp8_dtype, ...))
    Note over DCP: weights_only=True safe

    Note over FSDP2,tex: NVFP4 All-Gather Fix
    FSDP2->>QT: quantize weight shard
    QT->>Storage: "quantizer.amax_reduction_group = shard_pg"
    Storage->>tex: all_reduce amax across shard group
    tex-->>FSDP2: correctly scaled NVFP4 all-gathered param
Loading

Reviews (9): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior is unchanged with the change. And I would argue the implementation now is more correct with the change. untyped_storage() default implementation from QuantizedTensor(torch.Tensor) before this change, gives a storage with two properties.

  1. storage.nbytes() returns bytes based on the fake_dtype that we use to register our QuantizedTensor as a torchTensor using make_wrapper_subclass method of torch.

  2. storage.data_ptr() gives an error saying it is an invalid storage and there is no data_ptr()

Both of them is not ideal.
The first one is grossly incrorrect due to two reasons. First we manage the backing storage for the inner tensors of QuantizedTensor and torch has no idea about it. Second nbytes based on fake_dtype is misleading since that might not actually be the number of bytes we actually allocate.
Second one is causing problems with FSDP2 now since it expects some storage for identity check.

For QuantizedTensor, noop_cat today always returns an actual torch.cat which goes through a dequantization luckily due to this condition being true. This condition is going to be true now with the change as well since nbytes() would return 0.

If we do QuantizedTensor.data_ptr() today it gives you 0. QuantizedTensor.untyped_storage().data_ptr() will give invalid storage error which is inconsistent. And giving empty storage as empty storage will fix this inconsitency.

As far as idenity checking goes, FSDP2 does all the comparisong logic only if data_ptr() is not 0. And it also doesnt really make sense to compare two empty storages.

Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Outdated
Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Outdated
Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py
@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need dequant + quant here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doing it anymore

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:

  • It's uncomfortable how new_empty and empty_like would have different behavior. I suppose we could interpret empty_like as "make a tensor that matches the input" and new_empty as "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows.
  • Would this affect FSDP or CPU offloading?
  • Given the weirdness, would it be worthwhile raising a warning if new_empty is called outside of DCP?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it was ugly. Now we have a new solution where we have implemented the to_copy function in torch dispatch. This allows for staging the inner tensors of QuantizedTensor on CPU in a blocking/non-blocking way for sync/async DCP checkpointing.

We only do this in to_copy if dtype is unchanged. Otherwise we still go through the dequantize route.

# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).

>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
Suggested change
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
target_size = size

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the torch dispatch function now. So we dont have size here

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Outdated
Comment thread transformer_engine/pytorch/quantized_tensor.py Outdated
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 force-pushed the fsdp2_dcp_laod_fix branch from 3589ffa to 4197bee Compare May 13, 2026 04:00
@vthumbe1503 vthumbe1503 requested a review from ksivaman as a code owner May 13, 2026 04:00
pre-commit-ci Bot and others added 2 commits May 13, 2026 04:01
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/__init__.py Outdated
vthumbe1503 and others added 6 commits May 13, 2026 04:19
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 added the bug Something isn't working label May 18, 2026
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

timmoon10
timmoon10 previously approved these changes May 19, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py Outdated
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants