diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 0ba91a0378..bbfe0c1235 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -705,6 +705,9 @@ OBJS_SRCPW=H_Ewald_pw.o\ setup_pot.o\ setup_pwrho.o\ setup_pwwfc.o\ + update_cell_pw.o\ + dftu_pw.o\ + deltaspin_pw.o\ forces.o\ forces_us.o\ forces_nl.o\ diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index b3dc5744a0..efb17bc0fd 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -28,6 +28,9 @@ #include "source_io/module_ctrl/ctrl_output_pw.h" // mohan add 20250927 #include "source_estate/module_charge/chgmixing.h" // use charge mixing, mohan add 20251006 #include "source_estate/update_pot.h" // mohan add 20251016 +#include "source_pw/module_pwdft/update_cell_pw.h" // mohan add 20250309 +#include "source_pw/module_pwdft/dftu_pw.h" // mohan add 20250309 +#include "source_pw/module_pwdft/deltaspin_pw.h" // mohan add 20250309 #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info @@ -93,20 +96,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); //! Initialize exx pw - if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax" - || inp.calculation == "md") - { - if (GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop == true) - { - XC_Functional::set_xc_first_loop(ucell); - exx_helper.set_firstiter(); - } - - if (GlobalC::exx_info.info_global.cal_exx) - { - exx_helper.set_wg(&this->pelec->wg); - } - } + this->exx_helper.init(ucell, inp, this->pelec->wg); } template @@ -119,17 +109,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) ESolver_KS::before_scf(ucell, istep); //! Init variables (once the cell has changed) + pw::update_cell_pw(ucell, this->ppcell, this->kv, this->pw_wfc, PARAM.inp); + if (ucell.cell_parameter_updated) { - this->ppcell.rescale_vnl(ucell.omega); - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); - - this->pw_wfc->initgrids(ucell.lat0, ucell.latvec, this->pw_wfc->nx, this->pw_wfc->ny, this->pw_wfc->nz); - - this->pw_wfc->initparameters(false, PARAM.inp.ecutwfc, this->kv.get_nks(), this->kv.kvec_d.data()); - - this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma); - this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed); } @@ -151,17 +134,8 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // setup psi (electronic wave functions) this->stp.init(this->p_hamilt); - //! Exx calculations - if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" - || PARAM.inp.calculation == "cell-relax" || PARAM.inp.calculation == "md") - { - if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.basis_type == "pw") - { - auto hamilt_pw = reinterpret_cast*>(this->p_hamilt); - hamilt_pw->set_exx_helper(exx_helper); - exx_helper.set_psi(this->stp.psi_t); - } - } + //! Setup EXX helper for Hamiltonian and psi + exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); } @@ -181,16 +155,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // 4) update local occupations for DFT+U // should before lambda loop in DeltaSpin - if (PARAM.inp.dft_plus_u && (iter != 1 || istep != 0)) - { - // only old DFT+U method should calculate energy correction in esolver, - // new DFT+U method will calculate energy when evaluating the Hamiltonian - if (this->dftu.omc != 2) - { - this->dftu.cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta); - } - this->dftu.output(ucell); - } + pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp); } // Temporary, it should be replaced by hsolver later. @@ -218,26 +183,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; // run the inner lambda loop to contrain atomic moments with the DeltaSpin method - bool skip_solve = false; - - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain>& sc - = spinconstrain::SpinConstrain>::getScInstance(); - if (!sc.mag_converged() && this->drho > 0 && this->drho < PARAM.inp.sc_scf_thr) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - sc.set_mag_converged(true); - skip_solve = true; - } - else if (sc.mag_converged()) - { - // optimize lambda to get target magnetic moments, but the lambda is not near target - sc.run_lambda_loop(iter - 1); - skip_solve = true; - } - } + bool skip_solve = pw::run_deltaspin_lambda_loop(iter - 1, this->drho, PARAM.inp); if (!skip_solve) { @@ -293,65 +239,11 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int this->ppcell.cal_effective_D(veff, this->pw_rhod, ucell); } - // Related to EXX - if (GlobalC::exx_info.info_global.cal_exx) - { - if (GlobalC::exx_info.info_global.separate_loop) - { - if (conv_esolver) - { - auto start = std::chrono::high_resolution_clock::now(); - exx_helper.set_firstiter(false); - exx_helper.op_exx->first_iter = false; - double dexx = 0.0; - if (PARAM.inp.exx_thr_type == "energy") - { - dexx = exx_helper.cal_exx_energy(this->stp.psi_t); - exx_helper.set_psi(this->stp.psi_t); - dexx -= exx_helper.cal_exx_energy(this->stp.psi_t); - // std::cout << "dexx = " << dexx << std::endl; - } - bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr; - - conv_esolver = exx_helper.exx_after_converge(iter, conv_ene); - if (!conv_esolver) - { - if (PARAM.inp.exx_thr_type != "energy") - { - exx_helper.set_psi(this->stp.psi_t); - } - auto duration = std::chrono::high_resolution_clock::now() - start; - std::cout << " Setting Psi for EXX PW Inner Loop took " - << std::chrono::duration_cast(duration).count() / 1000.0 << "s" - << std::endl; - exx_helper.op_exx->first_iter = false; - XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); - elecstate::update_pot(ucell, this->pelec, this->chr, conv_esolver); - exx_helper.iter_inc(); - } - } - } - else - { - exx_helper.set_psi(this->stp.psi_t); - } - } + // Handle EXX-related operations after SCF iteration + exx_helper.iter_finish(this->pelec, &this->chr, this->stp.psi_t, ucell, PARAM.inp, conv_esolver, iter); // check if oscillate for delta_spin method - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain>& sc - = spinconstrain::SpinConstrain>::getScInstance(); - if (!sc.higher_mag_prec) - { - sc.higher_mag_prec = this->p_chgmix->if_scf_oscillate(iter, - this->drho, PARAM.inp.sc_os_ndim, PARAM.inp.scf_os_thr); - if (sc.higher_mag_prec) - { // if oscillate, increase the precision of magnetization and do mixing_restart in next iteration - this->p_chgmix->mixing_restart_step = iter + 1; - } - } - } + pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp); // the output quantities ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu, diff --git a/source/source_pw/module_pwdft/CMakeLists.txt b/source/source_pw/module_pwdft/CMakeLists.txt index 3c525240ea..9e34e9c7b4 100644 --- a/source/source_pw/module_pwdft/CMakeLists.txt +++ b/source/source_pw/module_pwdft/CMakeLists.txt @@ -15,6 +15,9 @@ list(APPEND objects setup_pot.cpp setup_pwrho.cpp setup_pwwfc.cpp + update_cell_pw.cpp + dftu_pw.cpp + deltaspin_pw.cpp forces_nl.cpp forces_cc.cpp forces_scc.cpp diff --git a/source/source_pw/module_pwdft/deltaspin_pw.cpp b/source/source_pw/module_pwdft/deltaspin_pw.cpp new file mode 100644 index 0000000000..caf8ea7852 --- /dev/null +++ b/source/source_pw/module_pwdft/deltaspin_pw.cpp @@ -0,0 +1,72 @@ +#include "source_pw/module_pwdft/deltaspin_pw.h" +#include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_estate/module_charge/charge_mixing.h" + +namespace pw +{ + +bool run_deltaspin_lambda_loop(const int iter, + const double drho, + const Input_para& inp) +{ + /// Return false if DeltaSpin is not enabled + if (!inp.sc_mag_switch) + { + return false; + } + + /// Get the singleton instance of SpinConstrain + spinconstrain::SpinConstrain>& sc + = spinconstrain::SpinConstrain>::getScInstance(); + + /// Case 1: Magnetic moments not yet converged and SCF is close to convergence. + /// This is the first time we enter the lambda loop after SCF is nearly converged. + if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr) + { + /// Optimize lambda to get target magnetic moments + sc.run_lambda_loop(iter); + sc.set_mag_converged(true); + return true; + } + /// Case 2: Magnetic moments already converged in previous iteration. + /// Continue to refine lambda in subsequent SCF iterations. + else if (sc.mag_converged()) + { + sc.run_lambda_loop(iter); + return true; + } + + /// Default: run the normal solver + return false; +} + +void check_deltaspin_oscillation(const int iter, + const double drho, + Charge_Mixing* p_chgmix, + const Input_para& inp) +{ + /// Return if DeltaSpin is not enabled + if (!inp.sc_mag_switch) + { + return; + } + + /// Get the singleton instance of SpinConstrain + spinconstrain::SpinConstrain>& sc + = spinconstrain::SpinConstrain>::getScInstance(); + + /// Check if higher magnetization precision is needed + if (!sc.higher_mag_prec) + { + /// Detect SCF oscillation + sc.higher_mag_prec = p_chgmix->if_scf_oscillate(iter, drho, inp.sc_os_ndim, inp.scf_os_thr); + + /// If oscillation detected, set mixing restart step for next iteration + if (sc.higher_mag_prec) + { + p_chgmix->mixing_restart_step = iter + 1; + } + } +} + +} diff --git a/source/source_pw/module_pwdft/deltaspin_pw.h b/source/source_pw/module_pwdft/deltaspin_pw.h new file mode 100644 index 0000000000..0509b61bb8 --- /dev/null +++ b/source/source_pw/module_pwdft/deltaspin_pw.h @@ -0,0 +1,46 @@ +#ifndef DELTASPIN_PW_H +#define DELTASPIN_PW_H + +#include "source_io/module_parameter/parameter.h" + +class Charge_Mixing; + +namespace pw +{ + +/** + * @brief Run the inner lambda loop for DeltaSpin method to constrain atomic magnetic moments. + * + * This function is used in the PW basis SCF iteration to optimize lambda parameters + * for constraining atomic magnetic moments to target values using the DeltaSpin method. + * + * @param iter The current iteration number (0-indexed). + * @param drho The current charge density difference. + * @param inp The input parameters. + * @return true if the solver should be skipped (lambda loop was executed), + * false otherwise. + */ +bool run_deltaspin_lambda_loop(const int iter, + const double drho, + const Input_para& inp); + +/** + * @brief Check if SCF oscillation occurs for DeltaSpin method. + * + * This function checks if the SCF iteration is oscillating and sets the + * mixing restart step if oscillation is detected. This is used to increase + * the precision of magnetization calculation. + * + * @param iter The current iteration number (1-indexed). + * @param drho The current charge density difference. + * @param p_chgmix Pointer to the Charge_Mixing object. + * @param inp The input parameters. + */ +void check_deltaspin_oscillation(const int iter, + const double drho, + Charge_Mixing* p_chgmix, + const Input_para& inp); + +} + +#endif diff --git a/source/source_pw/module_pwdft/dftu_pw.cpp b/source/source_pw/module_pwdft/dftu_pw.cpp new file mode 100644 index 0000000000..475a34620a --- /dev/null +++ b/source/source_pw/module_pwdft/dftu_pw.cpp @@ -0,0 +1,32 @@ +#include "source_pw/module_pwdft/dftu_pw.h" +#include "source_lcao/module_dftu/dftu.h" + +namespace pw +{ + +void iter_init_dftu_pw(const int iter, + const int istep, + Plus_U& dftu, + const void* psi, + const ModuleBase::matrix& wg, + const UnitCell& ucell, + const Input_para& inp) +{ + if (!inp.dft_plus_u) + { + return; + } + + if (iter == 1 && istep == 0) + { + return; + } + + if (dftu.omc != 2) + { + dftu.cal_occ_pw(iter, psi, wg, ucell, inp.mixing_beta); + } + dftu.output(ucell); +} + +} diff --git a/source/source_pw/module_pwdft/dftu_pw.h b/source/source_pw/module_pwdft/dftu_pw.h new file mode 100644 index 0000000000..8a30b04e76 --- /dev/null +++ b/source/source_pw/module_pwdft/dftu_pw.h @@ -0,0 +1,23 @@ +#ifndef DFTU_PW_H +#define DFTU_PW_H + +#include "source_io/module_parameter/parameter.h" +#include "source_cell/unitcell.h" +#include "source_base/matrix.h" + +class Plus_U; + +namespace pw +{ + +void iter_init_dftu_pw(const int iter, + const int istep, + Plus_U& dftu, + const void* psi, + const ModuleBase::matrix& wg, + const UnitCell& ucell, + const Input_para& inp); + +} + +#endif diff --git a/source/source_pw/module_pwdft/exx_helper.cpp b/source/source_pw/module_pwdft/exx_helper.cpp index 89c32d1584..2af13d3173 100644 --- a/source/source_pw/module_pwdft/exx_helper.cpp +++ b/source/source_pw/module_pwdft/exx_helper.cpp @@ -1,6 +1,121 @@ #include "exx_helper.h" #include "source_io/module_parameter/parameter.h" // use PARAM #include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info +#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional +#include "source_pw/module_pwdft/hamilt_pw.h" // use HamiltPW +#include "source_estate/update_pot.h" // use elecstate::update_pot +#include "source_estate/elecstate_pw.h" // use ElecStatePW +#include "source_estate/module_charge/charge.h" // use Charge +#include // for timing + +template +void Exx_Helper::init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg) +{ + if (inp.calculation != "scf" && inp.calculation != "relax" + && inp.calculation != "cell-relax" && inp.calculation != "md") + { + return; + } + + if (!GlobalC::exx_info.info_global.cal_exx) + { + return; + } + + if (GlobalC::exx_info.info_global.separate_loop) + { + XC_Functional::set_xc_first_loop(ucell); + this->set_firstiter(); + } + + this->set_wg(&wg); +} + +template +void Exx_Helper::before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp) +{ + /// Return if not a valid calculation type + if (inp.calculation != "scf" && inp.calculation != "relax" + && inp.calculation != "cell-relax" && inp.calculation != "md") + { + return; + } + + /// Return if EXX is not enabled or not PW basis + if (!GlobalC::exx_info.info_global.cal_exx || inp.basis_type != "pw") + { + return; + } + + /// Set EXX helper to Hamiltonian + auto hamilt_pw = reinterpret_cast*>(p_hamilt); + hamilt_pw->set_exx_helper(*this); + + /// Set psi for EXX calculation + this->set_psi(psi); +} + +template +bool Exx_Helper::iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, + UnitCell& ucell, const Input_para& inp, + bool& conv_esolver, int& iter) +{ + /// Return if EXX is not enabled + if (!GlobalC::exx_info.info_global.cal_exx) + { + return false; + } + + /// Handle separate_loop mode + if (GlobalC::exx_info.info_global.separate_loop) + { + if (conv_esolver) + { + auto start = std::chrono::high_resolution_clock::now(); + + this->set_firstiter(false); + this->op_exx->first_iter = false; + + double dexx = 0.0; + if (inp.exx_thr_type == "energy") + { + dexx = this->cal_exx_energy(psi); + this->set_psi(psi); + dexx -= this->cal_exx_energy(psi); + } + bool conv_ene = std::abs(dexx) < inp.exx_ene_thr; + + conv_esolver = this->exx_after_converge(iter, conv_ene); + + if (!conv_esolver) + { + if (inp.exx_thr_type != "energy") + { + this->set_psi(psi); + } + + auto duration = std::chrono::high_resolution_clock::now() - start; + std::cout << " Setting Psi for EXX PW Inner Loop took " + << std::chrono::duration_cast(duration).count() / 1000.0 << "s" + << std::endl; + + this->op_exx->first_iter = false; + XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); + + elecstate::ElecState* pelec = reinterpret_cast*>(p_elec); + elecstate::update_pot(ucell, pelec, *p_charge, conv_esolver); + + this->iter_inc(); + } + } + } + else + { + this->set_psi(psi); + } + + return true; +} template double Exx_Helper::cal_exx_energy(psi::Psi *psi_) diff --git a/source/source_pw/module_pwdft/exx_helper.h b/source/source_pw/module_pwdft/exx_helper.h index 283b035760..53056fdbbe 100644 --- a/source/source_pw/module_pwdft/exx_helper.h +++ b/source/source_pw/module_pwdft/exx_helper.h @@ -1,9 +1,13 @@ #include "source_psi/psi.h" #include "source_base/matrix.h" #include "source_pw/module_pwdft/op_pw_exx.h" +#include "source_io/module_parameter/input_parameter.h" #ifndef EXX_HELPER_H #define EXX_HELPER_H + +class Charge; + template struct Exx_Helper { @@ -14,6 +18,40 @@ struct Exx_Helper Exx_Helper() = default; OperatorEXX *op_exx = nullptr; + void init(const UnitCell& ucell, const Input_para& inp, const ModuleBase::matrix& wg); + + /** + * @brief Setup EXX helper before SCF iteration. + * + * This function sets up the EXX helper for the Hamiltonian and psi + * before each SCF iteration. It checks if the calculation type and + * EXX settings are appropriate. + * + * @param p_hamilt Pointer to the Hamiltonian object (void* to avoid circular dependency). + * @param psi Pointer to the wave function object. + * @param inp The input parameters. + */ + void before_scf(void* p_hamilt, psi::Psi* psi, const Input_para& inp); + + /** + * @brief Handle EXX-related operations after SCF iteration. + * + * This function handles EXX convergence checking and potential update + * after each SCF iteration. It is called in iter_finish. + * + * @param p_elec Pointer to the ElecState object (void* to avoid circular dependency). + * @param p_charge Pointer to the Charge object. + * @param psi Pointer to the wave function object. + * @param ucell The unit cell (non-const reference for update_pot). + * @param inp The input parameters. + * @param conv_esolver Whether SCF is converged (may be modified). + * @param iter The current iteration number (may be modified). + * @return true if EXX processing was done, false otherwise. + */ + bool iter_finish(void* p_elec, Charge* p_charge, psi::Psi* psi, + UnitCell& ucell, const Input_para& inp, + bool& conv_esolver, int& iter); + void set_firstiter(bool flag = true) { first_iter = flag; } void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; } void set_psi(psi::Psi *psi_); diff --git a/source/source_pw/module_pwdft/update_cell_pw.cpp b/source/source_pw/module_pwdft/update_cell_pw.cpp new file mode 100644 index 0000000000..a027884cc2 --- /dev/null +++ b/source/source_pw/module_pwdft/update_cell_pw.cpp @@ -0,0 +1,29 @@ +#include "source_pw/module_pwdft/update_cell_pw.h" +#include "source_base/global_variable.h" +#include "source_base/global_function.h" + +namespace pw +{ + +void update_cell_pw(const UnitCell& ucell, + pseudopot_cell_vnl& ppcell, + const K_Vectors& kv, + ModulePW::PW_Basis_K* pw_wfc, + const Input_para& inp) +{ + ModuleBase::TITLE("pw", "update_cell_pw"); + + if (!ucell.cell_parameter_updated) + { + return; + } + + ppcell.rescale_vnl(ucell.omega); + ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL"); + + pw_wfc->initgrids(ucell.lat0, ucell.latvec, pw_wfc->nx, pw_wfc->ny, pw_wfc->nz); + pw_wfc->initparameters(false, inp.ecutwfc, kv.get_nks(), kv.kvec_d.data()); + pw_wfc->collect_local_pw(inp.erf_ecut, inp.erf_height, inp.erf_sigma); +} + +} diff --git a/source/source_pw/module_pwdft/update_cell_pw.h b/source/source_pw/module_pwdft/update_cell_pw.h new file mode 100644 index 0000000000..704d5b6f80 --- /dev/null +++ b/source/source_pw/module_pwdft/update_cell_pw.h @@ -0,0 +1,20 @@ +#ifndef UPDATE_CELL_PW_H +#define UPDATE_CELL_PW_H + +#include "source_io/module_parameter/parameter.h" +#include "source_cell/unitcell.h" +#include "source_cell/klist.h" +#include "source_basis/module_pw/pw_basis_k.h" +#include "source_pw/module_pwdft/vnl_pw.h" + +namespace pw +{ + +void update_cell_pw(const UnitCell& ucell, + pseudopot_cell_vnl& ppcell, + const K_Vectors& kv, + ModulePW::PW_Basis_K* pw_wfc, + const Input_para& inp); + +} +#endif