Skip to content

Generalized Tensor Parallelism (GTP) #3005

Open
fanshiqing wants to merge 8 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release
Open

Generalized Tensor Parallelism (GTP) #3005
fanshiqing wants to merge 8 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release

Conversation

@fanshiqing
Copy link
Copy Markdown
Member

@fanshiqing fanshiqing commented May 18, 2026

Deisgn doc: GTP.docx

Description

Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.

Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.

Summary of features

  1. Fine-grained materialization & gradient reduction
  • Weight, gradient, and optimizer states are sharded along the GTP group.
  • Weights are temporarily materialized through prefetching in both the forward and backward passes.
  1. Composability with TP / SP / EP / DDP with efficient overlapping of computation and communication
  • GEMM + TP/EP communication + GTP communication + DDP communication.
  1. GTP + partial Cudagraphs with fine-grained synchronization across graphs
  • GTP reduce-scatter overlapping across graphs.
  1. Low-Precision quantize-then-gather
  • MXFP8 / NVFP4
  • Auto-padding/stripping to satisfy low-precision alignment requirements.
  1. Parallel folding for MoE layer
  • Support configuring the GTP size for dense layers and MoE layers separately.
  1. Distributed checkpointing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

# Type File Role
1 NEW transformer_engine/pytorch/module/generalized_tensor_parallelism.py Core GTP module: GTPShardedParam, GTPWeightCache, _TicketSlot, GTPShardHandle, GTPWeightState, GTPChain + _chain_state (GRAPHED / UNGRAPHED prefetch chains), GTPConfig, AG/RS stream helpers keyed by (chain_id, group), wrap_module_params_gtp, classify_gtp_chains, tag_gtp_params_with_names, update_config, reallocate_gtp_cache_to_mempool, wait_async_comms, NVFP4 coalesced-amax fast path
2 NEW transformer_engine/common/recipe/multi_amax.cu NVFP4 multi-tensor amax CUDA kernel (compute_multi_amax_nvfp4) fusing N per-expert (zero_amax + amax + D2D) chains into one launch — enables the coalesced-amax path for grouped experts
3 NEW tests/pytorch/distributed/test_gtp.py 61 unit tests: weight-state machine, buffer cache, weight sharding, module-param replacement, Linear/LayerNormLinear/GroupedLinear fwd/bwd correctness (BF16 + NVFP4 + MXFP8, aligned + unaligned), prefetch chain, wgrad reduce-scatter, microbatches, fp8-param-gather, DDP grad-accum hook
4 NEW tests/pytorch/distributed/test_tp_gtp.py 7 TP+GTP integration tests on 4 GPUs (TP=2, GTP=2): process group layout, Linear (column / row parallel) weight shape and fwd/bwd correctness, LayerNormLinear smoke test
5 MOD transformer_engine/pytorch/module/linear.py te.Linear accepts gtp_group kwarg; fwd/bwd dispatch through GTPShardedParam.all_gather_and_prefetch{,_bwd} and wgrad_reduce_scatter
6 MOD transformer_engine/pytorch/module/layernorm_linear.py Same GTP integration for te.LayerNormLinear
7 MOD transformer_engine/pytorch/module/grouped_linear.py te.GroupedLinear integration with batched AG/RS for routed experts
8 MOD transformer_engine/pytorch/module/base.py Reset-parameters hook: slice + quantize order coordinated with GTP shards; _gtp_sharded_weight_names tracking
9 MOD transformer_engine/pytorch/distributed.py gather_along_first_dim, _all_gather_nvfp4 (rowwise + columnwise + interleaved layout), _NVFP4AllGatherAsyncHandle
10 MOD transformer_engine/pytorch/csrc/quantizer.cpp set_usage(rowwise, columnwise) toggles for GTP fwd/bwd asymmetry; make_empty(..., dtype=BF16) for padded GTP buffers
11 MOD transformer_engine/pytorch/csrc/extensions/cast.cpp Bindings for compute_multi_amax_nvfp4 + quantize_cast_only_nvfp4 (NVFP4 cast without recomputing amax — the GTP fast-path)
12 MOD transformer_engine/pytorch/csrc/extensions/pybind.cpp Register new amax / cast-only bindings
13 MOD transformer_engine/pytorch/csrc/extensions.h Declarations for the new bindings
14 MOD transformer_engine/pytorch/csrc/common.h Helper declarations
15 MOD transformer_engine/common/include/transformer_engine/recipe.h C ABI declaration for nvte_compute_multi_amax_nvfp4
16 MOD transformer_engine/common/CMakeLists.txt Register multi_amax.cu in the build

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 18, 2026

Greptile Summary

This PR introduces Generalized Tensor Parallelism (GTP), a large new subsystem that shards weight parameters 1/N across a GTP process group, materializes them on demand via async all-gather with a prefetch chain, and accumulates gradients via async reduce-scatter — composable with existing TP/SP/EP/DDP and CUDA graph capture.

  • Core module (generalized_tensor_parallelism.py): GTPShardedParam, ticket-based GTPWeightCache, per-chain prefetch linked lists, AG/RS stream management, NVFP4 coalesced-amax fast path, and wait_async_comms for cross-graph boundary sync.
  • Distributed layer (distributed.py): gather_along_first_dim and _all_gather_nvfp4/mxfp8 extended with output_tensor pre-allocation and grouped mode; all address-rebinding = assignments replaced with .copy_() for CUDA-graph pointer stability.
  • CUDA kernel (multi_amax.cu): fused multi-tensor amax kernel that collapses N per-expert zero+amax+D2D chains into two launches, enabling the coalesced allreduce path for grouped experts.

Confidence Score: 4/5

The functional correctness of the core GTP forward/backward path is solid for single-model use; the issues found are design and robustness concerns rather than data-corrupting bugs in the primary training loop.

The _chain_state class-level dict causes cross-contamination between model instances in the same process, silently mis-wiring new chain heads and forcing async RS where sync RS is expected. This does not corrupt training results for the standard single-model-per-process deployment but will cause confusing failures in test isolation and multi-model scenarios.

transformer_engine/pytorch/module/generalized_tensor_parallelism.py — specifically the class-level _chain_state dict and the bare assert False in reallocate_to_mempool.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/generalized_tensor_parallelism.py New 1790-line core module; class-level _chain_state dict is shared across model instances and module-level globals have no reset path.
transformer_engine/pytorch/distributed.py Added output_tensor/grouped params to gather functions; in-place .copy_() for CUDA-graph stability; dead handle.output = out assignment when columnwise_usage is False and async_op is True.
transformer_engine/common/recipe/multi_amax.cu New CUDA kernel fusing N per-expert zero_amax + amax + D2D chains; alignment and batching logic looks correct.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds compute_amax_nvfp4, quantize_cast_only_nvfp4, and compute_multi_amax_nvfp4 bindings.
transformer_engine/pytorch/csrc/quantizer.cpp Adds skip_amax_reduction flag and new split-phase compute_amax_only/quantize_cast_only methods; well-isolated.
transformer_engine/pytorch/module/linear.py GTP dispatch added via gtp_size field; weight gather, wgrad RS, and main_grad handling correctly gated.
transformer_engine/pytorch/module/layernorm_linear.py Same GTP dispatch pattern as linear.py with wgrad_before_dgrad opt-in.
transformer_engine/pytorch/module/grouped_linear.py Adds batched AG/RS path for routed experts.
transformer_engine/pytorch/module/base.py Hooks GTP slicing into reset_parameters before FP8 quantize; clean implementation.

Sequence Diagram

sequenceDiagram
    participant Fwd as Forward Pass
    participant AG as AG Stream
    participant NCCL as NCCL (GTP group)
    participant Compute as Compute Stream
    participant RS as RS Stream
    participant Bwd as Backward Pass

    Fwd->>AG: async prefetch next_w (all_gather_and_prefetch)
    AG->>NCCL: all_gather_into_tensor (weight shard to full weight)
    Fwd->>Compute: GEMM with current gathered weight
    NCCL-->>AG: AG complete, ag_event.record()
    Compute-->>Fwd: output activations
    Fwd->>AG: next layer: wait ag_event, use prefetched weight
    Bwd->>AG: async prefetch prev_w (all_gather_and_prefetch_bwd)
    Bwd->>Compute: wgrad GEMM
    Bwd->>RS: async wgrad_reduce_scatter
    RS->>NCCL: reduce_scatter_tensor
    NCCL-->>RS: RS complete, rs_event.record()
    Bwd->>Compute: dgrad GEMM (overlaps with RS)
    RS-->>Bwd: cascade: rs_event.wait + main_grad.add_
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
fanshiqing and others added 5 commits May 18, 2026 00:50
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
fanshiqing and others added 2 commits May 18, 2026 02:25
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
@fanshiqing
Copy link
Copy Markdown
Member Author

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants