From 8d6cbdf2a2c4f60ffd260b6804d9cfc526119577 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Fri, 1 Aug 2025 18:39:28 +0800 Subject: [PATCH 1/4] fix bug in dsp compute --- CMakeLists.txt | 1 + .../source_base/kernels/dsp/dsp_connector.cpp | 6 ++++-- .../source_base/kernels/dsp/dsp_connector.h | 2 ++ .../module_pw/module_fft/fft_dsp.cpp | 2 +- source/source_esolver/esolver_ks_pw.cpp | 2 ++ source/source_io/read_wf2rho_pw.cpp | 6 +++--- source/source_io/to_wannier90_pw.cpp | 19 +++++++++---------- .../module_pwdft/stress_func_exx.cpp | 4 ++-- 8 files changed, 24 insertions(+), 18 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cc19db8f2..11547d501a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,7 @@ if (USE_DSP) target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a) target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a) endif() +target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR}) find_package(Threads REQUIRED) target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads) diff --git a/source/source_base/kernels/dsp/dsp_connector.cpp b/source/source_base/kernels/dsp/dsp_connector.cpp index a3c5f6d897..130646d593 100644 --- a/source/source_base/kernels/dsp/dsp_connector.cpp +++ b/source/source_base/kernels/dsp/dsp_connector.cpp @@ -12,6 +12,8 @@ extern "C" } namespace mtfunc { +std::complex* alp=nullptr; +std::complex* bet=nullptr; void dspInitHandle(int id) { mt_blas_init(id); @@ -271,9 +273,9 @@ void zgemm_mth_(const char* transa, const int* ldc, int cluster_id) { - std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + // std::complex* alp = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); *alp = *alpha; - std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); + // std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); *bet = *beta; mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor, convertBLASTranspose(transa), diff --git a/source/source_base/kernels/dsp/dsp_connector.h b/source/source_base/kernels/dsp/dsp_connector.h index 34ccbaec4b..babe2b7d9a 100644 --- a/source/source_base/kernels/dsp/dsp_connector.h +++ b/source/source_base/kernels/dsp/dsp_connector.h @@ -15,6 +15,8 @@ void* malloc_ht(size_t bytes, int cluster_id); void free_ht(void* ptr); // mtblas functions +extern std::complex* alp; +extern std::complex* bet; void sgemm_mt_(const char* transa, const char* transb, diff --git a/source/source_basis/module_pw/module_fft/fft_dsp.cpp b/source/source_basis/module_pw/module_fft/fft_dsp.cpp index e26292cf5b..0842066eb0 100644 --- a/source/source_basis/module_pw/module_fft/fft_dsp.cpp +++ b/source/source_basis/module_pw/module_fft/fft_dsp.cpp @@ -1,7 +1,7 @@ #include "fft_dsp.h" #include "source_base/global_variable.h" - +#include "source_base/tool_quit.h" #include #include #include diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 760a597d1c..b3aefcad1f 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -71,6 +71,8 @@ ESolver_KS_PW::ESolver_KS_PW() #ifdef __DSP std::cout << " ** Initializing DSP Hardware..." << std::endl; mtfunc::dspInitHandle(GlobalV::MY_RANK); + mtfunc::alp=(std::complex*)mtfunc::malloc_ht(sizeof(std::complex), GlobalV::MY_RANK); + mtfunc::bet=(std::complex*)mtfunc::malloc_ht(sizeof(std::complex), GlobalV::MY_RANK); #endif } diff --git a/source/source_io/read_wf2rho_pw.cpp b/source/source_io/read_wf2rho_pw.cpp index 1be65a268c..66f41b9448 100644 --- a/source/source_io/read_wf2rho_pw.cpp +++ b/source/source_io/read_wf2rho_pw.cpp @@ -129,8 +129,8 @@ void ModuleIO::read_wf2rho_pw( { const std::complex* wfc_ib = wfc_tmp.c + ib * ng_npol; const std::complex* wfc_ib2 = wfc_tmp.c + ib * ng_npol + ng_npol / 2; - pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik); - pw_wfc->recip2real(wfc_ib2, rho_tmp2.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(wfc_ib2, rho_tmp2.data(), ik); const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega; if (w1 != 0.0) @@ -152,7 +152,7 @@ void ModuleIO::read_wf2rho_pw( for (int ib = 0; ib < nbands; ++ib) { const std::complex* wfc_ib = wfc_tmp.c + ib * ng_npol; - pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik); const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega; diff --git a/source/source_io/to_wannier90_pw.cpp b/source/source_io/to_wannier90_pw.cpp index 9c33cf4976..72d628fe8d 100644 --- a/source/source_io/to_wannier90_pw.cpp +++ b/source/source_io/to_wannier90_pw.cpp @@ -253,7 +253,7 @@ void toWannier90_PW::out_unk( { int ib = cal_band_index[ib_w]; - wfcpw->recip2real(&psi_pw(ik, ib, 0), porter, ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(ik, ib, 0), porter, ik); if (GlobalV::RANK_IN_POOL == 0) { @@ -383,7 +383,7 @@ void toWannier90_PW::unkdotkb( } } - wfcpw->recip2real(phase, phase, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(phase, phase, cal_ik); if (PARAM.inp.nspin == 4) { @@ -396,17 +396,17 @@ void toWannier90_PW::unkdotkb( // (2) fft and get value // int npw_ik = wfcpw->npwk[cal_ik]; int npwx = wfcpw->npwk_max; - wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir_up, cal_ik); - // wfcpw->recip2real(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik); - wfcpw->recip2real(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir_up, cal_ik); + // wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik); for (int ir = 0; ir < wfcpw->nrxx; ir++) { psir_up[ir] *= phase[ir]; psir_dn[ir] *= phase[ir]; } - wfcpw->real2recip(psir_up, psir_up, cal_ikb); - wfcpw->real2recip(psir_dn, psir_dn, cal_ikb); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psir_up, psir_up, cal_ikb); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psir_dn, psir_dn, cal_ikb); for (int n = 0; n < num_bands; n++) { @@ -447,13 +447,12 @@ void toWannier90_PW::unkdotkb( ModuleBase::GlobalFunc::ZEROS(psir, wfcpw->nmaxgr); // (2) fft and get value - wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir, cal_ik); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir, cal_ik); for (int ir = 0; ir < wfcpw->nrxx; ir++) { psir[ir] *= phase[ir]; } - - wfcpw->real2recip(psir, psir, cal_ikb); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psir, psir, cal_ikb); for (int n = 0; n < num_bands; n++) { diff --git a/source/source_pw/module_pwdft/stress_func_exx.cpp b/source/source_pw/module_pwdft/stress_func_exx.cpp index 1885bb16ee..b01b24848a 100644 --- a/source/source_pw/module_pwdft/stress_func_exx.cpp +++ b/source/source_pw/module_pwdft/stress_func_exx.cpp @@ -260,7 +260,7 @@ void Stress_PW::stress_exx(ModuleBase::matrix& sigma, // psi_nk in real space d_psi_in->fix_kb(ik, nband); T* psi_nk = d_psi_in->get_pointer(); - wfcpw->recip2real(psi_nk, psi_nk_real, ik); + wfcpw->recip_to_real,Device>(psi_nk, psi_nk_real, ik); for (int iq = 0; iq < nqs; iq++) { @@ -269,7 +269,7 @@ void Stress_PW::stress_exx(ModuleBase::matrix& sigma, // psi_mq in real space d_psi_in->fix_kb(iq, mband); T* psi_mq = d_psi_in->get_pointer(); - wfcpw->recip2real(psi_mq, psi_mq_real, iq); + wfcpw->recip_to_real,Device>(psi_mq, psi_mq_real, iq); // overlap density in real space setmem_complex_op()(density_real, 0.0, rhopw->nrxx); From dedc0e18b0cfbc5ca0847a6d5903ba036975656e Mon Sep 17 00:00:00 2001 From: A-006 <3158793232@qq.com> Date: Fri, 1 Aug 2025 20:18:04 +0800 Subject: [PATCH 2/4] add convulution for veff --- source/source_basis/module_pw/pw_basis_k.h | 11 ++- .../source_basis/module_pw/pw_transform_k.cpp | 73 +++++++++++++++++++ .../module_pw/pw_transform_k_dsp.cpp | 4 +- .../module_pwdft/operator_pw/veff_pw.cpp | 21 ++++-- 4 files changed, 99 insertions(+), 10 deletions(-) diff --git a/source/source_basis/module_pw/pw_basis_k.h b/source/source_basis/module_pw/pw_basis_k.h index b87da9ca0f..7b7858669d 100644 --- a/source/source_basis/module_pw/pw_basis_k.h +++ b/source/source_basis/module_pw/pw_basis_k.h @@ -135,7 +135,6 @@ class PW_Basis_K : public PW_Basis const int ik, const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) - #if defined(__DSP) template void convolution(const Device* ctx, const int ik, @@ -145,6 +144,16 @@ class PW_Basis_K : public PW_Basis std::complex* output, const bool add = false, const FPTYPE factor =1.0) const ; + #if defined(__DSP) + template + void convolution_dsp(const Device* ctx, + const int ik, + const int size, + const std::complex* input, + const FPTYPE* input1, + std::complex* output, + const bool add = false, + const FPTYPE factor =1.0) const ; template void real2recip_dsp(const std::complex* in, diff --git a/source/source_basis/module_pw/pw_transform_k.cpp b/source/source_basis/module_pw/pw_transform_k.cpp index 36290d091a..3fc3888c76 100644 --- a/source/source_basis/module_pw/pw_transform_k.cpp +++ b/source/source_basis/module_pw/pw_transform_k.cpp @@ -337,6 +337,79 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, this->recip2real(in, out, ik, add, factor); #endif } +template <> +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const float* input1, + std::complex* output, + const bool add, + const float factor) const +{ +} + +template <> +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, + const int ik, + const int size, + const std::complex* input, + const double* input1, + std::complex* output, + const bool add, + const double factor) const +{ + ModuleBase::timer::tick(this->classname, "convolution"); + assert(this->gamma_only == false); + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data(), this->nst * this->nz); + // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz + auto* auxg = this->fft_bundle.get_auxg_data(); + auto* auxr=this->fft_bundle.get_auxr_data(); + + memset(auxg, 0, this->nst * this->nz * 2 * 8); + const int startig = ik * this->npwk_max; + const int npwk = this->npwk[ik]; + + // copy the mapping form the type of stick to the 3dfft + #ifdef _OPENMP + #pragma omp parallel for schedule(static, 4096 / sizeof(double)) + #endif + for (int igl = 0; igl < npwk; ++igl) + { + auxg[this->igl2isz_k[igl + startig]] = input[igl]; + } + + // use 3d fft backward + this->fft_bundle.fftzbac(auxg, auxg); + + this->gathers_scatterp(auxg, auxr); + + this->fft_bundle.fftxybac(auxr, auxr); + for (int ir = 0; ir < size; ir++) + { + auxr[ir] *= input1[ir]; + } + + // 3d fft + this->fft_bundle.fftxyfor(auxr, auxr); + + this->gatherp_scatters(auxr, auxg); + + this->fft_bundle.fftzfor(auxg, auxg); + // copy the result from the auxr to the out ,while consider the add + if (add) + { + double tmpfac = factor / double(this->nxyz); +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 4096 / sizeof(double)) +#endif + for (int igl = 0; igl < npwk; ++igl) + { + output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]]; + } + } + ModuleBase::timer::tick(this->classname, "convolution"); +} #if (defined(__CUDA) || defined(__ROCM)) template <> diff --git a/source/source_basis/module_pw/pw_transform_k_dsp.cpp b/source/source_basis/module_pw/pw_transform_k_dsp.cpp index 1449943550..6f35d545ca 100644 --- a/source/source_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/source_basis/module_pw/pw_transform_k_dsp.cpp @@ -91,7 +91,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex* in, } } template <> -void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, +void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx, const int ik, const int size, const std::complex* input, @@ -103,7 +103,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, } template <> -void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, +void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx, const int ik, const int size, const std::complex* input, diff --git a/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp b/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp index e2813b3a9a..0251857d0f 100644 --- a/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp +++ b/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp @@ -61,7 +61,7 @@ void Veff>::act( ModulePW::FFT_Guard guard(wfcpw->fft_bundle); for (int ib = 0; ib < nbands; ib += npol) { - wfcpw->convolution(this->ctx, + wfcpw->convolution_dsp(this->ctx, this->ik, this->veff_col, tmpsi_in, @@ -96,12 +96,19 @@ void Veff>::act( { for (int ib = 0; ib < nbands; ib += npol) { - wfcpw->recip_to_real(tmpsi_in, this->porter, this->ik); - // NOTICE: when MPI threads are larger than the number of Z grids - // veff would contain nothing, and nothing should be done in real space - // but the 3DFFT can not be skipped, it will cause hanging - veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col); - wfcpw->real_to_recip(this->porter, tmhpsi, this->ik, true); + wfcpw->convolution(this->ctx, + this->ik, + this->veff_col, + tmpsi_in, + this->veff + current_spin * this->veff_col, + tmhpsi, + true); + // wfcpw->recip_to_real(tmpsi_in, this->porter, this->ik); + // // NOTICE: when MPI threads are larger than the number of Z grids + // // veff would contain nothing, and nothing should be done in real space + // // but the 3DFFT can not be skipped, it will cause hanging + // veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col); + // wfcpw->real_to_recip(this->porter, tmhpsi, this->ik, true); tmhpsi += psi_offset; tmpsi_in += psi_offset; } From 40e73fd66941929b0ac779369844eceff0d258df Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Sun, 3 Aug 2025 21:20:08 +0800 Subject: [PATCH 3/4] add change on pw_basis_k --- .../source_basis/module_pw/pw_transform_k.cpp | 18 +++++++++--------- source/source_io/cal_ldos.cpp | 4 ++-- source/source_io/cal_mlkedf_descriptors.cpp | 2 +- source/source_io/get_wf_lcao.cpp | 6 +++--- source/source_io/unk_overlap_pw.cpp | 10 +++++----- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/source/source_basis/module_pw/pw_transform_k.cpp b/source/source_basis/module_pw/pw_transform_k.cpp index 3fc3888c76..13ce5ef2e6 100644 --- a/source/source_basis/module_pw/pw_transform_k.cpp +++ b/source/source_basis/module_pw/pw_transform_k.cpp @@ -366,7 +366,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, auto* auxg = this->fft_bundle.get_auxg_data(); auto* auxr=this->fft_bundle.get_auxr_data(); - memset(auxg, 0, this->nst * this->nz * 2 * 8); + memset(auxg, 0, this->nst * this->nz * 2 * 8); const int startig = ik * this->npwk_max; const int npwk = this->npwk[ik]; @@ -385,11 +385,14 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, this->gathers_scatterp(auxg, auxr); this->fft_bundle.fftxybac(auxr, auxr); + + #ifdef _OPENMP + #pragma omp parallel for simd schedule(static) aligned(auxr, input1: 64) + #endif for (int ir = 0; ir < size; ir++) { auxr[ir] *= input1[ir]; } - // 3d fft this->fft_bundle.fftxyfor(auxr, auxr); @@ -397,16 +400,13 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, this->fft_bundle.fftzfor(auxg, auxg); // copy the result from the auxr to the out ,while consider the add - if (add) - { - double tmpfac = factor / double(this->nxyz); + double tmpfac = factor / double(this->nxyz); #ifdef _OPENMP #pragma omp parallel for schedule(static, 4096 / sizeof(double)) #endif - for (int igl = 0; igl < npwk; ++igl) - { - output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]]; - } + for (int igl = 0; igl < npwk; ++igl) + { + output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]]; } ModuleBase::timer::tick(this->classname, "convolution"); } diff --git a/source/source_io/cal_ldos.cpp b/source/source_io/cal_ldos.cpp index ec2f00bfc7..ecea39202f 100644 --- a/source/source_io/cal_ldos.cpp +++ b/source/source_io/cal_ldos.cpp @@ -140,7 +140,7 @@ void stm_mode_pw(const elecstate::ElecStatePW>* pelec, for (int ib = 0; ib < nbands; ib++) { - pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik); + pelec->basis->recip_to_real,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik); const double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV; double weight = en > 0 ? pelec->klist->wk[ik] - pelec->wg(ik, ib) : pelec->wg(ik, ib); @@ -210,7 +210,7 @@ void ldos_mode_pw(const elecstate::ElecStatePW>* pelec, for (int ib = 0; ib < nbands; ib++) { - pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik); + pelec->basis->recip_to_real,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik); const double weight = pelec->klist->wk[ik] / ucell.omega; for (int ir = 0; ir < pelec->basis->nrxx; ir++) diff --git a/source/source_io/cal_mlkedf_descriptors.cpp b/source/source_io/cal_mlkedf_descriptors.cpp index 58e3f8f259..24b7517963 100644 --- a/source/source_io/cal_mlkedf_descriptors.cpp +++ b/source/source_io/cal_mlkedf_descriptors.cpp @@ -472,7 +472,7 @@ void Cal_MLKEDF_Descriptors::getF_KS( wfcr[ig] = psi->operator()(ibnd, ig) * std::complex(0.0, fact); } - pw_psi->recip2real(wfcr, wfcr, ik); + pw_psi->recip_to_real,base_device::DEVICE_CPU>(wfcr, wfcr, ik); for (int ir = 0; ir < this->nx; ++ir) { diff --git a/source/source_io/get_wf_lcao.cpp b/source/source_io/get_wf_lcao.cpp index 3d6c58a300..e784191b82 100644 --- a/source/source_io/get_wf_lcao.cpp +++ b/source/source_io/get_wf_lcao.cpp @@ -179,7 +179,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell, // Calculate real-space wave functions psi_g.fix_k(is); std::vector> wfc_r(pw_wfc->nrxx); - pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), is); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), is); // Extract real and imaginary parts std::vector wfc_real(pw_wfc->nrxx); @@ -399,7 +399,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell, // Calculate real-space wave functions std::vector> wfc_r(pw_wfc->nrxx); - pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), ik); + pw_wfc->recip_to_real,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), ik); // Extract real and imaginary parts std::vector wfc_real(pw_wfc->nrxx); @@ -551,7 +551,7 @@ void Get_wf_lcao::set_pw_wfc(const ModulePW::PW_Basis_K* pw_wfc, } // call FFT - pw_wfc->real2recip(Porter.data(), &wfc_g(ib, 0), ik); + pw_wfc->real_to_recip,base_device::DEVICE_CPU>(Porter.data(), &wfc_g(ib, 0), ik); } #ifdef __MPI diff --git a/source/source_io/unk_overlap_pw.cpp b/source/source_io/unk_overlap_pw.cpp index 0c87b1f6fb..1b4af5e7b1 100644 --- a/source/source_io/unk_overlap_pw.cpp +++ b/source/source_io/unk_overlap_pw.cpp @@ -93,7 +93,7 @@ std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, } // (3) calculate the overlap in ik_L and ik_R - wfcpw->real2recip(psi_r, psi_r, ik_R); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psi_r, psi_r, ik_R); for (int ig = 0; ig < evc->get_ngk(ik_R); ig++) { @@ -197,8 +197,8 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho // (2) fft and get value rhopw->recip2real(phase, phase); - wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L); - wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, 0), psi_up, ik_L); + wfcpw->recip_to_real,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L); for (int ir = 0; ir < wfcpw->nrxx; ir++) { @@ -207,8 +207,8 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho } // (3) calculate the overlap in ik_L and ik_R - wfcpw->real2recip(psi_up, psi_up, ik_L); - wfcpw->real2recip(psi_down, psi_down, ik_L); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psi_up, psi_up, ik_L); + wfcpw->real_to_recip,base_device::DEVICE_CPU>(psi_down, psi_down, ik_L); for (int i = 0; i < PARAM.globalv.npol; i++) { From 848a52a3635494edb3624899c0848c6ff78560ec Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Sun, 3 Aug 2025 21:45:10 +0800 Subject: [PATCH 4/4] revert convulution --- source/source_basis/module_pw/pw_basis_k.h | 11 +-- .../source_basis/module_pw/pw_transform_k.cpp | 73 ------------------- .../module_pw/pw_transform_k_dsp.cpp | 4 +- .../module_pwdft/operator_pw/veff_pw.cpp | 21 ++---- 4 files changed, 10 insertions(+), 99 deletions(-) diff --git a/source/source_basis/module_pw/pw_basis_k.h b/source/source_basis/module_pw/pw_basis_k.h index 7b7858669d..b87da9ca0f 100644 --- a/source/source_basis/module_pw/pw_basis_k.h +++ b/source/source_basis/module_pw/pw_basis_k.h @@ -135,18 +135,9 @@ class PW_Basis_K : public PW_Basis const int ik, const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) - template - void convolution(const Device* ctx, - const int ik, - const int size, - const std::complex* input, - const FPTYPE* input1, - std::complex* output, - const bool add = false, - const FPTYPE factor =1.0) const ; #if defined(__DSP) template - void convolution_dsp(const Device* ctx, + void convolution(const Device* ctx, const int ik, const int size, const std::complex* input, diff --git a/source/source_basis/module_pw/pw_transform_k.cpp b/source/source_basis/module_pw/pw_transform_k.cpp index 13ce5ef2e6..36290d091a 100644 --- a/source/source_basis/module_pw/pw_transform_k.cpp +++ b/source/source_basis/module_pw/pw_transform_k.cpp @@ -337,79 +337,6 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, this->recip2real(in, out, ik, add, factor); #endif } -template <> -void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, - const int ik, - const int size, - const std::complex* input, - const float* input1, - std::complex* output, - const bool add, - const float factor) const -{ -} - -template <> -void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, - const int ik, - const int size, - const std::complex* input, - const double* input1, - std::complex* output, - const bool add, - const double factor) const -{ - ModuleBase::timer::tick(this->classname, "convolution"); - assert(this->gamma_only == false); - // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data(), this->nst * this->nz); - // memset the auxr of 0 in the auxr,here the len of the auxr is nxyz - auto* auxg = this->fft_bundle.get_auxg_data(); - auto* auxr=this->fft_bundle.get_auxr_data(); - - memset(auxg, 0, this->nst * this->nz * 2 * 8); - const int startig = ik * this->npwk_max; - const int npwk = this->npwk[ik]; - - // copy the mapping form the type of stick to the 3dfft - #ifdef _OPENMP - #pragma omp parallel for schedule(static, 4096 / sizeof(double)) - #endif - for (int igl = 0; igl < npwk; ++igl) - { - auxg[this->igl2isz_k[igl + startig]] = input[igl]; - } - - // use 3d fft backward - this->fft_bundle.fftzbac(auxg, auxg); - - this->gathers_scatterp(auxg, auxr); - - this->fft_bundle.fftxybac(auxr, auxr); - - #ifdef _OPENMP - #pragma omp parallel for simd schedule(static) aligned(auxr, input1: 64) - #endif - for (int ir = 0; ir < size; ir++) - { - auxr[ir] *= input1[ir]; - } - // 3d fft - this->fft_bundle.fftxyfor(auxr, auxr); - - this->gatherp_scatters(auxr, auxg); - - this->fft_bundle.fftzfor(auxg, auxg); - // copy the result from the auxr to the out ,while consider the add - double tmpfac = factor / double(this->nxyz); -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096 / sizeof(double)) -#endif - for (int igl = 0; igl < npwk; ++igl) - { - output[igl] += tmpfac * auxg[this->igl2isz_k[igl + startig]]; - } - ModuleBase::timer::tick(this->classname, "convolution"); -} #if (defined(__CUDA) || defined(__ROCM)) template <> diff --git a/source/source_basis/module_pw/pw_transform_k_dsp.cpp b/source/source_basis/module_pw/pw_transform_k_dsp.cpp index 6f35d545ca..1449943550 100644 --- a/source/source_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/source_basis/module_pw/pw_transform_k_dsp.cpp @@ -91,7 +91,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex* in, } } template <> -void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx, +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, const int ik, const int size, const std::complex* input, @@ -103,7 +103,7 @@ void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx, } template <> -void PW_Basis_K::convolution_dsp(const base_device::DEVICE_CPU* ctx, +void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, const int ik, const int size, const std::complex* input, diff --git a/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp b/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp index 0251857d0f..e2813b3a9a 100644 --- a/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp +++ b/source/source_pw/module_pwdft/operator_pw/veff_pw.cpp @@ -61,7 +61,7 @@ void Veff>::act( ModulePW::FFT_Guard guard(wfcpw->fft_bundle); for (int ib = 0; ib < nbands; ib += npol) { - wfcpw->convolution_dsp(this->ctx, + wfcpw->convolution(this->ctx, this->ik, this->veff_col, tmpsi_in, @@ -96,19 +96,12 @@ void Veff>::act( { for (int ib = 0; ib < nbands; ib += npol) { - wfcpw->convolution(this->ctx, - this->ik, - this->veff_col, - tmpsi_in, - this->veff + current_spin * this->veff_col, - tmhpsi, - true); - // wfcpw->recip_to_real(tmpsi_in, this->porter, this->ik); - // // NOTICE: when MPI threads are larger than the number of Z grids - // // veff would contain nothing, and nothing should be done in real space - // // but the 3DFFT can not be skipped, it will cause hanging - // veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col); - // wfcpw->real_to_recip(this->porter, tmhpsi, this->ik, true); + wfcpw->recip_to_real(tmpsi_in, this->porter, this->ik); + // NOTICE: when MPI threads are larger than the number of Z grids + // veff would contain nothing, and nothing should be done in real space + // but the 3DFFT can not be skipped, it will cause hanging + veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col); + wfcpw->real_to_recip(this->porter, tmhpsi, this->ik, true); tmhpsi += psi_offset; tmpsi_in += psi_offset; }