Skip to content

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713

Open
cspades wants to merge 6 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp
Open

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
cspades wants to merge 6 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Feb 26, 2026

Summary

  • Support (H/F)SDP2 x TP strided sharding, and DTensor FP8 parameters for Torch DCP checkpointing, across all TransformerEngineBaseModule(s).
    • Except GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules under transformer_engine.pytorch.modules are supported.
    • FusibleOperation support is also a WIP, except for LayerNorm or RMSNorm which are TE modules.
  • Associated with BioNeMo-Recipes Llama3 TP: Enable TransformerEngine-backed Tensor Parallelism with Llama3. bionemo-framework#1483
    • Notably, TransformerEngine TP can be easily mixed with DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we use DTensor-based TP on the torch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to the torch.nn.Embedding, which is why we do not need to call set_device_mesh for the LM head!
  • Credit to @pstjohn for coming up with this idea!

Usage / Documentation

(tp_mesh and weight_mesh can also be passed in TEModule.__init__.)

    def set_device_mesh(
        self,
        tp_mesh: Optional[DeviceMesh] = None,
        weight_mesh: Optional[DeviceMesh] = None,
    ) -> None:
        """
        Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
        depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

        TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
        integration with Torch DCP checkpointing. This method should only be invoked when
        using DTensor parameters, e.g. when using FSDP2 or DCP.

        When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
        convert them into FSDP-TP strided or non-strided shards depending on the current
        sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
        matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
        This experimental FSDP-TP logic presides in this FSDP2 initialization function:
        ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

        Parameters
        ----------
        tp_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
            Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
        weight_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
            when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
            parameters and if the DTensor DeviceMesh includes dimensions that do not
            shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
            For example:
                - device_mesh["dp"] for FSDP.
                - device_mesh["dp_cp"] if using CP ranks in FSDP.
                - device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
                - device_mesh["tp"] if using TP.
                - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
        """

Details

DTensor Lifecycle in TransformerEngine

  • Initialization
    • __init__
      • TransformerEngine model parameters are initialized either on device or meta device with the appropriate tp_size and TP sharding strategy, e.g. parallel_mode and sequence_parallel.
    • TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)
      • Converts parameters to DTensor with appropriate TP placement(s) based on the TP sharding strategy specified in __init__, using transformer_engine.pytorch.distributed._convert_param_to_dtensor_param.
        • tp_mesh is a 1-D DeviceMesh containing the TP ProcessGroup that will be registered with the TransformerEngine module.
        • weight_mesh is the 1-D DeviceMesh containing the ProcessGroup that shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes like Float8CurrentScaling.
      • Needs to be invoked prior to fully_shard (which responds to the TP placements) and prior to reset_parameters(defer_init=False), which quantizes parameters.
      • Can also be directly invoked during __init__(tp_mesh, weight_mesh) for supported TransformerEngine modules.
    • fully_shard shards the TransformerEngine model with FSDP2.
      • When fully_shard encounters TP sharding on dim=0, it will use a _StridedShard for DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in the DeviceMesh and DTensor.placements. (See Appendix for visualization of this sharding strategy.)
    • reset_parameters is called if using meta device initialization.
  • Training
    • Pre-forward, FSDP2 all-gathers the sharded DTensor "main" weight that it registered during fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such as FusedAdam must be used to properly handle high-precision main weights.)
      • When using FSDP2 x TP, the all-gathered Tensor is actually a TP-sharded DTensor, which deviates from the original FSDP2 paradigm where the all-gathered Tensor is fully-unsharded and the DTensor wrapping is discarded. To support these DTensor compute weights in TransformerEngine modules, we utilize transformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensor to localize the DTensor and also inherit requires_grad attribute from the DTensor parameter as the local Tensor has this un-set during DTensor.from_local(Tensor) for FP8 parameters specifically!
    • Post-backward, the Tensor gradient is converted and attached to the DTensor.grad attribute.
      • NOTE(@cspades, @vthumbe1503): For some reason, FusibleOperation (RMSNorm and LayerNorm) require casting the gradient from Tensor to a DTensor matching the configuration of the DTensor weights. I have confirmed the gradient is installed correctly on RMSNorm weights (same shape and sharding configuration as the sharded optimizer state), and it will not affect normal TransfomerEngine operations, but it is not totally clear why this is necessary with FSDP2 x TP.

Bugs

  • Fix bug where "shard" was the presumed weight sharding sub-mesh in the DTensor.device_mesh. Now, users can precisely specify their own custom weight-sharding DeviceMesh for per-tensor amax_reduction_group via the set_device_mesh(weight_mesh) API.
  • TransformerEngineBaseModule: self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}

Testing

# TransformerEngine Main
[Rank 0] (after 1 iterations) memory (MB) | allocated: 23511.65 | max allocated: 25189.68 | reserved: 25678.00 | max reserved: 25678.00
 [2026-03-02 09:55:17.189564] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12715.7 | throughput per GPU (TFLOP/s/GPU): 530.6 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124915E+00 | loss scale: 1.0 | grad norm: 5.474 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:55:29.768521] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12578.7 | throughput per GPU (TFLOP/s/GPU): 536.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.143806E+00 | loss scale: 1.0 | grad norm: 5.366 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Post-DCP Modifications (This PR)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23511.65 | max allocated: 29783.24 | reserved: 25678.00 | max reserved: 31510.00
 [2026-03-02 09:29:36.550070] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12556.5 | throughput per GPU (TFLOP/s/GPU): 537.3 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124463E+00 | loss scale: 1.0 | grad norm: 5.471 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:29:49.216068] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12665.7 | throughput per GPU (TFLOP/s/GPU): 532.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.142863E+00 | loss scale: 1.0 | grad norm: 5.355 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • NOTE(@cspades): DelayedScaling has DCP save/load disparity issues, i.e. on the scale of +/-1 to the uint8 parameter checkpoint!

Appendix

_StridedShard - Using FSDP2 x TP Strided-Sharding

# (DP=4, TP=2)
(_StridedShard(dim=0, sf=2), Shard(dim=0))

┌───┬───┐
│ 0 │ 4 │ ← DP=0
├───┼───┤
│ 1 │ 5 │ ← DP=1
├───┼───┤          FSDP all-gather happens across the DP ranks,
│ 2 │ 6 │ ← DP=2   so we need to form the 0-3 and 4-7 TP shards!
├───┼───┤
│ 3 │ 7 │ ← DP=3
└───┴───┘
  ↑   ↑
TP=0 TP=1

When redistribute'ing a global DTensor to (_StridedShard(dim=0, sf=2), Shard(dim=0)), DTensor will perform the following steps:

  • Pre-shard the Tensor data with respect to the stride / shard factor, which is defined as the product of the parallelism sizes of all Shard placements to the right of _StridedShard. (In the above example, since TP=2, the factor is 2.)
    • [0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].
    • In the context of this PR and fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling _convert_param_to_dtensor_param!
  • Shard the pre-shards for _StridedShard.
    • [0] [1] [2] [3] and [4] [5] [6] [7]
  • Concatenate the strided shards.
    • [0 4] [1 5] [2 6] [3 7], which are assigned to the _StridedShard ranks.
    • Note that this is very different if we did left-to-right-sharding, which would have given us [0 1] [2 3] [4 5] [6 7]!
  • Subsequently / finally, each strided shard is sharded on the Shard placement.
    • [0] [4] / [1] [5] / [2] [6] / [3] [7], which are assigned to the Shard ranks.
    • Note that this is very different if we did left-to-right sharding, which would have given us [0] [1] / [2] [3] / [4] [5] / [6] [7]!

PyTorch also supports the inverse / un-sharding of this redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds Torch DCP (Distributed Checkpoint) compatibility for FSDP2 × TP strided sharding across all TransformerEngineBaseModules by converting module parameters to DTensor via a new set_device_mesh(tp_mesh, weight_mesh) API, and teaching every forward/backward pass to unwrap those DTensors back to plain Tensors before calling TE's C++ kernels. It also introduces a full DCP save/load round-trip test and fixes a pre-existing bug in _LayerNormMLP backward (ctx.fc1_weight_quantizerctx.fc1_weight in the QuantizedTensorStorage isinstance check) along with changing self.quantizers from a dict-of-dicts to a dict-of-lists.

Key changes:

  • distributed.py: Two new helpers — _convert_param_to_dtensor_param (wraps a Parameter as a DTensor while preserving custom attributes) and _extract_trainable_tensor_from_dtensor (localises and re-enables requires_grad cleared by FSDP2 for FP8 params).
  • Module set_device_mesh implementations: Added to Linear, LayerNormLinear, LayerNormMLP, LayerNorm, RMSNorm, DotProductAttention, MultiheadAttention, and TransformerLayer. Each applies the correct Shard/Replicate DTensor placement based on the TP parallelism mode (column → Shard(0), row → Shard(1), norms → Replicate).
  • FP8/MXFP8 tensors: Guard added to unwrap out from a DTensor using ._local_tensor (.to_local() is incompatible with Torch Dispatch for quantised tensors) before the all-gather output assignment.
  • Tests: AppState (Stateful DCP wrapper), full save/load parity validation for model weights and optimizer state, and FSDP-TP parametrisation in test_torch_fsdp2.py.
  • test_fp8_fsdp2_allgather: The function was not updated for the TP case — the manual all-gather uses the TP group (the only group visible on the TP device_mesh) while the FSDP all-gather operates over the dp_shard group, producing mismatched shapes/values in the parity assertion for TP+fp8_init=True configurations.

Confidence Score: 3/5

  • The core production code changes (DTensor conversion, kernel unwrapping, FP8 all-gather guard) are sound and well-tested via Megatron parity runs; the main risk lies in the test harness, not the feature itself.
  • The DTensor lifecycle design is correct and the Megatron-LM parity tests show near-identical loss curves. However, test_fp8_fsdp2_allgather appears to have a shape-mismatch bug for the new FSDP-TP parametrised case when fp8_init=True, which would cause CI failures without a guard or fix. Additionally, the pre-existing args.sharding_dims null-dereference in _parse_args (flagged in a previous thread but still present) means the test binary crashes when --sharding-dims is omitted. These are test-layer issues rather than runtime regressions for real workloads, which keeps the score at 3 rather than lower.
  • tests/pytorch/distributed/run_fsdp2_model.pytest_fp8_fsdp2_allgather logic for TP and the unguarded len(args.sharding_dims) in _parse_args.

Important Files Changed

Filename Overview
tests/pytorch/distributed/run_fsdp2_model.py Major DCP checkpoint testing infrastructure added (AppState, save/load round-trip, state parity validation). test_fp8_fsdp2_allgather has a shape-mismatch bug for FSDP-TP configs (wrong all-gather group). _parse_args still has an unguarded len(args.sharding_dims) call when sharding_dims is None (previously flagged).
transformer_engine/pytorch/distributed.py Adds two new helper functions: _convert_param_to_dtensor_param (wraps a Parameter as a DTensor, preserving custom attributes) and _extract_trainable_tensor_from_dtensor (localises a DTensor and re-enables requires_grad that FSDP2 clears for FP8 params). Logic is sound; attribute inheritance loop is a solid way to carry over Megatron/TE-specific metadata.
transformer_engine/pytorch/module/base.py Fixes quantizers initialisation from {} to [] (corrects integer-indexed access), adds DTensor localization of the input before TE C++ kernels, and improves the reset_parameters DTensor recreation to preserve parameter attributes via _convert_param_to_dtensor_param. Also fixes the fallback amax_reduction_group to use the full device_mesh group when no explicit group is set.
transformer_engine/pytorch/module/layernorm_mlp.py Adds set_device_mesh for FC1 (column-parallel Shard(0)), FC2 weight (row-parallel Shard(1)), FC2 bias and LayerNorm (Replicate). Fixes unrelated pre-existing bug where isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) was always False (now correctly checks ctx.fc1_weight). Bias DTensor conversion when use_bias=False acknowledged in previous threads.
transformer_engine/pytorch/tensor/float8_tensor.py Adds a guard to unwrap out from DTensor using ._local_tensor (not .to_local()) before the FP8 all-gather output assignment. Comment correctly explains that .to_local() fails under Torch Dispatch for quantized tensors with _transpose usage.

Sequence Diagram

sequenceDiagram
    participant User
    participant TEModule
    participant set_device_mesh
    participant _convert_param
    participant fully_shard
    participant reset_parameters
    participant forward

    User->>TEModule: __init__(tp_mesh, weight_mesh)
    TEModule->>set_device_mesh: set_device_mesh(tp_mesh, weight_mesh)
    set_device_mesh->>_convert_param: _convert_param_to_dtensor_param(weight, tp_mesh, Shard/Replicate)
    _convert_param-->>set_device_mesh: nn.Parameter(DTensor)
    set_device_mesh->>set_device_mesh: set amax_reduction_group via weight_mesh
    set_device_mesh-->>TEModule: DTensor parameters set

    TEModule->>reset_parameters: reset_parameters(defer_init=False)
    reset_parameters->>reset_parameters: reinitialise + quantise (handling DTensors)
    reset_parameters->>reset_parameters: _set_tensor_parallel_attributes()

    User->>fully_shard: fully_shard(model, mesh[dp_shard])
    fully_shard->>fully_shard: detects DTensor Shard(0) → _StridedShard(dim=0, sf=tp_size)
    fully_shard-->>User: FSDP2-TP sharded model

    User->>forward: forward(input)
    forward->>forward: FSDP2 all-gather (restores TP-sharded DTensor)
    forward->>forward: _extract_trainable_tensor_from_dtensor(weight)
    forward->>forward: TE C++ GEMM kernel (plain Tensor)
    forward-->>User: output

    User->>User: DCP save(AppState)
    User->>User: AppState.state_dict() evicts _extra_state
    User->>User: DCP load(AppState) with strict=False
Loading

Comments Outside Diff (1)

  1. tests/pytorch/distributed/run_fsdp2_model.py, line 253-275 (link)

    test_fp8_fsdp2_allgather shape mismatch for FSDP-TP configurations

    When tp_size > 1, TP-sharded parameters carry a 1-D TP device_mesh. In the manual all-gather section (before module.unshard()), dist_group = device_mesh.get_group() therefore resolves to the TP group, not the FSDP dp_shard group.

    Consider an example of DP=4, TP=2 with a weight of 8 rows. The strided-sharded FSDP local slice on rank (dp=0, tp=0) is 1 row ([0]). Gathering over the TP group (size=2) yields 2 rows ([0, 4]) stored in fp32_allgathered_params[name].

    After module.unshard(), the FSDP2 all-gather restores the full TP-local shard (4 rows) on that rank. The comparison then becomes:

    param._local_tensor.dequantize()  →  shape [4]
    fp32_allgathered_params[name]     →  shape [2]
    

    torch.testing.assert_close will raise a shape-mismatch error for every TP-sharded FP8 parameter when fp8_init=True and TP is active.

    A simple guard that skips this test for TP configurations (or rewrites the manual all-gather to operate over the dp_shard group using the known args.mesh["dp_shard"]) is needed to prevent spurious failures on the sharding_dims=[..., tp_size] parametrised cases.

Last reviewed commit: ec5effe

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from 4ec2947 to dbb9d14 Compare March 4, 2026 18:10
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from fcdd5bd to c912f5b Compare March 5, 2026 16:06
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from c912f5b to 2aadb35 Compare March 5, 2026 18:30
cspades and others added 4 commits March 5, 2026 15:50
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 3 times, most recently from a7a17c2 to bc82f02 Compare March 6, 2026 17:02
@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants