From 7d7a6f473105af087dfb87f0ed078c57af92bf59 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Mon, 9 Mar 2026 12:31:29 +0800 Subject: [PATCH 1/7] refactor(esolver): extract update_cell_pw function from esolver_ks_pw - Create new files update_cell_pw.h and update_cell_pw.cpp in source_pw/module_pwdft - Extract cell parameter update logic from ESolver_KS_PW::before_scf() - The new function handles: 1. Rescaling non-local pseudopotential (ppcell.rescale_vnl) 2. Reinitializing plane wave basis grids (pw_wfc->initgrids/initparameters/collect_local_pw) - Keep psi initialization (p_psi_init->prepare_init) in esolver to avoid template dependency - Update CMakeLists.txt and Makefile.Objects for new source files This refactoring improves code organization by moving PW-specific cell update logic out of the esolver, making the esolver code cleaner and more focused on high-level workflow control. --- source/Makefile.Objects | 1 + source/source_esolver/esolver_ks_pw.cpp | 12 ++------ source/source_pw/module_pwdft/CMakeLists.txt | 1 + .../source_pw/module_pwdft/update_cell_pw.cpp | 29 +++++++++++++++++++ .../source_pw/module_pwdft/update_cell_pw.h | 20 +++++++++++++ 5 files changed, 54 insertions(+), 9 deletions(-) create mode 100644 source/source_pw/module_pwdft/update_cell_pw.cpp create mode 100644 source/source_pw/module_pwdft/update_cell_pw.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 0ba91a0378..7343f620c5 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -705,6 +705,7 @@ OBJS_SRCPW=H_Ewald_pw.o\ setup_pot.o\ setup_pwrho.o\ setup_pwwfc.o\ + update_cell_pw.o\ forces.o\ forces_us.o\ forces_nl.o\ diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b3dc5744a0..212be9b176 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -28,6 +28,7 @@ #include "source_io/module_ctrl/ctrl_output_pw.h" // mohan add 20250927 #include "source_estate/module_charge/chgmixing.h" // use charge mixing, mohan add 20251006 #include "source_estate/update_pot.h" // mohan add 20251016 +#include "source_pw/module_pwdft/update_cell_pw.h" // mohan add 20250309 #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info @@ -119,17 +120,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) ESolver_KS::before_scf(ucell, istep); //! Init variables (once the cell has changed) + pw::update_cell_pw(ucell, this->ppcell, this->kv, this->pw_wfc, PARAM.inp); + if (ucell.cell_parameter_updated) { - this->ppcell.rescale_vnl(ucell.omega); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); - - this->pw_wfc->initgrids(ucell.lat0, ucell.latvec, this->pw_wfc->nx, this->pw_wfc->ny, this->pw_wfc->nz); - - this->pw_wfc->initparameters(false, PARAM.inp.ecutwfc, this->kv.get_nks(), this->kv.kvec_d.data()); - - this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma); - this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed); } diff --git a/source/source_pw/module_pwdft/CMakeLists.txt b/source/source_pw/module_pwdft/CMakeLists.txt index 3c525240ea..06bc2614f7 100644 --- a/source/source_pw/module_pwdft/CMakeLists.txt +++ b/source/source_pw/module_pwdft/CMakeLists.txt @@ -15,6 +15,7 @@ list(APPEND objects setup_pot.cpp setup_pwrho.cpp setup_pwwfc.cpp + update_cell_pw.cpp forces_nl.cpp forces_cc.cpp forces_scc.cpp diff --git a/source/source_pw/module_pwdft/update_cell_pw.cpp b/source/source_pw/module_pwdft/update_cell_pw.cpp new file mode 100644 index 0000000000..a027884cc2 --- /dev/null +++ b/source/source_pw/module_pwdft/update_cell_pw.cpp @@ -0,0 +1,29 @@ +#include "source_pw/module_pwdft/update_cell_pw.h" +#include "source_base/global_variable.h" +#include "source_base/global_function.h" + +namespace pw +{ + +void update_cell_pw(const UnitCell& ucell, + pseudopot_cell_vnl& ppcell, + const K_Vectors& kv, + ModulePW::PW_Basis_K* pw_wfc, + const Input_para& inp) +{ + ModuleBase::TITLE("pw", "update_cell_pw"); + + if (!ucell.cell_parameter_updated) + { + return; + } + + ppcell.rescale_vnl(ucell.omega); + ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); + + pw_wfc->initgrids(ucell.lat0, ucell.latvec, pw_wfc->nx, pw_wfc->ny, pw_wfc->nz); + pw_wfc->initparameters(false, inp.ecutwfc, kv.get_nks(), kv.kvec_d.data()); + pw_wfc->collect_local_pw(inp.erf_ecut, inp.erf_height, inp.erf_sigma); +} + +} diff --git a/source/source_pw/module_pwdft/update_cell_pw.h b/source/source_pw/module_pwdft/update_cell_pw.h new file mode 100644 index 0000000000..704d5b6f80 --- /dev/null +++ b/source/source_pw/module_pwdft/update_cell_pw.h @@ -0,0 +1,20 @@ +#ifndef UPDATE_CELL_PW_H +#define UPDATE_CELL_PW_H + +#include "source_io/module_parameter/parameter.h" +#include "source_cell/unitcell.h" +#include "source_cell/klist.h" +#include "source_basis/module_pw/pw_basis_k.h" +#include "source_pw/module_pwdft/vnl_pw.h" + +namespace pw +{ + +void update_cell_pw(const UnitCell& ucell, + pseudopot_cell_vnl& ppcell, + const K_Vectors& kv, + ModulePW::PW_Basis_K* pw_wfc, + const Input_para& inp); + +} +#endif From 95b3cc0174c62038ddec6b83b395047f60e11445 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Mon, 9 Mar 2026 12:52:28 +0800 Subject: [PATCH 2/7] refactor(esolver): extract EXX initialization into Exx_Helper::init - Add init() function to Exx_Helper class for EXX initialization - The init function handles: 1. Check if calculation type is scf/relax/cell-relax/md 2. Check if cal_exx is enabled 3. Set XC first loop if separate_loop is true 4. Set wg pointer for EXX calculation - Simplify ESolver_KS_PW::before_all_runners() by calling exx_helper.init() - Move EXX-specific logic out of esolver, improving code organization This refactoring makes the esolver code cleaner and more focused on high-level workflow control. --- source/source_esolver/esolver_ks_pw.cpp | 15 +----------- source/source_pw/module_pwdft/exx_helper.cpp | 24 ++++++++++++++++++++ source/source_pw/module_pwdft/exx_helper.h | 3 +++ 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 212be9b176..6ee20a8377 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -94,20 +94,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); //! Initialize exx pw - if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" - || inp.calculation == "md") - { - if (GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop == true) - { - XC_Functional::set_xc_first_loop(ucell); - exx_helper.set_firstiter(); - } - - if (GlobalC::exx_info.info_global.cal_exx) - { - exx_helper.set_wg(&this->pelec->wg); - } - } + this->exx_helper.init(ucell, inp, this->pelec->wg); } template diff --git a/source/source_pw/module_pwdft/exx_helper.cpp b/source/source_pw/module_pwdft/exx_helper.cpp index 89c32d1584..eef8a0574c 100644 --- a/source/source_pw/module_pwdft/exx_helper.cpp +++ b/source/source_pw/module_pwdft/exx_helper.cpp @@ -1,6 +1,30 @@ #include "exx_helper.h" #include "source_io/module_parameter/parameter.h" // use PARAM #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info +#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional + +template +void Exx_Helper::init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) +{ + if (inp.calculation != "scf" && inp.calculation != "relax" + && inp.calculation != "cell-relax" && inp.calculation != "md") + { + return; + } + + if (!GlobalC::exx_info.info_global.cal_exx) + { + return; + } + + if (GlobalC::exx_info.info_global.separate_loop) + { + XC_Functional::set_xc_first_loop(ucell); + this->set_firstiter(); + } + + this->set_wg(&wg); +} template double Exx_Helper::cal_exx_energy(psi::Psi *psi_) diff --git a/source/source_pw/module_pwdft/exx_helper.h b/source/source_pw/module_pwdft/exx_helper.h index 283b035760..ba5842256f 100644 --- a/source/source_pw/module_pwdft/exx_helper.h +++ b/source/source_pw/module_pwdft/exx_helper.h @@ -1,6 +1,7 @@ #include "source_psi/psi.h" #include "source_base/matrix.h" #include "source_pw/module_pwdft/op_pw_exx.h" +#include "source_io/module_parameter/input_parameter.h" #ifndef EXX_HELPER_H #define EXX_HELPER_H @@ -14,6 +15,8 @@ struct Exx_Helper Exx_Helper() = default; OperatorEXX *op_exx = nullptr; + void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg); + void set_firstiter(bool flag = true) { first_iter = flag; } void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; } void set_psi(psi::Psi *psi_); From b5f8ed3508031dfb8727034b9202bfabc26d1367 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Mon, 9 Mar 2026 13:36:21 +0800 Subject: [PATCH 3/7] refactor(esolver): extract DFT+U initialization into pw::iter_init_dftu_pw - Create new files dftu_pw.h and dftu_pw.cpp in source_pw/module_pwdft - Extract DFT+U occupation update logic from ESolver_KS_PW::iter_init() - The new function handles: 1. Check if DFT+U is enabled 2. Check iteration and step conditions 3. Call cal_occ_pw for occupation calculation 4. Output DFT+U results - Use void* for psi parameter to avoid template dependency - Update CMakeLists.txt and Makefile.Objects for new source files This refactoring improves code organization by moving DFT+U specific logic out of the esolver, making the esolver code cleaner and more focused on high-level workflow control. --- source/Makefile.Objects | 1 + source/source_esolver/esolver_ks_pw.cpp | 12 ++------ source/source_pw/module_pwdft/CMakeLists.txt | 1 + source/source_pw/module_pwdft/dftu_pw.cpp | 32 ++++++++++++++++++++ source/source_pw/module_pwdft/dftu_pw.h | 23 ++++++++++++++ 5 files changed, 59 insertions(+), 10 deletions(-) create mode 100644 source/source_pw/module_pwdft/dftu_pw.cpp create mode 100644 source/source_pw/module_pwdft/dftu_pw.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 7343f620c5..572a406fe1 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -706,6 +706,7 @@ OBJS_SRCPW=H_Ewald_pw.o\ setup_pwrho.o\ setup_pwwfc.o\ update_cell_pw.o\ + dftu_pw.o\ forces.o\ forces_us.o\ forces_nl.o\ diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 6ee20a8377..f8e1c66df0 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -29,6 +29,7 @@ #include "source_estate/module_charge/chgmixing.h" // use charge mixing, mohan add 20251006 #include "source_estate/update_pot.h" // mohan add 20251016 #include "source_pw/module_pwdft/update_cell_pw.h" // mohan add 20250309 +#include "source_pw/module_pwdft/dftu_pw.h" // mohan add 20250309 #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info @@ -162,16 +163,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // 4) update local occupations for DFT+U // should before lambda loop in DeltaSpin - if (PARAM.inp.dft_plus_u && (iter != 1 || istep != 0)) - { - // only old DFT+U method should calculate energy correction in esolver, - // new DFT+U method will calculate energy when evaluating the Hamiltonian - if (this->dftu.omc != 2) - { - this->dftu.cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta); - } - this->dftu.output(ucell); - } + pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp); } // Temporary, it should be replaced by hsolver later. diff --git a/source/source_pw/module_pwdft/CMakeLists.txt b/source/source_pw/module_pwdft/CMakeLists.txt index 06bc2614f7..6a733d9a71 100644 --- a/source/source_pw/module_pwdft/CMakeLists.txt +++ b/source/source_pw/module_pwdft/CMakeLists.txt @@ -16,6 +16,7 @@ list(APPEND objects setup_pwrho.cpp setup_pwwfc.cpp update_cell_pw.cpp + dftu_pw.cpp forces_nl.cpp forces_cc.cpp forces_scc.cpp diff --git a/source/source_pw/module_pwdft/dftu_pw.cpp b/source/source_pw/module_pwdft/dftu_pw.cpp new file mode 100644 index 0000000000..475a34620a --- /dev/null +++ b/source/source_pw/module_pwdft/dftu_pw.cpp @@ -0,0 +1,32 @@ +#include "source_pw/module_pwdft/dftu_pw.h" +#include "source_lcao/module_dftu/dftu.h" + +namespace pw +{ + +void iter_init_dftu_pw(const int iter, + const int istep, + Plus_U& dftu, + const void* psi, + const ModuleBase::matrix& wg, + const UnitCell& ucell, + const Input_para& inp) +{ + if (!inp.dft_plus_u) + { + return; + } + + if (iter == 1 && istep == 0) + { + return; + } + + if (dftu.omc != 2) + { + dftu.cal_occ_pw(iter, psi, wg, ucell, inp.mixing_beta); + } + dftu.output(ucell); +} + +} diff --git a/source/source_pw/module_pwdft/dftu_pw.h b/source/source_pw/module_pwdft/dftu_pw.h new file mode 100644 index 0000000000..8a30b04e76 --- /dev/null +++ b/source/source_pw/module_pwdft/dftu_pw.h @@ -0,0 +1,23 @@ +#ifndef DFTU_PW_H +#define DFTU_PW_H + +#include "source_io/module_parameter/parameter.h" +#include "source_cell/unitcell.h" +#include "source_base/matrix.h" + +class Plus_U; + +namespace pw +{ + +void iter_init_dftu_pw(const int iter, + const int istep, + Plus_U& dftu, + const void* psi, + const ModuleBase::matrix& wg, + const UnitCell& ucell, + const Input_para& inp); + +} + +#endif From 0eb26d6c207db54f516d7dbceb9ac525a8d297de Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Mon, 9 Mar 2026 14:00:16 +0800 Subject: [PATCH 4/7] refactor(esolver): extract DeltaSpin lambda loop into pw::run_deltaspin_lambda_loop - Create new files deltaspin_pw.h and deltaspin_pw.cpp in source_pw/module_pwdft - Extract DeltaSpin lambda loop logic from ESolver_KS_PW::hamilt2rho_single() - The new function handles: 1. Check if DeltaSpin (sc_mag_switch) is enabled 2. Get SpinConstrain singleton instance 3. Run lambda loop to constrain atomic magnetic moments 4. Return skip_solve flag to control solver execution - Add Doxygen-style comments in English - Update CMakeLists.txt and Makefile.Objects for new source files This refactoring improves code organization by moving DeltaSpin-specific logic out of the esolver, making the esolver code cleaner and more focused on high-level workflow control. --- source/Makefile.Objects | 1 + source/source_esolver/esolver_ks_pw.cpp | 22 +--------- source/source_pw/module_pwdft/CMakeLists.txt | 1 + .../source_pw/module_pwdft/deltaspin_pw.cpp | 42 +++++++++++++++++++ source/source_pw/module_pwdft/deltaspin_pw.h | 27 ++++++++++++ 5 files changed, 73 insertions(+), 20 deletions(-) create mode 100644 source/source_pw/module_pwdft/deltaspin_pw.cpp create mode 100644 source/source_pw/module_pwdft/deltaspin_pw.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 572a406fe1..bbfe0c1235 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -707,6 +707,7 @@ OBJS_SRCPW=H_Ewald_pw.o\ setup_pwwfc.o\ update_cell_pw.o\ dftu_pw.o\ + deltaspin_pw.o\ forces.o\ forces_us.o\ forces_nl.o\ diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index f8e1c66df0..6cf46ad552 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -30,6 +30,7 @@ #include "source_estate/update_pot.h" // mohan add 20251016 #include "source_pw/module_pwdft/update_cell_pw.h" // mohan add 20250309 #include "source_pw/module_pwdft/dftu_pw.h" // mohan add 20250309 +#include "source_pw/module_pwdft/deltaspin_pw.h" // mohan add 20250309 #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info @@ -191,26 +192,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; // run the inner lambda loop to contrain atomic moments with the DeltaSpin method - bool skip_solve = false; - - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain>& sc - = spinconstrain::SpinConstrain>::getScInstance(); - if (!sc.mag_converged() && this->drho > 0 && this->drho < PARAM.inp.sc_scf_thr) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - sc.set_mag_converged(true); - skip_solve = true; - } - else if (sc.mag_converged()) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - skip_solve = true; - } - } + bool skip_solve = pw::run_deltaspin_lambda_loop(iter - 1, this->drho, PARAM.inp); if (!skip_solve) { diff --git a/source/source_pw/module_pwdft/CMakeLists.txt b/source/source_pw/module_pwdft/CMakeLists.txt index 6a733d9a71..9e34e9c7b4 100644 --- a/source/source_pw/module_pwdft/CMakeLists.txt +++ b/source/source_pw/module_pwdft/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND objects setup_pwwfc.cpp update_cell_pw.cpp dftu_pw.cpp + deltaspin_pw.cpp forces_nl.cpp forces_cc.cpp forces_scc.cpp diff --git a/source/source_pw/module_pwdft/deltaspin_pw.cpp b/source/source_pw/module_pwdft/deltaspin_pw.cpp new file mode 100644 index 0000000000..ef509a849a --- /dev/null +++ b/source/source_pw/module_pwdft/deltaspin_pw.cpp @@ -0,0 +1,42 @@ +#include "source_pw/module_pwdft/deltaspin_pw.h" +#include "source_lcao/module_deltaspin/spin_constrain.h" + +namespace pw +{ + +bool run_deltaspin_lambda_loop(const int iter, + const double drho, + const Input_para& inp) +{ + /// Return false if DeltaSpin is not enabled + if (!inp.sc_mag_switch) + { + return false; + } + + /// Get the singleton instance of SpinConstrain + spinconstrain::SpinConstrain>& sc + = spinconstrain::SpinConstrain>::getScInstance(); + + /// Case 1: Magnetic moments not yet converged and SCF is close to convergence. + /// This is the first time we enter the lambda loop after SCF is nearly converged. + if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr) + { + /// Optimize lambda to get target magnetic moments + sc.run_lambda_loop(iter); + sc.set_mag_converged(true); + return true; + } + /// Case 2: Magnetic moments already converged in previous iteration. + /// Continue to refine lambda in subsequent SCF iterations. + else if (sc.mag_converged()) + { + sc.run_lambda_loop(iter); + return true; + } + + /// Default: run the normal solver + return false; +} + +} diff --git a/source/source_pw/module_pwdft/deltaspin_pw.h b/source/source_pw/module_pwdft/deltaspin_pw.h new file mode 100644 index 0000000000..3482a8d29f --- /dev/null +++ b/source/source_pw/module_pwdft/deltaspin_pw.h @@ -0,0 +1,27 @@ +#ifndef DELTASPIN_PW_H +#define DELTASPIN_PW_H + +#include "source_io/module_parameter/parameter.h" + +namespace pw +{ + +/** + * @brief Run the inner lambda loop for DeltaSpin method to constrain atomic magnetic moments. + * + * This function is used in the PW basis SCF iteration to optimize lambda parameters + * for constraining atomic magnetic moments to target values using the DeltaSpin method. + * + * @param iter The current iteration number (0-indexed). + * @param drho The current charge density difference. + * @param inp The input parameters. + * @return true if the solver should be skipped (lambda loop was executed), + * false otherwise. + */ +bool run_deltaspin_lambda_loop(const int iter, + const double drho, + const Input_para& inp); + +} + +#endif From 2bf6f5de1ac590e5d72e55491ef14dc4d8620c0d Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Mon, 9 Mar 2026 14:07:32 +0800 Subject: [PATCH 5/7] refactor(esolver): extract DeltaSpin oscillation check into pw::check_deltaspin_oscillation - Add check_deltaspin_oscillation() function to deltaspin_pw.h/cpp - Extract DeltaSpin SCF oscillation check logic from ESolver_KS_PW::iter_finish() - The new function handles: 1. Check if DeltaSpin (sc_mag_switch) is enabled 2. Get SpinConstrain singleton instance 3. Detect SCF oscillation using if_scf_oscillate() 4. Set mixing_restart_step if oscillation detected - Add Doxygen-style comments in English This refactoring consolidates all DeltaSpin-related functions in one place, making the code more modular and easier to maintain. --- source/source_esolver/esolver_ks_pw.cpp | 15 +--------- .../source_pw/module_pwdft/deltaspin_pw.cpp | 30 +++++++++++++++++++ source/source_pw/module_pwdft/deltaspin_pw.h | 19 ++++++++++++ 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 6cf46ad552..7a7fdb639b 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -293,20 +293,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // check if oscillate for delta_spin method - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain>& sc - = spinconstrain::SpinConstrain>::getScInstance(); - if (!sc.higher_mag_prec) - { - sc.higher_mag_prec = this->p_chgmix->if_scf_oscillate(iter, - this->drho, PARAM.inp.sc_os_ndim, PARAM.inp.scf_os_thr); - if (sc.higher_mag_prec) - { // if oscillate, increase the precision of magnetization and do mixing_restart in next iteration - this->p_chgmix->mixing_restart_step = iter + 1; - } - } - } + pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp); // the output quantities ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu, diff --git a/source/source_pw/module_pwdft/deltaspin_pw.cpp b/source/source_pw/module_pwdft/deltaspin_pw.cpp index ef509a849a..caf8ea7852 100644 --- a/source/source_pw/module_pwdft/deltaspin_pw.cpp +++ b/source/source_pw/module_pwdft/deltaspin_pw.cpp @@ -1,5 +1,6 @@ #include "source_pw/module_pwdft/deltaspin_pw.h" #include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_estate/module_charge/charge_mixing.h" namespace pw { @@ -39,4 +40,33 @@ bool run_deltaspin_lambda_loop(const int iter, return false; } +void check_deltaspin_oscillation(const int iter, + const double drho, + Charge_Mixing* p_chgmix, + const Input_para& inp) +{ + /// Return if DeltaSpin is not enabled + if (!inp.sc_mag_switch) + { + return; + } + + /// Get the singleton instance of SpinConstrain + spinconstrain::SpinConstrain>& sc + = spinconstrain::SpinConstrain>::getScInstance(); + + /// Check if higher magnetization precision is needed + if (!sc.higher_mag_prec) + { + /// Detect SCF oscillation + sc.higher_mag_prec = p_chgmix->if_scf_oscillate(iter, drho, inp.sc_os_ndim, inp.scf_os_thr); + + /// If oscillation detected, set mixing restart step for next iteration + if (sc.higher_mag_prec) + { + p_chgmix->mixing_restart_step = iter + 1; + } + } +} + } diff --git a/source/source_pw/module_pwdft/deltaspin_pw.h b/source/source_pw/module_pwdft/deltaspin_pw.h index 3482a8d29f..0509b61bb8 100644 --- a/source/source_pw/module_pwdft/deltaspin_pw.h +++ b/source/source_pw/module_pwdft/deltaspin_pw.h @@ -3,6 +3,8 @@ #include "source_io/module_parameter/parameter.h" +class Charge_Mixing; + namespace pw { @@ -22,6 +24,23 @@ bool run_deltaspin_lambda_loop(const int iter, const double drho, const Input_para& inp); +/** + * @brief Check if SCF oscillation occurs for DeltaSpin method. + * + * This function checks if the SCF iteration is oscillating and sets the + * mixing restart step if oscillation is detected. This is used to increase + * the precision of magnetization calculation. + * + * @param iter The current iteration number (1-indexed). + * @param drho The current charge density difference. + * @param p_chgmix Pointer to the Charge_Mixing object. + * @param inp The input parameters. + */ +void check_deltaspin_oscillation(const int iter, + const double drho, + Charge_Mixing* p_chgmix, + const Input_para& inp); + } #endif From fc3d52fc1f28d6d6bdcfb0c9e9d2a916f5be9935 Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Mon, 9 Mar 2026 14:27:41 +0800 Subject: [PATCH 6/7] refactor(esolver): extract EXX before_scf setup into Exx_Helper::before_scf - Add before_scf() function to Exx_Helper class - Extract EXX setup logic from ESolver_KS_PW::before_scf() - The new function handles: 1. Check if calculation type is valid (scf/relax/cell-relax/md) 2. Check if EXX is enabled and basis type is PW 3. Set EXX helper to Hamiltonian 4. Set psi for EXX calculation - Use void* for p_hamilt parameter to avoid circular dependency - Add Doxygen-style comments in English This refactoring consolidates EXX-related setup logic in the Exx_Helper class, making the code more modular and easier to maintain. --- source/source_esolver/esolver_ks_pw.cpp | 13 ++-------- source/source_pw/module_pwdft/exx_helper.cpp | 25 ++++++++++++++++++++ source/source_pw/module_pwdft/exx_helper.h | 13 ++++++++++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 7a7fdb639b..c0ca99a1ab 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -134,17 +134,8 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // setup psi (electronic wave functions) this->stp.init(this->p_hamilt); - //! Exx calculations - if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" - || PARAM.inp.calculation == "cell-relax" || PARAM.inp.calculation == "md") - { - if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.basis_type == "pw") - { - auto hamilt_pw = reinterpret_cast*>(this->p_hamilt); - hamilt_pw->set_exx_helper(exx_helper); - exx_helper.set_psi(this->stp.psi_t); - } - } + //! Setup EXX helper for Hamiltonian and psi + exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); } diff --git a/source/source_pw/module_pwdft/exx_helper.cpp b/source/source_pw/module_pwdft/exx_helper.cpp index eef8a0574c..682eb7c2b7 100644 --- a/source/source_pw/module_pwdft/exx_helper.cpp +++ b/source/source_pw/module_pwdft/exx_helper.cpp @@ -2,6 +2,7 @@ #include "source_io/module_parameter/parameter.h" // use PARAM #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info #include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional +#include "source_pw/module_pwdft/hamilt_pw.h" // use HamiltPW template void Exx_Helper::init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) @@ -26,6 +27,30 @@ void Exx_Helper::init(const UnitCell& ucell, const Input_para& inp, c this->set_wg(&wg); } +template +void Exx_Helper::before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp) +{ + /// Return if not a valid calculation type + if (inp.calculation != "scf" && inp.calculation != "relax" + && inp.calculation != "cell-relax" && inp.calculation != "md") + { + return; + } + + /// Return if EXX is not enabled or not PW basis + if (!GlobalC::exx_info.info_global.cal_exx || inp.basis_type != "pw") + { + return; + } + + /// Set EXX helper to Hamiltonian + auto hamilt_pw = reinterpret_cast*>(p_hamilt); + hamilt_pw->set_exx_helper(*this); + + /// Set psi for EXX calculation + this->set_psi(psi); +} + template double Exx_Helper::cal_exx_energy(psi::Psi *psi_) { diff --git a/source/source_pw/module_pwdft/exx_helper.h b/source/source_pw/module_pwdft/exx_helper.h index ba5842256f..bb370ac580 100644 --- a/source/source_pw/module_pwdft/exx_helper.h +++ b/source/source_pw/module_pwdft/exx_helper.h @@ -17,6 +17,19 @@ struct Exx_Helper void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg); + /** + * @brief Setup EXX helper before SCF iteration. + * + * This function sets up the EXX helper for the Hamiltonian and psi + * before each SCF iteration. It checks if the calculation type and + * EXX settings are appropriate. + * + * @param p_hamilt Pointer to the Hamiltonian object (void* to avoid circular dependency). + * @param psi Pointer to the wave function object. + * @param inp The input parameters. + */ + void before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp); + void set_firstiter(bool flag = true) { first_iter = flag; } void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; } void set_psi(psi::Psi *psi_); From 26f2a31860823e3f953f44220848a9e2a061cddd Mon Sep 17 00:00:00 2001 From: abacus_fixer Date: Mon, 9 Mar 2026 15:19:12 +0800 Subject: [PATCH 7/7] refactor(esolver): extract EXX iter_finish logic into Exx_Helper::iter_finish - Add iter_finish() function to Exx_Helper class - Extract EXX convergence handling logic from ESolver_KS_PW::iter_finish() - The new function handles: 1. Check if EXX is enabled 2. Handle separate_loop mode for EXX convergence 3. Calculate EXX energy difference for energy threshold 4. Update potential if SCF not converged 5. Increment EXX iteration counter - Use Charge* and void* parameters to avoid circular dependency - Add Doxygen-style comments in English This refactoring consolidates all EXX-related functions in the Exx_Helper class, making the code more modular and easier to maintain. --- source/source_esolver/esolver_ks_pw.cpp | 45 +------------ source/source_pw/module_pwdft/exx_helper.cpp | 66 ++++++++++++++++++++ source/source_pw/module_pwdft/exx_helper.h | 22 +++++++ 3 files changed, 90 insertions(+), 43 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index c0ca99a1ab..efb17bc0fd 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -239,49 +239,8 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int this->ppcell.cal_effective_D(veff, this->pw_rhod, ucell); } - // Related to EXX - if (GlobalC::exx_info.info_global.cal_exx) - { - if (GlobalC::exx_info.info_global.separate_loop) - { - if (conv_esolver) - { - auto start = std::chrono::high_resolution_clock::now(); - exx_helper.set_firstiter(false); - exx_helper.op_exx->first_iter = false; - double dexx = 0.0; - if (PARAM.inp.exx_thr_type == "energy") - { - dexx = exx_helper.cal_exx_energy(this->stp.psi_t); - exx_helper.set_psi(this->stp.psi_t); - dexx -= exx_helper.cal_exx_energy(this->stp.psi_t); - // std::cout << "dexx = " << dexx << std::endl; - } - bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr; - - conv_esolver = exx_helper.exx_after_converge(iter, conv_ene); - if (!conv_esolver) - { - if (PARAM.inp.exx_thr_type != "energy") - { - exx_helper.set_psi(this->stp.psi_t); - } - auto duration = std::chrono::high_resolution_clock::now() - start; - std::cout << " Setting Psi for EXX PW Inner Loop took " - << std::chrono::duration_cast(duration).count() / 1000.0 << "s" - << std::endl; - exx_helper.op_exx->first_iter = false; - XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); - elecstate::update_pot(ucell, this->pelec, this->chr, conv_esolver); - exx_helper.iter_inc(); - } - } - } - else - { - exx_helper.set_psi(this->stp.psi_t); - } - } + // Handle EXX-related operations after SCF iteration + exx_helper.iter_finish(this->pelec, &this->chr, this->stp.psi_t, ucell, PARAM.inp, conv_esolver, iter); // check if oscillate for delta_spin method pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp); diff --git a/source/source_pw/module_pwdft/exx_helper.cpp b/source/source_pw/module_pwdft/exx_helper.cpp index 682eb7c2b7..2af13d3173 100644 --- a/source/source_pw/module_pwdft/exx_helper.cpp +++ b/source/source_pw/module_pwdft/exx_helper.cpp @@ -3,6 +3,10 @@ #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info #include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional #include "source_pw/module_pwdft/hamilt_pw.h" // use HamiltPW +#include "source_estate/update_pot.h" // use elecstate::update_pot +#include "source_estate/elecstate_pw.h" // use ElecStatePW +#include "source_estate/module_charge/charge.h" // use Charge +#include // for timing template void Exx_Helper::init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) @@ -51,6 +55,68 @@ void Exx_Helper::before_scf(void* p_hamilt, psi::Psi* psi, this->set_psi(psi); } +template +bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, + UnitCell& ucell, const Input_para& inp, + bool& conv_esolver, int& iter) +{ + /// Return if EXX is not enabled + if (!GlobalC::exx_info.info_global.cal_exx) + { + return false; + } + + /// Handle separate_loop mode + if (GlobalC::exx_info.info_global.separate_loop) + { + if (conv_esolver) + { + auto start = std::chrono::high_resolution_clock::now(); + + this->set_firstiter(false); + this->op_exx->first_iter = false; + + double dexx = 0.0; + if (inp.exx_thr_type == "energy") + { + dexx = this->cal_exx_energy(psi); + this->set_psi(psi); + dexx -= this->cal_exx_energy(psi); + } + bool conv_ene = std::abs(dexx) < inp.exx_ene_thr; + + conv_esolver = this->exx_after_converge(iter, conv_ene); + + if (!conv_esolver) + { + if (inp.exx_thr_type != "energy") + { + this->set_psi(psi); + } + + auto duration = std::chrono::high_resolution_clock::now() - start; + std::cout << " Setting Psi for EXX PW Inner Loop took " + << std::chrono::duration_cast(duration).count() / 1000.0 << "s" + << std::endl; + + this->op_exx->first_iter = false; + XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); + + elecstate::ElecState* pelec = reinterpret_cast*>(p_elec); + elecstate::update_pot(ucell, pelec, *p_charge, conv_esolver); + + this->iter_inc(); + } + } + } + else + { + this->set_psi(psi); + } + + return true; +} + template double Exx_Helper::cal_exx_energy(psi::Psi *psi_) { diff --git a/source/source_pw/module_pwdft/exx_helper.h b/source/source_pw/module_pwdft/exx_helper.h index bb370ac580..53056fdbbe 100644 --- a/source/source_pw/module_pwdft/exx_helper.h +++ b/source/source_pw/module_pwdft/exx_helper.h @@ -5,6 +5,9 @@ #ifndef EXX_HELPER_H #define EXX_HELPER_H + +class Charge; + template struct Exx_Helper { @@ -30,6 +33,25 @@ struct Exx_Helper */ void before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp); + /** + * @brief Handle EXX-related operations after SCF iteration. + * + * This function handles EXX convergence checking and potential update + * after each SCF iteration. It is called in iter_finish. + * + * @param p_elec Pointer to the ElecState object (void* to avoid circular dependency). + * @param p_charge Pointer to the Charge object. + * @param psi Pointer to the wave function object. + * @param ucell The unit cell (non-const reference for update_pot). + * @param inp The input parameters. + * @param conv_esolver Whether SCF is converged (may be modified). + * @param iter The current iteration number (may be modified). + * @return true if EXX processing was done, false otherwise. + */ + bool iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, + UnitCell& ucell, const Input_para& inp, + bool& conv_esolver, int& iter); + void set_firstiter(bool flag = true) { first_iter = flag; } void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; } void set_psi(psi::Psi *psi_);