Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,9 @@ OBJS_SRCPW=H_Ewald_pw.o\
setup_pot.o\
setup_pwrho.o\
setup_pwwfc.o\
update_cell_pw.o\
dftu_pw.o\
deltaspin_pw.o\
forces.o\
forces_us.o\
forces_nl.o\
Expand Down
134 changes: 13 additions & 121 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include "source_io/module_ctrl/ctrl_output_pw.h" // mohan add 20250927
#include "source_estate/module_charge/chgmixing.h" // use charge mixing, mohan add 20251006
#include "source_estate/update_pot.h" // mohan add 20251016
#include "source_pw/module_pwdft/update_cell_pw.h" // mohan add 20250309
#include "source_pw/module_pwdft/dftu_pw.h" // mohan add 20250309
#include "source_pw/module_pwdft/deltaspin_pw.h" // mohan add 20250309

#include "source_hamilt/module_xc/exx_info.h" // use GlobalC::exx_info

Expand Down Expand Up @@ -93,20 +96,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS");

//! Initialize exx pw
if (inp.calculation == "scf" || inp.calculation == "relax" || inp.calculation == "cell-relax"
|| inp.calculation == "md")
{
if (GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop == true)
{
XC_Functional::set_xc_first_loop(ucell);
exx_helper.set_firstiter();
}

if (GlobalC::exx_info.info_global.cal_exx)
{
exx_helper.set_wg(&this->pelec->wg);
}
}
this->exx_helper.init(ucell, inp, this->pelec->wg);
}

template <typename T, typename Device>
Expand All @@ -119,17 +109,10 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
ESolver_KS<T, Device>::before_scf(ucell, istep);

//! Init variables (once the cell has changed)
pw::update_cell_pw(ucell, this->ppcell, this->kv, this->pw_wfc, PARAM.inp);

if (ucell.cell_parameter_updated)
{
this->ppcell.rescale_vnl(ucell.omega);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL");

this->pw_wfc->initgrids(ucell.lat0, ucell.latvec, this->pw_wfc->nx, this->pw_wfc->ny, this->pw_wfc->nz);

this->pw_wfc->initparameters(false, PARAM.inp.ecutwfc, this->kv.get_nks(), this->kv.kvec_d.data());

this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);

this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed);
}

Expand All @@ -151,17 +134,8 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
// setup psi (electronic wave functions)
this->stp.init(this->p_hamilt);

//! Exx calculations
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
|| PARAM.inp.calculation == "cell-relax" || PARAM.inp.calculation == "md")
{
if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.basis_type == "pw")
{
auto hamilt_pw = reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
hamilt_pw->set_exx_helper(exx_helper);
exx_helper.set_psi(this->stp.psi_t);
}
}
//! Setup EXX helper for Hamiltonian and psi
exx_helper.before_scf(this->p_hamilt, this->stp.psi_t, PARAM.inp);

ModuleBase::timer::tick("ESolver_KS_PW", "before_scf");
}
Expand All @@ -181,16 +155,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const

// 4) update local occupations for DFT+U
// should before lambda loop in DeltaSpin
if (PARAM.inp.dft_plus_u && (iter != 1 || istep != 0))
{
// only old DFT+U method should calculate energy correction in esolver,
// new DFT+U method will calculate energy when evaluating the Hamiltonian
if (this->dftu.omc != 2)
{
this->dftu.cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta);
}
this->dftu.output(ucell);
}
pw::iter_init_dftu_pw(iter, istep, this->dftu, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp);
}

// Temporary, it should be replaced by hsolver later.
Expand Down Expand Up @@ -218,26 +183,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;

// run the inner lambda loop to contrain atomic moments with the DeltaSpin method
bool skip_solve = false;

if (PARAM.inp.sc_mag_switch)
{
spinconstrain::SpinConstrain<std::complex<double>>& sc
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();
if (!sc.mag_converged() && this->drho > 0 && this->drho < PARAM.inp.sc_scf_thr)
{
// optimize lambda to get target magnetic moments, but the lambda is not near target
sc.run_lambda_loop(iter - 1);
sc.set_mag_converged(true);
skip_solve = true;
}
else if (sc.mag_converged())
{
// optimize lambda to get target magnetic moments, but the lambda is not near target
sc.run_lambda_loop(iter - 1);
skip_solve = true;
}
}
bool skip_solve = pw::run_deltaspin_lambda_loop(iter - 1, this->drho, PARAM.inp);

if (!skip_solve)
{
Expand Down Expand Up @@ -293,65 +239,11 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
this->ppcell.cal_effective_D(veff, this->pw_rhod, ucell);
}

// Related to EXX
if (GlobalC::exx_info.info_global.cal_exx)
{
if (GlobalC::exx_info.info_global.separate_loop)
{
if (conv_esolver)
{
auto start = std::chrono::high_resolution_clock::now();
exx_helper.set_firstiter(false);
exx_helper.op_exx->first_iter = false;
double dexx = 0.0;
if (PARAM.inp.exx_thr_type == "energy")
{
dexx = exx_helper.cal_exx_energy(this->stp.psi_t);
exx_helper.set_psi(this->stp.psi_t);
dexx -= exx_helper.cal_exx_energy(this->stp.psi_t);
// std::cout << "dexx = " << dexx << std::endl;
}
bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr;

conv_esolver = exx_helper.exx_after_converge(iter, conv_ene);
if (!conv_esolver)
{
if (PARAM.inp.exx_thr_type != "energy")
{
exx_helper.set_psi(this->stp.psi_t);
}
auto duration = std::chrono::high_resolution_clock::now() - start;
std::cout << " Setting Psi for EXX PW Inner Loop took "
<< std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() / 1000.0 << "s"
<< std::endl;
exx_helper.op_exx->first_iter = false;
XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func);
elecstate::update_pot(ucell, this->pelec, this->chr, conv_esolver);
exx_helper.iter_inc();
}
}
}
else
{
exx_helper.set_psi(this->stp.psi_t);
}
}
// Handle EXX-related operations after SCF iteration
exx_helper.iter_finish(this->pelec, &this->chr, this->stp.psi_t, ucell, PARAM.inp, conv_esolver, iter);

// check if oscillate for delta_spin method
if (PARAM.inp.sc_mag_switch)
{
spinconstrain::SpinConstrain<std::complex<double>>& sc
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();
if (!sc.higher_mag_prec)
{
sc.higher_mag_prec = this->p_chgmix->if_scf_oscillate(iter,
this->drho, PARAM.inp.sc_os_ndim, PARAM.inp.scf_os_thr);
if (sc.higher_mag_prec)
{ // if oscillate, increase the precision of magnetization and do mixing_restart in next iteration
this->p_chgmix->mixing_restart_step = iter + 1;
}
}
}
pw::check_deltaspin_oscillation(iter, this->drho, this->p_chgmix, PARAM.inp);

// the output quantities
ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu,
Expand Down
3 changes: 3 additions & 0 deletions source/source_pw/module_pwdft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ list(APPEND objects
setup_pot.cpp
setup_pwrho.cpp
setup_pwwfc.cpp
update_cell_pw.cpp
dftu_pw.cpp
deltaspin_pw.cpp
forces_nl.cpp
forces_cc.cpp
forces_scc.cpp
Expand Down
72 changes: 72 additions & 0 deletions source/source_pw/module_pwdft/deltaspin_pw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "source_pw/module_pwdft/deltaspin_pw.h"
#include "source_lcao/module_deltaspin/spin_constrain.h"
#include "source_estate/module_charge/charge_mixing.h"

namespace pw
{

bool run_deltaspin_lambda_loop(const int iter,
const double drho,
const Input_para& inp)
{
/// Return false if DeltaSpin is not enabled
if (!inp.sc_mag_switch)
{
return false;
}

/// Get the singleton instance of SpinConstrain
spinconstrain::SpinConstrain<std::complex<double>>& sc
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();

/// Case 1: Magnetic moments not yet converged and SCF is close to convergence.
/// This is the first time we enter the lambda loop after SCF is nearly converged.
if (!sc.mag_converged() && drho > 0 && drho < inp.sc_scf_thr)
{
/// Optimize lambda to get target magnetic moments
sc.run_lambda_loop(iter);
sc.set_mag_converged(true);
return true;
}
/// Case 2: Magnetic moments already converged in previous iteration.
/// Continue to refine lambda in subsequent SCF iterations.
else if (sc.mag_converged())
{
sc.run_lambda_loop(iter);
return true;
}

/// Default: run the normal solver
return false;
}

void check_deltaspin_oscillation(const int iter,
const double drho,
Charge_Mixing* p_chgmix,
const Input_para& inp)
{
/// Return if DeltaSpin is not enabled
if (!inp.sc_mag_switch)
{
return;
}

/// Get the singleton instance of SpinConstrain
spinconstrain::SpinConstrain<std::complex<double>>& sc
= spinconstrain::SpinConstrain<std::complex<double>>::getScInstance();

/// Check if higher magnetization precision is needed
if (!sc.higher_mag_prec)
{
/// Detect SCF oscillation
sc.higher_mag_prec = p_chgmix->if_scf_oscillate(iter, drho, inp.sc_os_ndim, inp.scf_os_thr);

/// If oscillation detected, set mixing restart step for next iteration
if (sc.higher_mag_prec)
{
p_chgmix->mixing_restart_step = iter + 1;
}
}
}

}
46 changes: 46 additions & 0 deletions source/source_pw/module_pwdft/deltaspin_pw.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef DELTASPIN_PW_H
#define DELTASPIN_PW_H

#include "source_io/module_parameter/parameter.h"

class Charge_Mixing;

namespace pw
{

/**
* @brief Run the inner lambda loop for DeltaSpin method to constrain atomic magnetic moments.
*
* This function is used in the PW basis SCF iteration to optimize lambda parameters
* for constraining atomic magnetic moments to target values using the DeltaSpin method.
*
* @param iter The current iteration number (0-indexed).
* @param drho The current charge density difference.
* @param inp The input parameters.
* @return true if the solver should be skipped (lambda loop was executed),
* false otherwise.
*/
bool run_deltaspin_lambda_loop(const int iter,
const double drho,
const Input_para& inp);

/**
* @brief Check if SCF oscillation occurs for DeltaSpin method.
*
* This function checks if the SCF iteration is oscillating and sets the
* mixing restart step if oscillation is detected. This is used to increase
* the precision of magnetization calculation.
*
* @param iter The current iteration number (1-indexed).
* @param drho The current charge density difference.
* @param p_chgmix Pointer to the Charge_Mixing object.
* @param inp The input parameters.
*/
void check_deltaspin_oscillation(const int iter,
const double drho,
Charge_Mixing* p_chgmix,
const Input_para& inp);

}

#endif
32 changes: 32 additions & 0 deletions source/source_pw/module_pwdft/dftu_pw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "source_pw/module_pwdft/dftu_pw.h"
#include "source_lcao/module_dftu/dftu.h"

namespace pw
{

void iter_init_dftu_pw(const int iter,
const int istep,
Plus_U& dftu,
const void* psi,
const ModuleBase::matrix& wg,
const UnitCell& ucell,
const Input_para& inp)
{
if (!inp.dft_plus_u)
{
return;
}

if (iter == 1 && istep == 0)
{
return;
}

if (dftu.omc != 2)
{
dftu.cal_occ_pw(iter, psi, wg, ucell, inp.mixing_beta);
}
dftu.output(ucell);
}

}
23 changes: 23 additions & 0 deletions source/source_pw/module_pwdft/dftu_pw.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef DFTU_PW_H
#define DFTU_PW_H

#include "source_io/module_parameter/parameter.h"
#include "source_cell/unitcell.h"
#include "source_base/matrix.h"

class Plus_U;

namespace pw
{

void iter_init_dftu_pw(const int iter,
const int istep,
Plus_U& dftu,
const void* psi,
const ModuleBase::matrix& wg,
const UnitCell& ucell,
const Input_para& inp);

}

#endif
Loading
Loading