diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 9793df569c..deda80e537 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 9793df569ce413f4b1844a9176f7ae24dd981603 +Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index ba610dcf02..24ba9a38de 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.7.0.dev0 +2.7.0 diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index b4c8767a59..826d0d2fc7 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -219,7 +219,9 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast( + enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + ): encoder = Net(num_embed) # We use nn.Embed, thus inputs need to be in int inputs = jnp.zeros(input_shape, dtype=jnp.int32) diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 110705d015..92baf4b0c5 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -193,7 +193,9 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast( + enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + ): cnn = Net(args.use_te) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) tx = optax.sgd(args.lr, args.momentum) diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index cd46b0b63c..aa56d69ed6 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp cmake -GNinja -Bbuild . cmake --build build export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) -ctest --test-dir build -j$NUM_PARALLEL_JOBS +ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)' diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 482ae6dcab..394273ca47 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,8 +23,6 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime" -pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" @@ -40,7 +38,6 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh new file mode 100755 index 0000000000..f4f914b3e9 --- /dev/null +++ b/qa/L1_cpp_distributed/test.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Find TE +: ${TE_PATH:=/opt/transformerengine} +TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') +export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH + +cd $TE_PATH/tests/cpp +cmake -GNinja -S. -Bbuild +cmake --build build +mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh new file mode 100644 index 0000000000..1486d50971 --- /dev/null +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +pip3 install onnxruntime==1.20.1 +pip3 install onnxruntime_extensions==0.13.0 + +: ${TE_PATH:=/opt/transformerengine} + +python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/setup.py b/setup.py index 0b1b523277..52adaf9238 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ """Installation script.""" +from importlib import metadata import os import time from pathlib import Path @@ -66,6 +67,18 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))): + cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON") + cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution( + "nvidia-cublasmp-cu12" + ).locate_file("nvidia/cublasmp/cu12") + cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") + nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( + "nvidia-nvshmem-cu12" + ).locate_file("nvidia/nvshmem") + cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") + print("CMAKE_FLAGS:", cmake_flags[-2:]) + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index eb2825ba41..c2c9d0d915 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -37,6 +37,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_ message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) diff --git a/tests/cpp/comm_gemm/CMakeLists.txt b/tests/cpp/comm_gemm/CMakeLists.txt new file mode 100644 index 0000000000..55f5207acf --- /dev/null +++ b/tests/cpp/comm_gemm/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +add_executable(test_comm_gemm + test_comm_gemm.cu + ../test_common.cu) + +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) +target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) +target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) + +include(GoogleTest) +gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/comm_gemm/test_comm_gemm.cu b/tests/cpp/comm_gemm/test_comm_gemm.cu new file mode 100644 index 0000000000..b34d4db4b8 --- /dev/null +++ b/tests/cpp/comm_gemm/test_comm_gemm.cu @@ -0,0 +1,441 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../test_common.h" +#include "common.h" + +using transformer_engine::DType; +using transformer_engine::TypeInfo; + +#define CHECK_MPI(expr) \ + do { \ + int err = (expr); \ + if (err != MPI_SUCCESS) { \ + char err_str[MPI_MAX_ERROR_STRING + 1]{}; \ + int _len{}; \ + MPI_Error_string(err, err_str, &_len); \ + EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \ + } \ + } while (false) + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t err = (expr); \ + if (err != ncclSuccess) { \ + EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \ + } \ + } while (false) + +#define CHECK_CU(expr) \ + do { \ + CUresult err = (expr); \ + if (err != CUDA_SUCCESS) { \ + const char* str{}; \ + CUresult e_str = cuGetErrorString(err, &str); \ + if (e_str != CUDA_SUCCESS) str = "(unknown)"; \ + EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \ + } \ + } while (false) + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + CHECK_MPI(MPI_Init(&argc, &argv)); + auto ret = RUN_ALL_TESTS(); + CHECK_MPI(MPI_Finalize()); + return ret; +} + +bool IsMulticastSupported(int device_id) { + int supported = 0; + CHECK_CU(cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_id)); + return supported; +} + +template +std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld) { + std::vector ret(msize * nsize); + size_t dst = 0; + for (size_t j = nstart; j < nstart + nsize; ++j) { + for (size_t i = mstart; i < mstart + msize; ++i) { + ret[dst++] = data[j * ld + i]; + } + } + return ret; +} + +template +test::Tensor Make(size_t m, size_t n, float scale) { + test::Tensor ret("", std::vector{n, m}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + return ret; +} + +template +test::Tensor MakeFromData(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld, float scale) { + test::Tensor ret("", std::vector{nsize, msize}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + auto local = CopyMatrix(data, mstart, nstart, msize, nsize, ld); + NVTE_CHECK_CUDA(cudaMemcpy(ret.rowwise_dptr(), local.data(), local.size() * sizeof local[0], + cudaMemcpyDefault)); + return ret; +} + +template +float GetScale(float amax) { + if constexpr (sizeof(T) > 1) return 1.0; + return static_cast(static_cast(std::numeric_limits::max())) / amax; +} + +struct Params { + DType a_type; + DType b_type; + DType d_type; + bool transa; + bool transb; + size_t m; + size_t n; + size_t k; + float tol; +}; + +class CommGemmFixure : public ::testing::TestWithParam { + protected: + CommGemmFixure() { + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks_)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank_)); + NVTE_CHECK_CUDA(cudaSetDevice(rank_)); + ncclUniqueId id{}; + if (rank_ == 0) CHECK_NCCL(ncclGetUniqueId(&id)); + CHECK_MPI(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + CHECK_NCCL(ncclCommInitRank(&comm_, nranks_, id, rank_)); + ctx_ = nvte_comm_gemm_ctx_create(comm_, nranks_, rank_); + } + ~CommGemmFixure() { + nvte_comm_gemm_ctx_destroy(ctx_); + ncclCommDestroy(comm_); + } + + struct PatternDims { + int64_t a_rows_start; + int64_t a_rows_num; + int64_t a_cols_start; + int64_t a_cols_num; + int64_t b_rows_start; + int64_t b_rows_num; + int64_t b_cols_start; + int64_t b_cols_num; + int64_t d_rows_start; + int64_t d_rows_num; + int64_t d_cols_start; + int64_t d_cols_num; + }; + + virtual PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) = 0; + + virtual void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) = 0; + + template + void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { + cudaStream_t stream{}; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + constexpr float MAX_IN = 1.0; + std::mt19937 rng(12); + std::uniform_real_distribution dist(0.0, MAX_IN); + + float a_scale = GetScale(MAX_IN); + float b_scale = GetScale(MAX_IN); + float d_scale = GetScale(MAX_IN * MAX_IN * k); + float bias_scale = GetScale(MAX_IN); + + std::vector adata(m * k); + std::generate(adata.begin(), adata.end(), + [&rng, &dist, a_scale] { return static_cast(dist(rng) * a_scale); }); + std::vector bdata(k * n); + std::generate(bdata.begin(), bdata.end(), + [&rng, &dist, b_scale] { return static_cast(dist(rng) * b_scale); }); + std::vector biasdata(m * n); + std::generate(biasdata.begin(), biasdata.end(), [&rng, &dist, bias_scale] { + return static_cast(dist(rng) * bias_scale); + }); + + auto ga = transa ? MakeFromData(adata, 0, 0, k, m, k, a_scale) + : MakeFromData(adata, 0, 0, m, k, m, a_scale); + auto gb = transb ? MakeFromData(bdata, 0, 0, n, k, n, b_scale) + : MakeFromData(bdata, 0, 0, k, n, k, b_scale); + auto gbias = MakeFromData(biasdata, 0, 0, m, n, m, bias_scale); + auto gd = Make(m, n, d_scale); + auto gaux = Make(m, n, d_scale); + + auto dims = DistributeTensors(m, n, k); + auto a = transa ? MakeFromData(adata, dims.a_rows_start, dims.a_cols_start, + dims.a_rows_num, dims.a_cols_num, k, a_scale) + : MakeFromData(adata, dims.a_cols_start, dims.a_rows_start, + dims.a_cols_num, dims.a_rows_num, m, a_scale); + auto b = transb ? MakeFromData(bdata, dims.b_cols_start, dims.b_rows_start, + dims.b_cols_num, dims.b_rows_num, n, b_scale) + : MakeFromData(bdata, dims.b_rows_start, dims.b_cols_start, + dims.b_rows_num, dims.b_cols_num, k, b_scale); + auto bias = MakeFromData(biasdata, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m, bias_scale); + auto d = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + auto aux = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + + bool grad = false; + bool accumulate = false; + CommGemm(m, n, k, a.data(), b.data(), d.data(), bias.data(), aux.data(), transa, transb, grad, + accumulate, 0 /*comm_sm_count*/, stream); + auto workspace = Make(1, 32 << 20, 1.0); + nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb, + grad, workspace.data(), accumulate, false /* use_split_accumulator */, + 0 /* math_sm_count */, stream); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + std::vector out(dims.d_rows_num * dims.d_cols_num); + NVTE_CHECK_CUDA( + cudaMemcpy(out.data(), d.rowwise_dptr(), out.size() * sizeof out[0], cudaMemcpyDefault)); + std::vector out_golden_global(m * n); + NVTE_CHECK_CUDA(cudaMemcpy(out_golden_global.data(), gd.rowwise_dptr(), + out_golden_global.size() * sizeof out_golden_global[0], + cudaMemcpyDefault)); + + auto out_golden = CopyMatrix(out_golden_global, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m); + NVTE_CHECK(out.size() == out_golden.size()); + for (size_t i = 0; i < out.size(); ++i) { + EXPECT_NEAR(static_cast(out[i]), static_cast(out_golden[i]), tol * k); + } + } + + NVTECommGemmCtx* ctx_{}; + int nranks_{}; + int rank_{}; + ncclComm_t comm_{}; +}; + +struct AgGemm : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto a_cols_num = nvte_comm_gemm_numroc(ctx_, m); + auto b_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t a_cols_start{}; + int64_t b_cols_start{}; + MPI_Exscan(&a_cols_num, &a_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&b_cols_num, &b_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = 0, + .a_rows_num = k, + .a_cols_start = a_cols_start, + .a_cols_num = a_cols_num, + .b_rows_start = 0, + .b_rows_num = k, + .b_cols_start = b_cols_start, + .b_cols_num = b_cols_num, + .d_rows_start = a_cols_start, + .d_rows_num = a_cols_num, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_all_gather_gemm(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmRs : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + auto d_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t rows_start{}; + int64_t d_cols_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&d_cols_num, &d_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = d_cols_start, + .d_cols_num = d_cols_num, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_reduce_scatter(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmAr : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + + int64_t rows_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_all_reduce(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } + + void SetUp() override { + if (!IsMulticastSupported(rank_)) + GTEST_SKIP() << "Multicast is not supported on device " << rank_; + } +}; + +TEST_P(AgGemm, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmRs, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmAr, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +std::string ParamSuffix(const testing::TestParamInfo& info) { + const auto [a_type, b_type, d_type, transa, transb, m, n, k, _tol] = info.param; + std::ostringstream ss; + ss << static_cast(a_type) << "_" << static_cast(b_type) << "_" + << static_cast(d_type) << "_" << (transa ? "T" : "N") << (transb ? "T" : "N") << "_" << m + << "x" << n << "x" << k; + return ss.str(); +} + +INSTANTIATE_TEST_SUITE_P(AgGemm, AgGemm, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P( + GemmAr, GemmAr, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}), + &ParamSuffix); diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 79186aa478..e3b1ecac96 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -173,7 +173,7 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): single_jitter = jax.jit( value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), @@ -330,7 +330,7 @@ def _test_layernorm_mlp( with use_jax_gemm(enabled=with_jax_gemm): # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, intermediate_dim=INTERMEDIATE, diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d59e130530..0d0dba5475 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -28,6 +28,7 @@ is_fp8_available, update_collections, ) +from transformer_engine.jax.sharding import MeshResource, global_shard_guard @pytest.fixture(autouse=True, scope="function") @@ -490,19 +491,28 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" QuantizeConfig.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_forward(data_shape, dtype) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" QuantizeConfig.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_backward(data_shape, dtype) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -510,7 +520,10 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 3088853a25..56bfa14234 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -274,6 +274,8 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference } diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 543f5f08d4..773031ece7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -111,13 +111,18 @@ def is_fused_attn_available( - config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True + config: ModelConfig, + dtype: torch.dtype, + qkv_layout="bshd_bshd_bshd", + is_training=True, + deterministic=False, ): _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, + deterministic=deterministic, ) return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends @@ -825,7 +830,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] - if not is_fused_attn_available(config, dtype): + if not is_fused_attn_available(config, dtype, deterministic=True): pytest.skip("No attention backend available.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -873,7 +878,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] - if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): pytest.skip("No attention backend available.") te_gpt = TransformerLayer( @@ -986,7 +993,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] - if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): pytest.skip("No attention backend available.") te_mha = MultiheadAttention( diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 839fb8dff8..b353333a50 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -36,6 +36,7 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import get_default_init_method +import tensorrt as trt # Global test configuration knobs. @@ -113,7 +114,7 @@ def trt_fp8_dequantize(t, scale): @onnx_op( - op_type="trt::TRT_MXFP8QuantizeLinear", + op_type="trt::TRT_MXFP8DynamicQuantize", domain="trt", inputs=[ PyCustomOpDef.dt_float, @@ -1139,3 +1140,59 @@ def test_export_ctx_manager(enabled): with te.onnx_export(enabled): assert is_in_onnx_export_mode() == enabled assert is_in_onnx_export_mode() == False + + +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +def test_trt_integration(fp8_recipe: recipe.Recipe): + + model = te.TransformerLayer( + hidden_size=128, + ffn_hidden_size=128, + num_attention_heads=4, + ).eval() + inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) + + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + out_ref = model(*inps) + + onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx") + os.close(onnx_fd) + try: + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.onnx_export(enabled=True): + torch.onnx.export( + model, + inps, + onnx_path, + output_names=["output"], + dynamo=True, + custom_translation_table=te_translation_table, + ) + + os.system(f"trtexec --onnx={onnx_path} --saveEngine={onnx_path}.engine") + + # Run TRT engine + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + with open(onnx_path + ".engine", "rb") as f: + engine_data = f.read() + engine = runtime.deserialize_cuda_engine(engine_data) + context = engine.create_execution_context() + context.set_tensor_address(engine.get_tensor_name(0), inps[0].data_ptr()) + stream = torch.cuda.Stream() + + out = torch.zeros_like(out_ref) + context.set_tensor_address("output", out.data_ptr()) + + context.execute_async_v3(stream_handle=stream.cuda_stream) + stream.synchronize() + + # Compare TRT and TE outputs + atol = 5e-2 if fp8_recipe is not None else 1e-4 + rtol = 5e-2 if fp8_recipe is not None else 1e-4 + torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) + finally: + try: + os.remove(onnx_path) + except FileNotFoundError: + pass diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 524bd3289c..38f400f659 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -266,8 +266,8 @@ def test(): ) ( use_flash_attention, - use_fused_attention, flash_attention_backend, + use_fused_attention, fused_attention_backend, use_unfused_attention, available_backends, diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b51e61929b..183a7a72ec 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -110,6 +110,12 @@ list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/comm_gemm_overlap.cpp) + +if (NVTE_WITH_CUBLASMP) +list(APPEND transformer_engine_SOURCES + comm_gemm/comm_gemm.cpp) +endif() + add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -123,6 +129,8 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine SYSTEM PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI @@ -141,6 +149,25 @@ if (NVTE_ENABLE_NVSHMEM) target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) endif() +option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) +if (NVTE_WITH_CUBLASMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + find_library(CUBLASMP_LIB + NAMES cublasmp libcublasmp + PATHS ${CUBLASMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + find_library(NVSHMEM_HOST_LIB + NAMES nvshmem_host libnvshmem_host.so.3 + PATHS ${NVSHMEM_DIR} + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") + message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp new file mode 100644 index 0000000000..76f46298db --- /dev/null +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -0,0 +1,519 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/comm_gemm.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +using namespace transformer_engine; + +namespace { + +// TODO: log warnings on failures of the *Destroy calls below, once TE has such ability. +// For now, just silently ignoring the errors, since the only diag available in TE is throwing +// exceptions, but these calls will typically be made from destructors, so cannot throw. + +template +auto CreateWithCudaCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + NVTE_CHECK_CUDA(create_fn(&raw, std::forward(args)...)); + return std::unique_ptr(raw, destroy_fn); +} + +using CudaStream = + std::unique_ptr, decltype(&cudaStreamDestroy)>; + +CudaStream CudaStreamCreate() { + return CreateWithCudaCheck(cudaStreamCreate, cudaStreamDestroy); +} + +using CudaEvent = std::unique_ptr, decltype(&cudaEventDestroy)>; + +CudaEvent CudaEventCreate(unsigned flags) { + return CreateWithCudaCheck(cudaEventCreateWithFlags, cudaEventDestroy, flags); +} + +template +auto CreateWithCublasMpCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + if constexpr (raw_last) { + NVTE_CHECK_CUBLASMP(create_fn(std::forward(args)..., &raw)); + } else { + NVTE_CHECK_CUBLASMP(create_fn(&raw, std::forward(args)...)); + } + return std::unique_ptr(raw, destroy_fn); +} + +using CublasMp = + std::unique_ptr, decltype(&cublasMpDestroy)>; + +CublasMp CublasMpCreate(cudaStream_t stream) { + return CreateWithCublasMpCheck(cublasMpCreate, cublasMpDestroy, stream); +} + +using CublasMpGrid = + std::unique_ptr, decltype(&cublasMpGridDestroy)>; + +CublasMpGrid CublasMpGridCreate(int64_t nprow, int64_t npcol, cublasMpGridLayout_t layout, + ncclComm_t comm) { + return CreateWithCublasMpCheck(cublasMpGridCreate, cublasMpGridDestroy, + nprow, npcol, layout, comm); +} + +using CublasMpMatrixDesc = std::unique_ptr, + decltype(&cublasMpMatrixDescriptorDestroy)>; + +CublasMpMatrixDesc CublasMpMatrixDescCreate(int64_t m, int64_t n, int64_t mb, int64_t nb, + int64_t rsrc, int64_t csrc, int64_t lld, + cudaDataType_t type, cublasMpGrid_t grid) { + return CreateWithCublasMpCheck( + cublasMpMatrixDescriptorCreate, cublasMpMatrixDescriptorDestroy, m, n, mb, nb, rsrc, csrc, + lld, type, grid); +} + +using CublasMpMatmulDesc = std::unique_ptr, + decltype(&cublasMpMatmulDescriptorDestroy)>; + +CublasMpMatmulDesc CublasMpMatmulDescCreate(cublasComputeType_t compute_type) { + return CreateWithCublasMpCheck( + cublasMpMatmulDescriptorCreate, cublasMpMatmulDescriptorDestroy, compute_type); +} + +} // namespace + +struct NVTECommGemmCtx { + int64_t nranks; + int64_t rank; + ncclComm_t comm; + CudaStream stream; + CudaEvent event; + CublasMp cublas_mp; + CublasMpGrid grid_col_major; + CublasMpGrid grid_row_major; + CublasMpMatrixDesc a_desc; + CublasMpMatrixDesc b_desc; + CublasMpMatrixDesc d_desc; + CublasMpMatmulDesc matmul_desc; + void* workspace; + size_t workspace_size; +}; + +namespace { + +int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) { + // Use non-cyclic layout to maximize opportunity for comm overlap. + return (global_size + ctx->nranks - 1) / ctx->nranks; +} + +void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, k, block_size(ctx, m), 0, 0, k, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a0 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, block_size(ctx, m), k, 0, 0, + block_size(ctx, m), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b0 == k, "Unsupported tensor dimensionin B: expected ", k, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), k, 0, 0, + block_size(ctx, n), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b1 == k, "Unsupported tensor dimension in B: expected ", k, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, k, block_size(ctx, n), 0, 0, k, + get_cuda_dtype(b->dtype()), + ctx->grid_row_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d0 == n, "Unsupported tensor dimension in D: expected ", n, ", got ", d0); + *ldd = block_size(ctx, m); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, block_size(ctx, m), block_size(ctx, n), 0, + 0, *ldd, get_cuda_dtype(d->dtype()), + ctx->grid_col_major.get(), ctx->d_desc.get())); +} + +void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a1 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, m, block_size(ctx, k), 0, 0, m, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n), + get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + k, n, block_size(ctx, k), block_size(ctx, n), 0, 0, block_size(ctx, k), + get_cuda_dtype(b->dtype()), ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); +} + +void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_ERROR("N transpose flag is not supported for input A"); + } + if (transb) { + NVTE_ERROR("T transpose flag is not supported for input B"); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, block_size(ctx, k), n, 0, 0, + block_size(ctx, k), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n * ctx->nranks, m, n, 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); + + const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); +} + +using InitMatricesFn = void (*)(NVTECommGemmCtx*, int64_t*, int64_t, int64_t, int64_t, + const Tensor*, const Tensor*, const Tensor*, bool, bool); + +cublasMpMatmulAlgoType_t cublasmp_algo(NVTECommGemmAlgoType algo) { + static const std::unordered_map s_map{ + {kNVTECommGemmAlgoDefault, CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT}, + {kNVTECommGemmAlgoSplitP2P, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_P2P}, + {kNVTECommGemmAlgoSplitMulticast, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_MULTICAST}, + {kNVTECommGemmAlgoAtomicP2P, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_P2P}, + {kNVTECommGemmAlgoAtomicMulticast, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_MULTICAST}, + }; + auto it = s_map.find(algo); + return it != s_map.end() ? it->second : static_cast(algo); +} + +void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECommGemmAlgoType algo, + int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b, + const Tensor* d, const Tensor* bias, const Tensor* pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream) { + for (auto t : {a, b, d}) { + NVTE_CHECK(is_tensor_scaling(t->scaling_mode), + "Unsupported scaling mode: " + std::to_string(t->scaling_mode)); + } + + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorInit(ctx->matmul_desc.get(), CUBLAS_COMPUTE_32F)); + + int64_t ldd{}; + init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb); + + const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, + sizeof trans_a)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, + sizeof trans_b)); + cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, + sizeof algo_attr)); + + const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; + if (is_fp8_dtype(a->dtype())) { + NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, + &a->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(b->dtype())) { + NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, + &b->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(d->dtype())) { + NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, + &d->scale.dptr, sizeof(void*))); + if (d->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, + &d->amax.dptr, sizeof(void*))); + } + } + + // Might be set to ALLREDUCE before, need to OR with the new flags to set. + cublasMpMatmulEpilogue_t epilogue{}; + size_t size_read{}; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue, &size_read)); + NVTE_CHECK(size_read == sizeof epilogue); + // (bias, gelu, grad) -> epilogue + const std::map, cublasMpMatmulEpilogue_t> flags_to_epilogue{ + {{true, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX_BIAS}, + {{true, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU_BGRAD}, + {{true, false, false}, CUBLASMP_MATMUL_EPILOGUE_BIAS}, + {{true, false, true}, CUBLASMP_MATMUL_EPILOGUE_BGRADB}, + {{false, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX}, + {{false, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU}, + }; + if (auto it = + flags_to_epilogue.find({bias ? bias->data.dptr != nullptr : false, + pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); + it != flags_to_epilogue.end()) { + epilogue = static_cast(epilogue | it->second); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); + } + + if (bias && bias->data.dptr) { + cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, + sizeof bias_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, + sizeof bias->data.dptr)); + } + + if (pre_act_out && pre_act_out->data.dptr) { + cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, + &aux_type, sizeof aux_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, + &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, + sizeof ldd)); + if (is_fp8_dtype(pre_act_out->dtype())) { + NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, + &scale_mode, sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, + &pre_act_out->scale.dptr, sizeof(void*))); + if (pre_act_out->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, + &pre_act_out->amax.dptr, sizeof(void*))); + } + } + } + + if (comm_sm_count) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, + &comm_sm_count, sizeof comm_sm_count)); + } + + NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + + size_t wrksp_size_device{}; + size_t wrksp_size_host{}; + + float alpha = 1.0; + float beta = accumulate ? 1.0 : 0.0; + std::tuple args{ctx->cublas_mp.get(), + ctx->matmul_desc.get(), + m, + n, + k, + &alpha, + a->data.dptr, + 1, + 1, + ctx->a_desc.get(), + b->data.dptr, + 1, + 1, + ctx->b_desc.get(), + &beta, + accumulate ? d->data.dptr : nullptr, + 1, + 1, + accumulate ? ctx->d_desc.get() : nullptr, + d->data.dptr, + 1, + 1, + ctx->d_desc.get()}; + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul_bufferSize, + std::tuple_cat(args, std::tuple{&wrksp_size_device, &wrksp_size_host}))); + + std::vector workspace_host(wrksp_size_host); + if (ctx->workspace_size < wrksp_size_device) { + nvshmem_free(ctx->workspace); + ctx->workspace = nvshmem_malloc(wrksp_size_device); + ctx->workspace_size = wrksp_size_device; + } + + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul, + std::tuple_cat(args, std::tuple{ctx->workspace, ctx->workspace_size, + workspace_host.data(), workspace_host.size()}))); + + NVTE_CHECK_CUDA(cudaEventRecord(ctx->event.get(), main_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream.get(), ctx->event.get(), 0)); +} + +} // namespace + +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_API_CALL(nvte_comm_gemm_ctx_create); + auto stream = CudaStreamCreate(); + auto event = CudaEventCreate(cudaEventDisableTiming); + auto cublas_mp = CublasMpCreate(stream.get()); + + auto col_major = CublasMpGridCreate(nranks, 1, CUBLASMP_GRID_LAYOUT_COL_MAJOR, comm); + auto row_major = CublasMpGridCreate(1, nranks, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, comm); + + // Pre-creating matrix descriptors here, will be initialized with the actual params later. + auto a_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto b_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto d_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + + auto matmul_desc = CublasMpMatmulDescCreate(CUBLAS_COMPUTE_32F); + + return new NVTECommGemmCtx{ + .nranks = nranks, + .rank = rank, + .comm = comm, + .stream = std::move(stream), + .event = std::move(event), + .cublas_mp = std::move(cublas_mp), + .grid_col_major = std::move(col_major), + .grid_row_major = std::move(row_major), + .a_desc = std::move(a_desc), + .b_desc = std::move(b_desc), + .d_desc = std::move(d_desc), + .matmul_desc = std::move(matmul_desc), + }; +} + +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { + NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); + nvshmemx_sync_all_on_stream(ctx->stream.get()); + delete ctx; +} + +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_all_gather_gemm); + cublasmp_gemm(AgGemmInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_reduce_scatter); + cublasmp_gemm(GemmRsInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_all_reduce); + cublasmp_gemm(GemmArInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) { + NVTE_API_CALL(nvte_comm_gemm_numroc); + return cublasMpNumroc(global_size, block_size(ctx, global_size), ctx->rank, 0, ctx->nranks); +} diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 4e697979d8..a810fb4717 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -26,6 +26,24 @@ __global__ void __launch_bounds__(1) } // namespace +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kFloat16: + return CUDA_R_16F; + case DType::kFloat32: + return CUDA_R_32F; + case DType::kBFloat16: + return CUDA_R_16BF; + case DType::kFloat8E4M3: + return CUDA_R_8F_E4M3; + case DType::kFloat8E5M2: + return CUDA_R_8F_E5M2; + default: + NVTE_ERROR("Invalid type"); + } +} + void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index aa47f2c3d9..e2a3c52aa2 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -270,6 +270,8 @@ struct QuantizationConfig { }; }; +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t); + template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); @@ -382,9 +384,19 @@ struct BitsNumber { template struct TypeInfo { #if FP4_TYPE_SUPPORTED - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #else - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #endif template diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index bb30261b91..60b10862e6 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -252,8 +252,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training && - sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && + (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 || + cudnn_runtime_version == 91300) && + is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1c4af23eb8..9e6c5417bc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -22,24 +22,6 @@ namespace { -cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kFloat16: - return CUDA_R_16F; - case DType::kFloat32: - return CUDA_R_32F; - case DType::kBFloat16: - return CUDA_R_16BF; - case DType::kFloat8E4M3: - return CUDA_R_8F_E4M3; - case DType::kFloat8E5M2: - return CUDA_R_8F_E5M2; - default: - NVTE_ERROR("Invalid type"); - } -} - uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; @@ -517,22 +499,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &epilogue, sizeof(epilogue))); if (counter != nullptr) { -#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) - NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", +#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); #endif #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( - "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); #endif #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ CUBLAS_VERSION < 130000 NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, - "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ", + "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, - "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; @@ -658,20 +640,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor using namespace transformer_engine; // Check CUDA and cuBLAS versions -#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) - NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", +#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); #endif #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) - NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", - CUBLAS_VERSION); + NVTE_ERROR( + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", + CUBLAS_VERSION); #endif - NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, - "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ", - cuda::cudart_version()); + NVTE_CHECK( + cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", + cuda::cudart_version()); NVTE_CHECK( cublas_version() >= 120205 && cublas_version() < 130000, - "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); const Tensor *inputA = convertNVTETensorCheck(A); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h new file mode 100644 index 0000000000..14cf56a002 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -0,0 +1,156 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_gemm.h + * \brief Functions for distributed (multi-GPU) matrix multiplication. + * + * This API is a TE-native binding to cuBLASMp library. + * Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific + * patterns, which allow communication-computation overlap. + * + * All GEMM functions here have the same computation semantic, as expressed + * on global matrices, similar to nvte_cublas_gemm call: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * Functions differ in matrix distribution patterns + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +typedef struct NVTECommGemmCtx NVTECommGemmCtx; + +enum NVTECommGemmAlgoType { + kNVTECommGemmAlgoDefault = 0, + kNVTECommGemmAlgoSplitP2P = 1, + kNVTECommGemmAlgoSplitMulticast = 2, + kNVTECommGemmAlgoAtomicP2P = 3, + kNVTECommGemmAlgoAtomicMulticast = 4 +}; + +/*! \brief Create a comm-gemm context. + * + * \param[in] comm NCCL communicator. + * \param[in] nranks Number of ranks. + * \param[in] rank Local rank. + */ +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank); + +/*! \brief Destroy a comm-gemm context. + * + * \param[in] ctx Context to destroy. + */ +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); + +/*! \brief Perform AllGather communication followed by GEMM + * + * Gathers distributed data from all ranks, then computes matrix multiplication. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by ReduceScatter communication + * + * Computes matrix multiplication, then distributes results across ranks with reduction. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by AllReduce communication + * + * Computes matrix multiplication, then reduces results across all ranks. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Get local number of rows or columns. + * + * Utility function to get local dimension. + * Block size, nranks and local rank is derived from the context ctx. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] global_size Global dimension. + */ +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_ diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 3a2247f5cf..a603d1f1a2 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,7 +14,6 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" -#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -168,12 +167,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } } -// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's -// store to global memory. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - cudaTriggerProgrammaticLaunchCompletion(); -#endif - // Step 3: Store cast output, Step 4: do transpose within thread tile OVecCast tmp_output_c; @@ -397,12 +390,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } -// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's -// store to global memory. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - cudaTriggerProgrammaticLaunchCompletion(); -#endif - // Step 3: Store cast output, Step 4: do transpose within thread tile // Edge case: in the non-full tile case, there are three subcases // for full thread tile, it's the same thing here @@ -526,15 +513,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); - dim3 grid(num_blocks_x, num_blocks_y, 1); - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; - cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0}; - if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) { - cfg.attrs = attribute; - cfg.numAttrs = 1; - } TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -545,6 +523,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transpose, kReturnTranspose, + dim3 grid(num_blocks_x, num_blocks_y, 1); const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; @@ -554,28 +533,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor tensor_map_output_trans = get_tensor_map(output_t, num_rows, row_length); } - cudaLaunchKernelEx(&cfg, - block_scaled_cast_transpose_kernel, - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, - scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale); + block_scaled_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + tensor_map_output_trans, pow_2_scale); } else { - cudaLaunchKernelEx( - &cfg, - block_scaled_cast_transpose_kernel_notaligned, - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - pow_2_scale); + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + pow_2_scale); } // full-tile ) // return_transpose ) // OutputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 5bf2f52010..6f5c0f3a6c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,7 +17,6 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" -#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -235,14 +234,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo __syncthreads(); -// If not return columnwise, we trigger the next kernel here so that it's load from global memory -// can overlap with this kernel's return rowwise. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - if (!return_columnwise_gemm_ready && !return_columnwise_compact) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - // Step 2: Cast and store to output_c if (return_rowwise) { constexpr int r_stride = @@ -334,14 +325,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } -// If return columnwise, we trigger the next kernel here so that it's load from global memory -// can overlap with this kernel's return columnwise. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - if (return_columnwise_gemm_ready || return_columnwise_compact) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t if (return_columnwise_gemm_ready) { constexpr int c_stride = @@ -601,10 +584,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); - dim3 grid(num_blocks_x, num_blocks_y, 1); - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -612,38 +591,31 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output.dtype, OutputType, + dim3 grid(num_blocks_x, num_blocks_y, 1); + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; TRANSFORMER_ENGINE_SWITCH_CONDITION( full_tile, kAligned, size_t smem_bytes = kSMemSize * sizeof(InputType); - - cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0}; - if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= - 90) { - cfg.attrs = attribute; - cfg.numAttrs = 1; - } // shared memory must be requested up if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( &block_scaled_1d_cast_transpose_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); - } cudaLaunchKernelEx(&cfg, - block_scaled_1d_cast_transpose_kernel, - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, - scale_t_stride_y, epsilon, rowwise_option, columnwise_option, - pow2_scale);) // kAligned - ) // OutputType - ) // InputType + } block_scaled_1d_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, + columnwise_option, pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index c084c31165..9a02d71f2d 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -203,11 +203,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], parity); - // Trigger the next kernel, so its TMA load can be overlapped with the current kernel - if (stage == STAGES - 1) { - cudaTriggerProgrammaticLaunchCompletion(); - } - float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; @@ -1127,13 +1122,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0}; - // This kernel will only be called on sm100+, so no need to check sm_arch - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute; - cfg.numAttrs = 1; - switch (scaling_type) { case ScalingType::ROWWISE: cudaFuncSetAttribute( @@ -1141,13 +1129,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cudaLaunchKernelEx( - &cfg, - cast_mxfp8_2D_kernel, - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); break; case ScalingType::COLWISE: cudaFuncSetAttribute( @@ -1155,13 +1143,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cudaLaunchKernelEx( - &cfg, - cast_mxfp8_2D_kernel, - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); break; case ScalingType::BIDIMENSIONAL: cudaFuncSetAttribute( @@ -1169,13 +1157,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cudaLaunchKernelEx( - &cfg, - cast_mxfp8_2D_kernel, - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); break; } diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 173aad52af..941899b28c 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -12,8 +12,13 @@ #include #include +#ifdef NVTE_WITH_CUBLASMP +#include +#endif // NVTE_WITH_CUBLASMP + #include #include +#include #include "../util/string.h" @@ -87,4 +92,16 @@ } \ } while (false) +#ifdef NVTE_WITH_CUBLASMP + +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ + } \ + } while (false) + +#endif // NVTE_WITH_CUBLASMP + #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9975f558bf..188b376015 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -8,6 +8,7 @@ from collections.abc import Iterable from typing import Tuple, Sequence, Union from functools import partial, reduce +import warnings import jax import jax.numpy as jnp @@ -34,6 +35,7 @@ is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, ) +from ..sharding import global_mesh_resource from .misc import get_padded_spec @@ -490,7 +492,8 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs + None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec + for spec in rhs_non_cspecs ) # Non-contracting dims of LHS to be gathered along the SP axis. @@ -656,6 +659,12 @@ def shardy_sharding_rule( prefix = "GemmPrimitive_" + warnings.warn( + "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," + " please turn off Shardy by exporting the environment variable" + " 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems." + ) + def _generate_operand_rules(name, ndim, cdims): specs = [] ldims = tuple(i for i in range(ndim) if i not in cdims) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index d85593c1e4..fb3ac7b9ae 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -26,6 +26,7 @@ from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import fused_attn +from ..attention import CPStrategy from ..softmax import SoftmaxType from ..sharding import num_of_devices from ..sharding import get_sharding_map_logic_axis_to_mesh_axis @@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT context_checkpoint_name: str = "context" @nn.compact @@ -323,6 +325,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_kvpacked(): @@ -350,6 +353,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_separate(): @@ -372,6 +376,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, ) else: @@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. Optimization parameters @@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_parallel_strategy: str = "DEFAULT" context_checkpoint_name: str = "context" @nn.compact @@ -648,6 +655,24 @@ def __call__( scale_factor = self.scale_factor del self.scale_factor + # case-insensitive mapping for context parallel strategy + cp_strategy_map = { + "DEFAULT": CPStrategy.DEFAULT, + "ALL_GATHER": CPStrategy.ALL_GATHER, + "ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling + "RING": CPStrategy.RING, + } + + strategy_key = self.context_parallel_strategy.upper() + if strategy_key in cp_strategy_map: + context_parallel_strategy = cp_strategy_map[strategy_key] + else: + valid_strategies = list(cp_strategy_map.keys()) + raise ValueError( + f"Invalid context parallel strategy: {self.context_parallel_strategy}. " + f"Valid options are: {valid_strategies} (case insensitive)" + ) + if not use_fused_attn: # unfused attention only supports splitted query, key, value if qkv_layout.is_qkvpacked(): @@ -696,6 +721,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, )( query, diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 122265ea27..f8d18983e4 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -404,9 +404,6 @@ def fp8_autocast( if fp8_recipe is None: fp8_recipe = recipe.DelayedScaling() - if mesh_resource is None: - mesh_resource = MeshResource() - Config = DelayedScalingQuantizeConfig if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): Config = BlockScalingQuantizeConfig diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 6d4894fd89..480989dcd6 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -286,7 +286,7 @@ class MeshResource: cp_resource: str = None -_GLOBAL_MESH_RESOURCE = MeshResource() +_GLOBAL_MESH_RESOURCE = None @contextmanager @@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource: Returns: The current MeshResource instance """ + assert _GLOBAL_MESH_RESOURCE is not None, ( + "Global mesh resource is not set. Please set the MeshResource via a global_shard_guard" + " context. If you are not using multiple GPUs, you can use an empty MeshResource by" + " wrapping your program in 'with global_shard_guard(MeshResource()):'" + ) return _GLOBAL_MESH_RESOURCE diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem.py new file mode 100644 index 0000000000..b8689ef579 --- /dev/null +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem.py @@ -0,0 +1,4096 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Context Parallelism.""" +import os +from typing import List, Union +import torch +import transformer_engine_torch as tex +import nvshmem.core as nvshmem +import torch.distributed as dist +from cuda.core.experimental import Device +from cuda.core.experimental import Stream +from nvshmem.core.interop.torch import tensor_get_buffer +from transformer_engine.pytorch.utils import ( + combine_tensors, + get_cudnn_version, + nvtx_range_pop, + nvtx_range_push, + get_device_compute_capability, +) +from transformer_engine.pytorch.cpp_extensions.fused_attn import ( + fused_attn_fwd, + fused_attn_bwd, + FusedAttnBackend, +) +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.jit import jit_fuser +from transformer_engine.pytorch.constants import ( + dist_group_type, + TE_DType, +) +from transformer_engine.pytorch.distributed import ( + get_distributed_world_size, + get_distributed_rank, + gather_along_first_dim, + reduce_scatter_along_first_dim, +) +from transformer_engine.pytorch.tensor.quantized_tensor import ( + prepare_for_saving, + restore_from_saved, +) + +# Import attention utils +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + FlashAttentionUtils as fa_utils, +) + +_cu_seqlens_info_with_cp_cache = {} +_seq_chunk_ids_cache_for_reordering_before_attn = {} +_seq_chunk_ids_cache_for_reordering_after_attn = {} + + +def flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm +): + """Point-to-point communications of KV and dKV in Attention with context parallelism""" + send_recv_ops = [] + + if batch_p2p_comm: + if rank % 2 == 0: + send_op = torch.distributed.P2POp( + torch.distributed.isend, send_tensor, send_dst, cp_group + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, recv_src, cp_group + ) + send_recv_ops.append(send_op) + send_recv_ops.append(recv_op) + else: + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, recv_src, cp_group + ) + send_op = torch.distributed.P2POp( + torch.distributed.isend, send_tensor, send_dst, cp_group + ) + send_recv_ops.append(recv_op) + send_recv_ops.append(send_op) + send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops) + else: + if rank % 2 == 0: + send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) + recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) + send_recv_ops.append(send_op) + send_recv_ops.append(recv_op) + else: + recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) + send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) + send_recv_ops.append(recv_op) + send_recv_ops.append(send_op) + send_recv_reqs = send_recv_ops + + return send_recv_reqs + +@jit_fuser +def flash_attn_fwd_out_correction_init( + out_init_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_init_step: torch.Tensor, + seq_dim: int, +): + """Merge partial outputs of the first step in Attention with context parallelism""" + softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_init_step * softmax_lse_corrected_exp + return out_corrected.to(out_init_step.dtype) + + +@jit_fuser +def flash_attn_fwd_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Merge partial outputs of each step in Attention with context parallelism""" + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out.add_(out_corrected) + + +@jit_fuser +def flash_attn_fwd_second_half_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Merge second half of partial outputs of each step in Attention with context parallelism""" + out_ = out.select(seq_dim, 1) + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :] + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out_.add_(out_corrected) + + +@jit_fuser +def flash_attn_fwd_softmax_lse_correction( + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, +): + """Merge softmax stats of each step in Attention with context parallelism""" + max_scale = torch.max(softmax_lse, softmax_lse_per_step) + min_scale = torch.min(softmax_lse, softmax_lse_per_step) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) + softmax_lse.copy_(new_scale) + + +@jit_fuser +def flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, +): + """Merge second half of softmax stats of each step in Attention with context parallelism""" + softmax_lse_ = softmax_lse[..., 1, :] + max_scale = torch.max(softmax_lse_, softmax_lse_per_step) + min_scale = torch.min(softmax_lse_, softmax_lse_per_step) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) + softmax_lse_.copy_(new_scale) + + +@jit_fuser +def get_cu_seqlens_on_cp_rank( + cu_seqlens: torch.Tensor, + cu_seqlens_padded_on_cp_rank: torch.Tensor, + cp_size: int, + cp_rank: int, + first_half: bool, + second_half: bool, +): + """Compute cu_seqlens of a context parallelism rank""" + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2 + zeros = torch.zeros_like(seqlens) + cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens) + if first_half: + seqlens_1 = seqlens - cp_rank * seqlens_padded + seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_1) + if second_half: + seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded + seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_2) + cu_seqlens_on_cp_rank.cumsum_(dim=0) + return cu_seqlens_on_cp_rank + + +@jit_fuser +def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to + be contigupus before attention compute. This function is to compute sequence chunk ids for + reordering. + """ + global _seq_chunk_ids_cache_for_reordering_before_attn + if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_before_attn: + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + _seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)] = chunk_ids + return _seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)] + + +@jit_fuser +def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + We need to reorder sequence chunks back to discontiguous after attention compute. This function + is to compute sequence chunk ids for reordering. + """ + global _seq_chunk_ids_cache_for_reordering_after_attn + if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_after_attn: + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + _seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)] = chunk_ids + return _seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)] + + +@jit_fuser +def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication before attention compute.""" + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + return x + + +@jit_fuser +def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication after attention compute.""" + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) + return x + + +def flash_attn_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + chunk_ids_for_a2a: torch.Tensor, + seq_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """A2A communication for context parallelism.""" + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + if before_attn: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a_before_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] + # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] + # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + a2a_inputs[i] = x.movedim(-3, 0).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] + # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] + # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + +def _get_cu_seqlens_info_with_cp( + batch_size: int, + max_seqlen: int, + cp_size: int, + cu_seqlens: torch.Tensor, +): + """Cumulative sequence lengths with CP being considered.""" + global _cu_seqlens_info_with_cp_cache + if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache: + _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = ( + cu_seqlens // cp_size, + cu_seqlens // (cp_size * 2), + ) + return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] + + +def get_fa_args( + forward: bool, + use_flash_attn_3: bool, + qkv_format: str, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + dq=None, + dk=None, + dv=None, +): + """Get forward/backward arguments for flash-attn v2 and v3.""" + if use_flash_attn_3: + if forward: + if qkv_format == "thd": + return [ + *[None] * 4, # k_new, v_new, qv, out + cu_seqlens_q, + cu_seqlens_kv, + *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + return [ + *[None] + * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + return [ + None, # cu_seqlens_q + None, # cu_seqlens_kv + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + if forward: + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [] + if qkv_format == "thd": + return [ + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [ + dq, + dk, + dv, + ] + +def nvshmem_get_on_stream(dst_tensor: torch.Tensor, src_tensor: torch.Tensor, + peer: int, stream: Stream = None) -> None: + if stream is None: + stream = Stream.from_handle(torch.cuda.current_stream().cuda_stream) + + nvshmem.get(dst_tensor, src_tensor, remote_pe=peer, stream=stream) + +def torchrun_uid_init_bcast_object_no_reinit(cp_group=None): + local_rank = torch.cuda.current_device() + dev = Device(local_rank) + dev.set_current() + + if cp_group is None: + rank_id = dist.get_rank() + num_ranks = dist.get_world_size() + else: + rank_id = dist.get_rank(group=cp_group) + num_ranks = dist.get_world_size(group=cp_group) + + uniqueid = nvshmem.get_unique_id(empty=True) + + if rank_id == 0: + uniqueid = nvshmem.get_unique_id() + broadcast_objects = [uniqueid] + else: + broadcast_objects = [None] + + dist.broadcast_object_list( + broadcast_objects, + src=0, + group=cp_group + ) + + dist.barrier(group=cp_group) + + nvshmem.init( + device=dev, + uid=broadcast_objects[0], + rank=rank_id, + nranks=num_ranks, + initializer_method="uid" + ) + + return True + +class AttnFuncWithCPAndKVP2P(torch.autograd.Function): + """ + Attention implementation with context parallelism. Exchange KV between CP ranks + with P2P in ring topology. Split attention compute into multiple steps, and overlap + current-step compute with next-step communication. + + This implementation also supports hierarchical CP, which parallelizes attention + heads in low-level CP groups and parallelizes sequence dimension in high-level CP + groups. For more details, please refer to `LongVILA `_ + and `USP `_. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + fp8, + fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, + quantizers, + pad_between_seqs, + use_flash_attn_3, + ): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + enable_mla = k.shape[-1] != v.shape[-1] + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if isinstance(cp_group, list): + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert attn_bias_type == "no_bias", ( + f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" + " yet!" + ) + cp_group_a2a = cp_group[0] + cp_size_a2a = get_distributed_world_size(cp_group_a2a) + rank_a2a = get_distributed_rank(cp_group_a2a) + cp_group = cp_group[1] + else: + cp_group_a2a = None + cp_size_a2a = 1 + rank_a2a = 0 + + cp_size = get_distributed_world_size(cp_group) + rank = get_distributed_rank(cp_group) + send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] + device_compute_capability = get_device_compute_capability() + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or ( + device_compute_capability < (10, 0) and cp_size == 2 + ) + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + + batch_dim = None + seq_dim = None + cu_seqlens_q_half, cu_seqlens_kv_half = None, None + if qkv_format in ["bshd", "sbhd"]: + seq_dim = qkv_format.index("s") + if enable_mla: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + else: + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None + if use_fused_attention: + batch_dim = qkv_format.index("b") + cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q + ) + cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv + ) + else: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size + + max_seqlen_q = max_seqlen_q // cp_size + max_seqlen_kv = max_seqlen_kv // cp_size + cu_seqlens_q_per_step = [None for _ in range(cp_size)] + cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + + fused_attn_backend = None + qkv_dtype = q.dtype + amax_per_step = None + S_quantizer_per_step = [None for _ in range(cp_size)] + O_CP_quantizer_per_step = [None for _ in range(cp_size)] + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype + is_input_fp8 = False + is_output_fp8 = False + + ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + + if fp8: + if use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] + + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + if is_input_fp8: + QKV_quantizer = q._quantizer + q, k, v = q._data, k._data, v._data + else: + q_f16, k_f16, v_f16 = q, k, v + if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q = QKV_quantizer(q_f16)._data + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # partial result quantizer + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() + O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) + + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + if not fp8: + q_f16 = q + elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16 = q + q = QKV_quantizer(q_f16)._data + + assert qkv_format == "thd" or ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + if causal: + if qkv_format == "bshd": + # [b, s, np, hn] -> [b, 2, s//2, np, hn] + q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]] + elif qkv_format == "sbhd": + # [s, b, np, hn] -> [2, s//2, b, np, hn] + q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + if attn_bias is not None: + assert len(attn_bias.shape) == 4, ( + "Only support bias shape of [b, h, sq, sk] for forward, " + "and [1, h, sq, sk] for backward!" + ) + assert ( + attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" + # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + attn_bias_ = attn_bias.view( + *attn_bias.shape[:-2], + 2, + attn_bias.shape[-2] // 2, + 2 * cp_size, + attn_bias.shape[-1] // (2 * cp_size), + ) + # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] + attn_bias = attn_bias.view( + *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) + ) + assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" + + softmax_lse_in_packed_format = False + if qkv_format == "thd": + if use_fused_attention: + softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) + else: + softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 + + flash_attn_fwd = None + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd_v3, + ) + + flash_attn_fwd = ( + _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment + ) + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + else: + if qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_fwd, + ) + + flash_attn_fwd = _flash_attn_varlen_fwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd, + ) + + flash_attn_fwd = _flash_attn_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 + if fa_utils.v2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if fa_utils.v2_5_7_plus and qkv_format == "thd": + fa_forward_kwargs["block_table"] = None + if fa_utils.v2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 + + # Flash Attn inputs + q_inputs = [None, None] + kv_inputs = [None, None] + attn_bias_inputs = [None, None] + # Flash Attn outputs + out_per_step = [None for _ in range(cp_size)] + softmax_lse_per_step = [None for _ in range(cp_size)] + rng_states = [None for _ in range(cp_size)] + attn_biases = [None for _ in range(cp_size)] + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + # synchronize fwd results correction across steps + fwd_results_correction_done = torch.cuda.Event() + + + init_ok = torchrun_uid_init_bcast_object_no_reinit(cp_group) + p2p_comm_buffers = [None for _ in range(cp_size)] + if enable_mla: + # If MLA, the shape of k and v does not match, so we flatten them + # and split them after receiving them. + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + buffer = torch.cat((k.view(-1), v.view(-1)), dim=-1) + p2p_comm_buffers[0] = nvshmem.tensor(list(buffer.shape), dtype=buffer.dtype) + p2p_comm_buffers[0].copy_(buffer) + elif qkv_format in ["bshd", "sbhd"]: + # p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) + buffer = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) + p2p_comm_buffers[0] = nvshmem.tensor(list(buffer.shape), dtype=buffer.dtype) + p2p_comm_buffers[0].copy_(buffer) + else: # qkv_format == "thd" + # p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + buffer = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + p2p_comm_buffers[0] = nvshmem.tensor(list(buffer.shape), dtype=buffer.dtype) + p2p_comm_buffers[0].copy_(buffer) + send_recv_reqs = [[], []] + + # Initialize NVSHMEM backend and allocate symmetric KV storage and signal. + # try: + # tex.init_nvshmem_backend(cp_group) + # Create a symmetric NVSHMEM tensor to hold this rank's KV chunk. + # nvshmem_kv = tex.create_nvshmem_tensor(list(p2p_comm_buffers[0].shape), p2p_comm_buffers[0].dtype) + + + # except Exception: + # print("nvshmem init failed, fallback to p2p") + # # If NVSHMEM is not enabled or initialization fails, fall back to P2P comm + # nvshmem_kv = None + # nvshmem_signal = None + nvshmem_kv = nvshmem.tensor(list(p2p_comm_buffers[0].shape), dtype=p2p_comm_buffers[0].dtype) + + # copy local KV into symmetric heap so remote ranks can fetch it + nvshmem_kv.copy_(p2p_comm_buffers[0]) + # p2p_comm_buffers[0] = nvshmem.tensor(list(p2p_comm_buffers[0].shape), dtype=p2p_comm_buffers[0].dtype) + # p2p_comm_buffers[0].copy_(p2p_comm_buffers[0]) + # p2p_comm_buffers[1] = nvshmem.tensor(list(p2p_comm_buffers[0].shape), dtype=p2p_comm_buffers[0].dtype) + device = Device() + communicate_stream = device.create_stream() + out = None + for i in range(cp_size + 1): + if i < cp_size: + with torch.cuda.stream(flash_attn_streams[i % 2]): + # wait until KV is received + for req in send_recv_reqs[(i + 1) % 2]: + req.wait() + if nvshmem_kv is not None: + # If NVSHMEM is available, ensure the get is completed before using the buffer + communicate_stream.sync() + + if i < (cp_size - 1): + p2p_comm_buffers[i + 1] = nvshmem.tensor(list(p2p_comm_buffers[i].shape), dtype=p2p_comm_buffers[i].dtype) + if nvshmem_kv is not None: + # Use NVSHMEM get: compute owner of the (i+1)-th step KV block + owner_idx = (rank - (i + 1)) % cp_size + # Map owner idx to global rank (accounting for a2a groups) + owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] + # nvshmem_get: dst (local buffer), src (symmetric address), peer=owner_global + nvshmem_get_on_stream(p2p_comm_buffers[i + 1], nvshmem_kv, owner_global, stream=communicate_stream) + else: + # fallback to P2P if NVSHMEM not available + send_recv_reqs[i % 2] = flash_attn_p2p_communicate( + rank, + p2p_comm_buffers[i], + send_dst, + p2p_comm_buffers[i + 1], + recv_src, + cp_group, + batch_p2p_comm, + ) + + if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + kv_inputs[i % 2] = p2p_comm_buffers[i] + else: + # KV exchange is in BF16/FP16, cast received KV in each step + kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data + if enable_mla: + # If MLA, k and v are flattened, so split them after receiving. + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + if causal: + if i == 0: + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + q_inputs[i % 2] = q + if use_fused_attention: + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + + q_part = q_inputs[i % 2] + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + fake_dtype=qkv_dtype, + fused_attention_backend=fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=True, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + elif i <= rank: + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + False, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0] + elif qkv_format == "thd": + q_inputs[i % 2] = q + if enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor( + k_part, cu_seqlens_kv_padded, 0 + ) + v_part = tex.thd_read_half_tensor( + v_part, cu_seqlens_kv_padded, 0 + ) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 + ) + if use_fused_attention: + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() + + q_part = q_inputs[i % 2] + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None + if cu_seqlens_kv_padded is None + else cu_seqlens_kv_padded // 2 + ), + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv // 2, + ) + if use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=False, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + else: + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q_half + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_inputs[i % 2] = q[:, 1, ...] + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_inputs[i % 2] = q[1] + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) + if use_fused_attention: + q_inputs[i % 2] = q_inputs[i % 2].contiguous() + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + + q_part = q_inputs[i % 2] + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q // 2, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=( + None + if cu_seqlens_q_padded is None + else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q // 2, + max_seqlen_kv=max_seqlen_kv, + ) + if use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=False, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + else: + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if use_fused_attention: + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + + q_part = q + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=False, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + + if i > 0: + # wait until fwd restuls correction of last step is done + if i > 1: + flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) + + with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): + if use_fused_attention: + # [b, np, sq, 1] -> [b, np, sq] or + # [t, np, 1] -> [t, np] + softmax_lse_per_step[i - 1].squeeze_(-1) + if softmax_lse_in_packed_format: + softmax_lse_per_step[i - 1] = ( + softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() + ) + if fp8: + out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) + if i == 1: + softmax_lse = torch.clone(softmax_lse_per_step[0]) + if qkv_format == "thd": + if enable_mla: + out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( + v_shape + ) + else: + # MHA or GQA + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( + q.shape + ) + elif (i - 1) <= rank or not causal: + flash_attn_fwd_softmax_lse_correction( + softmax_lse, softmax_lse_per_step[i - 1] + ) + else: + if qkv_format == "thd": + tex.thd_second_half_lse_correction( + softmax_lse, + softmax_lse_per_step[i - 1], + cu_seqlens_q_padded, + softmax_lse_in_packed_format, + ) + else: + flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), + softmax_lse_per_step[i - 1], + ) + + if i < cp_size: + flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) + + torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + + second_half_lse_seqlen = None + if causal and rank < (cp_size - 1): + second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + + for i in range(cp_size): + if i <= rank or not causal: + if qkv_format in ["bshd", "sbhd"]: + if i == 0: + out = flash_attn_fwd_out_correction_init( + out_per_step[0], + softmax_lse, + softmax_lse_per_step[0], + seq_dim, + ) + if enable_mla: + out = out.view(v_shape) + else: + out = out.view(q.shape) + else: + flash_attn_fwd_out_correction( + out.view(*out_per_step[i].shape), + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + seq_dim, + ) + elif qkv_format == "thd": + tex.thd_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) + else: + if qkv_format in ["bshd", "sbhd"]: + flash_attn_fwd_second_half_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + seq_dim, + ) + elif qkv_format == "thd": + tex.thd_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q_padded, + True, + softmax_lse_in_packed_format, + ) + + kv = p2p_comm_buffers[-1] + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + ctx.batch_size = out.shape[0] + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + ctx.batch_size = out.shape[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + ) + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + elif not use_fused_attention: + out = out.view(-1, *out.shape[-2:]) + + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + S_quantizer.amax.copy_(amax_cp_fwd[0]) + O_CP_quantizer.amax.copy_(amax_cp_fwd[1]) + + out_fp8 = None + out_f16 = out.to(qkv_dtype) + + if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_fp8 = O_quantizer(out_f16) # final result + + out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, kv_save, out_save = q, kv, out_fp8._data + elif fp8 and is_input_fp8: + q_save, kv_save, out_save = q, kv, out_f16 + else: + q_f16 = q_f16.view(q.shape) + q_save, kv_save, out_save = q_f16, kv, out_f16 + + tensors_to_save, tensor_objects = prepare_for_saving( + q_save, + kv_save, + out_save, + softmax_lse, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *cu_seqlens_q_per_step, + *cu_seqlens_kv_per_step, + *rng_states, + *attn_biases, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.cp_group_a2a = cp_group_a2a + ctx.cp_size_a2a = cp_size_a2a + ctx.rank_a2a = rank_a2a + ctx.cp_group = cp_group + ctx.cp_global_ranks = cp_global_ranks + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape + ctx.deterministic = deterministic + ctx.use_fused_attention = use_fused_attention + ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format + ctx.second_half_lse_seqlen = second_half_lse_seqlen + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 + ctx.use_flash_attn_3 = use_flash_attn_3 + + ctx.enable_mla = enable_mla + if enable_mla: + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape + + ctx.qkv_dtype = qkv_dtype + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dQKV_CP_quantizer = dQKV_CP_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer = O_quantizer.copy() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + + # free up some nvshmem buffers + nvshmem.free_tensor(nvshmem_kv) + for i in range(cp_size): + nvshmem.free_tensor(p2p_comm_buffers[i]) + + return out_ret + + @staticmethod + def backward(ctx, dout): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + cp_size_a2a = ctx.cp_size_a2a + rank_a2a = ctx.rank_a2a + + cp_size = get_distributed_world_size(ctx.cp_group) + rank = get_distributed_rank(ctx.cp_group) + send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] + device_compute_capability = get_device_compute_capability() + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or ( + device_compute_capability < (10, 0) and cp_size == 2 + ) + + q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( + restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + ) + cu_seqlens_q_per_step = other_tensors[:cp_size] + cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] + rng_states = other_tensors[cp_size * 2 : cp_size * 3] + attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + + causal = "causal" in ctx.attn_mask_type + padding = "padding" in ctx.attn_mask_type + + seq_dim = None + if ctx.qkv_format in ["bshd", "sbhd"]: + seq_dim = ctx.qkv_format.index("s") + if ctx.enable_mla: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + + if attn_biases[0] is not None: + # [b, np, sq, 2*cp, sk//(2*cp)] + attn_dbias = torch.zeros( + *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device + ) + # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + attn_dbias_ = attn_dbias.view( + *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] + ) + else: + attn_dbias = None + attn_dbias_ = None + + softmax_lse_ = None + if causal and ctx.second_half_lse_seqlen is not None: + if ctx.qkv_format == "thd": + softmax_lse_ = tex.thd_read_second_half_lse( + softmax_lse, + cu_seqlens_q_padded, + ctx.softmax_lse_in_packed_format, + ctx.second_half_lse_seqlen, + ) + else: + # [b, np, sq] -> [b, np, 2, sq//2] + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) + softmax_lse_ = softmax_lse_[..., 1, :].contiguous() + if ctx.use_fused_attention: + if ctx.softmax_lse_in_packed_format: + softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() + # [b, np, sq//2] -> [b, np, sq//2, 1] or + # [t//2, np] -> [t//2, np, 1] + softmax_lse_.unsqueeze_(-1) + if ctx.use_fused_attention: + if ctx.softmax_lse_in_packed_format: + softmax_lse = softmax_lse.transpose(0, 1).contiguous() + # [b, np, sq] -> [b, np, sq, 1] or + # [t, np] -> [t, np, 1] + softmax_lse.unsqueeze_(-1) + dout = dout.contiguous() + + dq = None + dout_dtype = dout.dtype + fused_attn_backend = None + fused_attn_dqkv_dtype = None + amax_per_step = None + dP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] + if ctx.fp8: + if ctx.use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] + + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.dO_quantizer = dout._quantizer + else: + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device) + dkv_fp8 = torch.empty( + (cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device + ) + dkv_fp8_ = torch.empty_like(dkv_fp8) + p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] + dout = dout._data + fp8_meta_kwargs = {} + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() + dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None: + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + kv = ctx.QKV_quantizer.create_tensor_from_data( + kv, fake_dtype=ctx.qkv_dtype, internal=True + ) + q = q.dequantize(dtype=ctx.qkv_dtype) + kv = kv.dequantize(dtype=ctx.qkv_dtype) + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + if cp_size_a2a == 1: + dout = dout.dequantize(dtype=dout_dtype) + else: + ctx.dO_quantizer = dout._quantizer + dout = dout._data + dq = torch.empty_like(q) + # p2p_comm_buffers = [ + # torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + # torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + # ] + # p2p_comm_buffers[0][0].copy_(kv) + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_dqkv_dtype = TE_DType[dout_dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if cp_size_a2a > 1: + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( + cp_size_a2a, out.device + ) + out, dout = flash_attn_a2a_communicate( + [out, dout], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + True, + ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + dout = dout.dequantize(dtype=dout_dtype) + + if ctx.enable_mla: + out = out.view(*ctx.v_shape) + dout = dout.view(*ctx.v_shape) + else: + # MHA or GQA + out = out.view(*q.shape) + dout = dout.view(*q.shape) + send_recv_reqs = [] + + flash_attn_bwd = None + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if ctx.use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd_v3, + ) + + flash_attn_bwd = ( + _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + ) + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + if ctx.qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_bwd, + ) + + flash_attn_bwd = _flash_attn_varlen_bwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd, + ) + + flash_attn_bwd = _flash_attn_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if fa_utils.v2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if fa_utils.v2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic + if fa_utils.v2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 + + p2p_comm_buffers = [ + nvshmem.tensor((2, *kv.shape), dtype=kv.dtype), + nvshmem.tensor((2, *kv.shape), dtype=kv.dtype), + ] + + # 复制数据到 NVSHMEM 缓冲区 + p2p_comm_buffers[0][0].copy_(kv) + + send_tensor = p2p_comm_buffers[i % 2] + recv_tensor = p2p_comm_buffers[(i + 1) % 2] + device = Device() + communicate_stream = device.create_stream() + + def flash_attn_p2p_communicate_nvshmem( + rank, send_tensor, send_dst, recv_tensor, recv_src, stream + ): + """NVSHMEM版本的双向P2P通信""" + + # 根据rank顺序避免死锁 + if rank % 2 == 0: + # 偶数rank:先发送后接收 + nvshmem.core.put( + recv_tensor, # 目标:远程PE的接收缓冲区 + send_tensor, # 源:本地发送数据 + remote_pe=send_dst, + stream=stream + ) + + # 接收数据 + nvshmem.core.get( + recv_tensor, # 目标:本地接收缓冲区 + send_tensor, # 源:远程PE的发送数据 + remote_pe=recv_src, + stream=stream + ) + else: + # 奇数rank:先接收后发送 + nvshmem.core.get( + recv_tensor, # 目标:本地接收缓冲区 + send_tensor, # 源:远程PE的发送数据 + remote_pe=recv_src, + stream=stream + ) + + # 发送数据 + nvshmem.core.put( + recv_tensor, # 目标:远程PE的接收缓冲区 + send_tensor, # 源:本地发送数据 + remote_pe=send_dst, + stream=stream + ) + + for i in range(cp_size): + # wait until KV is received + # for req in send_recv_reqs: + # req.wait() + communicate_stream.sync() + + send_tensor = p2p_comm_buffers[i % 2] + recv_tensor = p2p_comm_buffers[(i + 1) % 2] + if ctx.fp8: + if i < cp_size - 1: + # if nvshmem_kv is not None: + # # owner of the next KV block + # owner_idx = (rank - (i + 1)) % cp_size + # owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] + # tex.nvshmem_get_on_current_stream(recv_tensor[0], nvshmem_kv, int(owner_global)) + # send_recv_reqs = [] + # else: + # send_recv_reqs = flash_attn_p2p_communicate( + # rank, + # send_tensor[0], + # send_dst, + # recv_tensor[0], + # recv_src, + # ctx.cp_group, + # batch_p2p_comm, + # ) + flash_attn_p2p_communicate_nvshmem( + rank, send_buffer, send_dst, recv_buffer, recv_src, communicate_stream + ) + else: + dkv_a2a_req = torch.distributed.all_to_all_single( + dkv_fp8, + dkv_fp8_, + group=ctx.cp_group, + async_op=True, + ) + send_recv_reqs = [dkv_a2a_req] + else: + if i == 0: + send_tensor = send_tensor[0] + recv_tensor = recv_tensor[0] + if i == (cp_size - 1): + send_tensor = send_tensor[1] + recv_tensor = recv_tensor[1] + # if nvshmem_kv is not None: + # owner_idx = (rank - (i + 1)) % cp_size + # owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] + # tex.nvshmem_get_on_current_stream(recv_tensor, nvshmem_kv, int(owner_global)) + # send_recv_reqs = [] + # else: + # send_recv_reqs = flash_attn_p2p_communicate( + # rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm + # ) + flash_attn_p2p_communicate_nvshmem( + rank, send_buffer, send_dst, recv_buffer, recv_src, communicate_stream + ) + + kv = p2p_comm_buffers[i % 2][0] + q_, kv_, out_, dout_ = None, None, None, None + dq_, dk_, dv_ = None, None, None + if ctx.enable_mla: + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) + # In reversed order of fwd + if causal: + if i == (cp_size - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + q_, kv_, out_, dout_ = q, kv, out, dout + if ctx.use_fused_attention: + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: + dq_ = torch.empty_like(q_) + dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = (-1, 0) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout_, + q_, + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + out_, + softmax_lse, + *fa_backward_args_thd, + causal=True, + **fa_backward_kwargs, + ) + elif i >= (cp_size - rank - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0] + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0] + elif ctx.qkv_format == "thd": + q_, out_, dout_ = q, out, dout + if ctx.enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + if ctx.use_fused_attention: + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_ = kv_.contiguous() + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv // 2, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 + ), + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: + dq_ = torch.empty_like(q_) + dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv // 2, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout_, + q_, + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + out_, + softmax_lse, + *fa_backward_args_thd, + causal=False, + **fa_backward_kwargs, + ) + else: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_, out_, dout_ = q[1], out[1], dout[1] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_, out_, dout_ = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q, out, dout] + ] + kv_ = kv + if ctx.use_fused_attention: + q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + + q_part = q_ + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q // 2, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=( + None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: + dq_ = torch.empty_like(q_) + dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q // 2, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout_, + q_, + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + out_, + softmax_lse_, + *fa_backward_args_thd, + causal=False, + **fa_backward_kwargs, + ) + else: + if ctx.use_fused_attention: + if ctx.fp8: + aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q + if not ctx.enable_mla: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + + else: + dq_ = torch.empty_like(q) + dkv_ = torch.empty_like(kv) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + ) + if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout, + q, + kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], + kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + out, + softmax_lse, + *fa_backward_args_thd, + causal=False, + **fa_backward_kwargs, + ) + + if ctx.fp8: + dq = dq_fp8[(rank + i + 1) % cp_size] + if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): + # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or + # [sq, b, np, hn] -> [2, sq//2, b, np, hn] + dq_ = dq_.view(*dq.shape) + + if ctx.fp8: + if i >= (cp_size - rank - 1) or not causal: + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].fill_(0) + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[0].fill_(0) + dq[1].copy_(dq_) + elif causal: + if i > (cp_size - rank - 1): + dq.add_(dq_) + elif i == (cp_size - rank - 1): + if rank == (cp_size - 1): + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].copy_(dq_[:, 0, ...]) + dq[:, 1, ...].add_(dq_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dq[0].copy_(dq_[0]) + dq[1].add_(dq_[1]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add") + elif i > 0: + if ctx.qkv_format == "bshd": + dq[:, 1, ...].add_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].add_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add") + else: + if ctx.qkv_format == "bshd": + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].copy_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy") + else: + if i == 0: + dq.copy_(dq_) + else: + dq.add_(dq_) + + if attn_dbias is not None: + idx = (rank + i + 1) % cp_size + if i == (cp_size - 1) or not causal: + # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) + attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) + attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) + elif i >= (cp_size - rank - 1): + # [b, np, sq, sk//(2*cp)] + attn_dbias[..., idx, :].copy_(dbias_) + else: + # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) + attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) + attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) + + # wait until dKV is received + for req in send_recv_reqs: + req.wait() + + if ctx.fp8: + if i < cp_size - 1: + dkv = dkv_fp8_[(rank + i + 1) % cp_size] + else: + dkv = dkv_fp8[(rank + i + 1) % cp_size] + else: + dkv = p2p_comm_buffers[(i + 1) % 2][1] + if ctx.use_fused_attention: + if ctx.enable_mla: + dkv_ = None + elif ctx.qkv_format in ["bshd", "sbhd"]: + dkv_ = combine_tensors([dk_, dv_], -2) + elif ctx.qkv_format == "thd": + dkv_ = torch.cat( + (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 + ) # pylint: disable=used-before-assignment + if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. + dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) + dkv_ = dkv_.movedim(-3, 0) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv_ = dkv_.view(*dkv.shape) + + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] or + # [2, sk//2, b, np, hn] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8: + # enable_mla and fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # enable_mla and not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "copy" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dk.add_(dk_) + dv.add_(dv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "copy", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # enable_mla and not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + else: + if ctx.fp8: + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + dkv[:, 1, ...].fill_(0) + else: + dkv.copy_(dkv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) + dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_[:, 0, ...]) + dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dkv.add_(dkv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dkv.add_(dkv_) + else: # i == 0 + dkv.copy_(dkv_) + else: + # not fp8 and not causal + if i == 0: + dkv.copy_(dkv_) + else: # i > 0 + dkv.add_(dkv_) + + if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) + dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dq_fp8, fake_dtype=torch.float32, internal=True + ) + + if ctx.enable_mla: + # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] + dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) + dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dk_fp8, fake_dtype=torch.float32, internal=True + ) + dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] + dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] + else: + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dkv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + + if causal: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) + dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) + else: + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + dq = dq.view(-1, *dq.shape[-3:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + else: + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + + if ctx.qkv_format == "thd" and not ctx.use_fused_attention: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + if ctx.enable_mla: + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) + else: + dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + + if ctx.fp8 and ctx.is_input_fp8: + assert torch.uint8 not in [dq.dtype, dkv.dtype] + if ctx.enable_mla: + dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] + else: + dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] + if not ctx.enable_mla: + dk, dv = dkv[0], dkv[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + False, + ) + if ctx.qkv_format == "bshd": + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + elif ctx.qkv_format == "sbhd": + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if attn_dbias is not None: + # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + # converting torch.uint8 to float8tensor + if ctx.fp8 and ctx.is_input_fp8: + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + attn_dbias, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def get_kv_seq_info_after_all_gather( + local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size, causal +): + """Compute KV sequence index range and update window size after all-gather.""" + local_chunk_end_idx = (local_chunk_id + 1) * max_seqlen_kv + full_seq_end_idx = max_seqlen_kv * cp_size * 2 + + if window_size is None: + window_size = (-1, 0) if causal else (-1, -1) + + if window_size[1] == -1: + seq_end_idx = full_seq_end_idx + window_size_right = -1 + else: + seq_end_idx = min(full_seq_end_idx, local_chunk_end_idx + window_size[1]) + window_size_right = local_chunk_end_idx + window_size[1] - seq_end_idx + + if window_size[0] == -1: + seq_start_idx = 0 + window_size_left = -1 + else: + seq_start_idx = max(0, local_chunk_end_idx - max_seqlen_q - window_size[0]) + window_size_left = window_size[0] + seq_end_idx - local_chunk_end_idx + + return (seq_start_idx, seq_end_idx), (window_size_left, window_size_right) + + +class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): + """ + Attention implementation with context parallelism. KV all-gather between CP ranks is exposed. + Refer section 3.3.2 of `The Llama 3 Herd of Models `_. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + cp_group, + cp_stream, + use_flash_attn_3, + ): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + cp_size = get_distributed_world_size(cp_group) + rank = get_distributed_rank(cp_group) + + qkv_dtype = q.dtype + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + assert not padding, f"{attn_mask_type} mask type is not supported!" + if use_fused_attention and causal and "bottom_right" not in attn_mask_type: + attn_mask_type = attn_mask_type + "_bottom_right" + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" + assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert ( + use_fused_attention or fa_utils.v2_3_plus + ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + + flash_attn_fwd = None + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd_v3, + ) + + flash_attn_fwd = _flash_attn_fwd_v3 + else: + if qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_fwd, + ) + + flash_attn_fwd = _flash_attn_varlen_fwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd, + ) + + flash_attn_fwd = _flash_attn_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if fa_utils.v2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if fa_utils.v2_5_7_plus and qkv_format == "thd": + fa_forward_kwargs["block_table"] = None + if fa_utils.v2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 + + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + max_seqlen_q = max_seqlen_q // (2 * cp_size) + max_seqlen_kv = max_seqlen_kv // (2 * cp_size) + if use_fused_attention or qkv_format == "thd": + cu_seqlens_q = cu_seqlens_q // (2 * cp_size) + if cu_seqlens_q_padded is not None and qkv_format == "thd": + cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + else: + cu_seqlens_q_padded = None + + # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + + # [s, b, np, hn] -> [cp, s, b, np, hn] + k_ag, _ = gather_along_first_dim(k, cp_group) + v_ag, _ = gather_along_first_dim(v, cp_group) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + cp_stream.wait_stream(torch.cuda.current_stream()) + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] + kv_seq_range_per_step = [None, None] + window_size_per_step = [None, None] + cu_seqlens_kv_per_step = [None, None] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] + rng_states = [None, None] + out = torch.empty_like(q) + + for i in range(len(local_seq_chunk_ids) + 1): + if i < len(local_seq_chunk_ids): + with torch.cuda.stream(flash_attn_streams[i]): + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, + ) + ) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv_ = seq_end_idx - seq_start_idx + if use_fused_attention or qkv_format == "thd": + cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + if use_fused_attention: + out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv_, + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + q_, + k_, + v_, + qkv_dtype, + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + window_size=window_size_per_step[i], + ) + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv_, + ) + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_forward_kwargs["window_size"] = window_size_per_step[i] + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] + fa_outputs = flash_attn_fwd( + q_, + k_, + v_, + *fa_forward_args_thd, + causal=causal, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + + if i > 0: + with torch.cuda.stream(flash_attn_streams[i - 1]): + if qkv_format == "bshd": + out[:, i - 1].copy_(out_per_step[i - 1]) + elif qkv_format == "sbhd": + out[i - 1].copy_(out_per_step[i - 1]) + + torch.cuda.current_stream().wait_stream(cp_stream) + + if use_fused_attention: + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + else: + out = out.view(-1, *out.shape[-2:]) + + ctx.save_for_backward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_q_padded, + *cu_seqlens_kv_per_step, + *out_per_step, + *softmax_lse_per_step, + *rng_states, + ) + + ctx.qkv_dtype = qkv_dtype + ctx.kv_seq_range_per_step = kv_seq_range_per_step + ctx.window_size_per_step = window_size_per_step + ctx.cp_group = cp_group + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_bias_type = attn_bias_type + ctx.attn_mask_type = attn_mask_type + ctx.deterministic = deterministic + ctx.use_fused_attention = use_fused_attention + ctx.use_flash_attn_3 = use_flash_attn_3 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") + return out + + @staticmethod + def backward(ctx, dout): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") + cp_size = get_distributed_world_size(ctx.cp_group) + rank = get_distributed_rank(ctx.cp_group) + + (*saved_tensors,) = ctx.saved_tensors + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + cu_seqlens_kv_per_step = saved_tensors[5:7] + out_per_step = saved_tensors[7:9] + softmax_lse_per_step = saved_tensors[9:11] + rng_states = saved_tensors[11:13] + kv_seq_range_per_step = ctx.kv_seq_range_per_step + window_size_per_step = ctx.window_size_per_step + + seq_dim = ctx.qkv_format.index("s") + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + + dout = dout.view(q.shape) + dq = torch.empty_like(q) + dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) + dv = torch.zeros_like(dk) + dq_per_step = [None, None] + dk_per_step = [None, None] + dv_per_step = [None, None] + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), ctx.cp_stream] + # synchronize dkv update across steps + dkv_update_done = torch.cuda.Event() + + # [s, b, np, hn] -> [cp, s, b, np, hn] + k_ag, _ = gather_along_first_dim(k, ctx.cp_group) + v_ag, _ = gather_along_first_dim(v, ctx.cp_group) + + # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + ctx.cp_stream.wait_stream(torch.cuda.current_stream()) + + local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] + + flash_attn_bwd = None + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if ctx.use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd_v3, + ) + + flash_attn_bwd = _flash_attn_bwd_v3 + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + if ctx.qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_bwd, + ) + + flash_attn_bwd = _flash_attn_varlen_bwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd, + ) + + flash_attn_bwd = _flash_attn_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if fa_utils.v2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if fa_utils.v2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic + if fa_utils.v2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 + + for i in range(len(local_seq_chunk_ids) + 1): + if i < len(local_seq_chunk_ids): + with torch.cuda.stream(flash_attn_streams[i]): + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q.select(seq_dim, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + out_ = out_per_step[i] + dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + if ctx.use_fused_attention: + aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] + dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( + ctx.max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv_per_step[i], + q_, + k_, + v_, + out_, + dout_, + ctx.qkv_dtype, + TE_DType[dout.dtype], + aux_ctx_tensors, + tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + window_size=window_size_per_step[i], + deterministic=ctx.deterministic, + ) + else: + dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ + torch.empty_like(x) for x in [q_, k_, v_] + ] + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + dq=dq_per_step[i], + dk=dk_per_step[i], + dv=dv_per_step[i], + ) + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[i] + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = window_size_per_step[i] + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] + flash_attn_bwd( + dout_, + q_, + k_, + v_, + out_, + softmax_lse_per_step[i], + *fa_backward_args_thd, + causal="causal" in ctx.attn_mask_type, + **fa_backward_kwargs, + ) + + if i > 0: + with torch.cuda.stream(flash_attn_streams[i - 1]): + if ctx.qkv_format == "bshd": + dq[:, i - 1].copy_(dq_per_step[i - 1]) + elif ctx.qkv_format == "sbhd": + dq[i - 1].copy_(dq_per_step[i - 1]) + # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] + # wait until dkv update of last step is done + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) + + torch.cuda.current_stream().wait_stream(ctx.cp_stream) + + # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + + dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) + dk = dk.movedim(0, seq_dim).contiguous() + dv = dv.movedim(0, seq_dim).contiguous() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): + """ + Attention implementation with context parallelism. Like Ulysses, applying A2A to QKVO. + Refer the paper `DeepSpeed Ulysses `_. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + quantizers, + use_flash_attn_3, + ): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + cp_size = get_distributed_world_size(cp_group) + qkv_dtype = q.dtype + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + assert not padding, f"{attn_mask_type} mask type is not supported!" + assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" + assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" + assert ( + window_size == (-1, 0) + or window_size == (-1, -1) + or use_fused_attention + or fa_utils.v2_3_plus + ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" + + flash_attn_fwd = None + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd_v3, + ) + + flash_attn_fwd = _flash_attn_fwd_v3 + fa_forward_kwargs["window_size"] = window_size + else: + if qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_fwd, + ) + + flash_attn_fwd = _flash_attn_varlen_fwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd, + ) + + flash_attn_fwd = _flash_attn_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size"] = window_size + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size[0] + fa_forward_kwargs["window_size_right"] = window_size[1] + if fa_utils.v2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if fa_utils.v2_5_7_plus and qkv_format == "thd": + fa_forward_kwargs["block_table"] = None + if fa_utils.v2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 + + assert ( + q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 + ), "The number of attention heads needs to be divisible by CP size!" + + assert qkv_format != "thd", f"{qkv_format} format is not supported!" + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + + batch_dim = qkv_format.index("b") + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + + fused_attn_backend = None + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype + is_input_fp8 = False + is_output_fp8 = False + + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + ) + if fp8: + if use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + if is_input_fp8: + QKV_quantizer = q._quantizer + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] + fp8_meta_kwargs = {} + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True + ) + + if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16, k_f16, v_f16 = q, k, v + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] + + batch_size = q.shape[batch_dim] + if use_fused_attention: + q_part, k_part, v_part = q, k, v + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v, fake_dtype=qkv_dtype, internal=True + ) + out, aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_part, + k_part, + v_part, + qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + window_size=window_size, + **fp8_meta_kwargs, + ) + if fp8: + out = out._data + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) + fa_outputs = flash_attn_fwd( + q, + k, + v, + *fa_forward_args_thd, + causal=causal, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out, softmax_lse = fa_outputs[4], fa_outputs[5] + rng_state = fa_outputs[7] if not use_flash_attn_3 else None + else: + out, softmax_lse = fa_outputs[0], fa_outputs[1] + rng_state = fa_outputs[3] if not use_flash_attn_3 else None + aux_ctx_tensors = [softmax_lse, rng_state] + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + ) + + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, batch_size, *out.shape[-2:]) + + if fp8: + if is_output_fp8: + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=False + ) + out_ret = out_fp8 + out = out_fp8._data + else: + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=True + ) + out_f16 = out_fp8.dequantize(dtype=qkv_dtype) + out_ret = out_f16 + else: + out_ret = out + + if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, k_save, v_save, out_save = q, k, v, out + else: + if is_input_fp8: + q_save, k_save, v_save = q, k, v + else: + q_save, k_save, v_save = q_f16, k_f16, v_f16 + if is_output_fp8: + out_save = out + else: + out_save = out_f16 + + tensors_to_save, tensor_objects = prepare_for_saving( + q_save, + k_save, + v_save, + out_save, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *aux_ctx_tensors, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.batch_size = batch_size + ctx.cp_group = cp_group + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.deterministic = deterministic + ctx.window_size = window_size + ctx.use_fused_attention = use_fused_attention + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 + ctx.use_flash_attn_3 = use_flash_attn_3 + + ctx.qkv_dtype = qkv_dtype + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer = O_quantizer.copy() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") + return out_ret + + @staticmethod + def backward(ctx, dout): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") + cp_size = get_distributed_world_size(ctx.cp_group) + + ( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *aux_ctx_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + causal = "causal" in ctx.attn_mask_type + seq_dim = ctx.qkv_format.index("s") + + dout_dtype = dout.dtype + fused_attn_backend = None + fused_attn_dqkv_dtype = None + if ctx.fp8: + if ctx.use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.dO_quantizer = dout._quantizer + else: + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] + dout = dout._data + fp8_meta_kwargs = {} + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None: + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.dO_quantizer = dout._quantizer + dout = dout._data + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + k = ctx.QKV_quantizer.create_tensor_from_data( + k, fake_dtype=ctx.qkv_dtype, internal=True + ) + v = ctx.QKV_quantizer.create_tensor_from_data( + v, fake_dtype=ctx.qkv_dtype, internal=True + ) + q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_dqkv_dtype = TE_DType[dout_dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device) + out, dout = flash_attn_a2a_communicate( + [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: + out = ctx.O_quantizer.create_tensor_from_data( + out, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + out = out.dequantize(dtype=ctx.qkv_dtype) + dout = dout.dequantize(dtype=dout_dtype) + + flash_attn_bwd = None + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if ctx.use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd_v3, + ) + + flash_attn_bwd = ( + _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + ) + fa_backward_kwargs["window_size"] = ctx.window_size + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + if ctx.qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_bwd, + ) + + flash_attn_bwd = _flash_attn_varlen_bwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd, + ) + + flash_attn_bwd = _flash_attn_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size"] = ctx.window_size + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = ctx.window_size[0] + fa_backward_kwargs["window_size_right"] = ctx.window_size[1] + if fa_utils.v2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if fa_utils.v2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic + if fa_utils.v2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 + + if ctx.use_fused_attention: + q_part = q + k_part = k + v_part = v + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + + dq, dk, dv, _ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + window_size=ctx.window_size, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + if ctx.fp8: + dq = dq._data + dk = dk._data + dv = dv._data + else: + softmax_lse, rng_state = aux_ctx_tensors + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq, + dk=dk, + dv=dv, + ) + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_state + flash_attn_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + *fa_backward_args_thd, + causal=causal, + **fa_backward_kwargs, + ) + + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + + if ctx.qkv_format == "bshd": + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + elif ctx.qkv_format == "sbhd": + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if ctx.fp8: + dq = ctx.dQKV_quantizer.create_tensor_from_data( + dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dk = ctx.dQKV_quantizer.create_tensor_from_data( + dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dv = ctx.dQKV_quantizer.create_tensor_from_data( + dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + if not ctx.is_input_fp8: + dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def attn_forward_func_with_cp( + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + cp_comm_type, + softmax_scale=None, + qkv_format="bshd", + attn_mask_type="causal", + attn_bias_type="no_bias", + attn_bias=None, + deterministic=False, + use_fused_attention=False, + window_size=None, + fp8=False, + fp8_meta=None, + quantizers=None, + pad_between_seqs=False, + use_flash_attn_3=False, +) -> torch.Tensor: + """ + Attention implementation with context parallelism (CP). CP partitions tensors along the sequence + dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context + LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes + the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s + and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by + (cp_size * 2). It also requires tokens to be re-ordered before entering this function. + + For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example + use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position + in their corresponding sequence. + + GPU0 | GPU1 GPU0 | GPU1 + seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8 + ---------------------------|----------------- ---------------------------|------------------ + 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1, + 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1, + 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, + ---------------------------|----------------- ---------------------------|------------------ + 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0, + 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1, + + For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different + lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of + every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of + batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and + cp_size = 2. + + GPU0 | GPU1 GPU0 | GPU1 + seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1 + seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2 + ---------------------------|----------------- ---------------------------|------------------ + 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0, + 0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0, + 0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2, + ---------------------------|----------------- ---------------------------|------------------ + 0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0, + 1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2, + + When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks, + cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for + all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank + `_ + in Megatron-LM. + + """ + + if cp_comm_type == "a2a+p2p": + assert isinstance( + cp_group, list + ), "Hierarchical CP implementation needs multi-level CP groups!" + assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + if get_distributed_world_size(cp_group[0]) == 1: + cp_group = cp_group[1] + cp_comm_type = "p2p" + elif get_distributed_world_size(cp_group[1]) == 1: + cp_group = cp_group[0] + cp_comm_type = "a2a" + else: + assert isinstance( + cp_group, dist_group_type + ), f"Unsupported process group for CP communication type {cp_comm_type}!" + + assert qkv_format in [ + "bshd", + "sbhd", + "thd", + ], f"QKV format of {qkv_format} is not supported with context parallelism!" + assert ( + qkv_format != "sbhd" or use_fused_attention + ), "FlashAttention does not support sbhd format!" + assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( + """Attention bias is only supported with FusedAttention and "causal" """ + """or "no_mask" mask types!""" + ) + assert qkv_format != "thd" or ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" + + sliding_window_attn = ( + window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) + ) + assert not sliding_window_attn or cp_comm_type in [ + "a2a", + "all_gather", + ], "The context parallel running configs cannot support sliding window attetnion!" + + enable_mla = k.shape[-1] != v.shape[-1] + assert not enable_mla or cp_comm_type in [ + "p2p", + "a2a+p2p", + ], "The context parallel running configs cannot support MLA!" + + args = [ + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + ] + + if cp_comm_type in ["p2p", "a2a+p2p"]: + args += [ + fp8, + fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, + quantizers, + pad_between_seqs, + use_flash_attn_3, + ] + out = AttnFuncWithCPAndKVP2P.apply(*args) + elif cp_comm_type == "all_gather": + args.pop(5) + args.pop(8) + args += [window_size, cp_group, cp_stream, use_flash_attn_3] + out = AttnFuncWithCPAndKVAllGather.apply(*args) + elif cp_comm_type == "a2a": + args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] + out = AttnFuncWithCPAndQKVOA2A.apply(*args) + else: + raise ValueError(f"Unsupported communication type: {cp_comm_type}!") + + return out diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem_enhanced.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem_enhanced.py new file mode 100644 index 0000000000..80fb939722 --- /dev/null +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem_enhanced.py @@ -0,0 +1,3223 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Context Parallelism.""" +import os +from typing import List, Union +import torch +import transformer_engine_torch as tex +import nvshmem.core as nvshmem +import torch.distributed as dist +from cuda.core.experimental import Device +from cuda.core.experimental import Stream +import numpy as np +import cupy as cp +import torch.cuda.nvtx as nvtx +from transformer_engine.pytorch.utils import ( + combine_tensors, + get_cudnn_version, + nvtx_range_pop, + nvtx_range_push, + get_device_compute_capability, +) +from transformer_engine.pytorch.cpp_extensions.fused_attn import ( + fused_attn_fwd, + fused_attn_bwd, + FusedAttnBackend, +) +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.jit import jit_fuser +from transformer_engine.pytorch.constants import ( + dist_group_type, + TE_DType, +) +from transformer_engine.pytorch.distributed import ( + get_distributed_world_size, + get_distributed_rank, + gather_along_first_dim, + reduce_scatter_along_first_dim, +) +from transformer_engine.pytorch.tensor.quantized_tensor import ( + prepare_for_saving, + restore_from_saved, +) + +# Import attention utils +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + FlashAttentionUtils as fa_utils, +) + +_cu_seqlens_info_with_cp_cache = {} +_seq_chunk_ids_cache_for_reordering_before_attn = {} +_seq_chunk_ids_cache_for_reordering_after_attn = {} + + +def flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm +): + """Point-to-point communications of KV and dKV in Attention with context parallelism""" + send_recv_ops = [] + + if batch_p2p_comm: + if rank % 2 == 0: + send_op = torch.distributed.P2POp( + torch.distributed.isend, send_tensor, send_dst, cp_group + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, recv_src, cp_group + ) + send_recv_ops.append(send_op) + send_recv_ops.append(recv_op) + else: + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, recv_tensor, recv_src, cp_group + ) + send_op = torch.distributed.P2POp( + torch.distributed.isend, send_tensor, send_dst, cp_group + ) + send_recv_ops.append(recv_op) + send_recv_ops.append(send_op) + send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops) + else: + if rank % 2 == 0: + send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) + recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) + send_recv_ops.append(send_op) + send_recv_ops.append(recv_op) + else: + recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) + send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) + send_recv_ops.append(recv_op) + send_recv_ops.append(send_op) + send_recv_reqs = send_recv_ops + + return send_recv_reqs + +def nvshmem_get_on_stream(dst_tensor: torch.Tensor, src_tensor: torch.Tensor, + peer: int, stream: Stream = None) -> None: + if stream is None: + stream = Stream.from_handle(torch.cuda.current_stream().cuda_stream) + + nvshmem.get(dst_tensor, src_tensor, remote_pe=peer, stream=stream) + +def torchrun_uid_init_bcast_object_no_reinit(cp_group=None): + local_rank = torch.cuda.current_device() + dev = Device(local_rank) + dev.set_current() + + if cp_group is None: + rank_id = dist.get_rank() + num_ranks = dist.get_world_size() + else: + rank_id = dist.get_rank(group=cp_group) + num_ranks = dist.get_world_size(group=cp_group) + + uniqueid = nvshmem.get_unique_id(empty=True) + + if rank_id == 0: + uniqueid = nvshmem.get_unique_id() + broadcast_objects = [uniqueid] + else: + broadcast_objects = [None] + + dist.broadcast_object_list( + broadcast_objects, + src=0, + group=cp_group + ) + + dist.barrier(group=cp_group) + + nvshmem.init( + device=dev, + uid=broadcast_objects[0], + rank=rank_id, + nranks=num_ranks, + initializer_method="uid" + ) + + return True + +def torchrun_uid_init_bcast_object_no_reinit(cp_group=None): + # 设备已经在外面 set 过最好 + local_rank = torch.cuda.current_device() + dev = Device(local_rank) + dev.set_current() + + # 不要再 init_process_group !!! + + if cp_group is None: + rank_id = dist.get_rank() + num_ranks = dist.get_world_size() + else: + rank_id = dist.get_rank(group=cp_group) + num_ranks = dist.get_world_size(group=cp_group) + + uniqueid = nvshmem.get_unique_id(empty=True) + + if rank_id == 0: + uniqueid = nvshmem.get_unique_id() + broadcast_objects = [uniqueid] + else: + broadcast_objects = [None] + + dist.broadcast_object_list( + broadcast_objects, + src=0, + group=cp_group + ) + + dist.barrier(group=cp_group) + + nvshmem.init( + device=dev, + uid=broadcast_objects[0], + rank=rank_id, + nranks=num_ranks, + initializer_method="uid" + ) + + return True + +@jit_fuser +def flash_attn_fwd_out_correction_init( + out_init_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_init_step: torch.Tensor, + seq_dim: int, +): + """Merge partial outputs of the first step in Attention with context parallelism""" + softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_init_step * softmax_lse_corrected_exp + return out_corrected.to(out_init_step.dtype) + + +@jit_fuser +def flash_attn_fwd_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Merge partial outputs of each step in Attention with context parallelism""" + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out.add_(out_corrected) + + +@jit_fuser +def flash_attn_fwd_second_half_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Merge second half of partial outputs of each step in Attention with context parallelism""" + out_ = out.select(seq_dim, 1) + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :] + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out_.add_(out_corrected) + + +@jit_fuser +def flash_attn_fwd_softmax_lse_correction( + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, +): + """Merge softmax stats of each step in Attention with context parallelism""" + max_scale = torch.max(softmax_lse, softmax_lse_per_step) + min_scale = torch.min(softmax_lse, softmax_lse_per_step) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) + softmax_lse.copy_(new_scale) + + +@jit_fuser +def flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, +): + """Merge second half of softmax stats of each step in Attention with context parallelism""" + softmax_lse_ = softmax_lse[..., 1, :] + max_scale = torch.max(softmax_lse_, softmax_lse_per_step) + min_scale = torch.min(softmax_lse_, softmax_lse_per_step) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) + softmax_lse_.copy_(new_scale) + + +@jit_fuser +def get_cu_seqlens_on_cp_rank( + cu_seqlens: torch.Tensor, + cu_seqlens_padded_on_cp_rank: torch.Tensor, + cp_size: int, + cp_rank: int, + first_half: bool, + second_half: bool, +): + """Compute cu_seqlens of a context parallelism rank""" + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqlens_padded = (cu_seqlens_padded_on_cp_rank[1:] - cu_seqlens_padded_on_cp_rank[:-1]) // 2 + zeros = torch.zeros_like(seqlens) + cu_seqlens_on_cp_rank = torch.zeros_like(cu_seqlens) + if first_half: + seqlens_1 = seqlens - cp_rank * seqlens_padded + seqlens_1 = seqlens_1.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_1) + if second_half: + seqlens_2 = seqlens - (2 * cp_size - cp_rank - 1) * seqlens_padded + seqlens_2 = seqlens_2.clamp(zeros, seqlens_padded) + cu_seqlens_on_cp_rank[1:].add_(seqlens_2) + cu_seqlens_on_cp_rank.cumsum_(dim=0) + return cu_seqlens_on_cp_rank + + +@jit_fuser +def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to + be contigupus before attention compute. This function is to compute sequence chunk ids for + reordering. + """ + global _seq_chunk_ids_cache_for_reordering_before_attn + if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_before_attn: + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 + _seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)] = chunk_ids + return _seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)] + + +@jit_fuser +def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + We need to reorder sequence chunks back to discontiguous after attention compute. This function + is to compute sequence chunk ids for reordering. + """ + global _seq_chunk_ids_cache_for_reordering_after_attn + if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_after_attn: + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + _seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)] = chunk_ids + return _seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)] + + +def get_p2p_buffer_index_for_kv(fast_rank: int, slow_rank: int, slow_round: int, cp_size: int) -> int: + """Map a (fast_rank, slow_rank, slow_round) -> index in fast_rank's p2p_comm_buffers. + + Rationale: + - At runtime each rank keeps an array `p2p_comm_buffers` where index `j` holds the + KV chunk owned by owner = (fast_rank - j) % cp_size. + - When a slow rank wants the KV chunk it needs owner = (slow_rank - slow_round) % cp_size. + - So we solve for j: (fast_rank - j) % cp_size == owner -> j == (fast_rank - owner) % cp_size + - Substituting owner gives j = (fast_rank - (slow_rank - slow_round)) % cp_size + + Arguments: + fast_rank: the rank index of the 'fast' GPU (0..cp_size-1) + slow_rank: the rank index of the 'slow' GPU that needs the KV chunk + slow_round: the round index on the slow GPU (i in your description), integer >= 0 + cp_size: context-parallel group size + + Returns: + index into `p2p_comm_buffers` on the fast_rank that contains the requested KV chunk. + + Example: + cp_size=4, slow_rank=2, slow_round=0 -> owner=(2-0)%4=2 + if fast_rank=3 -> j=(3-2+0)%4=1 -> p2p_comm_buffers[1] on rank 3 holds owner 2's KV + """ + if cp_size <= 0: + raise ValueError("cp_size must be > 0") + # normalize inputs + fast_rank_mod = int(fast_rank) % cp_size + slow_rank_mod = int(slow_rank) % cp_size + slow_round_mod = int(slow_round) % cp_size + owner = (slow_rank_mod - slow_round_mod) % cp_size + idx = (fast_rank_mod - owner) % cp_size + return idx + + +@jit_fuser +def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication before attention compute.""" + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + return x + + +@jit_fuser +def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication after attention compute.""" + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) + return x + + +def flash_attn_a2a_communicate( + a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], + chunk_ids_for_a2a: torch.Tensor, + seq_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """A2A communication for context parallelism.""" + a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs + a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) + if before_attn: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a_before_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] + # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] + # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + a2a_inputs[i] = x.movedim(-3, 0).contiguous() + else: + for i in range(len(a2a_inputs) + 2): + if 0 < i < len(a2a_inputs) + 1: + a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1]) + a2a_reqs[i - 1] = torch.distributed.all_to_all_single( + a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True + ) + if i < len(a2a_inputs): + x = a2a_inputs[i] + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size + ) + if i > 1: + with torch.cuda.stream(cp_stream): + a2a_reqs[i - 2].wait() + x = a2a_outputs[i - 2] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] + # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] + # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) + torch.cuda.current_stream().wait_stream(cp_stream) + return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs + + +def _get_cu_seqlens_info_with_cp( + batch_size: int, + max_seqlen: int, + cp_size: int, + cu_seqlens: torch.Tensor, +): + """Cumulative sequence lengths with CP being considered.""" + global _cu_seqlens_info_with_cp_cache + if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache: + _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = ( + cu_seqlens // cp_size, + cu_seqlens // (cp_size * 2), + ) + return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] + + +def get_fa_args( + forward: bool, + use_flash_attn_3: bool, + qkv_format: str, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + dq=None, + dk=None, + dv=None, +): + """Get forward/backward arguments for flash-attn v2 and v3.""" + if use_flash_attn_3: + if forward: + if qkv_format == "thd": + return [ + *[None] * 4, # k_new, v_new, qv, out + cu_seqlens_q, + cu_seqlens_kv, + *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + return [ + *[None] + * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + return [ + None, # cu_seqlens_q + None, # cu_seqlens_kv + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + if forward: + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [] + if qkv_format == "thd": + return [ + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [ + dq, + dk, + dv, + ] + + +class AttnFuncWithCPAndKVP2P(torch.autograd.Function): + """ + Attention implementation with context parallelism. Exchange KV between CP ranks + with P2P in ring topology. Split attention compute into multiple steps, and overlap + current-step compute with next-step communication. + + This implementation also supports hierarchical CP, which parallelizes attention + heads in low-level CP groups and parallelizes sequence dimension in high-level CP + groups. For more details, please refer to `LongVILA `_ + and `USP `_. + """ + + @staticmethod + def forward( + ctx, + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + fp8, + fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, + quantizers, + pad_between_seqs, + use_flash_attn_3, + ): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + enable_mla = k.shape[-1] != v.shape[-1] + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if isinstance(cp_group, list): + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert attn_bias_type == "no_bias", ( + f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" + " yet!" + ) + cp_group_a2a = cp_group[0] + cp_size_a2a = get_distributed_world_size(cp_group_a2a) + rank_a2a = get_distributed_rank(cp_group_a2a) + cp_group = cp_group[1] + else: + cp_group_a2a = None + cp_size_a2a = 1 + rank_a2a = 0 + + cp_size = get_distributed_world_size(cp_group) + rank = get_distributed_rank(cp_group) + send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] + device_compute_capability = get_device_compute_capability() + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or ( + device_compute_capability < (10, 0) and cp_size == 2 + ) + + causal = "causal" in attn_mask_type + padding = "padding" in attn_mask_type + + batch_dim = None + seq_dim = None + cu_seqlens_q_half, cu_seqlens_kv_half = None, None + if qkv_format in ["bshd", "sbhd"]: + seq_dim = qkv_format.index("s") + if enable_mla: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + else: + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None + if use_fused_attention: + batch_dim = qkv_format.index("b") + cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q + ) + cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv + ) + else: + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size + + max_seqlen_q = max_seqlen_q // cp_size + max_seqlen_kv = max_seqlen_kv // cp_size + cu_seqlens_q_per_step = [None for _ in range(cp_size)] + cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + + fused_attn_backend = None + qkv_dtype = q.dtype + amax_per_step = None + S_quantizer_per_step = [None for _ in range(cp_size)] + O_CP_quantizer_per_step = [None for _ in range(cp_size)] + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype + is_input_fp8 = False + is_output_fp8 = False + + ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + + if fp8: + if use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] + + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + if is_input_fp8: + QKV_quantizer = q._quantizer + q, k, v = q._data, k._data, v._data + else: + q_f16, k_f16, v_f16 = q, k, v + if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q = QKV_quantizer(q_f16)._data + if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # partial result quantizer + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() + O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + q_f16 = q + if use_fused_attention: + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) + + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + if not fp8: + q_f16 = q + elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_f16 = q + q = QKV_quantizer(q_f16)._data + + assert qkv_format == "thd" or ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + if causal: + if qkv_format == "bshd": + # [b, s, np, hn] -> [b, 2, s//2, np, hn] + q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]] + elif qkv_format == "sbhd": + # [s, b, np, hn] -> [2, s//2, b, np, hn] + q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + if attn_bias is not None: + assert len(attn_bias.shape) == 4, ( + "Only support bias shape of [b, h, sq, sk] for forward, " + "and [1, h, sq, sk] for backward!" + ) + assert ( + attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 + ), "Sequence length does not meet divisible requirements!" + # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + attn_bias_ = attn_bias.view( + *attn_bias.shape[:-2], + 2, + attn_bias.shape[-2] // 2, + 2 * cp_size, + attn_bias.shape[-1] // (2 * cp_size), + ) + # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] + attn_bias = attn_bias.view( + *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) + ) + assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" + + softmax_lse_in_packed_format = False + if qkv_format == "thd": + if use_fused_attention: + softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) + else: + softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 + + flash_attn_fwd = None + if not use_fused_attention: + fa_forward_kwargs = {"softmax_scale": softmax_scale} + if use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd_v3, + ) + + flash_attn_fwd = ( + _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment + ) + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + else: + if qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_fwd, + ) + + flash_attn_fwd = _flash_attn_varlen_fwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_fwd, + ) + + flash_attn_fwd = _flash_attn_fwd + fa_forward_kwargs["dropout_p"] = dropout_p + fa_forward_kwargs["return_softmax"] = False + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 + if fa_utils.v2_4_plus: + fa_forward_kwargs["alibi_slopes"] = None + if fa_utils.v2_5_7_plus and qkv_format == "thd": + fa_forward_kwargs["block_table"] = None + if fa_utils.v2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 + + # Flash Attn inputs + q_inputs = [None, None] + kv_inputs = [None, None] + attn_bias_inputs = [None, None] + # Flash Attn outputs + out_per_step = [None for _ in range(cp_size)] + softmax_lse_per_step = [None for _ in range(cp_size)] + rng_states = [None for _ in range(cp_size)] + attn_biases = [None for _ in range(cp_size)] + # NVSHMEM storage for FlashAttention outputs (per-step) + nvshmem_fa_out = None + out_for_slow = None + nvshmem_fa_softmax_lse = None + softmax_lse_for_slow = None + + # NVSHMEM global storage for a full-attention output (non-stepwise path) + nvshmem_fa_out_global = None + nvshmem_fa_softmax_lse_global = None + + + + + def _store_fa_nvshmem(out_tensor, softmax_tensor): + # allocate symmetric tensors lazily and copy + # if not causal: + # return + nvshmem_fa_out = nvshmem.tensor((cp_size, *out_tensor.shape), out_tensor.dtype) + out_for_slow = nvshmem.tensor((cp_size, *out_tensor.shape), out_tensor.dtype) + nvshmem_fa_softmax_lse = nvshmem.tensor((cp_size, *softmax_tensor.shape), softmax_tensor.dtype) + softmax_lse_for_slow = nvshmem.tensor((cp_size, *softmax_tensor.shape), softmax_tensor.dtype) + return nvshmem_fa_out, out_for_slow, nvshmem_fa_softmax_lse, softmax_lse_for_slow + + + + # create two streams to resolve wave quantization issue of Flash Attn in each step + flash_attn_streams = [torch.cuda.current_stream(), cp_stream] + # synchronize fwd results correction across steps + fwd_results_correction_done = torch.cuda.Event() + + init_ok = torchrun_uid_init_bcast_object_no_reinit(cp_group) + p2p_comm_buffers = [None for _ in range(cp_size)] + if enable_mla: + # If MLA, the shape of k and v does not match, so we flatten them + # and split them after receiving them. + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + buffer = torch.cat((k.view(-1), v.view(-1)), dim=-1) + p2p_comm_buffers[0] = nvshmem.tensor(list(buffer.shape), dtype=buffer.dtype) + p2p_comm_buffers[0].copy_(buffer) + elif qkv_format in ["bshd", "sbhd"]: + # p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) + buffer = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) + p2p_comm_buffers[0] = nvshmem.tensor(list(buffer.shape), dtype=buffer.dtype) + p2p_comm_buffers[0].copy_(buffer) + else: # qkv_format == "thd" + # p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + buffer = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + p2p_comm_buffers[0] = nvshmem.tensor(list(buffer.shape), dtype=buffer.dtype) + p2p_comm_buffers[0].copy_(buffer) + send_recv_reqs = [[], []] + + # Create a symmetric NVSHMEM tensor to hold this rank's KV chunk. + # nvshmem_kv = tex.create_nvshmem_tensor(list(p2p_comm_buffers[0].shape), p2p_comm_buffers[0].dtype) + nvshmem_kv = nvshmem.tensor(list(p2p_comm_buffers[0].shape), dtype=p2p_comm_buffers[0].dtype) + # copy local KV into symmetric heap so remote ranks can fetch it + nvshmem_kv.copy_(p2p_comm_buffers[0]) + + # nvshmem_q = tex.create_nvshmem_tensor(list(q.shape), q.dtype) + nvshmem_q = nvshmem.tensor(list(q.shape), dtype=q.dtype) + nvshmem_q.copy_(q) + # Store slow rank's q in nvshmem + help_comm_buffer = nvshmem.tensor(list(q.shape), dtype=q.dtype) + + # NVSHMEM arrays for counting each rank's completed steps + count_array = nvshmem.array((cp_size, ), dtype=cp.int32) + # total_count_array = nvshmem.array((1,), dtype=cp.int32) + count_array[:] = cp.zeros(cp_size, dtype=cp.int32) + count_gather_array = nvshmem.array((cp_size * cp_size,), dtype=cp.int32) + device = Device() + communicate_stream = device.create_stream() + + out = None + for i in range(cp_size + 1): + if i < cp_size: + with torch.cuda.stream(flash_attn_streams[i % 2]): + # wait until KV is received + # for req in send_recv_reqs[(i + 1) % 2]: + # req.wait() + + if i < (cp_size - 1): + p2p_comm_buffers[i + 1] = nvshmem.tensor(list(p2p_comm_buffers[i].shape), dtype=p2p_comm_buffers[i].dtype) + # if nvshmem_kv is not None: + # Use NVSHMEM get: compute owner of the (i+1)-th step KV block + owner_idx = (rank - (i + 1)) % cp_size + # Map owner idx to global rank (accounting for a2a groups) + owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] + # nvshmem_get: dst (local buffer), src (symmetric address), peer=owner_global + nvshmem_get_on_stream(p2p_comm_buffers[i + 1], nvshmem_kv, owner_global, stream=communicate_stream) + # else: + # # fallback to P2P if NVSHMEM not available + # send_recv_reqs[i % 2] = flash_attn_p2p_communicate( + # rank, + # p2p_comm_buffers[i], + # send_dst, + # p2p_comm_buffers[i + 1], + # recv_src, + # cp_group, + # batch_p2p_comm, + # ) + + if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + kv_inputs[i % 2] = p2p_comm_buffers[i] + else: + # KV exchange is in BF16/FP16, cast received KV in each step + kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data + if enable_mla: + # If MLA, k and v are flattened, so split them after receiving. + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + if causal: + if i == 0: + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + q_inputs[i % 2] = q + if use_fused_attention: + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + + q_part = q_inputs[i % 2] + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + fake_dtype=qkv_dtype, + fused_attention_backend=fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=True, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + + # Store flash-attn outputs in NVSHMEM heap for CP communication (causal & flash-attn) + nvshmem_fa_out, out_for_slow, nvshmem_fa_softmax_lse, softmax_lse_for_slow = _store_fa_nvshmem(out_per_step[i], softmax_lse_per_step[i]) + count_array[i] = True # Mark that this step has completed + # total_count_array[0] = np.sum(count_array) + + elif i <= rank: + if count_array[i] == True: + continue # This step has been computed + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + False, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0] + elif qkv_format == "thd": + q_inputs[i % 2] = q + if enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor( + k_part, cu_seqlens_kv_padded, 0 + ) + v_part = tex.thd_read_half_tensor( + v_part, cu_seqlens_kv_padded, 0 + ) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 + ) + if use_fused_attention: + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() + + q_part = q_inputs[i % 2] + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv // 2, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None + if cu_seqlens_kv_padded is None + else cu_seqlens_kv_padded // 2 + ), + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv // 2, + ) + if use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=False, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + + count_array[i] = True # Mark that this step has completed + # total_count_array[0] = np.sum(count_array) + else: + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q_half + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_inputs[i % 2] = q[:, 1, ...] + if enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_inputs[i % 2] = q[1] + if enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[2:]) + v_part = v_part.view(-1, *v_part.shape[2:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) + if use_fused_attention: + q_inputs[i % 2] = q_inputs[i % 2].contiguous() + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + + q_part = q_inputs[i % 2] + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q // 2, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=( + None + if cu_seqlens_q_padded is None + else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q // 2, + max_seqlen_kv=max_seqlen_kv, + ) + if use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q_inputs[i % 2], + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=False, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + count_array[i] = True # Mark that this step has completed + # total_count_array[0] = np.sum(count_array) + else: + if pad_between_seqs: + cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - i) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if use_fused_attention: + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i % 2] = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + + q_part = q + if not enable_mla: + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + q_part, + k_part, + v_part, + qkv_dtype, + fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs[i % 2], + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + **fp8_meta_kwargs, + ) + if fp8: + softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + else: + softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors + attn_biases[i] = rest[0] if len(rest) > 0 else None + else: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, + causal=False, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] + + if i > 0: + # wait until fwd restuls correction of last step is done + if i > 1: + flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) + + with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): + if use_fused_attention: + # [b, np, sq, 1] -> [b, np, sq] or + # [t, np, 1] -> [t, np] + softmax_lse_per_step[i - 1].squeeze_(-1) + if softmax_lse_in_packed_format: + softmax_lse_per_step[i - 1] = ( + softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() + ) + if fp8: + out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) + if i == 1: + softmax_lse = torch.clone(softmax_lse_per_step[0]) + if qkv_format == "thd": + if enable_mla: + out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( + v_shape + ) + else: + # MHA or GQA + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( + q.shape + ) + elif (i - 1) <= rank or not causal: + flash_attn_fwd_softmax_lse_correction( + softmax_lse, softmax_lse_per_step[i - 1] + ) + else: + if qkv_format == "thd": + tex.thd_second_half_lse_correction( + softmax_lse, + softmax_lse_per_step[i - 1], + cu_seqlens_q_padded, + softmax_lse_in_packed_format, + ) + else: + flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), + softmax_lse_per_step[i - 1], + ) + + if i < cp_size: + flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) + + torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + + second_half_lse_seqlen = None + if causal and rank < (cp_size - 1): + second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + + + signal_array = nvshmem.array((1,), dtype="bool") + # Let the fast rank help computing for the slow ranks + # fast ranks will finish cpsize + 1 for loops while slow ranks not done + # so let fast ranks get the q from the slow ranks, compute and send back the out by NVSHMEM + def help_slow_ranks(slow_rank, idx): + # Get the q from slow rank + + # Use NVSHMEM get: compute owner of the (i+1)-th step Q block + # Map owner idx to global rank (accounting for a2a groups) + owner_global = cp_global_ranks[slow_rank * cp_size_a2a + rank_a2a] + # nvshmem_get: dst (local buffer), src (symmetric address), peer=owner_global + nvshmem_get_on_stream(help_comm_buffer, nvshmem_q, owner_global, stream=communicate_stream) + # q = tex.nvshmem_get_on_current_stream(help_comm_buffer, nvshmem_q, int(owner_global)) + q = help_comm_buffer + + kv_idx = get_p2p_buffer_index_for_kv(rank, slow_rank, idx, cp_size) + kv = p2p_comm_buffers[kv_idx] + # if pad_between_seqs: + # cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( + # cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + # ) + # cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( + # cu_seqlens_kv, + # cu_seqlens_kv_padded, + # cp_size, + # (rank - i) % cp_size, + # True, + # False, + # ) + if qkv_format == "thd": + cu_seqlens_q_per_step_help = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step_help = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step_help = cu_seqlens_q + cu_seqlens_kv_per_step_help = cu_seqlens_kv_half + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q = q.view(q.shape[0], -1, *q.shape[-2:]) + kv = kv[:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv[i % 2] = kv[i % 2][0] + elif qkv_format == "thd": + # q = q + # [2, t, np, hn] -> [2, t/2, np, hn] + kv = tex.thd_read_half_tensor( + kv, cu_seqlens_kv_padded, 0 + ) + + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step_help, + cu_seqlens_kv=cu_seqlens_kv_per_step_help, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv // 2, + ) + if use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + + if out_for_slow is None: + print("warning: out_for_slow is None, allocating a new tensor for it. This may cause extra memory usage.") + # Need to add MLA support once Flash Attention supports MLA + fa_outputs = flash_attn_fwd( + q, + ( + kv[..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv[0] + ), + ( + kv[..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv[1] + ), + *fa_forward_args_thd, + causal=False, + **fa_forward_kwargs, + ) + if not fa_utils.v2_7_0_plus: + out_for_slow[:] = fa_outputs[4] + softmax_lse_for_slow[:] = fa_outputs[5] + if not use_flash_attn_3: + rng_states_for_slow = fa_outputs[7] + else: + out_for_slow[:] = fa_outputs[0] + softmax_lse_for_slow[:] = fa_outputs[1] + if not use_flash_attn_3: + rng_states_for_slow = fa_outputs[3] + + # Send back the out to slow rank + # Use NVSHMEM put: dst (symmetric address), src (local buffer), + target_index = idx + # nvshmem.put(count_array[target_index:target_index+1], signal_array, slow_rank, stream=communicate_stream) + nvshmem.rma.put_offset(count_array, signal_array, target_index, slow_rank, stream=communicate_stream) + # NVSHMEM put: dst (symmetric address), src (local buffer), peer=slow_rank, target_index=idx + nvshmem.rma.put_offset(nvshmem_fa_out, out_for_slow, target_index, slow_rank, stream=communicate_stream) + nvshmem.rma.put_offset(nvshmem_fa_softmax_lse, softmax_lse_for_slow, target_index, slow_rank, stream=communicate_stream) + + # device = Device() + # stream = device.create_stream() + # print("count array:", count_array) + j = 0 + nvtx.range_push("helping slow ranks") + while True: + + print("iteration: ", j) + # nvshmem.fcollect(count_gather_array, count_array, cp_size) + nvshmem.fcollect(team=nvshmem.Teams.TEAM_WORLD, dst_array=count_gather_array, src_array=count_array, stream=communicate_stream) + print("count_gather_array:", count_gather_array) + true_count_per_pe = cp.array([cp.sum(count_gather_array[i*cp_size:(i+1)*cp_size]) for i in range(cp_size)]) + min_true_pe = int(cp.argmin(true_count_per_pe)) + all_min_true_count = true_count_per_pe[min_true_pe] + # print("all_min_true_count", all_min_true_count) + + if all_min_true_count == cp_size: + break + slow_rank = min_true_pe + slow_count = count_gather_array[slow_rank*cp_size:(slow_rank+1)*cp_size] + last_false_index = len(slow_count) - cp.argmax(cp.flip(slow_count) == False) - 1 + + # q_local = q + print("finished rank", rank, "helping slow rank", slow_rank, "for step", last_false_index) + # print("count array", count_array) + help_slow_ranks(slow_rank, last_false_index) + j += 1 + nvtx.range_pop() + + # q = q_local + + for i in range(cp_size): + if i <= rank or not causal: + if qkv_format in ["bshd", "sbhd"]: + if i == 0: + out = flash_attn_fwd_out_correction_init( + out_per_step[0], + softmax_lse, + softmax_lse_per_step[0], + seq_dim, + ) + if enable_mla: + out = out.view(v_shape) + else: + out = out.view(q.shape) + else: + flash_attn_fwd_out_correction( + out.view(*out_per_step[i].shape), + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + seq_dim, + ) + elif qkv_format == "thd": + tex.thd_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) + else: + if qkv_format in ["bshd", "sbhd"]: + flash_attn_fwd_second_half_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + seq_dim, + ) + elif qkv_format == "thd": + tex.thd_out_correction( + out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q_padded, + True, + softmax_lse_in_packed_format, + ) + + kv = p2p_comm_buffers[-1] + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + ctx.batch_size = out.shape[0] + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) + ctx.batch_size = out.shape[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) + out = flash_attn_a2a_communicate( + out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False + ) + if use_fused_attention: + if qkv_format == "bshd": + # [b*s, np, hn] -> [b, s, np, hn] + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + # [s*b, np, hn] -> [s, b, np, hn] + out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + elif not use_fused_attention: + out = out.view(-1, *out.shape[-2:]) + + if fp8 and use_fused_attention: + amax_cp_fwd = amax_per_step.amax(dim=1) + S_quantizer.amax.copy_(amax_cp_fwd[0]) + O_CP_quantizer.amax.copy_(amax_cp_fwd[1]) + + out_fp8 = None + out_f16 = out.to(qkv_dtype) + + if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_fp8 = O_quantizer(out_f16) # final result + + out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 + + if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q_save, kv_save, out_save = q, kv, out_fp8._data + elif fp8 and is_input_fp8: + q_save, kv_save, out_save = q, kv, out_f16 + else: + q_f16 = q_f16.view(q.shape) + q_save, kv_save, out_save = q_f16, kv, out_f16 + + tensors_to_save, tensor_objects = prepare_for_saving( + q_save, + kv_save, + out_save, + softmax_lse, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *cu_seqlens_q_per_step, + *cu_seqlens_kv_per_step, + *rng_states, + *attn_biases, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.cp_group_a2a = cp_group_a2a + ctx.cp_size_a2a = cp_size_a2a + ctx.rank_a2a = rank_a2a + ctx.cp_group = cp_group + ctx.cp_global_ranks = cp_global_ranks + ctx.cp_stream = cp_stream + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv + ctx.softmax_scale = softmax_scale + ctx.qkv_format = qkv_format + ctx.attn_mask_type = attn_mask_type + ctx.attn_bias_type = attn_bias_type + ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape + ctx.deterministic = deterministic + ctx.use_fused_attention = use_fused_attention + ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format + ctx.second_half_lse_seqlen = second_half_lse_seqlen + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 + ctx.use_flash_attn_3 = use_flash_attn_3 + + ctx.enable_mla = enable_mla + if enable_mla: + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape + + ctx.qkv_dtype = qkv_dtype + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dQKV_CP_quantizer = dQKV_CP_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + if ctx.fp8: + ctx.QKV_quantizer = QKV_quantizer.copy() + ctx.QKV_quantizer.scale = QKV_quantizer.scale.clone() + ctx.O_quantizer = O_quantizer.copy() + ctx.O_quantizer.scale = O_quantizer.scale.clone() + ctx.S_quantizer = S_quantizer.copy() + ctx.S_quantizer.scale = S_quantizer.scale.clone() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + # free up some nvshmem buffers + nvshmem.free_tensor(nvshmem_kv) + nvshmem.free_tensor(help_comm_buffer) + nvshmem.free_tensor(nvshmem_fa_out) + nvshmem.free_tensor(nvshmem_fa_softmax_lse) + nvshmem.free_array(count_array) + nvshmem.free_array(count_gather_array) + nvshmem.free_array(signal_array) + nvshmem.free_tensor(out_for_slow) + nvshmem.free_tensor(softmax_lse_for_slow) + nvshmem.free_tensor(nvshmem_q) + for i in range(cp_size): + nvshmem.free_tensor(p2p_comm_buffers[i]) + + return out_ret + + @staticmethod + def backward(ctx, dout): + # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + cp_size_a2a = ctx.cp_size_a2a + rank_a2a = ctx.rank_a2a + + cp_size = get_distributed_world_size(ctx.cp_group) + rank = get_distributed_rank(ctx.cp_group) + send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] + recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] + device_compute_capability = get_device_compute_capability() + batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or ( + device_compute_capability < (10, 0) and cp_size == 2 + ) + + q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( + restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + ) + cu_seqlens_q_per_step = other_tensors[:cp_size] + cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] + rng_states = other_tensors[cp_size * 2 : cp_size * 3] + attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + + causal = "causal" in ctx.attn_mask_type + padding = "padding" in ctx.attn_mask_type + + seq_dim = None + if ctx.qkv_format in ["bshd", "sbhd"]: + seq_dim = ctx.qkv_format.index("s") + if ctx.enable_mla: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] + else: + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + + if attn_biases[0] is not None: + # [b, np, sq, 2*cp, sk//(2*cp)] + attn_dbias = torch.zeros( + *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device + ) + # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + attn_dbias_ = attn_dbias.view( + *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] + ) + else: + attn_dbias = None + attn_dbias_ = None + + softmax_lse_ = None + if causal and ctx.second_half_lse_seqlen is not None: + if ctx.qkv_format == "thd": + softmax_lse_ = tex.thd_read_second_half_lse( + softmax_lse, + cu_seqlens_q_padded, + ctx.softmax_lse_in_packed_format, + ctx.second_half_lse_seqlen, + ) + else: + # [b, np, sq] -> [b, np, 2, sq//2] + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) + softmax_lse_ = softmax_lse_[..., 1, :].contiguous() + if ctx.use_fused_attention: + if ctx.softmax_lse_in_packed_format: + softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() + # [b, np, sq//2] -> [b, np, sq//2, 1] or + # [t//2, np] -> [t//2, np, 1] + softmax_lse_.unsqueeze_(-1) + if ctx.use_fused_attention: + if ctx.softmax_lse_in_packed_format: + softmax_lse = softmax_lse.transpose(0, 1).contiguous() + # [b, np, sq] -> [b, np, sq, 1] or + # [t, np] -> [t, np, 1] + softmax_lse.unsqueeze_(-1) + dout = dout.contiguous() + + dq = None + dout_dtype = dout.dtype + fused_attn_backend = None + fused_attn_dqkv_dtype = None + amax_per_step = None + dP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] + if ctx.fp8: + if ctx.use_fused_attention: + fused_attn_backend = FusedAttnBackend["FP8"] + + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.dO_quantizer = dout._quantizer + else: + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] + dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device) + dkv_fp8 = torch.empty( + (cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device + ) + dkv_fp8_ = torch.empty_like(dkv_fp8) + p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] + dout = dout._data + fp8_meta_kwargs = {} + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() + dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + else: + assert False, "FP8 is only supported with Fused Attention!" + else: + if ctx.fp8_meta is not None: + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + kv = ctx.QKV_quantizer.create_tensor_from_data( + kv, fake_dtype=ctx.qkv_dtype, internal=True + ) + q = q.dequantize(dtype=ctx.qkv_dtype) + kv = kv.dequantize(dtype=ctx.qkv_dtype) + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + if cp_size_a2a == 1: + dout = dout.dequantize(dtype=dout_dtype) + else: + ctx.dO_quantizer = dout._quantizer + dout = dout._data + dq = torch.empty_like(q) + p2p_comm_buffers = [ + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), + ] + p2p_comm_buffers[0][0].copy_(kv) + if ctx.use_fused_attention: + fp8_meta_kwargs = {} + fused_attn_dqkv_dtype = TE_DType[dout_dtype] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if cp_size_a2a > 1: + if not ctx.use_fused_attention: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(*out.shape) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( + cp_size_a2a, out.device + ) + out, dout = flash_attn_a2a_communicate( + [out, dout], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + True, + ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + dout = dout.dequantize(dtype=dout_dtype) + + if ctx.enable_mla: + out = out.view(*ctx.v_shape) + dout = dout.view(*ctx.v_shape) + else: + # MHA or GQA + out = out.view(*q.shape) + dout = dout.view(*q.shape) + send_recv_reqs = [] + + flash_attn_bwd = None + if not ctx.use_fused_attention: + fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} + if ctx.use_flash_attn_3: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd_v3, + ) + + flash_attn_bwd = ( + _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + ) + fa_backward_kwargs["deterministic"] = ctx.deterministic + else: + if ctx.qkv_format == "thd": + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_varlen_bwd, + ) + + flash_attn_bwd = _flash_attn_varlen_bwd + else: + from transformer_engine.pytorch.attention.dot_product_attention.backends import ( + _flash_attn_bwd, + ) + + flash_attn_bwd = _flash_attn_bwd + fa_backward_kwargs["dropout_p"] = ctx.dropout_p + if fa_utils.v2_4_plus: + fa_backward_kwargs["alibi_slopes"] = None + if fa_utils.v2_4_1_plus: + fa_backward_kwargs["deterministic"] = ctx.deterministic + if fa_utils.v2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 + + for i in range(cp_size): + # wait until KV is received + for req in send_recv_reqs: + req.wait() + + send_tensor = p2p_comm_buffers[i % 2] + recv_tensor = p2p_comm_buffers[(i + 1) % 2] + if ctx.fp8: + if i < cp_size - 1: + if nvshmem_kv is not None: + # owner of the next KV block + owner_idx = (rank - (i + 1)) % cp_size + owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] + tex.nvshmem_get_on_current_stream(recv_tensor[0], nvshmem_kv, int(owner_global)) + send_recv_reqs = [] + else: + send_recv_reqs = flash_attn_p2p_communicate( + rank, + send_tensor[0], + send_dst, + recv_tensor[0], + recv_src, + ctx.cp_group, + batch_p2p_comm, + ) + else: + dkv_a2a_req = torch.distributed.all_to_all_single( + dkv_fp8, + dkv_fp8_, + group=ctx.cp_group, + async_op=True, + ) + send_recv_reqs = [dkv_a2a_req] + else: + if i == 0: + send_tensor = send_tensor[0] + recv_tensor = recv_tensor[0] + if i == (cp_size - 1): + send_tensor = send_tensor[1] + recv_tensor = recv_tensor[1] + if nvshmem_kv is not None: + owner_idx = (rank - (i + 1)) % cp_size + owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a] + tex.nvshmem_get_on_current_stream(recv_tensor, nvshmem_kv, int(owner_global)) + send_recv_reqs = [] + else: + send_recv_reqs = flash_attn_p2p_communicate( + rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm + ) + + kv = p2p_comm_buffers[i % 2][0] + q_, kv_, out_, dout_ = None, None, None, None + dq_, dk_, dv_ = None, None, None + if ctx.enable_mla: + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) + # In reversed order of fwd + if causal: + if i == (cp_size - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + q_, kv_, out_, dout_ = q, kv, out, dout + if ctx.use_fused_attention: + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: + dq_ = torch.empty_like(q_) + dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = (-1, 0) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout_, + q_, + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + out_, + softmax_lse, + *fa_backward_args_thd, + causal=True, + **fa_backward_kwargs, + ) + elif i >= (cp_size - rank - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0] + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part[0] + v_part = v_part[0] + else: + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0] + elif ctx.qkv_format == "thd": + q_, out_, dout_ = q, out, dout + if ctx.enable_mla: + # [t, np, hn] -> [t/2, np, hn] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + else: + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + if ctx.use_fused_attention: + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + kv_ = kv_.contiguous() + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv // 2, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=( + None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 + ), + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: + dq_ = torch.empty_like(q_) + dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv // 2, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout_, + q_, + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + out_, + softmax_lse, + *fa_backward_args_thd, + causal=False, + **fa_backward_kwargs, + ) + else: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) + v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) + else: + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_, out_, dout_ = q[1], out[1], dout[1] + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + k_part = k_part.view(-1, *k_part.shape[-3:]) + v_part = v_part.view(-1, *v_part.shape[-3:]) + else: + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_, out_, dout_ = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q, out, dout] + ] + kv_ = kv + if ctx.use_fused_attention: + q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] + if ctx.fp8: + aux_ctx_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - i - 1], + ] + else: + aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + + q_part = q_ + if not ctx.enable_mla: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q // 2, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=( + None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 + ), + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type="padding" if padding else "no_mask", + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: + dq_ = torch.empty_like(q_) + dkv_ = torch.empty_like(kv_) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q // 2, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout_, + q_, + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + out_, + softmax_lse_, + *fa_backward_args_thd, + causal=False, + **fa_backward_kwargs, + ) + else: + if ctx.use_fused_attention: + if ctx.fp8: + aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] + else: + aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] + if attn_dbias is not None: + aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q + if not ctx.enable_mla: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv_per_step[cp_size - i - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + dout_dtype, + fused_attn_dqkv_dtype, + aux_ctx_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + attn_scale=ctx.softmax_scale, + dropout=ctx.dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=ctx.attn_mask_type, + attn_bias_type=ctx.attn_bias_type, + deterministic=ctx.deterministic, + **fp8_meta_kwargs, + ) + + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + + else: + dq_ = torch.empty_like(q) + dkv_ = torch.empty_like(kv) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + ) + if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] + # Need to add MLA support once Flash Attention supports MLA + flash_attn_bwd( + dout, + q, + kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], + kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + out, + softmax_lse, + *fa_backward_args_thd, + causal=False, + **fa_backward_kwargs, + ) + + if ctx.fp8: + dq = dq_fp8[(rank + i + 1) % cp_size] + if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): + # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or + # [sq, b, np, hn] -> [2, sq//2, b, np, hn] + dq_ = dq_.view(*dq.shape) + + if ctx.fp8: + if i >= (cp_size - rank - 1) or not causal: + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].fill_(0) + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[0].fill_(0) + dq[1].copy_(dq_) + elif causal: + if i > (cp_size - rank - 1): + dq.add_(dq_) + elif i == (cp_size - rank - 1): + if rank == (cp_size - 1): + dq.copy_(dq_) + else: + if ctx.qkv_format == "bshd": + dq[:, 0, ...].copy_(dq_[:, 0, ...]) + dq[:, 1, ...].add_(dq_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dq[0].copy_(dq_[0]) + dq[1].add_(dq_[1]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "copy", "add") + elif i > 0: + if ctx.qkv_format == "bshd": + dq[:, 1, ...].add_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].add_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "add") + else: + if ctx.qkv_format == "bshd": + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].copy_(dq_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dq, dq_, cu_seqlens_q_padded, "none", "copy") + else: + if i == 0: + dq.copy_(dq_) + else: + dq.add_(dq_) + + if attn_dbias is not None: + idx = (rank + i + 1) % cp_size + if i == (cp_size - 1) or not causal: + # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) + attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) + attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) + elif i >= (cp_size - rank - 1): + # [b, np, sq, sk//(2*cp)] + attn_dbias[..., idx, :].copy_(dbias_) + else: + # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) + attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) + attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) + + # wait until dKV is received + for req in send_recv_reqs: + req.wait() + + if ctx.fp8: + if i < cp_size - 1: + dkv = dkv_fp8_[(rank + i + 1) % cp_size] + else: + dkv = dkv_fp8[(rank + i + 1) % cp_size] + else: + dkv = p2p_comm_buffers[(i + 1) % 2][1] + if ctx.use_fused_attention: + if ctx.enable_mla: + dkv_ = None + elif ctx.qkv_format in ["bshd", "sbhd"]: + dkv_ = combine_tensors([dk_, dv_], -2) + elif ctx.qkv_format == "thd": + dkv_ = torch.cat( + (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 + ) # pylint: disable=used-before-assignment + if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. + dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) + dkv_ = dkv_.movedim(-3, 0) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv_ = dkv_.view(*dkv.shape) + + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] or + # [2, sk//2, b, np, hn] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8: + # enable_mla and fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # enable_mla and not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "copy" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dk.add_(dk_) + dv.add_(dv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "copy", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dk, dk_, cu_seqlens_kv_padded, "add", "none" + ) + tex.thd_grad_correction( + dv, dv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # enable_mla and not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + else: + if ctx.fp8: + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + dkv[:, :, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + dkv[:, 1, ...].fill_(0) + else: + dkv.copy_(dkv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) + dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_[:, 0, ...]) + dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" + ) + else: + dkv.add_(dkv_) + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" + ) + else: + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction( + dkv, dkv_, cu_seqlens_kv_padded, "add", "none" + ) + elif i > 0: + dkv.add_(dkv_) + else: # i == 0 + dkv.copy_(dkv_) + else: + # not fp8 and not causal + if i == 0: + dkv.copy_(dkv_) + else: # i > 0 + dkv.add_(dkv_) + + if ctx.fp8 and ctx.use_fused_attention: + amax_cp_bwd = amax_per_step.amax(dim=1) + ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) + ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) + dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dq_fp8, fake_dtype=torch.float32, internal=True + ) + + if ctx.enable_mla: + # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] + dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) + dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dk_fp8, fake_dtype=torch.float32, internal=True + ) + dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] + dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] + else: + if ctx.qkv_format in ["bshd", "sbhd"]: + # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or + # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] + dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) + dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dkv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] + dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + + if causal: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) + if ctx.enable_mla: + # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] + dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) + dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) + else: + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + dq = dq.view(-1, *dq.shape[-3:]) + if ctx.enable_mla: + # [2, sk//2, b, np, hn] -> [sk, b, np, hn] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + else: + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + + if ctx.qkv_format == "thd" and not ctx.use_fused_attention: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + if ctx.enable_mla: + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) + else: + dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + + if ctx.fp8 and ctx.is_input_fp8: + assert torch.uint8 not in [dq.dtype, dkv.dtype] + if ctx.enable_mla: + dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] + else: + dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] + if not ctx.enable_mla: + dk, dv = dkv[0], dkv[1] + + if cp_size_a2a > 1: + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) + dq, dk, dv = flash_attn_a2a_communicate( + [dq, dk, dv], + chunk_ids_for_a2a, + seq_dim, + cp_size_a2a, + ctx.cp_group_a2a, + ctx.cp_stream, + False, + ) + if ctx.qkv_format == "bshd": + dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] + elif ctx.qkv_format == "sbhd": + dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + + if attn_dbias is not None: + # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + # converting torch.uint8 to float8tensor + if ctx.fp8 and ctx.is_input_fp8: + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + return ( + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + attn_dbias, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + +def attn_forward_func_with_cp( + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + cp_group, + cp_global_ranks, + cp_stream, + cp_comm_type, + softmax_scale=None, + qkv_format="bshd", + attn_mask_type="causal", + attn_bias_type="no_bias", + attn_bias=None, + deterministic=False, + use_fused_attention=False, + window_size=None, + fp8=False, + fp8_meta=None, + quantizers=None, + pad_between_seqs=False, + use_flash_attn_3=False, +) -> torch.Tensor: + """ + Attention implementation with context parallelism (CP). CP partitions tensors along the sequence + dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context + LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes + the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s + and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by + (cp_size * 2). It also requires tokens to be re-ordered before entering this function. + + For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example + use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position + in their corresponding sequence. + + GPU0 | GPU1 GPU0 | GPU1 + seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8 + ---------------------------|----------------- ---------------------------|------------------ + 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1, + 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1, + 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, + ---------------------------|----------------- ---------------------------|------------------ + 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0, + 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1, + + For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different + lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of + every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of + batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and + cp_size = 2. + + GPU0 | GPU1 GPU0 | GPU1 + seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1 + seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2 + ---------------------------|----------------- ---------------------------|------------------ + 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0, + 0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0, + 0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2, + ---------------------------|----------------- ---------------------------|------------------ + 0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0, + 1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2, + + When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks, + cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for + all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank + `_ + in Megatron-LM. + + """ + + if cp_comm_type == "a2a+p2p": + assert isinstance( + cp_group, list + ), "Hierarchical CP implementation needs multi-level CP groups!" + assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + if get_distributed_world_size(cp_group[0]) == 1: + cp_group = cp_group[1] + cp_comm_type = "p2p" + elif get_distributed_world_size(cp_group[1]) == 1: + cp_group = cp_group[0] + cp_comm_type = "a2a" + else: + assert isinstance( + cp_group, dist_group_type + ), f"Unsupported process group for CP communication type {cp_comm_type}!" + + assert qkv_format in [ + "bshd", + "sbhd", + "thd", + ], f"QKV format of {qkv_format} is not supported with context parallelism!" + assert ( + qkv_format != "sbhd" or use_fused_attention + ), "FlashAttention does not support sbhd format!" + assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( + """Attention bias is only supported with FusedAttention and "causal" """ + """or "no_mask" mask types!""" + ) + assert qkv_format != "thd" or ( + cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None + ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" + + sliding_window_attn = ( + window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) + ) + assert not sliding_window_attn or cp_comm_type in [ + "a2a", + "all_gather", + ], "The context parallel running configs cannot support sliding window attetnion!" + + enable_mla = k.shape[-1] != v.shape[-1] + assert not enable_mla or cp_comm_type in [ + "p2p", + "a2a+p2p", + ], "The context parallel running configs cannot support MLA!" + + args = [ + is_training, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + dropout_p, + softmax_scale, + qkv_format, + attn_mask_type, + attn_bias_type, + attn_bias, + deterministic, + use_fused_attention, + ] + + if cp_comm_type in ["p2p", "a2a+p2p"]: + args += [ + fp8, + fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, + quantizers, + pad_between_seqs, + use_flash_attn_3, + ] + out = AttnFuncWithCPAndKVP2P.apply(*args) + elif cp_comm_type == "all_gather": + args.pop(5) + args.pop(8) + args += [window_size, cp_group, cp_stream, use_flash_attn_3] + out = AttnFuncWithCPAndKVAllGather.apply(*args) + elif cp_comm_type == "a2a": + args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] + out = AttnFuncWithCPAndQKVOA2A.apply(*args) + else: + raise ValueError(f"Unsupported communication type: {cp_comm_type}!") + + return out diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9d6677b628..3b26741cd7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -73,6 +73,8 @@ def setup_logging(): """ Set up log levels, logger and handlers """ + if AttentionLogging._is_logging_setup: + return _log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} AttentionLogging._log_level = _log_levels[ AttentionLogging._log_level if AttentionLogging._log_level in [0, 1, 2] else 2 @@ -106,7 +108,7 @@ class FlashAttentionUtils: version = PkgVersion("0") version_required = PkgVersion("2.1.1") version_required_blackwell = PkgVersion("2.7.3") - max_version = PkgVersion("2.8.1") + max_version = PkgVersion("2.8.3") v2_plus = False v2_1_plus = False v2_3_plus = False @@ -434,8 +436,8 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12") + if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0): + logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") @@ -822,7 +824,7 @@ def get_attention_backend( # flash-attn >=2.4.1 | yes # FusedAttention | # sub-backend 0 | yes - # sub-backend 1 | workspace optimization path and sm90+: yes; + # sub-backend 1 | workspace optimization path and sm90: yes; # | otherwise: no # sub-backend 2 | no # UnfusedDotProductAttention | yes @@ -838,8 +840,9 @@ def get_attention_backend( use_flash_attention_2 = False if use_fused_attention and deterministic: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False + fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and is_training @@ -849,8 +852,13 @@ def get_attention_backend( or cudnn_version < (8, 9, 5) ) ): - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") + use_fused_attention = False + fused_attention_backend = None + if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0): + logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") use_fused_attention = False + fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2f4414328b..5237f86863 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -422,6 +422,8 @@ void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind); +void nvshmem_get_on_current_stream(at::Tensor dst, at::Tensor src, int peer); + void nvshmem_finalize(); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp b/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp index 9c31678ee5..bda49ca6d1 100644 --- a/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp @@ -87,6 +87,21 @@ void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wai #endif } +void nvshmem_get_on_current_stream(torch::Tensor dst, torch::Tensor src, int peer) { +#ifdef NVTE_ENABLE_NVSHMEM + void *dst_ptr = reinterpret_cast(dst.data_ptr()); + void *src_ptr = reinterpret_cast(src.data_ptr()); + auto nelement = dst.numel() * dst.element_size(); + at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream(); + + nvshmemx_getmem_on_stream(dst_ptr, src_ptr, nelement, peer, (cudaStream_t)cur_stream); +#else + NVTE_ERROR( + "Internal TE error: nvshmem_get_on_current_stream cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"); +#endif +} + torch::Tensor create_nvshmem_tensor(const std::vector &shape, c10::ScalarType dtype) { #ifdef NVTE_ENABLE_NVSHMEM auto option_gpu = diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d38348ae9b..7fe3b5165b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -326,6 +326,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::nvshmem_send_on_current_stream, "Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream", py::call_guard()); + m.def("nvshmem_get_on_current_stream", + &transformer_engine::pytorch::nvshmem_get_on_current_stream, + "Asynchronously get tensor data from a remote PE using NVSHMEM on the current CUDA stream", + py::call_guard()); m.def("nvshmem_wait_on_current_stream", &transformer_engine::pytorch::nvshmem_wait_on_current_stream, "Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA " diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 866f0b6390..eda18a185b 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -4,6 +4,8 @@ """Functions for CUDA Graphs support in FP8""" from collections.abc import Iterable +import contextlib +import gc from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -58,6 +60,25 @@ def graph_pool_handle(): return _graph_pool_handle() +@contextlib.contextmanager +def _graph_context_wrapper(*args, **kwargs): + """Wrapper around `torch.cuda.graph`. + + This wrapper is a temporary workaround for a PyTorch bug: + automatic garbage collection can destroy a graph while another + graph is being captured, resulting in a CUDA error. See + https://github.com/pytorch/pytorch/pull/161037. + + """ + gc_is_enabled = gc.isenabled() + if gc_is_enabled: + gc.disable() + with torch.cuda.graph(*args, **kwargs): + yield + if gc_is_enabled: + gc.enable() + + def _make_graphed_callables( callables: SingleOrTuple[Callable], sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], @@ -445,7 +466,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] - with torch.cuda.graph(fwd_graph, pool=mempool): + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) @@ -483,7 +504,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with torch.cuda.graph(bwd_graph, pool=mempool): + with _graph_context_wrapper(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -548,7 +569,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_output_unflatten_spec = [] graph_id = 0 for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): - with torch.cuda.graph(fwd_graph, pool=mempool): + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) graph_callables[graph_id] = func graph_id += 1 @@ -570,7 +591,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with torch.cuda.graph(bwd_graph, pool=mempool): + with _graph_context_wrapper(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), diff --git a/transformer_engine/pytorch/onnx_extensions.py b/transformer_engine/pytorch/onnx_extensions.py index e34fd78468..42f5a1d551 100644 --- a/transformer_engine/pytorch/onnx_extensions.py +++ b/transformer_engine/pytorch/onnx_extensions.py @@ -194,12 +194,12 @@ def onnx_quantize_mxfp8_symbolic( tensor: onnxscript.onnx_types.TensorType, ) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]: """Symbolic quantize to MXFP8Tensor used for inference.""" - tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor) + tensor_out, scale_inv_out = TRT_MXFP8DynamicQuantize(tensor) return tensor_out, scale_inv_out schema = defs.OpSchema( - name="TRT_MXFP8QuantizeLinear", + name="TRT_MXFP8DynamicQuantize", domain="trt", since_version=1, doc="TRT MXFP8 Quantize Linear used for inference.", @@ -214,8 +214,8 @@ def onnx_quantize_mxfp8_symbolic( ], ) -TRT_MXFP8QuantizeLinear = onnxscript.values.Op( - opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema +TRT_MXFP8DynamicQuantize = onnxscript.values.Op( + opset=trt_opset, name="TRT_MXFP8DynamicQuantize", op_schema=schema ) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 8775968249..8336330558 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,7 +12,6 @@ import torch -from transformer_engine.pytorch.module.base import get_workspace from ...cpp_extensions import general_gemm from ...distributed import ( CudaRNGStatesTracker, @@ -20,18 +19,24 @@ reduce_scatter_along_first_dim, ) from ...fp8 import FP8GlobalStateManager, Recipe -from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, + get_workspace, +) from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase -from ..op import BasicOperation, OperationContext -from .._common import maybe_dequantize, is_quantized_tensor from ...utils import ( canonicalize_device, canonicalize_dtype, clear_tensor_data, devices_match, ) +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize, is_quantized_tensor def _wait_async(handle: Optional[Any]) -> None: @@ -73,7 +78,8 @@ class BasicLinear(BasicOperation): weight's `main_grad` attribute instead of relying on PyTorch autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be - meaningful. + meaningful. This is primarily intented to integrate with + Megatron-LM. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -979,20 +985,22 @@ def op_backward( # Saved tensors from forward pass (x_local, w) = ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = self._accumulate_into_main_grad grad_weight = None if ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(self.weight, "__fsdp_param__"): - self.weight.main_grad = self.weight.get_main_grad() - - if not hasattr(self.weight, "main_grad"): + weight_param = self.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = self.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -1019,6 +1027,17 @@ def op_backward( # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weight = None + weight_param = self.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [grad_weight] diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 8af46a27cd..845ba262a0 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -9,13 +9,10 @@ import torch -from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...module.base import get_dummy_wgrad from ...utils import clear_tensor_data +from ..basic import BasicLinear, MakeExtraOutput +from ..op import FusedOperation, FusibleOperation, OperationContext class BackwardLinearAdd(FusedOperation): @@ -53,20 +50,22 @@ def fuser_backward( # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -92,12 +91,23 @@ def fuser_backward( grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) - if accumulate_into_main_grad: - grad_weight = None # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [(grad_weight,), ()], [(), ()] diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index 630a631576..a9595d5167 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -9,13 +9,10 @@ import torch -from ..basic import BasicLinear, ConstantScale -from ..op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...module.base import get_dummy_wgrad from ...utils import clear_tensor_data +from ..basic import BasicLinear, ConstantScale +from ..op import FusedOperation, FusibleOperation, OperationContext class BackwardLinearScale(FusedOperation): @@ -54,20 +51,22 @@ def fuser_backward( # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -92,12 +91,23 @@ def fuser_backward( grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) - if accumulate_into_main_grad: - grad_weight = None # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [(), (grad_weight,)], [(), ()] diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 54a4d49db6..c595325212 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -14,11 +14,12 @@ from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...module.base import ( + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, fill_userbuffers_buffer_for_all_gather, + get_dummy_wgrad, get_ub, get_workspace, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, ) from ...tensor.quantized_tensor import Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer @@ -513,20 +514,22 @@ def fuser_backward( # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -558,10 +561,21 @@ def fuser_backward( # Clear input tensor if possible clear_tensor_data(x_local) - # Return gradients - grad_params = [() for _ in range(len(self.basic_ops))] + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + # Return gradients + grad_params = [() for _ in range(len(self.basic_ops))] grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 8686c18531..325126a3d4 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -54,7 +54,8 @@ class Linear(FusedOperation): weight's `main_grad` attribute instead of relying on PyTorch autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be - meaningful. + meaningful. This is primarily intented to integrate with + Megatron-LM. """