diff --git a/source/source_lcao/module_operator_lcao/td_nonlocal_lcao.cpp b/source/source_lcao/module_operator_lcao/td_nonlocal_lcao.cpp index a5eba57688..6f39ad039f 100644 --- a/source/source_lcao/module_operator_lcao/td_nonlocal_lcao.cpp +++ b/source/source_lcao/module_operator_lcao/td_nonlocal_lcao.cpp @@ -1,16 +1,21 @@ #include "td_nonlocal_lcao.h" -#include "source_io/module_parameter/parameter.h" #include "source_base/timer.h" #include "source_base/tool_title.h" #include "source_cell/module_neighbor/sltk_grid_driver.h" -#include "source_lcao/module_operator_lcao/operator_lcao.h" +#include "source_io/module_parameter/parameter.h" #include "source_lcao/module_hcontainer/hcontainer_funcs.h" +#include "source_lcao/module_operator_lcao/operator_lcao.h" #include "source_lcao/module_rt/snap_psibeta_half_tddft.h" +#ifdef __CUDA +#include "source_base/module_device/device.h" +#include "source_lcao/module_rt/kernels/snap_psibeta_gpu.h" +#endif + #include "source_pw/module_pwdft/global.h" #ifdef _OPENMP -#include #include +#include #endif template @@ -127,6 +132,27 @@ void hamilt::TDNonlocal>::calculate_HR() ModuleBase::TITLE("TDNonlocal", "calculate_HR"); ModuleBase::timer::tick("TDNonlocal", "calculate_HR"); + // Determine whether to use GPU path: + // GPU is only used when both __CUDA is defined AND device is set to "gpu" +#ifdef __CUDA + const bool use_gpu = (PARAM.inp.device == "gpu"); +#else + const bool use_gpu = false; +#endif + + // Initialize GPU resources if using GPU + if (use_gpu) + { +#ifdef __CUDA + // Use set_device_by_rank for multi-GPU support + int dev_id = 0; +#ifdef __MPI + dev_id = base_device::information::set_device_by_rank(MPI_COMM_WORLD); +#endif + module_rt::gpu::initialize_gpu_resources(); +#endif + } + const Parallel_Orbitals* paraV = this->hR_tmp->get_atom_pair(0).get_paraV(); const int npol = this->ucell->get_npol(); const int nlm_dim = TD_info::out_current ? 4 : 1; @@ -145,9 +171,27 @@ void hamilt::TDNonlocal>::calculate_HR() nlm_tot[i].resize(nlm_dim); } - #pragma omp parallel + if (use_gpu) { - #pragma omp for schedule(dynamic) +#ifdef __CUDA + // GPU path: Atom-level GPU batch processing + module_rt::gpu::snap_psibeta_atom_batch_gpu(orb_, + this->ucell->infoNL, + T0, + tau0 * this->ucell->lat0, + cart_At, + adjs, + this->ucell, + paraV, + npol, + nlm_dim, + nlm_tot); +#endif + } + else + { + // CPU path: OpenMP parallel over neighbors to compute nlm_tot +#pragma omp parallel for schedule(dynamic) for (int ad = 0; ad < adjs.adj_num + 1; ++ad) { const int T1 = adjs.ntype[ad]; @@ -160,35 +204,36 @@ void hamilt::TDNonlocal>::calculate_HR() all_indexes.insert(all_indexes.end(), col_indexes.begin(), col_indexes.end()); std::sort(all_indexes.begin(), all_indexes.end()); all_indexes.erase(std::unique(all_indexes.begin(), all_indexes.end()), all_indexes.end()); - for (int iw1l = 0; iw1l < all_indexes.size(); iw1l += npol) + + // CPU path: loop over orbitals + for (size_t iw1l = 0; iw1l < all_indexes.size(); iw1l += npol) { const int iw1 = all_indexes[iw1l] / npol; std::vector>> nlm; - // nlm is a vector of vectors, but size of outer vector is only 1 when out_current is false - // and size of outer vector is 4 when out_current is true (3 for , 1 for - // ) inner loop : all projectors (L0,M0) - - // snap_psibeta_half_tddft() are used to calculate - // and as well if current are needed module_rt::snap_psibeta_half_tddft(orb_, - this->ucell->infoNL, - nlm, - tau1 * this->ucell->lat0, - T1, - atom1->iw2l[iw1], - atom1->iw2m[iw1], - atom1->iw2n[iw1], - tau0 * this->ucell->lat0, - T0, - cart_At, - TD_info::out_current); + this->ucell->infoNL, + nlm, + tau1 * this->ucell->lat0, + T1, + atom1->iw2l[iw1], + atom1->iw2m[iw1], + atom1->iw2n[iw1], + tau0 * this->ucell->lat0, + T0, + cart_At, + TD_info::out_current); for (int dir = 0; dir < nlm_dim; dir++) { nlm_tot[ad][dir].insert({all_indexes[iw1l], nlm[dir]}); } } } + } + // 2. calculate D for each pair of atoms + // This runs for BOTH GPU and CPU paths +#pragma omp parallel + { #ifdef _OPENMP // record the iat number of the adjacent atoms std::set ad_atom_set; @@ -205,7 +250,7 @@ void hamilt::TDNonlocal>::calculate_HR() const int thread_id = omp_get_thread_num(); std::set ad_atom_set_thread; int i = 0; - for(const auto iat1 : ad_atom_set) + for (const auto iat1: ad_atom_set) { if (i % num_threads == thread_id) { @@ -215,7 +260,6 @@ void hamilt::TDNonlocal>::calculate_HR() } #endif - // 2. calculate D for each pair of atoms for (int ad1 = 0; ad1 < adjs.adj_num + 1; ++ad1) { const int T1 = adjs.ntype[ad1]; @@ -228,7 +272,7 @@ void hamilt::TDNonlocal>::calculate_HR() continue; } #endif - + const ModuleBase::Vector3& R_index1 = adjs.box[ad1]; for (int ad2 = 0; ad2 < adjs.adj_num + 1; ++ad2) { @@ -247,9 +291,9 @@ void hamilt::TDNonlocal>::calculate_HR() if (TD_info::out_current) { std::complex* tmp_c[3] = {nullptr, nullptr, nullptr}; - for (int i = 0; i < 3; i++) + for (int ii = 0; ii < 3; ii++) { - tmp_c[i] = TD_info::td_vel_op->get_current_term_pointer(i) + tmp_c[ii] = TD_info::td_vel_op->get_current_term_pointer(ii) ->find_matrix(iat1, iat2, R_vector[0], R_vector[1], R_vector[2]) ->get_pointer(); } @@ -276,13 +320,13 @@ void hamilt::TDNonlocal>::calculate_HR() } } } - } - } - + } // end omp parallel for matrix assembly + } // end for iat0 ModuleBase::timer::tick("TDNonlocal", "calculate_HR"); } // cal_HR_IJR() + template void hamilt::TDNonlocal>::cal_HR_IJR( const int& iat1, @@ -396,7 +440,6 @@ void hamilt::TDNonlocal>::set_HR_fixed(void* hR_tmp this->allocated = false; } - // contributeHR() template void hamilt::TDNonlocal>::contributeHR() @@ -436,7 +479,6 @@ void hamilt::TDNonlocal>::contributeHR() return; } - template void hamilt::TDNonlocal>::contributeHk(int ik) { diff --git a/source/source_lcao/module_rt/CMakeLists.txt b/source/source_lcao/module_rt/CMakeLists.txt index cca79af03a..35a65a5aa7 100644 --- a/source/source_lcao/module_rt/CMakeLists.txt +++ b/source/source_lcao/module_rt/CMakeLists.txt @@ -18,6 +18,13 @@ if(ENABLE_LCAO) boundary_fix.cpp ) + if(USE_CUDA) + list(APPEND objects + kernels/cuda/snap_psibeta_kernel.cu + kernels/cuda/snap_psibeta_gpu.cu + ) + endif() + add_library( tddft OBJECT diff --git a/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_gpu.cu b/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_gpu.cu new file mode 100644 index 0000000000..cb381ece78 --- /dev/null +++ b/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_gpu.cu @@ -0,0 +1,380 @@ +/** + * @file snap_psibeta_gpu.cu + * @brief Host-side GPU interface for overlap computation + * + * This file provides the high-level interface for GPU-accelerated computation + * of overlap integrals between atomic orbitals (psi) and non-local projectors + * (beta). It handles: + * - GPU resource initialization and cleanup + * - Data marshalling from ABACUS structures to GPU-friendly formats + * - Kernel launch configuration + * - Result unpacking back to ABACUS data structures + */ + +#include "../snap_psibeta_gpu.h" +#include "snap_psibeta_kernel.cuh" +#include "source_base/timer.h" +#include "source_base/tool_quit.h" + +#include +#include +#include +#include + +namespace module_rt +{ +namespace gpu +{ + +//============================================================================= +// GPU Resource Management +//============================================================================= + +/** + * @brief Initialize GPU resources for snap_psibeta computation + * + * Checks for available CUDA devices and copies integration grids + * (Lebedev-Laikov angular and Gauss-Legendre radial) to constant memory. + * + * @note Call this once at the start of a calculation session before any + * snap_psibeta_atom_batch_gpu calls. + */ +void initialize_gpu_resources() +{ + // Verify CUDA device availability + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) + { + ModuleBase::WARNING_QUIT("snap_psibeta_gpu", "No CUDA devices found or error getting device count!"); + } + + // Initialize integration grids in constant memory + copy_grids_to_device(); + + // Synchronize to ensure initialization is complete + cudaDeviceSynchronize(); +} + +//============================================================================= +// Internal Helper Structures +//============================================================================= + +/** + * @brief Mapping structure for reconstructing output data + * + * Associates each orbital in the flattened GPU array with its original + * neighbor and orbital indices for proper result placement. + */ +struct OrbitalMapping +{ + int neighbor_idx; ///< Index of neighbor atom in adjacency list + int iw_index; ///< Global orbital index for output mapping +}; + +//============================================================================= +// Main GPU Interface Function +//============================================================================= + +/** + * @brief Compute overlap integrals on GPU + * + * This function processes ALL neighbor atoms for a single center atom (where + * the projectors are located) in a single kernel launch, providing significant + * performance improvement over per-neighbor processing. + * + * Workflow: + * 1. Collect all (neighbor, orbital) pairs into flattened arrays + * 2. Prepare projector data for the center atom + * 3. Transfer data to GPU and launch kernel + * 4. Retrieve results and reconstruct nlm_tot structure + * + * @param orb LCAO orbital information + * @param infoNL_ Non-local projector information + * @param T0 Atom type of center atom (projector location) + * @param R0 Position of center atom + * @param A Vector potential for phase factor + * @param adjs Adjacent atom information + * @param ucell Unit cell information + * @param paraV Parallel orbital distribution + * @param npol Number of spin polarizations + * @param nlm_dim Output dimension (1 for overlap only, 4 for overlap + current) + * @param nlm_tot Output: overlap integrals indexed as [neighbor][direction][orbital] + */ +void snap_psibeta_atom_batch_gpu( + const LCAO_Orbitals& orb, + const InfoNonlocal& infoNL_, + const int T0, + const ModuleBase::Vector3& R0, + const ModuleBase::Vector3& A, + const AdjacentAtomInfo& adjs, + const UnitCell* ucell, + const Parallel_Orbitals* paraV, + const int npol, + const int nlm_dim, + std::vector>>>>& nlm_tot) +{ + ModuleBase::timer::tick("module_rt", "snap_psibeta_gpu"); + + //========================================================================= + // Early exit if no projectors on center atom + //========================================================================= + + const int nproj = infoNL_.nproj[T0]; + if (nproj == 0) + { + ModuleBase::timer::tick("module_rt", "snap_psibeta_gpu"); + return; + } + + //========================================================================= + // Compute projector output indices + //========================================================================= + + int natomwfc = 0; // Total number of projector components + std::vector proj_m0_offset_h(nproj); + + for (int ip = 0; ip < nproj; ip++) + { + proj_m0_offset_h[ip] = natomwfc; + int L0 = infoNL_.Beta[T0].Proj[ip].getL(); + + // Validate angular momentum + if (L0 > MAX_L) + { + ModuleBase::WARNING_QUIT("snap_psibeta_gpu", + "L0=" + std::to_string(L0) + " exceeds MAX_L=" + std::to_string(MAX_L)); + } + natomwfc += 2 * L0 + 1; + } + + //========================================================================= + // Collect all (neighbor, orbital) pairs + //========================================================================= + + std::vector neighbor_orbitals_h; + std::vector psi_radial_h; + std::vector orbital_mappings; + + for (int ad = 0; ad < adjs.adj_num + 1; ++ad) + { + const int T1 = adjs.ntype[ad]; + const int I1 = adjs.natom[ad]; + const int iat1 = ucell->itia2iat(T1, I1); + const ModuleBase::Vector3& tau1 = adjs.adjacent_tau[ad]; + const Atom* atom1 = &ucell->atoms[T1]; + + // Get unique orbital indices (union of row and column indices) + auto all_indexes = paraV->get_indexes_row(iat1); + auto col_indexes = paraV->get_indexes_col(iat1); + all_indexes.insert(all_indexes.end(), col_indexes.begin(), col_indexes.end()); + std::sort(all_indexes.begin(), all_indexes.end()); + all_indexes.erase(std::unique(all_indexes.begin(), all_indexes.end()), all_indexes.end()); + + // Process each orbital + for (size_t iw1l = 0; iw1l < all_indexes.size(); iw1l += npol) + { + const int iw1 = all_indexes[iw1l] / npol; + const int L1 = atom1->iw2l[iw1]; + const int m1 = atom1->iw2m[iw1]; + const int N1 = atom1->iw2n[iw1]; + + // Skip orbitals with angular momentum beyond supported limit + if (L1 > MAX_L) + { + continue; + } + + // Get orbital radial function (use getPsi(), not getPsi_r()) + const double* phi_psi = orb.Phi[T1].PhiLN(L1, N1).getPsi(); + int mesh = orb.Phi[T1].PhiLN(L1, N1).getNr(); + double dk = orb.Phi[T1].PhiLN(L1, N1).getDk(); + double rcut = orb.Phi[T1].getRcut(); + + // Append to flattened psi array + size_t psi_offset = psi_radial_h.size(); + psi_radial_h.insert(psi_radial_h.end(), phi_psi, phi_psi + mesh); + + // Create neighbor-orbital data + NeighborOrbitalData norb; + norb.neighbor_idx = ad; + norb.R1 = make_double3(tau1.x * ucell->lat0, tau1.y * ucell->lat0, tau1.z * ucell->lat0); + norb.L1 = L1; + norb.m1 = m1; + norb.N1 = N1; + norb.iw_index = all_indexes[iw1l]; + norb.psi_offset = static_cast(psi_offset); + norb.psi_mesh = mesh; + norb.psi_dk = dk; + norb.psi_rcut = rcut; + + neighbor_orbitals_h.push_back(norb); + + // Track mapping for result reconstruction + OrbitalMapping mapping; + mapping.neighbor_idx = ad; + mapping.iw_index = all_indexes[iw1l]; + orbital_mappings.push_back(mapping); + } + } + + int total_neighbor_orbitals = static_cast(neighbor_orbitals_h.size()); + if (total_neighbor_orbitals == 0) + { + ModuleBase::timer::tick("module_rt", "snap_psibeta_gpu"); + return; + } + + //========================================================================= + // Prepare projector data + //========================================================================= + + std::vector projectors_h(nproj); + std::vector beta_radial_h; + + for (int ip = 0; ip < nproj; ip++) + { + const auto& proj = infoNL_.Beta[T0].Proj[ip]; + int L0 = proj.getL(); + int mesh = proj.getNr(); + double dk = proj.getDk(); + double rcut = proj.getRcut(); + const double* beta_r = proj.getBeta_r(); + const double* radial = proj.getRadial(); + + projectors_h[ip].L0 = L0; + projectors_h[ip].beta_offset = static_cast(beta_radial_h.size()); + projectors_h[ip].beta_mesh = mesh; + projectors_h[ip].beta_dk = dk; + projectors_h[ip].beta_rcut = rcut; + projectors_h[ip].r_min = radial[0]; + projectors_h[ip].r_max = radial[mesh - 1]; + + beta_radial_h.insert(beta_radial_h.end(), beta_r, beta_r + mesh); + } + + //========================================================================= + // Allocate GPU memory + //========================================================================= + + NeighborOrbitalData* neighbor_orbitals_d = nullptr; + ProjectorData* projectors_d = nullptr; + double* psi_radial_d = nullptr; + double* beta_radial_d = nullptr; + int* proj_m0_offset_d = nullptr; + cuDoubleComplex* nlm_out_d = nullptr; + + size_t output_size = total_neighbor_orbitals * nlm_dim * natomwfc; + + CUDA_CHECK(cudaMalloc(&neighbor_orbitals_d, total_neighbor_orbitals * sizeof(NeighborOrbitalData))); + CUDA_CHECK(cudaMalloc(&projectors_d, nproj * sizeof(ProjectorData))); + CUDA_CHECK(cudaMalloc(&psi_radial_d, psi_radial_h.size() * sizeof(double))); + CUDA_CHECK(cudaMalloc(&beta_radial_d, beta_radial_h.size() * sizeof(double))); + CUDA_CHECK(cudaMalloc(&proj_m0_offset_d, nproj * sizeof(int))); + CUDA_CHECK(cudaMalloc(&nlm_out_d, output_size * sizeof(cuDoubleComplex))); + + //========================================================================= + // Transfer data to GPU + //========================================================================= + + CUDA_CHECK(cudaMemcpy(neighbor_orbitals_d, + neighbor_orbitals_h.data(), + total_neighbor_orbitals * sizeof(NeighborOrbitalData), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(projectors_d, projectors_h.data(), nproj * sizeof(ProjectorData), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(psi_radial_d, psi_radial_h.data(), psi_radial_h.size() * sizeof(double), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(beta_radial_d, beta_radial_h.data(), beta_radial_h.size() * sizeof(double), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(proj_m0_offset_d, proj_m0_offset_h.data(), nproj * sizeof(int), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemset(nlm_out_d, 0, output_size * sizeof(cuDoubleComplex))); + + //========================================================================= + // Launch kernel + //========================================================================= + + double3 R0_d3 = make_double3(R0.x, R0.y, R0.z); + double3 A_d3 = make_double3(A.x, A.y, A.z); + + dim3 grid(total_neighbor_orbitals, nproj, 1); + dim3 block(BLOCK_SIZE, 1, 1); + + snap_psibeta_atom_batch_kernel<<>>(R0_d3, + A_d3, + neighbor_orbitals_d, + projectors_d, + psi_radial_d, + beta_radial_d, + proj_m0_offset_d, + total_neighbor_orbitals, + nproj, + natomwfc, + nlm_dim, + nlm_out_d); + + // Check for launch errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + cudaFree(neighbor_orbitals_d); + cudaFree(projectors_d); + cudaFree(psi_radial_d); + cudaFree(beta_radial_d); + cudaFree(proj_m0_offset_d); + cudaFree(nlm_out_d); + ModuleBase::WARNING_QUIT("snap_psibeta_gpu", + std::string("Atom batch kernel launch error: ") + cudaGetErrorString(err)); + } + + CUDA_CHECK(cudaDeviceSynchronize()); + + //========================================================================= + // Retrieve results + //========================================================================= + + std::vector nlm_out_h(output_size); + CUDA_CHECK(cudaMemcpy(nlm_out_h.data(), nlm_out_d, output_size * sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost)); + + //========================================================================= + // Reconstruct output structure + //========================================================================= + + for (int i = 0; i < total_neighbor_orbitals; i++) + { + int ad = orbital_mappings[i].neighbor_idx; + int iw_index = orbital_mappings[i].iw_index; + + std::vector>> nlm(nlm_dim); + for (int d = 0; d < nlm_dim; d++) + { + nlm[d].resize(natomwfc); + for (int k = 0; k < natomwfc; k++) + { + size_t idx = i * nlm_dim * natomwfc + d * natomwfc + k; + nlm[d][k] = std::complex(nlm_out_h[idx].x, nlm_out_h[idx].y); + } + } + + // Insert into nlm_tot[neighbor][direction][orbital] + for (int dir = 0; dir < nlm_dim; dir++) + { + nlm_tot[ad][dir].insert({iw_index, nlm[dir]}); + } + } + + //========================================================================= + // Cleanup GPU memory + //========================================================================= + + cudaFree(neighbor_orbitals_d); + cudaFree(projectors_d); + cudaFree(psi_radial_d); + cudaFree(beta_radial_d); + cudaFree(proj_m0_offset_d); + cudaFree(nlm_out_d); + + ModuleBase::timer::tick("module_rt", "snap_psibeta_gpu"); +} + +} // namespace gpu +} // namespace module_rt diff --git a/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_kernel.cu b/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_kernel.cu new file mode 100644 index 0000000000..60ea9ad26b --- /dev/null +++ b/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_kernel.cu @@ -0,0 +1,567 @@ +/** + * @file snap_psibeta_kernel.cu + * @brief CUDA kernel implementation for overlap integrals + * + * This file implements the GPU-accelerated numerical integration for computing + * overlap integrals between atomic orbitals (psi) and non-local projectors (beta). + * The implementation uses: + * - Lebedev-Laikov quadrature (110 points) for angular integration + * - Gauss-Legendre quadrature (140 points) for radial integration + * - Templated spherical harmonics with compile-time L for optimization + * - Warp-level shuffle reduction for efficient parallel summation + */ + +#include "snap_psibeta_kernel.cuh" +#include "source_base/constants.h" +#include "source_base/math_integral.h" + +#include +#include + +namespace module_rt +{ +namespace gpu +{ + +//============================================================================= +// Constant Memory - Integration Grids +//============================================================================= + +// Lebedev-Laikov angular quadrature grid (110 points) +__constant__ double d_lebedev_x[ANGULAR_GRID_NUM]; ///< x-direction cosines +__constant__ double d_lebedev_y[ANGULAR_GRID_NUM]; ///< y-direction cosines +__constant__ double d_lebedev_z[ANGULAR_GRID_NUM]; ///< z-direction cosines +__constant__ double d_lebedev_w[ANGULAR_GRID_NUM]; ///< Angular integration weights + +// Gauss-Legendre radial quadrature grid (140 points) +__constant__ double d_gl_x[RADIAL_GRID_NUM]; ///< Quadrature abscissae on [-1, 1] +__constant__ double d_gl_w[RADIAL_GRID_NUM]; ///< Quadrature weights + +//============================================================================= +// Spherical Harmonics - Helper Functions +//============================================================================= + +/** + * @brief Access element in lower-triangular stored Legendre polynomial array + * + * For associated Legendre polynomials P_l^m, we only need 0 <= m <= l. + * Storage layout: P_0^0, P_1^0, P_1^1, P_2^0, P_2^1, P_2^2, ... + * Linear index: l*(l+1)/2 + m + */ +__device__ __forceinline__ double& p_access(double* p, int l, int m) +{ + return p[l * (l + 1) / 2 + m]; +} + +/** + * @brief Read-only access to Legendre polynomial array + */ +__device__ __forceinline__ double p_get(const double* p, int l, int m) +{ + return p[l * (l + 1) / 2 + m]; +} + +//============================================================================= +// Spherical Harmonics - Main Implementation +//============================================================================= + +/** + * @brief Compute real spherical harmonics Y_lm (templated version) + * + * Uses the recursive computation of associated Legendre polynomials: + * P_l^m = ((2l-1)*cos(theta)*P_{l-1}^m - (l-1+m)*P_{l-2}^m) / (l-m) + * P_l^{l-1} = (2l-1)*cos(theta)*P_{l-1}^{l-1} + * P_l^l = (-1)^l * (2l-1)!! * sin^l(theta) + * + * Real spherical harmonics are defined as: + * Y_{lm} = c_l * P_l^0 for m = 0 + * Y_{l,2m-1} = c_l * sqrt(2/(l-m)!/(l+m)!) * P_l^m * cos(m*phi) for m > 0 + * Y_{l,2m} = c_l * sqrt(2/(l-m)!/(l+m)!) * P_l^m * sin(m*phi) for m > 0 + * where c_l = sqrt((2l+1)/(4*pi)) + * + * @tparam L Maximum angular momentum (compile-time constant) + * @param x, y, z Direction vector components (need not be normalized) + * @param ylm Output array storing Y_lm values in order: Y_00, Y_10, Y_11c, Y_11s, ... + */ +template +__device__ void compute_ylm_gpu(double x, double y, double z, double* ylm) +{ + + constexpr int P_SIZE = (L + 1) * (L + 2) / 2; // Lower triangular storage size + + // Y_00 = 1/(2*sqrt(pi)) + ylm[0] = 0.5 * sqrt(1.0 / ModuleBase::PI); + + if (L == 0) + { + return; + } + + // Compute spherical angles + double r2 = x * x + y * y + z * z; + double r = sqrt(r2); + + double cost, sint, phi; + if (r < 1e-10) + { + // At origin, default to z-axis direction + cost = 1.0; + sint = 0.0; + phi = 0.0; + } + else + { + cost = z / r; + sint = sqrt(1.0 - cost * cost); + phi = atan2(y, x); + } + + // Ensure sint is non-negative (numerical safety) + if (sint < 0.0) + { + sint = 0.0; + } + + // Associated Legendre polynomials P_l^m in lower-triangular storage + double p[P_SIZE]; + + // Base cases + p_access(p, 0, 0) = 1.0; + + if (L >= 1) + { + p_access(p, 1, 0) = cost; // P_1^0 = cos(theta) + p_access(p, 1, 1) = -sint; // P_1^1 = -sin(theta) + } + + // Recurrence relations for l >= 2 +#pragma unroll + for (int l = 2; l <= L; l++) + { + // P_l^m for m = 0 to l-2: standard recurrence +#pragma unroll + for (int m = 0; m <= l - 2; m++) + { + p_access(p, l, m) = ((2 * l - 1) * cost * p_get(p, l - 1, m) - (l - 1 + m) * p_get(p, l - 2, m)) + / static_cast(l - m); + } + + // P_l^{l-1} = (2l-1) * cos(theta) * P_{l-1}^{l-1} + p_access(p, l, l - 1) = (2 * l - 1) * cost * p_get(p, l - 1, l - 1); + + // P_l^l = (-1)^l * (2l-1)!! * sin^l(theta) + double double_factorial = 1.0; +#pragma unroll + for (int i = 1; i <= 2 * l - 1; i += 2) + { + double_factorial *= i; + } + + double sint_power = 1.0; +#pragma unroll + for (int i = 0; i < l; i++) + { + sint_power *= sint; + } + + p_access(p, l, l) = double_factorial * sint_power; + if (l % 2 == 1) + { + p_access(p, l, l) = -p_access(p, l, l); + } + } + + // Transform Legendre polynomials to real spherical harmonics + int lm = 0; +#pragma unroll + for (int l = 0; l <= L; l++) + { + double c = sqrt((2.0 * l + 1.0) / ModuleBase::FOUR_PI); + + // m = 0 component + ylm[lm] = c * p_get(p, l, 0); + lm++; + + // m > 0 components (cosine and sine parts) +#pragma unroll + for (int m = 1; m <= l; m++) + { + // Compute normalization factor: sqrt(2 * (l-m)! / (l+m)!) + double factorial_ratio = 1.0; +#pragma unroll + for (int i = l - m + 1; i <= l + m; i++) + { + factorial_ratio *= i; + } + double norm = c * sqrt(1.0 / factorial_ratio) * ModuleBase::SQRT2; + + double sin_mphi, cos_mphi; + sincos(m * phi, &sin_mphi, &cos_mphi); + + ylm[lm] = norm * p_get(p, l, m) * cos_mphi; // Y_{l,m} cosine part + lm++; + + ylm[lm] = norm * p_get(p, l, m) * sin_mphi; // Y_{l,m} sine part + lm++; + } + } +} + +// Explicit template instantiations for L = 0, 1, 2, 3, 4 +template __device__ void compute_ylm_gpu<0>(double x, double y, double z, double* ylm); +template __device__ void compute_ylm_gpu<1>(double x, double y, double z, double* ylm); +template __device__ void compute_ylm_gpu<2>(double x, double y, double z, double* ylm); +template __device__ void compute_ylm_gpu<3>(double x, double y, double z, double* ylm); +template __device__ void compute_ylm_gpu<4>(double x, double y, double z, double* ylm); + +//============================================================================= +// Warp-Level Reduction +//============================================================================= + +/** + * @brief Warp-level sum reduction using shuffle instructions + * + * Performs a parallel reduction within a warp (32 threads) using __shfl_down_sync. + * After this function, lane 0 contains the sum of all input values in the warp. + * + * @param val Input value from each thread + * @return Sum across all threads in the warp (valid only in lane 0) + */ +__device__ __forceinline__ double warp_reduce_sum(double val) +{ + for (int offset = 16; offset > 0; offset /= 2) + { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +//============================================================================= +// Main Kernel Implementation +//============================================================================= + +/** + * @brief Atom-level batch kernel for overlap integrals + * + * Integration is performed using restructured loops for efficiency: + * - Outer loop: angular points (each thread handles different angles) + * - Inner loop: radial points (each thread accumulates all radii) + * + * This structure exploits the fact that Y_lm for the projector (ylm0) only + * depends on the angular direction, not the radial distance, saving + * RADIAL_GRID_NUM redundant ylm0 computations per angular point. + */ +__global__ void snap_psibeta_atom_batch_kernel(double3 R0, + double3 A, + const NeighborOrbitalData* __restrict__ neighbor_orbitals, + const ProjectorData* __restrict__ projectors, + const double* __restrict__ psi_radial, + const double* __restrict__ beta_radial, + const int* __restrict__ proj_m0_offset, + int total_neighbor_orbitals, + int nproj, + int natomwfc, + int nlm_dim, + cuDoubleComplex* __restrict__ nlm_out) +{ + // Thread/block indices + const int norb_idx = blockIdx.x; // Which (neighbor, orbital) pair + const int proj_idx = blockIdx.y; // Which projector + const int tid = threadIdx.x; + + // Early exit for out-of-bounds blocks + if (norb_idx >= total_neighbor_orbitals || proj_idx >= nproj) + { + return; + } + + //------------------------------------------------------------------------- + // Load input data + //------------------------------------------------------------------------- + + const NeighborOrbitalData& norb = neighbor_orbitals[norb_idx]; + const ProjectorData& proj = projectors[proj_idx]; + + const double3 R1 = norb.R1; + const int L1 = norb.L1; + const int m1 = norb.m1; + const int L0 = proj.L0; + const int m0_offset = proj_m0_offset[proj_idx]; + + // Skip if angular momentum exceeds supported limit + if (L1 > MAX_L || L0 > MAX_L) + { + return; + } + + //------------------------------------------------------------------------- + // Compute geometry + //------------------------------------------------------------------------- + + // Note: dR (R1 - R0) is computed inline as dRx/dRy/dRz in the integration loop + + // Orbital cutoff + const double r1_max = norb.psi_rcut; + + // Integration range from projector radial grid + const double r_min = proj.r_min; + const double r_max = proj.r_max; + const double xl = 0.5 * (r_max - r_min); // Half-range for Gauss-Legendre + const double xmean = 0.5 * (r_max + r_min); // Midpoint + + // Phase factor exp(i * A · R0) + const double AdotR0 = A.x * R0.x + A.y * R0.y + A.z * R0.z; + const cuDoubleComplex exp_iAR0 = cu_exp_i(AdotR0); + + //------------------------------------------------------------------------- + // Shared memory for warp reduction + //------------------------------------------------------------------------- + + constexpr int NUM_WARPS = BLOCK_SIZE / 32; // 128 / 32 = 4 warps + __shared__ double s_temp_re[NUM_WARPS]; + __shared__ double s_temp_im[NUM_WARPS]; + + //------------------------------------------------------------------------- + // Initialize accumulators (per-thread registers) + //------------------------------------------------------------------------- + + const int num_m0 = 2 * L0 + 1; + + double result_re[MAX_M0_SIZE]; + double result_im[MAX_M0_SIZE]; + double result_r_re[3][MAX_M0_SIZE]; // For current operator: x, y, z components + double result_r_im[3][MAX_M0_SIZE]; + + for (int m0 = 0; m0 < num_m0; m0++) + { + result_re[m0] = 0.0; + result_im[m0] = 0.0; + for (int d = 0; d < 3; d++) + { + result_r_re[d][m0] = 0.0; + result_r_im[d][m0] = 0.0; + } + } + + //------------------------------------------------------------------------- + // Main integration loop + // Outer: angular points (parallelized across threads) + // Inner: radial points (accumulated per thread) + //------------------------------------------------------------------------- + + for (int ian = tid; ian < ANGULAR_GRID_NUM; ian += BLOCK_SIZE) + { + // Load angular grid point + const double leb_x = d_lebedev_x[ian]; + const double leb_y = d_lebedev_y[ian]; + const double leb_z = d_lebedev_z[ian]; + const double w_ang = d_lebedev_w[ian]; + + // Precompute Y_lm for projector (independent of radial distance) + double ylm0[MAX_YLM_SIZE]; + DISPATCH_YLM(L0, leb_x, leb_y, leb_z, ylm0); + const int offset_L0 = L0 * L0; + + // Precompute A · direction (for phase factor) + const double A_dot_leb = A.x * leb_x + A.y * leb_y + A.z * leb_z; + + // Vector from R1 to R0 (for computing distance to orbital center) + const double dRx = R0.x - R1.x; + const double dRy = R0.y - R1.y; + const double dRz = R0.z - R1.z; + + // Radial integration +#pragma unroll 4 + for (int ir = 0; ir < RADIAL_GRID_NUM; ir++) + { + // Transform Gauss-Legendre point from [-1,1] to [r_min, r_max] + const double r_val = xmean + xl * d_gl_x[ir]; + const double w_rad = xl * d_gl_w[ir]; + + // Integration point position relative to R0 + const double rx = r_val * leb_x; + const double ry = r_val * leb_y; + const double rz = r_val * leb_z; + + // Vector from R1 to integration point + const double tx = rx + dRx; + const double ty = ry + dRy; + const double tz = rz + dRz; + const double tnorm = sqrt(tx * tx + ty * ty + tz * tz); + + // Check if within orbital cutoff + if (tnorm <= r1_max) + { + // Compute Y_lm for orbital (depends on direction from R1) + double ylm1[MAX_YLM_SIZE]; + if (tnorm > 1e-10) + { + const double inv_tnorm = 1.0 / tnorm; + DISPATCH_YLM(L1, tx * inv_tnorm, ty * inv_tnorm, tz * inv_tnorm, ylm1); + } + else + { + DISPATCH_YLM(L1, 0.0, 0.0, 1.0, ylm1); + } + + // Interpolate orbital radial function + const double psi_val + = interpolate_radial_gpu(psi_radial + norb.psi_offset, norb.psi_mesh, 1.0 / norb.psi_dk, tnorm); + + // Interpolate projector radial function + const double beta_val + = interpolate_radial_gpu(beta_radial + proj.beta_offset, proj.beta_mesh, 1.0 / proj.beta_dk, r_val); + + // Phase factor exp(i * A · r) + const double phase = r_val * A_dot_leb; + const cuDoubleComplex exp_iAr = cu_exp_i(phase); + + // Orbital Y_lm value + const double ylm_L1_val = ylm1[L1 * L1 + m1]; + + // Combined integration factor: Y_L1m1 * psi * beta * r * dr * dOmega + const double factor = ylm_L1_val * psi_val * beta_val * r_val * w_rad * w_ang; + const cuDoubleComplex common_factor = cu_mul_real(exp_iAr, factor); + + // Accumulate for all m0 components of projector +#pragma unroll + for (int m0 = 0; m0 < num_m0; m0++) + { + const double ylm0_val = ylm0[offset_L0 + m0]; + + result_re[m0] += common_factor.x * ylm0_val; + result_im[m0] += common_factor.y * ylm0_val; + + // Current operator contribution (if requested) + if (nlm_dim == 4) + { + const double r_op_x = rx + R0.x; + const double r_op_y = ry + R0.y; + const double r_op_z = rz + R0.z; + + result_r_re[0][m0] += common_factor.x * ylm0_val * r_op_x; + result_r_im[0][m0] += common_factor.y * ylm0_val * r_op_x; + result_r_re[1][m0] += common_factor.x * ylm0_val * r_op_y; + result_r_im[1][m0] += common_factor.y * ylm0_val * r_op_y; + result_r_re[2][m0] += common_factor.x * ylm0_val * r_op_z; + result_r_im[2][m0] += common_factor.y * ylm0_val * r_op_z; + } + } + } + } // End radial loop + } // End angular loop + + //------------------------------------------------------------------------- + // Parallel reduction and output + // Uses warp shuffle for efficiency, followed by cross-warp reduction + //------------------------------------------------------------------------- + + const int out_base = norb_idx * nlm_dim * natomwfc; + const int warp_id = tid / 32; + const int lane_id = tid % 32; + + for (int m0 = 0; m0 < num_m0; m0++) + { + // Step 1: Warp-level reduction using shuffle + double sum_re = warp_reduce_sum(result_re[m0]); + double sum_im = warp_reduce_sum(result_im[m0]); + + // Step 2: First lane of each warp writes to shared memory + if (lane_id == 0) + { + s_temp_re[warp_id] = sum_re; + s_temp_im[warp_id] = sum_im; + } + __syncthreads(); + + // Step 3: First warp reduces across all warps and writes output + if (warp_id == 0) + { + sum_re = (lane_id < NUM_WARPS) ? s_temp_re[lane_id] : 0.0; + sum_im = (lane_id < NUM_WARPS) ? s_temp_im[lane_id] : 0.0; + sum_re = warp_reduce_sum(sum_re); + sum_im = warp_reduce_sum(sum_im); + + if (lane_id == 0) + { + cuDoubleComplex result = make_cuDoubleComplex(sum_re, sum_im); + result = cu_mul(result, exp_iAR0); + result = cu_conj(result); + nlm_out[out_base + 0 * natomwfc + m0_offset + m0] = result; + } + } + __syncthreads(); + + // Process current operator components (if nlm_dim == 4) + if (nlm_dim == 4) + { + for (int d = 0; d < 3; d++) + { + double sum_r_re = warp_reduce_sum(result_r_re[d][m0]); + double sum_r_im = warp_reduce_sum(result_r_im[d][m0]); + + if (lane_id == 0) + { + s_temp_re[warp_id] = sum_r_re; + s_temp_im[warp_id] = sum_r_im; + } + __syncthreads(); + + if (warp_id == 0) + { + sum_r_re = (lane_id < NUM_WARPS) ? s_temp_re[lane_id] : 0.0; + sum_r_im = (lane_id < NUM_WARPS) ? s_temp_im[lane_id] : 0.0; + sum_r_re = warp_reduce_sum(sum_r_re); + sum_r_im = warp_reduce_sum(sum_r_im); + + if (lane_id == 0) + { + cuDoubleComplex result_r = make_cuDoubleComplex(sum_r_re, sum_r_im); + result_r = cu_mul(result_r, exp_iAR0); + result_r = cu_conj(result_r); + nlm_out[out_base + (d + 1) * natomwfc + m0_offset + m0] = result_r; + } + } + __syncthreads(); + } + } + } +} + +//============================================================================= +// Host-side Helper Functions +//============================================================================= + +/** + * @brief Copy integration grids to GPU constant memory + * + * Initializes the constant memory arrays with Lebedev-Laikov angular grid + * and Gauss-Legendre radial grid for use in kernel integration. + */ +void copy_grids_to_device() +{ + // Copy Lebedev-Laikov 110-point angular quadrature grid + CUDA_CHECK(cudaMemcpyToSymbol(d_lebedev_x, + ModuleBase::Integral::Lebedev_Laikov_grid110_x, + ANGULAR_GRID_NUM * sizeof(double))); + CUDA_CHECK(cudaMemcpyToSymbol(d_lebedev_y, + ModuleBase::Integral::Lebedev_Laikov_grid110_y, + ANGULAR_GRID_NUM * sizeof(double))); + CUDA_CHECK(cudaMemcpyToSymbol(d_lebedev_z, + ModuleBase::Integral::Lebedev_Laikov_grid110_z, + ANGULAR_GRID_NUM * sizeof(double))); + CUDA_CHECK(cudaMemcpyToSymbol(d_lebedev_w, + ModuleBase::Integral::Lebedev_Laikov_grid110_w, + ANGULAR_GRID_NUM * sizeof(double))); + + // Compute and copy Gauss-Legendre radial quadrature grid + std::vector h_gl_x(RADIAL_GRID_NUM); + std::vector h_gl_w(RADIAL_GRID_NUM); + ModuleBase::Integral::Gauss_Legendre_grid_and_weight(RADIAL_GRID_NUM, h_gl_x.data(), h_gl_w.data()); + + CUDA_CHECK(cudaMemcpyToSymbol(d_gl_x, h_gl_x.data(), RADIAL_GRID_NUM * sizeof(double))); + CUDA_CHECK(cudaMemcpyToSymbol(d_gl_w, h_gl_w.data(), RADIAL_GRID_NUM * sizeof(double))); +} + +} // namespace gpu +} // namespace module_rt diff --git a/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_kernel.cuh b/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_kernel.cuh new file mode 100644 index 0000000000..157224ea57 --- /dev/null +++ b/source/source_lcao/module_rt/kernels/cuda/snap_psibeta_kernel.cuh @@ -0,0 +1,331 @@ +/** + * @file snap_psibeta_kernel.cuh + * @brief CUDA kernel declarations for computing overlap integrals + * + * This file provides GPU-accelerated computation of overlap integrals between + * atomic orbitals (psi) and non-local projectors (beta) for real-time TDDFT + * calculations. The implementation uses numerical integration on a combined + * radial (Gauss-Legendre) and angular (Lebedev-Laikov) grid. + * + * Key Features: + * - Atom-level batching: processes all neighbors for a center atom in single kernel + * - Templated spherical harmonics for compile-time optimization + * - Efficient memory access via constant memory for integration grids + * - Warp-level reduction for high-performance summation + */ + +#ifndef SNAP_PSIBETA_KERNEL_CUH +#define SNAP_PSIBETA_KERNEL_CUH + +#include "source_base/tool_quit.h" + +#include +#include +#include +#include + +//============================================================================= +// CUDA Error Checking Macro +//============================================================================= + +/** + * @brief CUDA error checking macro with file/line information + * + * Checks the return value of CUDA API calls and calls WARNING_QUIT + * with error information if the call fails. + */ +#define CUDA_CHECK(call) \ + do \ + { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) \ + { \ + ModuleBase::WARNING_QUIT("CUDA_CHECK", \ + std::string("Error at ") + __FILE__ + ":" + std::to_string(__LINE__) + " - " \ + + cudaGetErrorString(err)); \ + } \ + } while (0) + +namespace module_rt +{ +namespace gpu +{ + +//============================================================================= +// Configuration Constants +//============================================================================= + +/// Number of points in radial Gauss-Legendre grid +constexpr int RADIAL_GRID_NUM = 140; + +/// Number of points in angular Lebedev-Laikov grid (110-point rule) +constexpr int ANGULAR_GRID_NUM = 110; + +/// Thread block size for kernel execution +constexpr int BLOCK_SIZE = 128; + +/// Maximum supported angular momentum quantum number L +constexpr int MAX_L = 4; + +/// Size of spherical harmonics array: (MAX_L + 1)^2 = 25 +constexpr int MAX_YLM_SIZE = (MAX_L + 1) * (MAX_L + 1); + +/// Maximum number of magnetic quantum numbers for a single L: 2*MAX_L + 1 = 9 +constexpr int MAX_M0_SIZE = 2 * MAX_L + 1; + +//============================================================================= +// Device Helper Functions - Complex Arithmetic +//============================================================================= + +/** + * @brief Compute exp(i * theta) = cos(theta) + i * sin(theta) + * @param theta Phase angle in radians + * @return Complex exponential as cuDoubleComplex + */ +__device__ __forceinline__ cuDoubleComplex cu_exp_i(double theta) +{ + double s, c; + sincos(theta, &s, &c); + return make_cuDoubleComplex(c, s); +} + +/** + * @brief Complex multiplication: a * b + */ +__device__ __forceinline__ cuDoubleComplex cu_mul(cuDoubleComplex a, cuDoubleComplex b) +{ + return make_cuDoubleComplex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); +} + +/** + * @brief Complex addition: a + b + */ +__device__ __forceinline__ cuDoubleComplex cu_add(cuDoubleComplex a, cuDoubleComplex b) +{ + return make_cuDoubleComplex(a.x + b.x, a.y + b.y); +} + +/** + * @brief Complex conjugate: conj(a) + */ +__device__ __forceinline__ cuDoubleComplex cu_conj(cuDoubleComplex a) +{ + return make_cuDoubleComplex(a.x, -a.y); +} + +/** + * @brief Complex times real: a * r + */ +__device__ __forceinline__ cuDoubleComplex cu_mul_real(cuDoubleComplex a, double r) +{ + return make_cuDoubleComplex(a.x * r, a.y * r); +} + +//============================================================================= +// Device Helper Functions - Radial Interpolation +//============================================================================= + +/** + * @brief Cubic spline interpolation for radial functions + * + * Implements cubic polynomial interpolation using 4 consecutive grid points. + * This is the GPU equivalent of CPU-side PolyInt::Polynomial_Interpolation. + * + * @param psi Radial function values on uniform grid + * @param mesh Number of grid points + * @param inv_dk Inverse of grid spacing (1/dk) + * @param distance Radial distance r at which to interpolate + * @return Interpolated function value + */ +__device__ __forceinline__ double interpolate_radial_gpu(const double* __restrict__ psi, + int mesh, + double inv_dk, + double distance) +{ + double position = distance * inv_dk; + int iq = __double2int_rd(position); // floor(position) + + // Boundary checks + if (iq > mesh - 4 || iq < 0) + { + return 0.0; + } + + // Lagrange interpolation weights + double x0 = position - static_cast(iq); + double x1 = 1.0 - x0; + double x2 = 2.0 - x0; + double x3 = 3.0 - x0; + + // 4-point Lagrange interpolation formula + return x1 * x2 * (psi[iq] * x3 + psi[iq + 3] * x0) / 6.0 + x0 * x3 * (psi[iq + 1] * x2 - psi[iq + 2] * x1) / 2.0; +} + +//============================================================================= +// Device Helper Functions - Spherical Harmonics +//============================================================================= + +/** + * @brief Compute real spherical harmonics Y_lm at a given direction + * + * TEMPLATED VERSION: L is a compile-time constant enabling loop unrolling + * and register allocation optimizations by the compiler. + * + * Uses the recursive computation of associated Legendre polynomials + * followed by transformation to real spherical harmonics with proper + * normalization (same as ModuleBase::Ylm). + * + * @tparam L Maximum angular momentum (0 <= L <= MAX_L) + * @param x, y, z Direction vector components (need not be normalized, normalization is done internally) + * @param ylm Output array of size (L+1)^2, indexed as ylm[l*l + l + m] + */ +template +__device__ void compute_ylm_gpu(double x, double y, double z, double* ylm); + +/** + * @brief Runtime dispatch macro for templated compute_ylm_gpu + * + * Converts a runtime L value to the appropriate compile-time template + * instantiation for optimal performance. + * + * @param L_val Runtime angular momentum value + * @param x, y, z Direction vector components + * @param ylm Output array for spherical harmonics + */ +#define DISPATCH_YLM(L_val, x, y, z, ylm) \ + do \ + { \ + switch (L_val) \ + { \ + case 0: \ + compute_ylm_gpu<0>(x, y, z, ylm); \ + break; \ + case 1: \ + compute_ylm_gpu<1>(x, y, z, ylm); \ + break; \ + case 2: \ + compute_ylm_gpu<2>(x, y, z, ylm); \ + break; \ + case 3: \ + compute_ylm_gpu<3>(x, y, z, ylm); \ + break; \ + case 4: \ + compute_ylm_gpu<4>(x, y, z, ylm); \ + break; \ + default: \ + compute_ylm_gpu<4>(x, y, z, ylm); \ + break; \ + } \ + } while (0) + +//============================================================================= +// Data Structures for Kernel Input +//============================================================================= + +/** + * @brief Non-local projector (beta function) information + * + * Contains all data needed to evaluate a single projector during integration. + */ +struct ProjectorData +{ + int L0; ///< Angular momentum quantum number + int beta_offset; ///< Offset into flattened beta radial array + int beta_mesh; ///< Number of radial mesh points + double beta_dk; ///< Radial grid spacing + double beta_rcut; ///< Cutoff radius for projector + double r_min; ///< Minimum radial grid value (integration start) + double r_max; ///< Maximum radial grid value (integration end) +}; + +/** + * @brief Neighbor atom orbital information for atom-level batching + * + * Each structure represents one (neighbor_atom, orbital) pair that contributes + * to the overlap integral. This enables processing ALL neighbors for a center + * atom in a single kernel launch, minimizing launch overhead. + */ +struct NeighborOrbitalData +{ + int neighbor_idx; ///< Index of neighbor atom (ad index in adjacency list) + double3 R1; ///< Neighbor atom position in Cartesian coordinates (tau * lat0) + + // Orbital information + int L1; ///< Angular momentum of orbital + int m1; ///< Magnetic quantum number of orbital + int N1; ///< Radial quantum number of orbital + int iw_index; ///< Global orbital index for output mapping + int psi_offset; ///< Offset into flattened psi radial array + int psi_mesh; ///< Number of radial mesh points for orbital + double psi_dk; ///< Radial grid spacing for orbital + double psi_rcut; ///< Cutoff radius for orbital +}; + +//============================================================================= +// Main CUDA Kernel Declaration +//============================================================================= + +/** + * @brief Atom-level batch kernel for overlap computation + * + * This kernel processes ALL neighbor orbitals for a single center atom in one + * launch, significantly reducing kernel launch overhead. Each thread block + * handles the integration for one (neighbor_orbital, projector) pair. + * + * Grid Configuration: + * - gridDim.x = total_neighbor_orbitals (all orbitals from all neighbors) + * - gridDim.y = nproj (number of projectors on center atom) + * + * Block Configuration: + * - blockDim.x = BLOCK_SIZE threads for parallel integration + * + * Integration Strategy: + * - Angular loop (outer): each thread processes different angular points + * - Radial loop (inner): each thread accumulates over all radial points + * - Warp shuffle reduction for efficient summation + * + * @param R0 Center atom position (projector location) + * @param A Vector potential for phase factor + * @param neighbor_orbitals Array of neighbor-orbital data [total_neighbor_orbitals] + * @param projectors Array of projector data [nproj] + * @param psi_radial Flattened array of orbital radial functions + * @param beta_radial Flattened array of projector radial functions + * @param proj_m0_offset Starting index of each projector's m=0 component in output + * @param total_neighbor_orbitals Total number of (neighbor, orbital) pairs + * @param nproj Number of projectors on center atom + * @param natomwfc Total projector components: sum of (2*L0+1) for all projectors + * @param nlm_dim Output dimension: 1 for overlap only, 4 for overlap + current + * @param nlm_out Output array [total_neighbor_orbitals * nlm_dim * natomwfc] + */ +__global__ void snap_psibeta_atom_batch_kernel(double3 R0, + double3 A, + const NeighborOrbitalData* __restrict__ neighbor_orbitals, + const ProjectorData* __restrict__ projectors, + const double* __restrict__ psi_radial, + const double* __restrict__ beta_radial, + const int* __restrict__ proj_m0_offset, + int total_neighbor_orbitals, + int nproj, + int natomwfc, + int nlm_dim, + cuDoubleComplex* __restrict__ nlm_out); + +//============================================================================= +// Host-side Initialization +//============================================================================= + +/** + * @brief Copy integration grids to GPU constant memory + * + * Copies the Lebedev-Laikov angular grid (110 points) and Gauss-Legendre + * radial grid (140 points) to CUDA constant memory for fast access during + * kernel execution. + * + * @note Must be called once before any kernel launches in a calculation session. + */ +void copy_grids_to_device(); + +} // namespace gpu +} // namespace module_rt + +#endif // SNAP_PSIBETA_KERNEL_CUH diff --git a/source/source_lcao/module_rt/kernels/snap_psibeta_gpu.h b/source/source_lcao/module_rt/kernels/snap_psibeta_gpu.h new file mode 100644 index 0000000000..43b617e188 --- /dev/null +++ b/source/source_lcao/module_rt/kernels/snap_psibeta_gpu.h @@ -0,0 +1,70 @@ +#ifndef SNAP_PSIBETA_GPU_H +#define SNAP_PSIBETA_GPU_H + +#include "source_base/vector3.h" +#include "source_basis/module_ao/ORB_read.h" +#include "source_basis/module_ao/parallel_orbitals.h" +#include "source_cell/module_neighbor/sltk_grid_driver.h" +#include "source_cell/setup_nonlocal.h" +#include "source_cell/unitcell.h" + +#include +#include +#include + +#ifdef __CUDA +#include +#endif + +namespace module_rt +{ +namespace gpu +{ + +/** + * @brief Initialize GPU resources (copy grids to constant memory) + * Should be called at the start of each calculate_HR + */ +void initialize_gpu_resources(); + +/** + * @brief Release GPU resources (clear any error states) + * Should be called at the end of each calculate_HR + */ +void finalize_gpu_resources(); + +/** + * @brief Atom-level GPU batch processing interface + * + * Processes ALL neighbors for a center atom in a SINGLE kernel launch. + * This significantly reduces kernel launch overhead compared to neighbor-level batching. + * + * @param orb Orbital information + * @param infoNL_ Non-local pseudopotential information + * @param T0 Center atom type (projector location) + * @param R0 Center atom position (already multiplied by lat0) + * @param A Vector potential + * @param adjs Adjacent atom information for this center atom + * @param ucell Unit cell pointer + * @param paraV Parallel orbitals information + * @param npol Polarization number + * @param nlm_dim 1 for no current, 4 for current calculation + * @param nlm_tot Output: nlm_tot[ad][dir][iw_index] = nlm_vector + */ +void snap_psibeta_atom_batch_gpu( + const LCAO_Orbitals& orb, + const InfoNonlocal& infoNL_, + const int T0, + const ModuleBase::Vector3& R0, + const ModuleBase::Vector3& A, + const AdjacentAtomInfo& adjs, + const UnitCell* ucell, + const Parallel_Orbitals* paraV, + const int npol, + const int nlm_dim, + std::vector>>>>& nlm_tot); + +} // namespace gpu +} // namespace module_rt + +#endif // SNAP_PSIBETA_GPU_H