Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a)
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a)
endif()

target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR})

if (USE_SW)
add_compile_definitions(__SW)
set(SW ON)
Expand All @@ -295,6 +298,7 @@ if (USE_SW)
target_link_libraries(${ABACUS_BIN_NAME} ${SW_MATH}/libswblas.a)
endif()


find_package(Threads REQUIRED)
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)

Expand Down
6 changes: 4 additions & 2 deletions source/source_base/kernels/dsp/dsp_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ extern "C"
}
namespace mtfunc
{
std::complex<double>* alp=nullptr;
std::complex<double>* bet=nullptr;
void dspInitHandle(int id)
{
mt_blas_init(id);
Expand Down Expand Up @@ -271,9 +273,9 @@ void zgemm_mth_(const char* transa,
const int* ldc,
int cluster_id)
{
std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
// std::complex<double>* alp = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*alp = *alpha;
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
// std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
*bet = *beta;
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
convertBLASTranspose(transa),
Expand Down
2 changes: 2 additions & 0 deletions source/source_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ void* malloc_ht(size_t bytes, int cluster_id);
void free_ht(void* ptr);

// mtblas functions
extern std::complex<double>* alp;
extern std::complex<double>* bet;

void sgemm_mt_(const char* transa,
const char* transb,
Expand Down
2 changes: 1 addition & 1 deletion source/source_basis/module_pw/module_fft/fft_dsp.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "fft_dsp.h"

#include "source_base/global_variable.h"

#include "source_base/tool_quit.h"
#include <iostream>
#include <string.h>
#include <vector>
Expand Down
2 changes: 2 additions & 0 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
#ifdef __DSP
std::cout << " ** Initializing DSP Hardware..." << std::endl;
mtfunc::dspInitHandle(GlobalV::MY_RANK);
mtfunc::alp=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), GlobalV::MY_RANK);
mtfunc::bet=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), GlobalV::MY_RANK);
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions source/source_io/cal_ldos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void stm_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,

for (int ib = 0; ib < nbands; ib++)
{
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);

const double eigenval = (pelec->ekb(ik, ib) - efermi) * ModuleBase::Ry_to_eV;
double weight = en > 0 ? pelec->klist->wk[ik] - pelec->wg(ik, ib) : pelec->wg(ik, ib);
Expand Down Expand Up @@ -210,7 +210,7 @@ void ldos_mode_pw(const elecstate::ElecStatePW<std::complex<double>>* pelec,

for (int ib = 0; ib < nbands; ib++)
{
pelec->basis->recip2real(&psi(ib, 0), wfcr.data(), ik);
pelec->basis->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi(ib, 0), wfcr.data(), ik);
const double weight = pelec->klist->wk[ik] / ucell.omega;

for (int ir = 0; ir < pelec->basis->nrxx; ir++)
Expand Down
2 changes: 1 addition & 1 deletion source/source_io/cal_mlkedf_descriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ void Cal_MLKEDF_Descriptors::getF_KS(
wfcr[ig] = psi->operator()(ibnd, ig) * std::complex<double>(0.0, fact);
}

pw_psi->recip2real(wfcr, wfcr, ik);
pw_psi->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfcr, wfcr, ik);

for (int ir = 0; ir < this->nx; ++ir)
{
Expand Down
6 changes: 3 additions & 3 deletions source/source_io/get_wf_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,
// Calculate real-space wave functions
psi_g.fix_k(is);
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), is);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), is);

// Extract real and imaginary parts
std::vector<double> wfc_real(pw_wfc->nrxx);
Expand Down Expand Up @@ -399,7 +399,7 @@ void Get_wf_lcao::begin(const UnitCell& ucell,

// Calculate real-space wave functions
std::vector<std::complex<double>> wfc_r(pw_wfc->nrxx);
pw_wfc->recip2real(&psi_g(ib, 0), wfc_r.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_g(ib, 0), wfc_r.data(), ik);

// Extract real and imaginary parts
std::vector<double> wfc_real(pw_wfc->nrxx);
Expand Down Expand Up @@ -551,7 +551,7 @@ void Get_wf_lcao::set_pw_wfc(const ModulePW::PW_Basis_K* pw_wfc,
}

// call FFT
pw_wfc->real2recip(Porter.data(), &wfc_g(ib, 0), ik);
pw_wfc->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(Porter.data(), &wfc_g(ib, 0), ik);
}

#ifdef __MPI
Expand Down
6 changes: 3 additions & 3 deletions source/source_io/read_wf2rho_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ void ModuleIO::read_wf2rho_pw(
{
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
const std::complex<double>* wfc_ib2 = wfc_tmp.c + ib * ng_npol + ng_npol / 2;
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip2real(wfc_ib2, rho_tmp2.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib2, rho_tmp2.data(), ik);
const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;

if (w1 != 0.0)
Expand All @@ -152,7 +152,7 @@ void ModuleIO::read_wf2rho_pw(
for (int ib = 0; ib < nbands; ++ib)
{
const std::complex<double>* wfc_ib = wfc_tmp.c + ib * ng_npol;
pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik);
pw_wfc->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(wfc_ib, rho_tmp.data(), ik);

const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega;

Expand Down
19 changes: 9 additions & 10 deletions source/source_io/to_wannier90_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ void toWannier90_PW::out_unk(
{
int ib = cal_band_index[ib_w];

wfcpw->recip2real(&psi_pw(ik, ib, 0), porter, ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(ik, ib, 0), porter, ik);

if (GlobalV::RANK_IN_POOL == 0)
{
Expand Down Expand Up @@ -383,7 +383,7 @@ void toWannier90_PW::unkdotkb(
}
}

wfcpw->recip2real(phase, phase, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(phase, phase, cal_ik);

if (PARAM.inp.nspin == 4)
{
Expand All @@ -396,17 +396,17 @@ void toWannier90_PW::unkdotkb(
// (2) fft and get value
// int npw_ik = wfcpw->npwk[cal_ik];
int npwx = wfcpw->npwk_max;
wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir_up, cal_ik);
// wfcpw->recip2real(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik);
wfcpw->recip2real(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir_up, cal_ik);
// wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npw_ik), psir_dn, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, npwx), psir_dn, cal_ik);
for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
psir_up[ir] *= phase[ir];
psir_dn[ir] *= phase[ir];
}

wfcpw->real2recip(psir_up, psir_up, cal_ikb);
wfcpw->real2recip(psir_dn, psir_dn, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir_up, psir_up, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir_dn, psir_dn, cal_ikb);

for (int n = 0; n < num_bands; n++)
{
Expand Down Expand Up @@ -447,13 +447,12 @@ void toWannier90_PW::unkdotkb(
ModuleBase::GlobalFunc::ZEROS(psir, wfcpw->nmaxgr);

// (2) fft and get value
wfcpw->recip2real(&psi_pw(cal_ik, im, 0), psir, cal_ik);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&psi_pw(cal_ik, im, 0), psir, cal_ik);
for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
psir[ir] *= phase[ir];
}

wfcpw->real2recip(psir, psir, cal_ikb);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psir, psir, cal_ikb);

for (int n = 0; n < num_bands; n++)
{
Expand Down
10 changes: 5 additions & 5 deletions source/source_io/unk_overlap_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
}

// (3) calculate the overlap in ik_L and ik_R
wfcpw->real2recip(psi_r, psi_r, ik_R);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_r, psi_r, ik_R);

for (int ig = 0; ig < evc->get_ngk(ik_R); ig++)
{
Expand Down Expand Up @@ -197,8 +197,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho

// (2) fft and get value
rhopw->recip2real(phase, phase);
wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
wfcpw->recip2real(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, 0), psi_up, ik_L);
wfcpw->recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(&evc[0](ik_L, iband_L, npwx), psi_down, ik_L);

for (int ir = 0; ir < wfcpw->nrxx; ir++)
{
Expand All @@ -207,8 +207,8 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
}

// (3) calculate the overlap in ik_L and ik_R
wfcpw->real2recip(psi_up, psi_up, ik_L);
wfcpw->real2recip(psi_down, psi_down, ik_L);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_up, psi_up, ik_L);
wfcpw->real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(psi_down, psi_down, ik_L);

for (int i = 0; i < PARAM.globalv.npol; i++)
{
Expand Down
4 changes: 2 additions & 2 deletions source/source_pw/module_pwdft/stress_func_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void Stress_PW<FPTYPE, Device>::stress_exx(ModuleBase::matrix& sigma,
// psi_nk in real space
d_psi_in->fix_kb(ik, nband);
T* psi_nk = d_psi_in->get_pointer();
wfcpw->recip2real(psi_nk, psi_nk_real, ik);
wfcpw->recip_to_real<std::complex<FPTYPE>,Device>(psi_nk, psi_nk_real, ik);

for (int iq = 0; iq < nqs; iq++)
{
Expand All @@ -269,7 +269,7 @@ void Stress_PW<FPTYPE, Device>::stress_exx(ModuleBase::matrix& sigma,
// psi_mq in real space
d_psi_in->fix_kb(iq, mband);
T* psi_mq = d_psi_in->get_pointer();
wfcpw->recip2real(psi_mq, psi_mq_real, iq);
wfcpw->recip_to_real<std::complex<FPTYPE>,Device>(psi_mq, psi_mq_real, iq);

// overlap density in real space
setmem_complex_op()(density_real, 0.0, rhopw->nrxx);
Expand Down
Loading