@@ -14,139 +14,226 @@ namespace ABFs_Construct
1414{
1515namespace PCA
1616{
17- void tensor_dsyev (const char jobz, const char uplo, RI::Tensor<double > & a, double *const w, int & info)
17+ template <>
18+ void tensor_syev<double >(char jobz, char uplo, RI::Tensor<double >& a, double * w, int & info)
1819 {
19- // reference: dsyev in lapack_connector.h (for ModuleBase::matrix)
2020 assert (a.shape .size () == 2 );
2121 assert (a.shape [0 ] == a.shape [1 ]);
22- const int nr = a.shape [0 ];
23- const int nc = a.shape [1 ];
2422
25- double work_tmp=0.0 ;
23+ const int n = a.shape [0 ];
24+ const int lda = a.shape [1 ];
25+
26+ double work_query = 0.0 ;
2627 constexpr int minus_one = -1 ;
27- dsyev_ (&jobz, &uplo, &nr, a.ptr (), &nc, w, &work_tmp, &minus_one, &info); // get best lwork
2828
29- const int lwork = work_tmp;
29+ dsyev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, &work_query, &minus_one, &info);
30+
31+ const int lwork = static_cast <int >(work_query);
3032 std::vector<double > work (std::max (1 , lwork));
31- dsyev_ (&jobz, &uplo, &nr, a.ptr (), &nc, w, work.data (), &lwork, &info);
33+
34+ dsyev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, work.data (), &lwork, &info);
35+ }
36+
37+ template <>
38+ void tensor_syev<float >(char jobz, char uplo, RI::Tensor<float >& a, float * w, int & info)
39+ {
40+ assert (a.shape .size () == 2 );
41+ assert (a.shape [0 ] == a.shape [1 ]);
42+
43+ const int n = a.shape [0 ];
44+ const int lda = a.shape [1 ];
45+
46+ float work_query = 0 .0f ;
47+ constexpr int minus_one = -1 ;
48+
49+ ssyev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, &work_query, &minus_one, &info);
50+
51+ const int lwork = static_cast <int >(work_query);
52+ std::vector<float > work (std::max (1 , lwork));
53+
54+ ssyev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, work.data (), &lwork, &info);
3255 }
3356
34- RI::Tensor<double > get_sub_matrix (
35- const RI::Tensor<double > & m, // size: (lcaos, lcaos, abfs)
36- const std::size_t & T,
37- const std::size_t & L,
38- const ModuleBase::Element_Basis_Index::Range & range,
39- const ModuleBase::Element_Basis_Index::IndexLNM & index )
57+ template <>
58+ void tensor_syev<std::complex <double >>(char jobz, char uplo, RI::Tensor<std::complex <double >>& a, double * w, int & info)
59+ {
60+ assert (a.shape .size () == 2 );
61+ assert (a.shape [0 ] == a.shape [1 ]);
62+
63+ const int n = a.shape [0 ];
64+ const int lda = a.shape [1 ];
65+
66+ std::complex <double > work_query;
67+ constexpr int minus_one = -1 ;
68+
69+ zheev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, &work_query, &minus_one, nullptr , &info);
70+
71+ const int lwork = static_cast <int >(work_query.real ());
72+ std::vector<std::complex <double >> work (std::max (1 , lwork));
73+ std::vector<double > rwork (std::max (1 , 3 * n - 2 ));
74+
75+ zheev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, work.data (), &lwork, rwork.data (), &info);
76+ }
77+
78+ template <>
79+ void tensor_syev<std::complex <float >>(char jobz, char uplo, RI::Tensor<std::complex <float >>& a, float * w, int & info)
80+ {
81+ assert (a.shape .size () == 2 );
82+ assert (a.shape [0 ] == a.shape [1 ]);
83+
84+ const int n = a.shape [0 ];
85+ const int lda = a.shape [1 ];
86+
87+ std::complex <float > work_query;
88+ constexpr int minus_one = -1 ;
89+
90+ cheev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, &work_query, &minus_one, nullptr , &info);
91+
92+ const int lwork = static_cast <int >(work_query.real ());
93+ std::vector<std::complex <float >> work (std::max (1 , lwork));
94+ std::vector<float > rwork (std::max (1 , 3 * n - 2 ));
95+
96+ cheev_ (&jobz, &uplo, &n, a.ptr (), &lda, w, work.data (), &lwork, rwork.data (), &info);
97+ }
98+ // void tensor_dsyev(const char jobz, const char uplo, RI::Tensor<double> & a, double*const w, int & info)
99+ // {
100+ // // reference: dsyev in lapack_connector.h (for ModuleBase::matrix)
101+ // assert(a.shape.size() == 2);
102+ // assert(a.shape[0] == a.shape[1]);
103+ // const int nr = a.shape[0];
104+ // const int nc = a.shape[1];
105+
106+ // double work_tmp=0.0;
107+ // constexpr int minus_one = -1;
108+ // dsyev_(&jobz, &uplo, &nr, a.ptr(), &nc, w, &work_tmp, &minus_one, &info); // get best lwork
109+
110+ // const int lwork = work_tmp;
111+ // std::vector<double> work(std::max(1, lwork));
112+ // dsyev_(&jobz, &uplo, &nr, a.ptr(), &nc, w, work.data(), &lwork, &info);
113+ // }
114+
115+ RI::Tensor<double > get_sub_matrix (const RI::Tensor<double >& m, // size: (lcaos, lcaos, abfs)
116+ const std::size_t & T,
117+ const std::size_t & L,
118+ const ModuleBase::Element_Basis_Index::Range& range,
119+ const ModuleBase::Element_Basis_Index::IndexLNM& index)
40120 {
41121 ModuleBase::TITLE (" ABFs_Construct::PCA::get_sub_matrix" );
42122 assert (m.shape .size () == 3 );
43- RI::Tensor<double > m_sub ({ m.shape [0 ], m.shape [1 ], range[T][L].N });
44- for (std::size_t ir=0 ; ir!=m.shape [0 ]; ++ir) {
45- for (std::size_t jr=0 ; jr!=m.shape [1 ]; ++jr) {
46- for (std::size_t N=0 ; N!=range[T][L].N ; ++N) {
123+ RI::Tensor<double > m_sub ({m.shape [0 ], m.shape [1 ], range[T][L].N });
124+ for (std::size_t ir = 0 ; ir != m.shape [0 ]; ++ir)
125+ {
126+ for (std::size_t jr = 0 ; jr != m.shape [1 ]; ++jr)
127+ {
128+ for (std::size_t N = 0 ; N != range[T][L].N ; ++N)
129+ {
47130 m_sub (ir, jr, N) = m (ir, jr, index[T][L][N][0 ]);
48131}
49132}
50133}
51- m_sub = m_sub.reshape ({ m.shape [0 ] * m.shape [1 ], range[T][L].N });
134+ m_sub = m_sub.reshape ({m.shape [0 ] * m.shape [1 ], range[T][L].N });
52135 return m_sub;
53136 }
54137
55- RI::Tensor<double > get_column_mean0_matrix ( const RI::Tensor<double > & m )
138+ RI::Tensor<double > get_column_mean0_matrix (const RI::Tensor<double >& m)
56139 {
57140 ModuleBase::TITLE (" ABFs_Construct::PCA::get_column_mean0_matrix" );
58- RI::Tensor<double > m_new ( m.shape );
59- for ( std::size_t ic= 0 ; ic!= m.shape [1 ]; ++ic )
141+ RI::Tensor<double > m_new (m.shape );
142+ for ( std::size_t ic = 0 ; ic != m.shape [1 ]; ++ic)
60143 {
61- double sum=0 ;
62- for ( std::size_t ir=0 ; ir!=m.shape [0 ]; ++ir ) {
63- sum += m (ir,ic);
144+ double sum = 0 ;
145+ for (std::size_t ir = 0 ; ir != m.shape [0 ]; ++ir)
146+ {
147+ sum += m (ir, ic);
64148}
65- const double mean = sum/m.shape [0 ];
66- for ( std::size_t ir=0 ; ir!=m.shape [0 ]; ++ir ) {
67- m_new (ir,ic) = m (ir,ic) - mean;
149+ const double mean = sum / m.shape [0 ];
150+ for (std::size_t ir = 0 ; ir != m.shape [0 ]; ++ir)
151+ {
152+ m_new (ir, ic) = m (ir, ic) - mean;
68153}
69154 }
70155 return m_new;
71156 }
72157
73158 std::vector<std::vector<std::pair<std::vector<double >, RI::Tensor<double >>>> cal_PCA (
74- const UnitCell & ucell,
159+ const UnitCell& ucell,
75160 const LCAO_Orbitals& orb,
76- const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> & lcaos,
77- const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> & abfs,
78- const double kmesh_times )
161+ const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>>& lcaos,
162+ const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>>& abfs,
163+ const double kmesh_times)
79164 {
80165 ModuleBase::TITLE (" ABFs_Construct::PCA::cal_PCA" );
81166
82- const ModuleBase::Element_Basis_Index::Range
83- range_lcaos = ModuleBase::Element_Basis_Index::construct_range ( lcaos );
84- const ModuleBase::Element_Basis_Index::IndexLNM
85- index_lcaos = ModuleBase::Element_Basis_Index::construct_index ( range_lcaos );
167+ const ModuleBase::Element_Basis_Index::Range range_lcaos = ModuleBase::Element_Basis_Index::construct_range (lcaos);
168+ const ModuleBase::Element_Basis_Index::IndexLNM index_lcaos
169+ = ModuleBase::Element_Basis_Index::construct_index (range_lcaos);
86170
87- const ModuleBase::Element_Basis_Index::Range
88- range_abfs = ModuleBase::Element_Basis_Index::construct_range ( abfs );
89- const ModuleBase::Element_Basis_Index::IndexLNM
90- index_abfs = ModuleBase::Element_Basis_Index::construct_index ( range_abfs );
171+ const ModuleBase::Element_Basis_Index::Range range_abfs = ModuleBase::Element_Basis_Index::construct_range (abfs);
172+ const ModuleBase::Element_Basis_Index::IndexLNM index_abfs
173+ = ModuleBase::Element_Basis_Index::construct_index (range_abfs);
91174
92175 const int Lmax_bak = GlobalC::exx_info.info_ri .abfs_Lmax ;
93176 GlobalC::exx_info.info_ri .abfs_Lmax = std::numeric_limits<int >::min ();
94- for ( std::size_t T=0 ; T!=abfs.size (); ++T ) {
95- GlobalC::exx_info.info_ri .abfs_Lmax = std::max ( GlobalC::exx_info.info_ri .abfs_Lmax , static_cast <int >(abfs[T].size ())-1 );
177+ for (std::size_t T = 0 ; T != abfs.size (); ++T)
178+ {
179+ GlobalC::exx_info.info_ri .abfs_Lmax
180+ = std::max (GlobalC::exx_info.info_ri .abfs_Lmax , static_cast <int >(abfs[T].size ()) - 1 );
96181}
97182
98183 Matrix_Orbs21 m_abfslcaos_lcaos;
99184 ORB_gaunt_table MGT;
100185 int Lmax;
101- m_abfslcaos_lcaos.init ( 1 , ucell , orb, kmesh_times, orb.get_Rmax (), Lmax );
186+ m_abfslcaos_lcaos.init (1 , ucell, orb, kmesh_times, orb.get_Rmax (), Lmax);
102187 MGT.init_Gaunt_CH (Lmax);
103188 MGT.init_Gaunt (Lmax);
104- m_abfslcaos_lcaos.init_radial ( abfs, lcaos, lcaos, MGT );
189+ m_abfslcaos_lcaos.init_radial (abfs, lcaos, lcaos, MGT);
105190
106- std::map<std::size_t ,std::map<std::size_t ,std::set<double >>> delta_R;
107- for ( std::size_t it=0 ; it!=abfs.size (); ++it ) {
191+ std::map<std::size_t , std::map<std::size_t , std::set<double >>> delta_R;
192+ for (std::size_t it = 0 ; it != abfs.size (); ++it)
193+ {
108194 delta_R[it][it] = {0.0 };
109195}
110196 m_abfslcaos_lcaos.init_radial_table (delta_R);
111197
112198 GlobalC::exx_info.info_ri .abfs_Lmax = Lmax_bak;
113199
114- std::vector<std::vector<std::pair<std::vector<double >,RI::Tensor<double >>>> eig (abfs.size ());
115- for ( std::size_t T= 0 ; T!= abfs.size (); ++T )
200+ std::vector<std::vector<std::pair<std::vector<double >, RI::Tensor<double >>>> eig (abfs.size ());
201+ for ( std::size_t T = 0 ; T != abfs.size (); ++T)
116202 {
117- const RI::Tensor<double > A = m_abfslcaos_lcaos.cal_overlap_matrix <double >(
118- T,
203+ const RI::Tensor<double > A = m_abfslcaos_lcaos.cal_overlap_matrix <double >(T,
119204 T,
120- ModuleBase::Vector3<double >{0 ,0 , 0 },
121- ModuleBase::Vector3<double >{0 ,0 , 0 },
205+ ModuleBase::Vector3<double >{0 , 0 , 0 },
206+ ModuleBase::Vector3<double >{0 , 0 , 0 },
122207 index_abfs,
123208 index_lcaos,
124209 index_lcaos,
125210 Matrix_Orbs21::Matrix_Order::A2BA1);
126211
127212 eig[T].resize (abfs[T].size ());
128- for ( std::size_t L= 0 ; L!= abfs[T].size (); ++L )
213+ for ( std::size_t L = 0 ; L != abfs[T].size (); ++L)
129214 {
130- const RI::Tensor<double > A_sub = get_sub_matrix ( A, T, L, range_abfs, index_abfs );
215+ const RI::Tensor<double > A_sub = get_sub_matrix (A, T, L, range_abfs, index_abfs);
131216 RI::Tensor<double > mm = A_sub.transpose () * A_sub;
132217 std::vector<double > eig_value (mm.shape [0 ]);
133218
134- int info= 1 ;
219+ int info = 1 ;
135220
136- tensor_dsyev (' V' , ' L' , mm, eig_value.data (), info);
221+ tensor_syev< double > (' V' , ' L' , mm, eig_value.data (), info);
137222
138- if ( info )
223+ if ( info)
139224 {
140225 std::cout << std::endl << " info_dsyev = " << info << std::endl;
141- auto tensor_print = [](RI::Tensor<double >& m, std::ostream& os, const double threshold)
142- {
226+ auto tensor_print = [](RI::Tensor<double >& m, std::ostream& os, const double threshold) {
143227 for (int ir = 0 ; ir != m.shape [0 ]; ++ir)
144228 {
145229 for (int ic = 0 ; ic != m.shape [1 ]; ++ic)
146230 {
147- if (std::abs (m (ir, ic)) > threshold) {
231+ if (std::abs (m (ir, ic)) > threshold)
232+ {
148233 os << m (ir, ic) << " \t " ;
149- } else {
234+ }
235+ else
236+ {
150237 os << 0 << " \t " ;
151238}
152239 }
@@ -155,15 +242,15 @@ namespace PCA
155242 os << std::endl;
156243 };
157244 tensor_print (mm, GlobalV::ofs_warning, 0.0 );
158- std::cout<< " in file " << __FILE__<< " line " << __LINE__<< std::endl;
245+ std::cout << " in file " << __FILE__ << " line " << __LINE__ << std::endl;
159246 ModuleBase::QUIT ();
160247 }
161- eig[T][L] = std::make_pair ( eig_value, mm );
248+ eig[T][L] = std::make_pair (eig_value, mm);
162249 }
163250 }
164251
165252 return eig;
166253 }
167254
168- } // namespace ABFs_Construct:: PCA
169- } // namespace ABFs_Construct
255+ } // namespace PCA
256+ } // namespace ABFs_Construct
0 commit comments