From bd56acb3a5b12ff2403e1888d5aaa50dd1e0041a Mon Sep 17 00:00:00 2001 From: kunkunblueberry <1833921874@qq.com> Date: Wed, 31 Dec 2025 12:42:47 +0800 Subject: [PATCH] Fix: high performance optimization for force calculation --- .../module_hamilt_pw/hamilt_pwdft/forces.cpp | 61 +++++++++++++++++-- .../hamilt_pwdft/kernels/force_op.h | 31 +++++++++- .../hamilt_pwdft/kernels/rocm/force_op.hip.cu | 6 +- .../kernels/rocm/dngvd_op.hip.cu | 12 ++-- 4 files changed, 96 insertions(+), 14 deletions(-) diff --git a/source/module_hamilt_pw/hamilt_pwdft/forces.cpp b/source/module_hamilt_pw/hamilt_pwdft/forces.cpp index 3c7a2ae4bb..e554d8b6e8 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/forces.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/forces.cpp @@ -16,7 +16,7 @@ #include "module_hamilt_general/module_surchem/surchem.h" #include "module_hamilt_general/module_vdw/vdw.h" #include "kernels/force_op.h" - +#include #ifdef _OPENMP #include #endif @@ -579,7 +579,7 @@ void Forces::cal_force_loc(const UnitCell& ucell, syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, forcelc_d, forcelc.c, this->nat * 3); syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, vloc_d, vloc.c, vloc.nr * vloc.nc); - hamilt::cal_force_loc_op()( + /* hamilt::cal_force_loc_op()( this->nat, rho_basis->npw, ucell.tpiba * ucell.omega, @@ -590,7 +590,34 @@ void Forces::cal_force_loc(const UnitCell& ucell, aux_d, vloc_d, vloc.nc, - forcelc_d); + forcelc_d);*/ + if constexpr (std::is_same::value) { + hamilt::cal_force_loc_sincos_op()( + this->ctx, + this->nat, + rho_basis->npw, + ucell.ntype, + gcar_d, + tau_d, + vloc_d, + aux_d, + static_cast(ucell.tpiba * ucell.omega), + forcelc_d); + } else { + hamilt::cal_force_loc_op()( + this->nat, + rho_basis->npw, + ucell.tpiba * ucell.omega, + iat2it_d, + ig2gg_d, + gcar_d, + tau_d, + aux_d, + vloc_d, + vloc.nc, + forcelc_d); + } + syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, forcelc.c, forcelc_d, this->nat * 3); delmem_int_op()(this->ctx,iat2it_d); @@ -788,7 +815,7 @@ void Forces::cal_force_ew(const UnitCell& ucell, syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, aux_d, aux.data(), rho_basis->npw); syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, forceion_d, forceion.c, this->nat * 3); - hamilt::cal_force_ew_op()( + /* hamilt::cal_force_ew_op()( this->nat, rho_basis->npw, rho_basis->ig_gge0, @@ -798,7 +825,31 @@ void Forces::cal_force_ew(const UnitCell& ucell, it_fact_d, aux_d, forceion_d); - + */ + if constexpr (std::is_same::value) { + hamilt::cal_force_ew_sincos_op()( + this->ctx, + this->nat, + rho_basis->npw, + rho_basis->ig_gge0, + gcar_d, + tau_d, + it_fact_d, + aux_d, + forceion_d); + } else { + hamilt::cal_force_ew_op()( + this->nat, + rho_basis->npw, + rho_basis->ig_gge0, + iat2it_d, + gcar_d, + tau_d, + it_fact_d, + aux_d, + forceion_d); + } + syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, forceion.c, forceion_d, this->nat * 3); delmem_int_op()(this->ctx,iat2it_d); delmem_var_op()(this->ctx,gcar_d); diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/force_op.h b/source/module_hamilt_pw/hamilt_pwdft/kernels/force_op.h index ca2c6694cf..6ce0e9b7af 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/force_op.h +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/force_op.h @@ -179,6 +179,9 @@ struct cal_force_ew_op{ FPTYPE* forceion ) {}; }; + +template struct cal_force_loc_sincos_op; +template struct cal_force_ew_sincos_op; #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM template struct cal_vkb1_nl_op @@ -335,6 +338,32 @@ struct cal_force_ew_op{ FPTYPE* forceion ); }; +template +struct cal_force_loc_sincos_op { + void operator()(const base_device::DEVICE_GPU* ctx, + const int& nat, + const int& npw, + const int& ntype, + const FPTYPE* gcar, + const FPTYPE* tau, + const FPTYPE* vloc_per_type, + const std::complex* aux, + const FPTYPE& scale_factor, + FPTYPE* force); +}; + +template +struct cal_force_ew_sincos_op { + void operator()(const base_device::DEVICE_GPU* ctx, + const int& nat, + const int& npw, + const int& ig_gge0, + const FPTYPE* gcar, + const FPTYPE* tau, + const FPTYPE* it_facts, + const std::complex* aux, + FPTYPE* force); +}; #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM } // namespace hamilt -#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H \ No newline at end of file +#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/force_op.hip.cu b/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/force_op.hip.cu index 6bb3a84e7e..ab5205b344 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/force_op.hip.cu +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/force_op.hip.cu @@ -10,6 +10,8 @@ namespace hamilt { +__device__ __forceinline__ void sincos_(float x, float* s, float* c) { sincosf(x, s, c); } +__device__ __forceinline__ void sincos_(double x, double* s, double* c) { sincos(x, s, c); } template __global__ void cal_vkb1_nl( const int npwx, @@ -658,7 +660,7 @@ __global__ void cal_force_loc_sincos_kernel( // Use HIP intrinsic for sincos FPTYPE sinp, cosp; - sincos(phase, &sinp, &cosp); + sincos_(phase, &sinp, &cosp); // Calculate force factor const FPTYPE vloc_factor = vloc_per_type[iat * npw + ig]; @@ -718,7 +720,7 @@ __global__ void cal_force_ew_sincos_kernel( // Use HIP intrinsic for sincos FPTYPE sinp, cosp; - sincos(phase, &sinp, &cosp); + sincos_(phase, &sinp, &cosp); // Calculate Ewald sum contribution (fixed sign error) const FPTYPE factor = it_fact * (-cosp * aux[ig].imag() + sinp * aux[ig].real()); diff --git a/source/module_hsolver/kernels/rocm/dngvd_op.hip.cu b/source/module_hsolver/kernels/rocm/dngvd_op.hip.cu index 8cba06db1b..e2ba4d1d40 100644 --- a/source/module_hsolver/kernels/rocm/dngvd_op.hip.cu +++ b/source/module_hsolver/kernels/rocm/dngvd_op.hip.cu @@ -128,8 +128,8 @@ void dngvd_op, base_device::DEVICE_GPU>::operator()(const ba hipsolverErrcheck(hipsolverDnChegvd_bufferSize( hipsolver_H, HIPSOLVER_EIG_TYPE_1, HIPSOLVER_EIG_MODE_VECTOR, uplo, nstart, - reinterpret_cast(_vcc), ldh, - reinterpret_cast(_scc), ldh, + const_cast(reinterpret_cast(_vcc)), ldh, + const_cast(reinterpret_cast(_scc)), ldh, _eigenvalue, &lwork)); @@ -140,7 +140,7 @@ void dngvd_op, base_device::DEVICE_GPU>::operator()(const ba hipsolverErrcheck(hipsolverDnChegvd( hipsolver_H, HIPSOLVER_EIG_TYPE_1, HIPSOLVER_EIG_MODE_VECTOR, uplo, nstart, - reinterpret_cast(_vcc), ldh, + const_cast(reinterpret_cast(_vcc)), ldh, const_cast(reinterpret_cast(_scc)), ldh, _eigenvalue, work, lwork, devInfo)); @@ -206,8 +206,8 @@ void dngvd_op, base_device::DEVICE_GPU>::operator()(const b hipsolverErrcheck(hipsolverDnZhegvd_bufferSize( hipsolver_H, HIPSOLVER_EIG_TYPE_1, HIPSOLVER_EIG_MODE_VECTOR, uplo, nstart, - reinterpret_cast(_vcc), ldh, - reinterpret_cast(_scc), ldh, + const_cast(reinterpret_cast(_vcc)), ldh, + const_cast(reinterpret_cast(_scc)), ldh, _eigenvalue, &lwork)); @@ -365,4 +365,4 @@ void dngvx_op::operator()(const base_device::DE } #endif // __LCAO -} // namespace hsolver \ No newline at end of file +} // namespace hsolver