diff --git a/pyproject.toml b/pyproject.toml index d973e430..dd950868 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ ] description = "Yambopy: a pre/post-processing tool for Yambo" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: GNU General Public License v2 or later (GPLv2+)", diff --git a/tutorial/databases_yambopy/exc_real_wf.py b/tutorial/databases_yambopy/exc_real_wf.py new file mode 100644 index 00000000..aab96d62 --- /dev/null +++ b/tutorial/databases_yambopy/exc_real_wf.py @@ -0,0 +1,40 @@ +import numpy as np +from yambopy.dbs.excitondb import YamboExcitonDB +from yambopy.dbs import excitondb +from yambopy.dbs.latticedb import YamboLatticeDB +from yambopy.dbs.wfdb import YamboWFDB +import os + +iqpt = 1 # qpt index of exciton +calc_path = '.' +BSE_dir = 'GW_BSE' + +# load lattice db +lattice = YamboLatticeDB.from_db_file(os.path.join(calc_path, 'SAVE','ns.db1')) + +# load exciton db +# note in case too many excitons, load only first few with `neigs' flag +# DO NOT forget to include all degenerate states when giving neigs flag ! +# +filename = 'ndb.BS_diago_Q%d' % (iqpt) +excdb = YamboExcitonDB.from_db_file(lattice, filename=filename, + folder=os.path.join(calc_path, BSE_dir), + neigs = -1) + +#Load the wavefunction database +wfdb = YamboWFDB(path='.', latdb=lattice, + bands_range=[np.min(excdb.table[:, 1]) - 1, + np.max(excdb.table[:, 2])]) + +## plot the exciton wavefunction with hole fixed at [0,0,0] +# in a [1,1,1] supercell with 80 Ry wf cutoff. (give -1 to use full cutoff) +# I want to set the degeneracy threshold to 0.01 eV +# For example I want to plot the 3rd exciton, so iexe = 2 (python indexing ) +# +excdb.real_wf_to_cube(iexe=2, wfdb=wfdb, fixed_postion=[0.0 , 0.0 , 0.0], + supercell=[1,1,1], degen_tol=0.01, wfcCutoffRy=-1, fix_particle='h') +# fixed_postion is in reduced units +# in case, you want to plot hole density by fixing electron, set fix_particle = 'e' + +## .cube will be dumped and use vesta to visualize it ! + diff --git a/yambopy/bse/exciton_matrix_elements.py b/yambopy/bse/exciton_matrix_elements.py new file mode 100644 index 00000000..65797a70 --- /dev/null +++ b/yambopy/bse/exciton_matrix_elements.py @@ -0,0 +1,113 @@ +### +### +# This file contains a genenal functions to compute +# < S | O | S'>, where O is an operator. +# for Ex: is O is dV_scf, then these ,matrix elements are +# ex-ph matrix elements. if O is S_z (spin operator), we get +# spin matrix elements of excitons +### +import numpy as np +from yambopy.kpoints import build_ktree, find_kpt +from yambopy.tools.function_profiler import func_profile + +@func_profile +def exciton_X_matelem(exe_kvec, O_qvec, Akq, Ak, Omn, kpts, contribution='b', diagonal_only=False, ktree=None): + """ + Compute the exciton matrix elements in the Tamm-Dancoff approximation. + + This function calculates the matrix elements , + discarding the third term (disconnected diagram). The calculation is performed in the + Tamm-Dancoff approximation. + + Parameters + ---------- + exe_kvec : array_like + Exciton k-vector in crystal coordinates (k). + O_qvec : array_like + Momentum transfer vector q in crystal coordinates (q). + Akq : array_like + Wavefunction coefficients for k+q (bra wfc) with shape (n_exe_states, 1, ns, nk, nc, nv). + Ak : array_like + Wavefunction coefficients for k (ket wfc) with shape (n_exe_states, 1, ns, nk, nc, nv). + Omn : array_like + Matrix elements of the operator O in the basis of electronic states with shape (nlambda, nk, nspin, m_bnd, n_bnd). + ie Omn = < k+q, m, s | O(q) | n, k, s>, where m_bnd and n_bnd are final and initial bands, respectively. + s is spin index + kpts : array_like + K-points used to construct the BSE with shape (nk, 3). + contribution : str, optional + Specifies the contribution to include in the calculation: + - 'e' : Only electronic contribution. + - 'h' : Only hole contribution. + - 'b' : Both electron and hole contributions (default). + diagonal_only : bool, optional + If True, only the diagonal terms are computed. Default is False. + + ktree : KDtree, optional + If None, will build internally, else use the user provided + Returns + ------- + ex_O_mat : ndarray + The computed exciton matrix elements with shape (nlambda, n_exe_states) if diagonal_only is True, + or (nlambda, n_exe_states (final), n_exe_states (initial)) if diagonal_only is False. + """ + # Number of arbitrary parameters (lambda) in the Omn matrix + nlambda = Omn.shape[0] + # + assert Akq.shape[1] == 1, "Works only with TDA." + # Shape of the wavefunction coefficients + n_exe_states, bse_calc, ns, nk, nc, nv = Akq.shape + # + # Ensure that the shapes of Akq and Ak match + assert Akq.shape == Ak.shape, "Wavefunction coefficient mismatch" + # + # Ensure that the contribution parameter is valid + assert contribution in ['b', 'e', 'h'], "Allowed values for contribution are 'b', 'e', 'h'" + # + # Build a k-point tree for efficient k-point searching + if ktree is None : ktree = build_ktree(kpts) + # + # Find the indices of k-Q-q and k-q in the k-point tree + idx_k_minus_Q_minus_q = find_kpt(ktree, kpts - O_qvec[None, :] - exe_kvec[None, :]) # k-Q-q + idx_k_minus_q = find_kpt(ktree, kpts - O_qvec[None, :]) # k-q + # + # Extract the occupied and unoccupied parts of the Omn matrix + Occ = Omn[:, idx_k_minus_q, :, nv:, nv:].transpose(0,2,1,3,4) # Occupied part + Ovv = Omn[:, idx_k_minus_Q_minus_q, :, :nv, :nv].transpose(0,2,1,3,4) # conduction part + # + # Ensure the arrays are C-contiguous to reduce cache misses + Ak_electron = np.ascontiguousarray(Ak[:,0][:,:,idx_k_minus_q, ...]) + Akq_conj = Akq[:,0].reshape(n_exe_states, -1).conj() + # + # Initialize the output matrix + if diagonal_only: + ex_O_mat = np.zeros((nlambda, n_exe_states), dtype=Ak.dtype) # (nlambda, final, initial) + else: + ex_O_mat = np.zeros((nlambda, n_exe_states, n_exe_states), dtype=Ak.dtype) # (nlambda, final, initial) + # + # Loop over the arbitrary parameters (lambda) + for il in range(nlambda): + # Compute the electron contribution + if contribution == 'e' or contribution == 'b': + tmp_wfc = Occ[il][None, ...] @ Ak_electron + # + # Compute the hole contribution and subtract from the electron contribution + if contribution == 'h' or contribution == 'b': + tmp_h = -Ak[:,0] @ Ovv[il][None, ...] + if contribution == 'b': + tmp_wfc += tmp_h + else: + tmp_wfc = tmp_h + # + # Reshape the temporary wavefunction coefficients + tmp_wfc = tmp_wfc.reshape(n_exe_states, -1) + # + # Compute the final matrix elements + if diagonal_only: + ex_O_mat[il] = np.sum(Akq_conj * tmp_wfc, axis=-1) + else: + np.matmul(Akq_conj, tmp_wfc.T, out=ex_O_mat[il]) + # + # Return the computed exciton matrix elements + return ex_O_mat + diff --git a/yambopy/bse/exciton_spin.py b/yambopy/bse/exciton_spin.py new file mode 100644 index 00000000..8509cc73 --- /dev/null +++ b/yambopy/bse/exciton_spin.py @@ -0,0 +1,179 @@ +import os +import numpy as np +from yambopy.dbs.excitondb import YamboExcitonDB +from yambopy.dbs.latticedb import YamboLatticeDB +from yambopy.dbs.wfdb import YamboWFDB +from .exciton_matrix_elements import exciton_X_matelem +from yambopy.tools.degeneracy_finder import find_degeneracy_evs + + +def compute_exciton_spin(lattice, excdb, wfdb, elec_sz, contribution='b',diagonal=False): + """ + Compute the spin matrix elements for excitons. + + This function calculates the spin matrix elements for excitons using the + wavefunctions and spin operators. The spin matrix is computed in the basis + of exciton states, and off-diagonal elements are included. Diagonalization + of the matrix in degenerate subspaces is required to obtain the spin values. + + Parameters + ---------- + lattice : latticedb + Lattice database + excdb : exciton db + Exciton Database + wfdb : wfc db + wavefunction Database + elec_sz : ndarray + Electron spin matrix elements (nk, nbnds. nbnds). + contribution : str, optional + Specifies which contribution to compute: + - 'b': Total spin (default). + - 'e': Electron spin only. + - 'h': Hole spin only. + diagonal : bool, optional + If True, only diagonal spin elements are computed. Default is False. + + Returns + ------- + exe_Sz : ndarray + Spin matrix elements for excitons with shape (nstates, nstates). + """ + # + # Ensure the calculation is valid only for spinor wavefunctions + assert wfdb.nspinor == 2, "Makes sense only for nspinor = 2" + # + # Sanity check + assert np.min(excdb.table[:, 1]) - 1 == wfdb.min_bnd, \ + "wfdb and exciton db are inconsistant (Bands)" + ## sanity check + assert np.max(excdb.table[:, 2]) == wfdb.min_bnd + wfdb.nbands, \ + "wfdb and exciton db are inconsistant (Bands)" + # + assert elec_sz.shape == (wfdb.nkBZ, wfdb.nbands, wfdb.nbands) + # get Akcv + Akcv = excdb.get_Akcv() + # + # Get the exciton q-point in Cartesian coordinates + excQpt = excdb.car_qpoint + # + # Convert the q-point to crystal coordinates + excQpt = lattice.lat @ excQpt + # + # Compute the exciton spin matrix elements + exe_Sz = exciton_X_matelem(excQpt, np.array([0, 0, 0]), Akcv, + Akcv, elec_sz[None,:,None,...], wfdb.kBZ, + diagonal_only=diagonal,contribution=contribution) + # + return exe_Sz[0] + + + + + +def compute_exc_spin_iqpt(path='.', bse_dir='SAVE', iqpt=1, + nstates=-1, contribution='b', degen_tol = 1e-2, + sz=0.5 * np.array([[1, 0], [0, -1]]), + return_dbs_and_spin=True): + """ + + + Description + ----------- + Compute expectation value of S_z operator for excitons. + + Parameters + ---------- + path : str, optional + Path to the directory containing calculation SAVE and BSE folder. + Default: '.' (current directory) + bse_dir : str, optional + Directory containing BSE calculation data. Default: 'SAVE' + iqpt : int or array-like, optional + Q-point index or list of Q-point indices to analyze. Default: 1 + (Fortran indexing) + nstates : int, optional + Number of excitonic states to consider. Use -1 for all states. Default: -1 + contribution : str, optional + Which contribution to compute: + - 'b': both electron and hole (default) + - 'e': electron only + - 'h': hole only + degen_tol : float, optional + Tolerance for detecting degenerate states. Default: 1e-2 + sz : ndarray, optional + S_z operator matrix representation. Default: 0.5 * np.array([[1, 0], [0, -1]]) + return_dbs_and_spin : bool, optional + If True, returns both spin values and database objects. Default: True + + Returns + ------- + exe_Sz : ndarray + Array containing S_z expectation values for excitonic states + dbs_objects : list, optional + If return_dbs_and_spin=True, returns [lattice, wfdb, excdb, elec_sz] database objects + Examples + -------- + Compute the total spin matrix elements for excitons: + + >>> import numpy as np + >>> from yambopy.bse.exciton_spin import compute_exc_spin_iqpt + >>> Sz_exe = compute_exc_spin_iqpt(bse_dir='GW_BSE', nstates=2) + >>> print(Sz_exe) + + Compute only the electron spin contribution: + + >>> Sz_exe = compute_exc_spin_iqpt(bse_dir='GW_BSE', nstates=2, contribution='e') + + Compute only the hole spin contribution: + + >>> Sz_exe = compute_exc_spin_iqpt(bse_dir='GW_BSE', nstates=2, contribution='h') + """ + # + ## Check if it single Q or multiple Q's + if np.isscalar(iqpt): iqpt = [iqpt] + else : iqpt = list(iqpt) + # Load the lattice database + lattice = YamboLatticeDB.from_db_file(os.path.join(path, 'SAVE', 'ns.db1')) + ## load exbds + excdb = [] + for iq in iqpt: + filename = 'ndb.BS_diago_Q%d' % (iq) + excdb.append(YamboExcitonDB.from_db_file(lattice, filename=filename, + folder=os.path.join(path, bse_dir), + Load_WF=True, neigs=nstates)) + # Load the wavefunction database + wfdb = YamboWFDB(path=path, latdb=lattice, + bands_range=[np.min(excdb[0].table[:, 1]) - 1, + np.max(excdb[0].table[:, 2])]) + # + # Compute the spin matrix elements in the BZ + elec_sz = wfdb.get_spin_m_e_BZ(s_z=sz) + # + exe_Sz = [] + for ixdb in excdb: + smat = compute_exciton_spin(lattice, ixdb, + wfdb, elec_sz, + contribution=contribution, + diagonal=False) + smat = get_spinvals(smat, ixdb.eigenvalues, atol=degen_tol) + ss_tmp = [] + for i in smat: ss_tmp = ss_tmp + list(i) + exe_Sz.append(ss_tmp) + # + exe_Sz = np.array(exe_Sz) + if return_dbs_and_spin : return exe_Sz,[lattice, wfdb, excdb, elec_sz] + else : return exe_Sz + + + +def get_spinvals(spin_matrix, eigenvalues, atol=1e-3, rtol=1e-3): + degen_idx = find_degeneracy_evs(eigenvalues,atol=atol, rtol=rtol) + spins = [] + for id in degen_idx: + w = np.linalg.eigvals(spin_matrix[id,:][:,id]) + spins.append(w) + return spins + + + diff --git a/yambopy/bse/realSpace_excitonwf.py b/yambopy/bse/realSpace_excitonwf.py new file mode 100644 index 00000000..19821322 --- /dev/null +++ b/yambopy/bse/realSpace_excitonwf.py @@ -0,0 +1,443 @@ +### Compute real space exction wavefunction when hole/electron is fixed. +import numpy as np +from yambopy.kpoints import build_ktree, find_kpt +from yambopy.tools.function_profiler import func_profile +from tqdm import tqdm +import os + +## Usage +""" +import numpy as np +from yambopy.dbs.excitondb import YamboExcitonDB +from yambopy.dbs.latticedb import YamboLatticeDB +from yambopy.dbs.wfdb import YamboWFDB +import os + +iqpt = 1 # qpt index of exciton + +# load lattice db +lattice = YamboLatticeDB.from_db_file(os.path.join('.', 'SAVE', 'ns.db1')) + +# load exciton db +# note in case too many excitons, load only first few with `neigs' flag +# DO NOT forget to include all degenerate states when giving neigs flag ! +# +filename = 'ndb.BS_diago_Q%d' % (iqpt) +excdb = YamboExcitonDB.from_db_file(lattice, filename=filename, + folder=os.path.join('.', 'GW_BSE'), + neigs = 20) + +#Load the wavefunction database +wfdb = YamboWFDB(path='.', latdb=lattice, + bands_range=[np.min(excdb.table[:, 1]) - 1, + np.max(excdb.table[:, 2])]) + +## plot the exciton wavefunction with hole fixed at [0,0,0] (crystal coordinates) +# in a [1,1,1] supercell with 80 Ry wf cutoff +# I want to set the degeneracy threshold to 0.01 eV +# For example I want to plot the 3rd exciton, so iexe = 2 (python indexing ) +# +excdb.real_wf_to_cube(iexe=2, wfdb=wfdb, fixed_postion=[0,0,0], supercell=[1,1,1], + degen_tol=0.01, wfcCutoffRy=80, fix_particle='h') + +## fix_particle = 'e' if you want to fix electron and plot hole density +## .cube will be dumped and use vesta to visualize it ! +""" + +## + +@func_profile +def ex_wf2Real(Akcv, Qpt, wfcdb, bse_bnds, fixed_postion, + fix_particle='h', supercell=[1,1,1], wfcCutoffRy=-1, + block_size=256): + """ + Compute real-space exciton wavefunction when hole/electron is fixed. + + This is the main interface function that handles both resonant and anti-resonant parts + of the exciton wavefunction in the case of non-TDA calculations. + + Args: + Akcv (numpy.ndarray): Exciton wavefunction coefficients with shape: + - (Nstates,1,ns,k,c,v) for TDA + - (Nstates,2,ns,k,c,v) for non-TDA (2 for resonant/anti-resonant) + Qpt (numpy.ndarray): Q-point of exciton in crystal coordinates + wfcdb (YamboWFDB): Wavefunction database + bse_bnds (list): Band range used in BSE [min_band, max_band] (python indexing) + fixed_postion (list): Position of fixed particle in crystal coordinates + fix_particle (str): 'e' to fix electron, 'h' to fix hole (default) + supercell (list): Supercell dimensions [nx,ny,nz] + wfcCutoffRy (float): Wavefunction cutoff in Rydberg (-1 for no cutoff) + block_size (int): Block size for memory-efficient computation. + ## choosing lowe block_size will slight lower the memory requirment but also less faster + + Returns: + tuple: (supercell_latvecs, atom_nums, atom_pos, exe_wfc_real) + - supercell_latvecs: Supercell lattice vectors (3,3) + - atom_nums: Atomic numbers. + - atom_pos: Atomic positions in cartisian units + - exe_wfc_real: Real-space exciton wavefunction (nstates, nspin, nspinor_electron, nspinor_hole, + Nx_grid, Ny_grid, Nz_grid) + note if nspin =2, then nspinor = 1, similarly, if nspinor = 2, nspin = 1. + """ + ## first the resonat part + supercell_latvecs,atom_nums,atom_pos,exe_wfc_real = \ + ex_wf2Real_kernel(Akcv[:,0], Qpt, wfcdb, bse_bnds, fixed_postion, + fix_particle=fix_particle, supercell=supercell, + wfcCutoffRy=wfcCutoffRy, block_size=block_size, + ares=False, out_res=None) + # for nonTDA add the anti-resonant part + if Akcv.shape[1] == 2: + supercell_latvecs,atom_nums,atom_pos,exe_wfc_real = \ + ex_wf2Real_kernel(Akcv[:,1], Qpt, wfcdb, bse_bnds, fixed_postion, + fix_particle=fix_particle, supercell=supercell, + wfcCutoffRy=wfcCutoffRy, block_size=block_size, + ares=True, out_res=exe_wfc_real) + + return supercell_latvecs,atom_nums,atom_pos,exe_wfc_real + + + +@func_profile +def ex_wf2Real_kernel(Akcv, Qpt, wfcdb, bse_bnds, fixed_postion, + fix_particle='h', supercell=[1,1,1], wfcCutoffRy=-1, + block_size=256, ares=False, out_res=None): + """ + Core kernel function for computing real-space exciton wavefunction. + + Computes either: + - Akcv * psi_{kc}(r_e) * (psi_{k-Q,v}(r_h))^* (resonant part) + - Akcv * psi_{kv}(r_e) * (psi_{k-Q,c}(r_h))^* (anti-resonant part) + + Note: For density, one must compute the absolute value squared. + + Args: + Akcv (numpy.ndarray): Exciton coefficients [nstates,ns,nk,nc,nv] + Qpt (numpy.ndarray): Q-point in crystal coordinates [3] + wfcdb (YamboWFDB): Wavefunction database + bse_bnds (list): BSE band range used in bse [nb1, nb2]. fortran indexing. i.e index starts from 1. + i.e nb1 and nb2 are same indices used in yambo input i.e + % BSEBands nb1 | nb2 % in yambo input + fixed_postion (list): Fixed particle position in crystal coords [3] + fix_particle (str): 'e'=fix electron, 'h'=fix hole (default) + supercell (list): Supercell dimensions [nx,ny,nz] + wfcCutoffRy (float): Wavefunction cutoff in Rydberg (-1= full cutoff) + block_size (int): Memory block size for computation. is a postive integer, + the default is 256 which is very good but uses more memory. + # decrease it when you run into memory issues + ares (bool): If True, compute anti-resonant part + out_res (numpy.ndarray): Adds to this array + and is returned instead of internally creating. + Make sure it its consistant. (no internal checking done) + + Returns: + tuple: (supercell_latvecs, atom_nums, atom_pos, exe_wfc_real) + Natoms = (natom_in_unit_call * Nsupercell) + 1 (+1 due to hole/electron) + - supercell_latvecs (numpy.ndarray): Supercell lattice vectors [3,3] + - atom_nums (numpy.ndarray): Atomic numbers [Natoms] + - atom_pos (numpy.ndarray): Atomic positions in cartisian units [Natoms,3] + - exe_wfc_real (numpy.ndarray): Wavefunction in real space + [nstates, nspin, nspinor_electron, nspinor_hole, FFTx, FFTy, FFTz] + note if nspin =2, then nspinor = 1, similarly, if nspinor = 2, nspin = 1 + """ + # + # + if block_size < 1: + print('Warning: Wrong block_size. setting to 1') + block_size = 1 + # + # Convert them to + for i in range(3): + if supercell[i]%2 == 0: + print('Warning : Even supercell given, so increasing' + + ' supercell size along %d direction by 1'%(i+1)) + supercell[i] = supercell[i] +1 + # + # + fix_particle = fix_particle.lower() + bse_bnds = [min(bse_bnds)-1,max(bse_bnds)] + assert bse_bnds[0] >= wfcdb.min_bnd, \ + "%d is used in bse but not found in wfcdb, load more wfcs" %(bse_bnds[0]+1) + assert bse_bnds[1] <= wfcdb.min_bnd + wfcdb.nbands, \ + "%d is used in bse but not found in wfcdb, load more wfcs" %(wfcdb.min_bnd + wfcdb.nbands) + bse_bnds = np.array(bse_bnds,dtype=int)-wfcdb.min_bnd + + # ns is number of collinear spin components + nstates, ns, nk, nc, nv = Akcv.shape + + kpt_idx = wfcdb.ydb.kpoints_indexes + sym_idx = wfcdb.ydb.symmetry_indexes + nkBZ = len(sym_idx) + assert nc + nv == bse_bnds[1]-bse_bnds[0], "Band mismatch" + assert nk == nkBZ, "kpoint mismatch" + # + fixed_postion = np.array(fixed_postion) + # + hole_bnds = [bse_bnds[0],bse_bnds[0]+nv] + elec_bnds = [bse_bnds[0]+nv,bse_bnds[1]] + # + if ares: + # if anti-resonant part, we swap the electron and hole bands + # first transpose c,v dimensions + Akcv = Akcv.transpose(0,1,2,4,3) + tmp = hole_bnds + hole_bnds = elec_bnds + elec_bnds = tmp + nstates, ns, nk, nc, nv = Akcv.shape + + lat_vec = wfcdb.ydb.lat.T + blat = np.linalg.inv(lat_vec) + gvecs_iBZ_idx = [] + fft_box = np.zeros(3,dtype=int) + # + + for ik in range(len(wfcdb.gvecs)): + idx_gvecs_tmp = np.arange(wfcdb.ngvecs[ik],dtype=int) + if wfcCutoffRy > 0: + tmp_gvecs = 2*np.pi*np.linalg.norm((wfcdb.gvecs[ik, :wfcdb.ngvecs[ik], :] + + wfcdb.kpts_iBZ[ik][None,:])@blat,axis=-1) + idx_tmp = tmp_gvecs < np.sqrt(wfcCutoffRy) + idx_gvecs_tmp = idx_gvecs_tmp[idx_tmp].copy() + # + gvecs_iBZ_idx.append(idx_gvecs_tmp) + ## Get the fft box + min_fft_idx = np.min(wfcdb.gvecs[ik, :wfcdb.ngvecs[ik], :][idx_gvecs_tmp] , axis=0) + max_fft_idx = np.max(wfcdb.gvecs[ik, :wfcdb.ngvecs[ik], :][idx_gvecs_tmp] , axis=0) + assert np.min(max_fft_idx) >= 0 and np.max(min_fft_idx) < 0, "Invalid G-vectors" + for i in range(3): + fft_box[i] = max([fft_box[i], max_fft_idx[i] - min_fft_idx[i] + 3]) + + # Compute nstates, nk, Nx, Ny, Nz object + if out_res is None : print("Wfc FFT Grid : ",fft_box[0], fft_box[1], fft_box[2]) + # + ## + # find the nearest fft grid point. + fixed_postion = fixed_postion.astype(np.float64) + fx_pnt_int = np.floor(fixed_postion) + fixed_postion -= fx_pnt_int + fixed_postion = np.round(fixed_postion * fft_box) / fft_box + fixed_postion += fx_pnt_int + # shift the position of hole to middle of supercell + fixed_postion += np.array(supercell)//2 + # + if fix_particle == 'h': + print("Position of the hole (reduced units) is set to : ", + fixed_postion[0], fixed_postion[1], fixed_postion[2]) + if fix_particle == 'e': + print("Position of the electron (reduced units) is set to : ", + fixed_postion[0], fixed_postion[1], fixed_postion[2]) + # + ktree = wfcdb.ktree #build_ktree(wfcdb.kBZ) + # + nspinorr = wfcdb.nspinor + if out_res is not None: + exe_wfc_real = out_res.reshape(nstates, ns, nspinorr, nspinorr, + supercell[0],fft_box[0], + supercell[1],fft_box[1], + supercell[2],fft_box[2]) + else: + exe_wfc_real = np.zeros((nstates, ns, nspinorr, nspinorr, + supercell[0],fft_box[0], + supercell[1],fft_box[1], + supercell[2],fft_box[2]), + dtype=np.complex64) + # + Lx = np.arange(supercell[0],dtype=int) + Ly = np.arange(supercell[1],dtype=int) + Lz = np.arange(supercell[2],dtype=int) + Lsupercells = np.zeros((supercell[0],supercell[1],supercell[2],3),dtype=int) + Lsupercells[...,0], Lsupercells[...,1], Lsupercells[...,2] = np.meshgrid(Lx,Ly,Lz,indexing='ij') + # + FFTxx = np.fft.fftfreq(fft_box[0]) + FFTyy = np.fft.fftfreq(fft_box[1]) + FFTzz = np.fft.fftfreq(fft_box[2]) + FFTxx = FFTxx - np.floor(FFTxx) + FFTyy = FFTyy - np.floor(FFTyy) + FFTzz = FFTzz - np.floor(FFTzz) + FFFboxs = np.zeros((fft_box[0],fft_box[1],fft_box[2],3)) + FFFboxs[...,0], FFFboxs[...,1], FFFboxs[...,2] = np.meshgrid(FFTxx,FFTyy,FFTzz,indexing='ij') + # + # ns is nspin + exe_tmp_wf = np.zeros((nstates, ns, nspinorr, nspinorr, min(nk,block_size) , + fft_box[0], fft_box[1], fft_box[2]),dtype=np.complex64) + # + exp_tmp_kL = np.zeros((min(nk,block_size) ,supercell[0], + supercell[1],supercell[2]),dtype=np.complex64) + + nblks = nk//block_size + nrem = nk%block_size + if nrem > 0: nblks = nblks+1 + # + pbar = tqdm(total=nk, desc="Ex-wf") + # + for ibk in range(nblks): + ikstart = ibk*block_size + ikstop = min(ikstart + block_size,nk) + for ik in range(ikstart,ikstop): + ## First get the electronic wfcs + ik_ibz = kpt_idx[ik] + isym = sym_idx[ik] + wfc_tmp, gvecs_tmp = wfcdb.get_iBZ_wf(ik_ibz) + wfc_tmp = wfc_tmp[:,elec_bnds[0]:elec_bnds[1],:,gvecs_iBZ_idx[ik_ibz]] + gvecs_tmp = gvecs_tmp[gvecs_iBZ_idx[ik_ibz]] + kvec = wfcdb.get_iBZ_kpt(ik_ibz) + # get the rotated wf + if isym != 0: + sym_mat = wfcdb.ydb.sym_car[isym] + time_rev = (isym >= len(wfcdb.ydb.sym_car + ) / (1 + int(np.rint(wfcdb.ydb.time_rev)))) + kvec, wfc_tmp, gvecs_tmp = wfcdb.apply_symm( + kvec, wfc_tmp, gvecs_tmp, time_rev, sym_mat) + + kelec = kvec + wfc_elec = wfc_tmp + gvecs_elec = gvecs_tmp + + ## Do the same and get hole wfc + ikhole = find_kpt(ktree, kelec-Qpt) + ik_ibz = kpt_idx[ikhole] + isym = sym_idx[ikhole] + wfc_tmp, gvecs_tmp = wfcdb.get_iBZ_wf(ik_ibz) + wfc_tmp = wfc_tmp[:,hole_bnds[0]:hole_bnds[1],:,gvecs_iBZ_idx[ik_ibz]] + gvecs_tmp = gvecs_tmp[gvecs_iBZ_idx[ik_ibz]] + kvec = wfcdb.get_iBZ_kpt(ik_ibz) + # get the rotated wf + if isym != 0: + sym_mat = wfcdb.ydb.sym_car[isym] + time_rev = (isym >= len(wfcdb.ydb.sym_car) / (1 + int(np.rint(wfcdb.ydb.time_rev)))) + kvec, wfc_tmp, gvecs_tmp = wfcdb.apply_symm(kvec, wfc_tmp, gvecs_tmp, time_rev, sym_mat) + # + khole = -kvec + wfc_hole = wfc_tmp.conj() + gvecs_hole = -gvecs_tmp + + if fix_particle == 'h': + fx_kvec = khole + fx_wfc = wfc_hole + fx_gvec = gvecs_hole + # + ft_kvec = kelec + ft_wfc = wfc_elec + ft_gvec = gvecs_elec + else : + ft_kvec = khole + ft_wfc = wfc_hole + ft_gvec = gvecs_hole + # + fx_kvec = kelec + fx_wfc = wfc_elec + fx_gvec = gvecs_elec + # compute + ## NM : Donot perform FFT as we only need it for one point. + exp_fx = np.exp(2*np.pi*1j*((fx_gvec + fx_kvec[None,:])@fixed_postion)) + fx_wfc *= exp_fx[None,None,None,:] + fx_wfc = np.sum(fx_wfc,axis=-1) #(spin,bnd,spinor) + ns1, nbndc, nspinorr, ng = ft_wfc.shape + #if ft_ikpt not in prev_ikpts: + ft_wfcr = wfcdb.to_real_space(ft_wfc.reshape(-1,nspinorr,ng),ft_gvec, grid=fft_box) + ft_wfcr = ft_wfcr.reshape(ns1,nbndc,nspinorr,fft_box[0],fft_box[1],fft_box[2]) + exp_kx_r = np.exp(2*np.pi*1j*FFFboxs.reshape(-1,3)@ft_kvec).reshape(FFFboxs.shape[:3]) + ft_wfcr *= exp_kx_r[None,None,None,...] + # + fx_wfc = fx_wfc.astype(np.complex64) + ft_wfcr = ft_wfcr.astype(np.complex64) + # + if fix_particle == 'h': + np.einsum('nscv,svy,scxijk->nsxyijk',Akcv[:,:,ik,...].astype(np.complex64),fx_wfc,ft_wfcr, + optimize=True,out=exe_tmp_wf[:,:,:,:,ik-ikstart]) + else : + np.einsum('nscv,scx,svyijk->nsxyijk',Akcv[:,:,ik,...].astype(np.complex64),fx_wfc,ft_wfcr, + optimize=True,out=exe_tmp_wf[:,:,:,:,ik-ikstart]) + # + #exe_tmp_wf[:,:,:,ik-ikstart] *= exp_kx_r[...].reshape(FFFboxs.shape[:3])[None,None,None] + exp_tmp_kL[ik-ikstart] = np.exp(1j*2*np.pi*np.einsum('...x,x->...',Lsupercells,ft_kvec)) + # + # update progess bar + pbar.update(1) + # + ## perform gemm operation + total_gemms_t = nstates*ns*nspinorr**2 + exp_tmp_kL_tmp = exp_tmp_kL.reshape(len(exp_tmp_kL),-1)[:(ikstop-ikstart)].T + exe_tmp_wf_tmp = exe_tmp_wf.reshape(nstates,ns,nspinorr,nspinorr,-1,np.prod(fft_box)) + exe_tmp_wf_tmp = exe_tmp_wf_tmp[...,:(ikstop-ikstart),:] + # + for igemms in range(total_gemms_t): + ii, iis, jj, kk = np.unravel_index(igemms, (nstates,ns,nspinorr,nspinorr)) + Ctmp = (exp_tmp_kL_tmp @ exe_tmp_wf_tmp[ii, iis, jj, kk ]) + Ctmp = Ctmp.reshape(supercell[0], supercell[1], supercell[2], + fft_box[0], fft_box[1], fft_box[2]) + Ctmp *= (1.0/np.prod(supercell)) + exe_wfc_real[ii, iis, jj, kk ] += Ctmp.transpose(0,3,1,4,2,5) + # + exe_wfc_real = exe_wfc_real.reshape(nstates,ns,nspinorr,nspinorr, + supercell[0]*fft_box[0], + supercell[1]*fft_box[1], + supercell[2]*fft_box[2]) + # + # compute postioon of fixed particle in cart units + fixed_postion_cc = lat_vec@fixed_postion + Lsupercells = Lsupercells.reshape(-1,3)#/np.array(supercell)[None,:] + Lsupercells = Lsupercells@lat_vec.T + atom_pos = Lsupercells[:,None,:] + wfcdb.ydb.car_atomic_positions[None,:,:] + atom_pos = np.append(atom_pos.reshape(-1,3),fixed_postion_cc[None,:],axis=0) + supercell_latvecs = lat_vec*np.array(supercell)[None,:] + ## Make atomic numbers + atom_nums = np.zeros(len(atom_pos),dtype=wfcdb.ydb.atomic_numbers.dtype) + attmp = atom_nums[:-1].reshape(-1,len(wfcdb.ydb.atomic_numbers)) + attmp[...]= wfcdb.ydb.atomic_numbers[None,:] + atom_nums[-1] = 200 + # + return supercell_latvecs,atom_nums,atom_pos,exe_wfc_real + +#def compute_exc_wfc_real(path='.', bse_dir='SAVE', iqpt=1, nstates=[1], +# fixed_postion=[0,0,0], fix_particle='h', aveg=True, supercell=[1,1,1], +# wfcCutoffRy=-1, phase=False, block_size=256): +# # +# lattice = YamboLatticeDB.from_db_file(os.path.join(path, 'SAVE', 'ns.db1')) +# filename = 'ndb.BS_diago_Q%d' % (iqpt) +# excdb = YamboExcitonDB.from_db_file(lattice, filename=filename, +# folder=os.path.join(path, bse_dir), +# Load_WF=True, neigs=max(nstates)) +# # Load the wavefunction database +# wfdb = YamboWFDB(path=path, latdb=lattice, +# bands_range=[np.min(excdb.table[:, 1]) - 1, +# np.max(excdb.table[:, 2])]) +# # +# Akcv = excdb.get_Akcv()[min(nstates)-1:max(nstates)] +# excQpt = excdb.car_qpoint +# # +# # Convert the q-point to crystal coordinates +# Qpt = wfdb.ydb.lat @ excQpt + +# sc_latvecs, atom_nums, atom_pos, real_wfc = ex_wf2Real(Akcv, Qpt, wfdb, [np.min(excdb.table[:, 1]), +# np.max(excdb.table[:, 2])], fixed_postion=fixed_postion, +# fix_particle=fix_particle, supercell=supercell, +# wfcCutoffRy=wfcCutoffRy, block_size=block_size) +# # +# # +# nstates_range = np.arange(nstates[0],nstates[1],dtype=int) +# density = np.abs(real_wfc)**2 + +# if fix_particle == 'h': name_file = 'electron' +# else: name_file = 'hole' +# if real_wfc.shape[1] != 1: +# print("phase plot only works for nspin = 1 and nspinor == 1") +# phase = False +# if phase: +# phase = np.sign(real_wfc.real) #np.sign(np.angle(real_wfc)) +# density *= phase +# if aveg: +# real_wfc = np.sum(density,axis=(0,1,2)) +# real_wfc = real_wfc/np.max(np.abs(real_wfc)) +# write_cube('exe_wf_avg_%s_%d-%d.cube' %(name_file,nstates[0],nstates[1]), +# real_wfc, sc_latvecs, atom_pos, atom_nums, header='Real space exciton wavefunction') +# else: +# real_wfc = np.sum(density,axis=(1,2)) +# for i in range(len(real_wfc)): +# real_wfc1 = real_wfc[i]/np.max(np.abs(real_wfc[i])) +# write_cube('exe_wf_%s_%d.cube' %(name_file,nstates[i]), real_wfc1, sc_latvecs, +# atom_pos, atom_nums, header='Real space exciton wavefunction') + + + +##if __name__ == "__main__": + diff --git a/yambopy/bse/rotate_excitonwf.py b/yambopy/bse/rotate_excitonwf.py new file mode 100644 index 00000000..8cac01ef --- /dev/null +++ b/yambopy/bse/rotate_excitonwf.py @@ -0,0 +1,76 @@ +import numpy as np +from yambopy.kpoints import build_ktree, find_kpt +from yambopy.tools.function_profiler import func_profile + + +@func_profile +def rotate_exc_wf(Ak, symm_mat_red, kpoints, exe_qpt, dmats, time_rev, ktree=None): + """ + Rotate the exciton wavefunction Ak using symmetry operations. + + This function applies a symmetry operation to the exciton wavefunction Ak, + which is represented in the basis of electronic states. The rotation is + performed using the symmetry matrix in reduced coordinates and the + corresponding representation matrices. + + Parameters + ---------- + Ak : array_like + Exciton wavefunction coefficients with shape (n_exe_states, 1 or 2, nspin, nk, nc, nv). + 1 for TDA and 2 for coupling + symm_mat_red : array_like + Symmetry matrix in reduced coordinates with shape (3, 3). + kpoints : array_like + K-points in the full Brillouin zone (crystal coordinates) with shape (nk, 3). + exe_qpt : array_like + Momentum of the given exciton (q-point) in crystal coordinates with shape (3,). + dmats : array_like + Representation matrices for the symmetry operation with shape (nk, nspin, Rk_band, k_band). + time_rev : bool + If True, apply time-reversal symmetry to the wavefunction. + ktree : object, optional + Pre-built k-point tree for efficient k-point searching. If not provided, it will be built. + + Returns + ------- + rot_Ak : ndarray + Rotated exciton wavefunction coefficients with the same shape as Ak. + """ + # Initialize the rotated Ak array + rot_Ak = np.zeros(Ak.shape, dtype=Ak.dtype) + # Check TDA + tda = True + if Ak.shape[1] == 2 : tda = False + + ns, nk, nc, nv = Ak.shape[2:] + # Build a k-point tree if not provided + if ktree is None: ktree = build_ktree(kpoints) + + # Compute the indices of Rk and Rk - q + Rkpts = kpoints @ symm_mat_red.T # Rotated k-points + k_minus_q = kpoints - exe_qpt[None, :] # k - q + idx_Rk = find_kpt(ktree, Rkpts) # Indices of rotated k-points + idx_k_minus_q = find_kpt(ktree, k_minus_q) # Indices of k - q + + # Extract the conduction and valence parts of the representation matrices + Dcc = dmats[:, :, nv:, nv:].transpose(1,0,2,3) # Conduction band part + Dvv = dmats[idx_k_minus_q, :, :nv, :nv].transpose(1,0,2,3).conj() # Valence band part (conjugated) + + # Apply time-reversal symmetry if required + Ak_tmp = Ak + if time_rev: Ak_tmp = Ak.conj() + + # Rotate the Ak wavefunction using the representation matrices + ## rotate the resonant part + rot_Ak[:, :1, :, idx_Rk, ...] = ((Dcc[None, ...] @ Ak_tmp[:,0,...]) + @ (Dvv.transpose(0, 1, 3, 2)[None, ...]) + ).reshape(rot_Ak[:,:1].shape) + if not tda : + # Rotate the anti-resonant part + Dvv = dmats[:, :, :nv, :nv].transpose(1,0,2,3) + Dcc = dmats[idx_k_minus_q, :, nv:, nv:].transpose(1,0,2,3).conj() + rot_Ak[:, 1:, :, idx_Rk, ...] = ((Dcc[None, ...] @ Ak_tmp[:,1,...]) + @ (Dvv.transpose(0, 1, 3, 2)[None, ...]) + ).reshape(rot_Ak[:, 1:].shape) + return rot_Ak + diff --git a/yambopy/dbs/excitondb.py b/yambopy/dbs/excitondb.py index 6a21aa79..b9df8c52 100644 --- a/yambopy/dbs/excitondb.py +++ b/yambopy/dbs/excitondb.py @@ -24,6 +24,8 @@ from yambopy.dbs.latticedb import YamboLatticeDB from yambopy.dbs.electronsdb import YamboElectronsDB from yambopy.dbs.qpdb import YamboQPDB +from yambopy.io.cubetools import write_cube +from yambopy.bse.realSpace_excitonwf import ex_wf2Real class ExcitonList(): """ @@ -84,11 +86,13 @@ def __init__(self,lattice,Qpt,eigenvalues,l_residual,r_residual,spin_pol='no',ca self.spin_pol = spin_pol @classmethod - def from_db_file(cls,lattice,filename='ndb.BS_diago_Q1',folder='.',Load_WF=True): + def from_db_file(cls,lattice,filename='ndb.BS_diago_Q1',folder='.',Load_WF=True, neigs=-1): """ Initialize this class from a file Set `Read_WF=False` to avoid reading eigenvectors for faster IO and memory efficiency. + If neigs < 0 ; all eigen values (vectors) are loaded or else first neigs are loaded + " In case of non-TDA, we load right eigenvectors. """ path_filename = os.path.join(folder,filename) if not os.path.isfile(path_filename): @@ -98,6 +102,13 @@ def from_db_file(cls,lattice,filename='ndb.BS_diago_Q1',folder='.',Load_WF=True) Qpt = filename.split("Q",1)[1] with Dataset(path_filename) as database: + #energies + eig = database.variables['BS_Energies'][:]*ha2ev + eigenvalues = eig[:,0]+eig[:,1]*I + neig_full = len(eigenvalues) + if neigs < 0 or neigs > neig_full: neigs = neig_full + eigenvalues = eigenvalues[:neigs] + if 'BS_left_Residuals' in list(database.variables.keys()): # MN: using complex views instead of a+I*b copies to avoid memory duplication # Old (yet instructive) memory duplication code @@ -105,13 +116,13 @@ def from_db_file(cls,lattice,filename='ndb.BS_diago_Q1',folder='.',Load_WF=True) #rer,imr = database.variables['BS_right_Residuals'][:].T #l_residual = rel+iml*I #r_residual = rer+imr*I - l_residual = database.variables['BS_left_Residuals'][:] - r_residual = database.variables['BS_right_Residuals'][:] + l_residual = database['BS_left_Residuals'][:neigs,...].data + r_residual = database['BS_right_Residuals'][:neigs,...].data l_residual = l_residual.view(dtype=CmplxType(l_residual)).reshape(len(l_residual)) r_residual = r_residual.view(dtype=CmplxType(r_residual)).reshape(len(r_residual)) if 'BS_Residuals' in list(database.variables.keys()): # Compatibility with older Yambo versions - rel,iml,rer,imr = database.variables['BS_Residuals'][:].T + rel,iml,rer,imr = database['BS_Residuals'][:neigs,...].data.T l_residual = rel+iml*I r_residual = rer+imr*I @@ -121,19 +132,15 @@ def from_db_file(cls,lattice,filename='ndb.BS_diago_Q1',folder='.',Load_WF=True) car_qpoint = database.variables['Q-point'][:]/lattice.alat if Qpt=="1": car_qpoint = np.zeros(3) - #energies - eig = database.variables['BS_Energies'][:]*ha2ev - eigenvalues = eig[:,0]+eig[:,1]*I - #eigenvectors table = None eigenvectors = None if Load_WF and 'BS_EIGENSTATES' in database.variables: - eiv = database.variables['BS_EIGENSTATES'][:] + eiv = database['BS_EIGENSTATES'][:neigs,...].data #eiv = eiv[:,:,0] + eiv[:,:,1]*I #eigenvectors = eiv eigenvectors = eiv.view(dtype=CmplxType(eiv)).reshape(eiv.shape[:-1]) - table = database.variables['BS_TABLE'][:].T.astype(int) + table = np.rint(database.variables['BS_TABLE'][:].T).astype(int) spin_vars = [int(database.variables['SPIN_VARS'][:][0]), int(database.variables['SPIN_VARS'][:][1])] if spin_vars[0] == 2 and spin_vars[1] == 1: @@ -230,6 +237,118 @@ def write_sorted(self,prefix='yambo'): for i,n in sort_i: f.write("%3d %12.8lf %12.8e\n"%(n+1,eig[n],i)) + def get_Akcv(self): + """ + Convert eigenvectors from (neigs,BS_table) -> (neigs,nblks,nspin,k,c,v) + nblks = 2 for coupling, else 1 for TDA + """ + nspin = 1 + if self.spin_pol == 'pol': nspin = 2 + # + tmp_akcv = getattr(self, 'Akcv', None) + if tmp_akcv is not None: return tmp_akcv + # + if self.eigenvectors is None: return None + eig_wfcs = self.eigenvectors + # + nk = self.nkpoints + nv = self.nvbands + nc = self.ncbands + # Make sure nspin * nc * nv * nk = BS_TABLE length + table_len = nspin*nk*nv*nc + assert table_len == self.table.shape[0], "BS_TABLE length not equal to ns * nc * nv * nk" + # + v_min = np.min(self.table[:,1]) + c_min = np.min(self.table[:,2]) + bs_table0 = self.table[:,0]-1 + bs_table1 = self.table[:,1] - v_min + bs_table2 = self.table[:,2] - c_min + bs_table3 = self.table[:,3]-1 + # + eig_wfcs_returned = np.zeros(eig_wfcs.shape,dtype=eig_wfcs.dtype) + # + sort_idx = bs_table0*nc*nv + bs_table2*nv + bs_table1 + nk*nc*nv*bs_table3 + # + eig_wfcs_returned[:,sort_idx] = eig_wfcs[...,:table_len] + # check if this is coupling . + if eig_wfcs.shape[-1]//table_len == 2: + eig_wfcs_returned[:,sort_idx+table_len] = eig_wfcs[...,table_len:] + # NM : Note that here v and c are inverted i.e + # psi_S = Akcv * phi_v(r_e) * phi_c^*(r_h) + eig_wfcs_returned = eig_wfcs_returned.reshape(-1,2,nspin,nk,nc,nv) + else : + eig_wfcs_returned = eig_wfcs_returned.reshape(-1,1,nspin,nk,nc,nv) + # + self.Akcv = eig_wfcs_returned + return self.Akcv + + def real_wf_to_cube(self, iexe, wfdb, fixed_postion=[0,0,0], supercell=[1,1,1], degen_tol=1e-2, + wfcCutoffRy=-1, fix_particle='h', phase=False, block_size=256): + """ + Function to compute and save real-space exciton wavefunctions and + dump to cube file + + Args: + iexe: index of excitonic states (python indexing. so starts with 0) + wfcb: wavefunction database. + fixed_postion (list): Position of fixed particle in crystal coordinates + supercell (list): Supercell dimensions [nx,ny,nz] + degen_tol (float): degeneracy threshold (in eV). default 0.01 eV + fix_particle (str): 'e' to fix electron, 'h' to fix hole (default) + wfcCutoffRy (float): Wavefunction cutoff in Rydberg (-1 for full cutoff) + phase (bool): If True, include phase information i.e multiply the density with + sign of real part of the wavefunction + block_size (int): Block size for memory-efficient computation. leave it to default + unless you are in exteremely low memory situation. + + Returns: + None (write cube file to disk) + """ + # + # first get all degenerate states + iexe_degen_states = np.array(self.get_degenerate(iexe+1,eps=degen_tol))-1 + print("Degenerate states: ",iexe_degen_states+1) + # nicely arrange eigvectors to Akcv + Akcv = self.get_Akcv()[iexe_degen_states] + excQpt = self.car_qpoint + # Convert the q-point to crystal coordinates + Qpt = wfdb.ydb.lat @ excQpt + # + if fix_particle == 'h': name_file = 'electron' + else: name_file = 'hole' + + if phase and wfdb.wf.shape[1] != 1 and wfdb.wf.shape[3] != 1: + print("phase plot only works for nspin = 1 and nspinor == 1") + phase = False + if phase and len(iexe_degen_states) > 1: + phase = False + print("Warning: phase plots donot work for degenerate states") + + print('Computing exciton wavefunction (%s density) to real space.' %(name_file)) + sc_latvecs, atom_nums, atom_pos, real_wfc = ex_wf2Real(Akcv, Qpt, wfdb, [np.min(self.table[:, 1]), + np.max(self.table[:, 2])], fixed_postion=fixed_postion, + fix_particle=fix_particle, supercell=supercell, + wfcCutoffRy=wfcCutoffRy, block_size=block_size) + # Compute the absoulte^2 + density = np.abs(real_wfc)**2 + # Multiply with phase if necessary + if phase: + phase = np.sign(real_wfc.real) #np.sign(np.angle(real_wfc)) + density *= phase + # + # sum over spinor indices and degenerate states + real_wfc = np.sum(density,axis=(0,1,2,3)) + # normalize with max value + max_normalize_val = np.max(np.abs(real_wfc)) + print('Max Normalization value: ',max_normalize_val) + real_wfc *= (1.0/max_normalize_val) + # write to cube file + print('Writing to .cube file') + write_cube('exe_wf_%s_%d.cube' %(name_file,iexe+1), + real_wfc, sc_latvecs, atom_pos, atom_nums, + header='Real space exciton wavefunction') + + def get_nondegenerate(self,eps=1e-4): """ get a list of non-degenerate excitons @@ -276,12 +395,9 @@ def get_degenerate(self,index,eps=1e-4): Args: eps: maximum energy difference to consider the two excitons degenerate in eV """ - energy = self.eigenvalues[index-1] - excitons = [] - for n,e in enumerate(self.eigenvalues): - if np.isclose(energy,e,atol=eps): - excitons.append(n+1) - return excitons + energy = self.eigenvalues[index-1].real + excitons = np.where(np.isclose(self.eigenvalues.real, energy, atol=eps))[0] + 1 + return excitons.tolist() def exciton_bs(self,energies,path,excitons=(0,),debug=False): """ diff --git a/yambopy/dbs/wfdb.py b/yambopy/dbs/wfdb.py index b5304625..93034d0b 100644 --- a/yambopy/dbs/wfdb.py +++ b/yambopy/dbs/wfdb.py @@ -1,8 +1,6 @@ -# Copyright (c) 2018, Henrique Miranda +# Copyright (c) 2025, Muralidhar Nalabothula # All rights reserved. # -# This file is part of the yambopy project -# # Author: MN from yambopy import * @@ -25,6 +23,7 @@ except ImportError as e: from scipy.spatial import KDTree from yambopy.kpoints import build_ktree, find_kpt +from yambopy.tools.function_profiler import func_profile class YamboWFDB: """ @@ -129,6 +128,7 @@ def read(self, bands_range=[], latdb=None): # K-points in BZ (crystal units) self.kBZ = self.ydb.iku_kpoints / lat_param[None, :] self.kBZ = self.kBZ @ lat_vec + self.ktree = build_ktree(self.kBZ) # G-vectors in cartesian units G_vec = ns_db1['G-VECTORS'][...].data.T @@ -215,6 +215,13 @@ def assert_k_inrange(self, ik): def assert_bnd_range(self, ib): """Assert that the band index is valid.""" assert 0 <= ib < self.nbands, "Invalid band index" + + def kptBZidx(self, kpts): + """ + return the index of the kpoint or kpoints in the wavefunction db. + The kpts must be in crystal (reduced) coordinates. + """ + return find_kpt(self.ktree, kpts) def get_spin_projections(self, ik, ib=-1, s_z=np.array([[1, 0], [0, -1]])): """ @@ -389,7 +396,8 @@ def rotate_wfc(self, ik, isym): time_rev = (isym >= len(self.ydb.sym_car) / (1 + int(np.rint(self.ydb.time_rev)))) return self.apply_symm(kvec, wfc_k, gvecs_k, time_rev, sym_mat) - + + @func_profile def apply_symm(self, kvec, wfc_k, gvecs_k, time_rev, sym_mat, frac_vec=np.array([0, 0, 0])): """ Apply symmetry to wavefunctions. @@ -427,7 +435,8 @@ def apply_symm(self, kvec, wfc_k, gvecs_k, time_rev, sym_mat, frac_vec=np.array( return [Rkvec, wfc_rot, gvec_rot] - + + @func_profile def to_real_space(self, wfc_tmp, gvec_tmp, grid=[]): """ Convert wavefunctions from G-space to real space. @@ -492,6 +501,9 @@ def expand_fullBZ(self): self.kBZ[i] = kbz self.wf_bz[i][...,:ng_t] = w_t self.g_bz[i][:ng_t,:] = g_t + # + self.ktree = build_ktree(self.kBZ) + return def get_BZ_kpt(self, ik): """ @@ -518,7 +530,8 @@ def get_BZ_wf(self, ik): # return [self.wf_bz[ik][..., :self.ngBZ[ik]], self.g_bz[ik, :self.ngBZ[ik], :]] - + + @func_profile def Dmat(self, symm_mat=None, frac_vec=None, time_rev=None): """ Computes the symmetry-adapted matrix elements < Rk | U(R) | k >. @@ -561,7 +574,7 @@ def Dmat(self, symm_mat=None, frac_vec=None, time_rev=None): frac_vec = np.zeros((len(symm_mat),3),dtype=symm_mat.dtype) time_rev = int(np.rint(self.ydb.time_rev)) # - ktree = build_ktree(self.kBZ) + ktree = self.ktree #build_ktree(self.kBZ) Dmat = [] nsym = len(symm_mat) kpt_idx = self.ydb.kpoints_indexes @@ -604,7 +617,40 @@ def Dmat(self, symm_mat=None, frac_vec=None, time_rev=None): # return Dmat + def OverlapUkkp(self, kpt_bra, kpt_ket): + """ + Compute the following matrix elements : < k_bra | e^{i(k_bra-k_ket).r} | k_ket> + in other words, it computes overlap of periodic parts of k_bra an k_ket + """ + kpt_bra = np.array(kpt_bra) + kpt_ket = np.array(kpt_ket) + + ikpt_ket = find_kpt(self.ktree,kpt_ket) + ikpt_bra = find_kpt(self.ktree,kpt_bra) + # + kpt_idx = self.ydb.kpoints_indexes + sym_idx = self.ydb.symmetry_indexes + # + ibz_ket = kpt_idx[ikpt_ket] + isym_ket = sym_idx[ikpt_ket] + # + ibz_bra = kpt_idx[ikpt_bra] + isym_bra = sym_idx[ikpt_bra] + # + ## get the wfcs: + k_rk_ket, w_rk_ket, g_rk_ket = self.rotate_wfc(ibz_ket, isym_ket) + k_rk_bra, w_rk_bra, g_rk_bra = self.rotate_wfc(ibz_bra, isym_bra) + + G0_ket = kpt_ket-k_rk_ket + G0_bra = kpt_bra-k_rk_bra + # + G0 = G0_ket-G0_bra + return wfc_inner_product(G0, w_rk_bra, g_rk_bra, np.array([0,0,0]), w_rk_ket, g_rk_ket) + +## end of class +## +@func_profile def wfc_inner_product(k_bra, wfc_bra, gvec_bra, k_ket, wfc_ket, gvec_ket, ket_Gtree=None): """ Computes the inner product between two wavefunctions in reciprocal space. diff --git a/yambopy/kpoints.py b/yambopy/kpoints.py index 2247968c..e01cdf1e 100644 --- a/yambopy/kpoints.py +++ b/yambopy/kpoints.py @@ -213,4 +213,39 @@ def find_kpt(tree, kpt_search, tol=1e-5): return idx # Return the index of the found k-point +def find_kpatch(kpts, kcentre, kdist, lat_vecs): + """ + find set of kpoints around the kcentre with in kdist + + Parameters + ---------- + kpts : kpoints in crystal coordinates (nk,3) + kcentre : kpoint centre in crystal coordinates (3) + kdist : distance around kcentre to be considered in atomic units + i.e 1/bohr. + lat_vecs: lattice vectors. ith lattice vector is ai = a[:,i] + Returns + ------- + int array + Indices of kpoints in kpts array which satify the given condition i.e + | k - kcentre + G0| <= kdist, where G0 is reciprocal lattice vector to bring to BZ + """ + # + blat = 2*np.pi*np.linalg.inv(lat_vecs) + kdiff = kpts-kcentre[None,:] + kdiff = kdiff-np.floor(kdiff) + # + tmp_arr = np.array([-3, -2, -1, 0, 1, 2, 3]) + nG0 = len(tmp_arr) + G0 = np.zeros((nG0,nG0,nG0,3)) + G0[...,0], G0[...,1], G0[...,2] = np.meshgrid(tmp_arr, tmp_arr, + tmp_arr, indexing='ij') + G0 = G0.reshape(-1,3) + kdiff = kdiff[:,None,:]-G0[None,:,:] + kdiff = kdiff.reshape(-1,3)@blat + kdiff = np.linalg.norm(kdiff,axis=-1).reshape(len(kpts),-1) + kdiff = np.min(kdiff,axis=-1) + return np.where(kdiff <= kdist)[0] + + diff --git a/yambopy/letzelphc_interface/lelph2y.py b/yambopy/letzelphc_interface/lelph2y.py index 701ecf0a..0856db92 100644 --- a/yambopy/letzelphc_interface/lelph2y.py +++ b/yambopy/letzelphc_interface/lelph2y.py @@ -50,9 +50,14 @@ def __init__(self,OBJ,code,SAVE_path,OUT_path=None): self.get_yambo_header_variables(SAVE_path) # Get el-ph data from external code - match code: - case 'lelphc': self.get_elph_variables_LELPHC(OBJ) - case _: raise NotImplementedError("Code %s not found or implemented"%code) + # + # NM : using match will enforce python 3.10. some HPC's still use <= 3.8 + # Commenting it out until it gets old enough. + # match code: + # case 'lelphc': self.get_elph_variables_LELPHC(OBJ) + # case _: raise NotImplementedError("Code %s not found or implemented"%code) + if code.strip() == 'lelphc': self.get_elph_variables_LELPHC(OBJ) + else: raise NotImplementedError("Code %s not found or implemented"%code) if not os.path.isdir(OUT_path): os.mkdir(OUT_path) @@ -137,6 +142,7 @@ def write_header(self,OUT_path): dbs.createDimension('D_%010d'%1,1) dbs.createDimension('D_%010d'%2,2) dbs.createDimension('D_%010d'%4,4) + for value in [self.natoms,self.nkpoints_ibz,len_pars,self.nqpoints_bz,self.nkpoints_bz]: if value not in [1,2,3,4]: try: dbs.createDimension('D_%010d'%value,value) diff --git a/yambopy/letzelphc_interface/lelphcdb.py b/yambopy/letzelphc_interface/lelphcdb.py index d68726d7..d53d8ece 100644 --- a/yambopy/letzelphc_interface/lelphcdb.py +++ b/yambopy/letzelphc_interface/lelphcdb.py @@ -2,12 +2,13 @@ from netCDF4 import Dataset from yambopy.tools.string import marquee from yambopy.units import ha2ev +from yambopy.kpoints import build_ktree, find_kpt class LetzElphElectronPhononDB(): """ Python class to read the electron-phonon matrix elements from LetzElPhC. - About LetzElPhC: https://github.com/yambo-code/LetzElPhC/tree/main + About LetzElPhC: https://gitlab.com/lumen-code/LetzElPhC By default it reads the full database g(q,k,m,s,b1,b2) including phonon energies. @@ -37,6 +38,7 @@ def __init__(self,filename,read_all=True,div_by_energies=True,verbose=False): try: database = Dataset(filename) except: raise FileNotFoundError("error opening %s in LetzElphElectronPhononDB"%filename) + self.filename = filename # Read DB dimensions self.nb1 = database.dimensions['initial_band'].size self.nb2 = database.dimensions['final_band_PH_abs'].size @@ -46,11 +48,34 @@ def __init__(self,filename,read_all=True,div_by_energies=True,verbose=False): self.nq = database.dimensions['nq'].size self.ns = database.dimensions['nspin'].size self.nsym = database.dimensions['nsym_ph'].size - + self.div_by_energies = div_by_energies # if true, the elph store are normalized with 1/(2*w_ph) + # + # + conv = database['convention'][...].data + if isinstance(conv, np.ndarray): + if conv.dtype.kind == 'S': # Byte strings (C chars) + conv = conv.tobytes().decode('utf-8').strip() + else: + conv = str(conv) # Fallback for non-string arrays + elif isinstance(conv, bytes): + conv = conv.decode('utf-8').strip() + else: + conv = str(conv).strip() + conv = conv.strip().replace('\0', '') + # + # + if conv == 'standard': + print("Convention used in Letzelphc : k -> k+q (standard)") + else: + print("Convention used in Letzelphc : k-q -> k (yambo)") + self.convention = conv + # # Read DB self.kpoints = database.variables['kpoints'][:] self.qpoints = database.variables['qpoints'][:] self.bands = database.variables['bands'][:] + self.ktree = build_ktree(self.kpoints) + self.qtree = build_ktree(self.qpoints) self.ph_energies = database.variables['FREQ'][:]*(ha2ev/2.) # Energy units are in Rydberg self.check_energies() @@ -128,6 +153,110 @@ def scale_g(self,dvscf): g[iq,:,inu,:,:,:] = dvscf[iq,:,inu,:,:,:]/np.sqrt(2.*ph_E) return g + + def read_iq(self,iq, bands_range=[], database=None, convention='yambo'): + """ + Reads the electron-phonon matrix elements and phonon eigenvectors for a specific q-point index. + + If the data is already loaded in memory, it returns the corresponding array slice. Otherwise, + it reads from the database without storing the data in memory. + + This function reads data for a single q-point instead of the entire dataset, which is useful + for handling large databases efficiently. + + Parameters + ---------- + iq : int + Index of the q-point. + bands_range : list, optional + Specifies the range of bands to read. The start index follows Python indexing (starting from 0), + and the end index is excluded. If not provided, it defaults to the minimum and maximum bands available. + database : Dataset, optional + If provided, the function will use this open dataset instead of opening the file again. + convention : str, optional + Defines the convention used for electron-phonon matrix elements. + - 'yambo': Outputs \. + - Any other value: Outputs \. + + Returns + ------- + tuple + A tuple containing: + - ph_eigenvectors : ndarray + The phonon eigenvectors. + - ph_elph_me : ndarray + The electron-phonon matrix elements with the specified convention. + ( nk, nm, nspin, initial bnd, final bnd) + """ + # + if len(bands_range) == 0: + bands_range = [min(self.bands)-1,max(self.bands)] + min_bnd = min(bands_range) + max_bnd = max(bands_range) + nbnds = max_bnd - min_bnd + assert (min_bnd >= min(self.bands)-1) + assert (max_bnd <= max(self.bands)) + start_bnd_idx = 1+min_bnd - min(self.bands) + end_bnd = start_bnd_idx + nbnds + + # self.ph_eigenvectors , self.gkkp + if hasattr(self, 'ph_eigenvectors'): + ph_eigs = self.ph_eigenvectors[iq] + eph_mat = self.gkkp[iq, :, :, :, start_bnd_idx:end_bnd, start_bnd_idx:end_bnd ] + else : + ## else we load from the file + close_file = False + if not database : + close_file = True + database = Dataset(self.filename,'r') + eph_mat = database['elph_mat'][iq, :, :, :, start_bnd_idx:end_bnd, start_bnd_idx:end_bnd, :].data + # ( nk, nm, nspin, initial bnd, final bnd) + ph_eigs = database['POLARIZATION_VECTORS'][iq,...].data + eph_mat = eph_mat[...,0] + 1j*eph_mat[...,1] + ph_eigs = ph_eigs[...,0] + 1j*ph_eigs[...,1] + ## normalize with ph_freq + if self.div_by_energies: + ph_freq_iq = np.sqrt(2.0*np.abs(self.ph_energies[iq])/(ha2ev/2.)) + if iq >0: + ph_freq_iq = 1.0/ph_freq_iq + eph_mat *= ph_freq_iq[None,:,None,None,None] + else: + eph_mat[:,:3] = 0.0 + ph_freq_iq = 1.0/ph_freq_iq[3:] + eph_mat[:,3:] *= ph_freq_iq[None,:,None,None,None] + + if close_file :database.close() + ## output elph matrix elements unit (Ry if div_by_energies else Ry^1.5) + # ( nk, nm, nspin, initial bnd, final bnd) + return [ph_eigs, self.change_convention(self.qpoints[iq],eph_mat, convention)] + + def change_convention(self, qpt, elph_iq, convention='yambo'): + """ + Adjusts the convention of the electron-phonon matrix elements. + + Parameters + ---------- + qpt : ndarray + The q-point in crystal coordinates. + elph_iq : ndarray + The electron-phonon matrix elements. + convention : str, optional + Defines the output format: + - 'yambo': Outputs \. + - Any other value: Outputs \. + + Returns + ------- + ndarray + The electron-phonon matrix elements in the desired convention. The returned array is a view, not a copy. + """ + if convention.strip() != 'yambo': convention = 'standard' + if self.convention == convention: return elph_iq + if convention == 'standard': factor = 1.0 + else :factor = -1.0 + idx_q = find_kpt(self.ktree, factor*qpt[None, :] + self.kpoints) + return elph_iq[idx_q, ...] + def __str__(self): lines = []; app = lines.append diff --git a/yambopy/symmetries/__init__.py b/yambopy/symmetries/__init__.py new file mode 100644 index 00000000..22b7d0c7 --- /dev/null +++ b/yambopy/symmetries/__init__.py @@ -0,0 +1,5 @@ +# Copyright (C) 2025, YamboPy project +# All rights reserved. +# +# This file is part of yambopy +# diff --git a/yambopy/symmetries/crystal_symmetries.py b/yambopy/symmetries/crystal_symmetries.py new file mode 100644 index 00000000..e7e6bb69 --- /dev/null +++ b/yambopy/symmetries/crystal_symmetries.py @@ -0,0 +1,84 @@ +import numpy as np +import spgrep +import spglib +from yambopy.dbs.latticedb import YamboLatticeDB + +class Crystal_Symmetries: + def __init__(self, latdb, magnetic_moments=None, tol=1e-5): + """ + Class to handle crystal symmetries + # + Given a lattice database, finds all symmetries of a crystal. + Works irrespective of whether the SAVE has symmetries or not. + # + Parameters: + ----------- + latdb : object + Lattice database object with required attributes: + magnetic_moments : array (natom) or (natom,3) or None, optional + magnetic_moments of each atom (default: None) + In collinear case, natom is sufficient, but in non-collinear + case provide (mx, my, mz) for each atom. + tol : float, optional + Symmetry tolerance (default: 1e-5) + Members: + -------- + dataset : dict + The full symmetry dataset from spglib + spacegroup_type : dict + Space group type information + rotations : ndarray + Rotation matrices in Cartesian coordinates (n_sym, 3, 3) + translations : ndarray + Translation vectors in Cartesian coordinates (n_sym, 3) + international_symbol : str + International symbol of the space group + hall_symbol : str + Hall symbol of the space group + wyckoffs : list + Wyckoff letters for each atom + pointgroup : str + Point group symbol + spacegroup_number : int + International space group number + """ + # Get lattice and atomic information + lattice = latdb.lat + numbers = np.rint(latdb.atomic_numbers).astype(int) + lat_inv = np.linalg.inv(lattice) + positions = latdb.car_atomic_positions @ lat_inv + positions = positions - np.floor(positions) + # + self.magnetic = not (magnetic_moments is None) + # + # Get symmetry dataset + if self.magnetic : + cell = (lattice, positions, numbers, magnetic_moments) + print("Currently Magnetic systems not supported.") + exit() + else : + cell = (lattice, positions, numbers) + self.dataset = spglib.get_symmetry_dataset(cell, symprec=tol) + self.spacegroup_type = spglib.get_spacegroup_type(self.dataset.hall_number) + # + # Convert rotation and fractional translations to Cartesian units + self.rotations = lattice.T[None,:,:] @ self.dataset.rotations @ lat_inv.T[None,:,:] + self.translations = self.dataset.translations @ lattice + # + self.international_symbol = self.dataset.international + self.hall_symbol = self.dataset.hall + self.wyckoffs = self.dataset.wyckoffs + self.pointgroup = self.dataset.pointgroup + self.pointgroup_schoenflies = self.spacegroup_type.pointgroup_schoenflies + self.spacegroup_number = self.dataset.number + + +## Test +if __name__ == "__main__": + import os + np.set_printoptions(suppress=True) + latdb = YamboLatticeDB.from_db_file(os.path.join('.', 'SAVE')) + symm = Crystal_Symmetries(latdb,tol=1e-4) + print(symm.wyckoffs) + print(symm.pointgroup_schoenflies) + diff --git a/yambopy/tools/degeneracy_finder.py b/yambopy/tools/degeneracy_finder.py new file mode 100644 index 00000000..d4e8d48d --- /dev/null +++ b/yambopy/tools/degeneracy_finder.py @@ -0,0 +1,53 @@ +import numpy as np + +def find_degeneracy_evs(eigenvalues, atol=1e-3, rtol=1e-3): + """ + Identify sets of degenerate eigenstates based on eigenvalues. + + Parameters + ---------- + eigenvalues : array-like + A list or array of eigenvalues + atol : float, optional + Absolute tolerance for degeneracy comparison. Default is 1e-3. + rtol : float, optional + Relative tolerance for degeneracy comparison. Default is 1e-3. + + Returns + ------- + list of lists + A list where each sublist contains the indices of degenerate eigenstates. + + Raises + ------ + ValueError + If `eigenvalues` is empty or not a valid array. + """ + eigenvalues = np.asarray(eigenvalues) + + if eigenvalues.size == 0: + raise ValueError("Input eigenvalues must not be empty.") + if atol < 0 or rtol < 0: + raise ValueError("Tolerances `atol` and `rtol` must be non-negative.") + + # Sort eigenvalues and get sorted indices + idx_sorted = np.argsort(eigenvalues) + eigenvalues_sorted = eigenvalues[idx_sorted] + + # Compute differences between consecutive eigenvalues + diffs = np.diff(eigenvalues_sorted) + tolerance = atol + rtol * np.abs(eigenvalues_sorted[:-1]) + + # Identify where the differences exceed the tolerance + split_indices = np.where(diffs > tolerance)[0] + + # Group indices of degenerate states (in sorted order) + degen_sets_sorted = np.split(idx_sorted, split_indices + 1) + + return degen_sets_sorted + + +if __name__ == '__main__': + eigenvalues = [3.0001, 2.99976, 1.99999, 1.0, 1.001, 1.002, 2.0, 2.001, 3.0] + degen_sets = find_degeneracy_evs(eigenvalues, atol=1e-3, rtol=1e-3) + print(degen_sets) diff --git a/yambopy/tools/function_profiler.py b/yambopy/tools/function_profiler.py new file mode 100644 index 00000000..a43d2038 --- /dev/null +++ b/yambopy/tools/function_profiler.py @@ -0,0 +1,40 @@ +import time +from functools import wraps + +# decarator to measure number of calls and wall time in seconds of function +""" +@func_profile +def abc(xxx): + np.log(np.sin(xxx)+1j+np.cos(xxx)+xxx**2) + +x = np.linspace(0,10,1000000) + +for i in range(10): + abc(x) + +# number of calls to abc +print(abc.call_count) + +# total time spent in abc +print(abc.total_time) + +""" +def func_profile(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.perf_counter() + + # Increment call count (initialize if first call) + if not hasattr(wrapper, "call_count"): + wrapper.call_count = 0 + wrapper.call_count += 1 + + result = func(*args, **kwargs) # Call original function + + # Accumulate time (initialize if first call) + if not hasattr(wrapper, "total_time"): + wrapper.total_time = 0.0 + wrapper.total_time += time.perf_counter() - start_time + + return result + return wrapper