Generalized Tensor Parallelism (GTP) #3005
Conversation
Greptile SummaryThis PR introduces Generalized Tensor Parallelism (GTP), a large new subsystem that shards weight parameters 1/N across a GTP process group, materializes them on demand via async all-gather with a prefetch chain, and accumulates gradients via async reduce-scatter — composable with existing TP/SP/EP/DDP and CUDA graph capture.
Confidence Score: 4/5The functional correctness of the core GTP forward/backward path is solid for single-model use; the issues found are design and robustness concerns rather than data-corrupting bugs in the primary training loop. The transformer_engine/pytorch/module/generalized_tensor_parallelism.py — specifically the class-level Important Files Changed
Sequence DiagramsequenceDiagram
participant Fwd as Forward Pass
participant AG as AG Stream
participant NCCL as NCCL (GTP group)
participant Compute as Compute Stream
participant RS as RS Stream
participant Bwd as Backward Pass
Fwd->>AG: async prefetch next_w (all_gather_and_prefetch)
AG->>NCCL: all_gather_into_tensor (weight shard to full weight)
Fwd->>Compute: GEMM with current gathered weight
NCCL-->>AG: AG complete, ag_event.record()
Compute-->>Fwd: output activations
Fwd->>AG: next layer: wait ag_event, use prefetched weight
Bwd->>AG: async prefetch prev_w (all_gather_and_prefetch_bwd)
Bwd->>Compute: wgrad GEMM
Bwd->>RS: async wgrad_reduce_scatter
RS->>NCCL: reduce_scatter_tensor
NCCL-->>RS: RS complete, rs_event.record()
Bwd->>Compute: dgrad GEMM (overlaps with RS)
RS-->>Bwd: cascade: rs_event.wait + main_grad.add_
Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com> Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Deisgn doc: GTP.docx
Description
Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.
Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.
Summary of features
Type of change
Changes
Please list the changes introduced in this PR:
transformer_engine/pytorch/module/generalized_tensor_parallelism.pyGTPShardedParam,GTPWeightCache,_TicketSlot,GTPShardHandle,GTPWeightState,GTPChain+_chain_state(GRAPHED / UNGRAPHED prefetch chains),GTPConfig, AG/RS stream helpers keyed by(chain_id, group),wrap_module_params_gtp,classify_gtp_chains,tag_gtp_params_with_names,update_config,reallocate_gtp_cache_to_mempool,wait_async_comms, NVFP4 coalesced-amax fast pathtransformer_engine/common/recipe/multi_amax.cucompute_multi_amax_nvfp4) fusing N per-expert (zero_amax + amax + D2D) chains into one launch — enables the coalesced-amax path for grouped expertstests/pytorch/distributed/test_gtp.pyLinear/LayerNormLinear/GroupedLinearfwd/bwd correctness (BF16 + NVFP4 + MXFP8, aligned + unaligned), prefetch chain, wgrad reduce-scatter, microbatches, fp8-param-gather, DDP grad-accum hooktests/pytorch/distributed/test_tp_gtp.pyLinear(column / row parallel) weight shape and fwd/bwd correctness,LayerNormLinearsmoke testtransformer_engine/pytorch/module/linear.pyte.Linearacceptsgtp_groupkwarg; fwd/bwd dispatch throughGTPShardedParam.all_gather_and_prefetch{,_bwd}andwgrad_reduce_scattertransformer_engine/pytorch/module/layernorm_linear.pyte.LayerNormLineartransformer_engine/pytorch/module/grouped_linear.pyte.GroupedLinearintegration with batched AG/RS for routed expertstransformer_engine/pytorch/module/base.py_gtp_sharded_weight_namestrackingtransformer_engine/pytorch/distributed.pygather_along_first_dim,_all_gather_nvfp4(rowwise + columnwise + interleaved layout),_NVFP4AllGatherAsyncHandletransformer_engine/pytorch/csrc/quantizer.cppset_usage(rowwise, columnwise)toggles for GTP fwd/bwd asymmetry;make_empty(..., dtype=BF16)for padded GTP bufferstransformer_engine/pytorch/csrc/extensions/cast.cppcompute_multi_amax_nvfp4+quantize_cast_only_nvfp4(NVFP4 cast without recomputing amax — the GTP fast-path)transformer_engine/pytorch/csrc/extensions/pybind.cpptransformer_engine/pytorch/csrc/extensions.htransformer_engine/pytorch/csrc/common.htransformer_engine/common/include/transformer_engine/recipe.hnvte_compute_multi_amax_nvfp4transformer_engine/common/CMakeLists.txtmulti_amax.cuin the buildChecklist: