From fcc2c98ccd26a8904adf1cb30ecd1695ade0c713 Mon Sep 17 00:00:00 2001 From: linpz Date: Sun, 7 Jun 2026 06:16:49 +0800 Subject: [PATCH 1/2] Feature: add row-major in DensityMatrix::_DMK --- source/source_estate/cal_dm.h | 6 +-- source/source_estate/module_dm/cal_dm_psi.cpp | 21 ++++---- source/source_estate/module_dm/cal_dm_psi.h | 8 ++- .../module_dm/density_matrix.cpp | 52 ++++++++++++------- .../source_estate/module_dm/density_matrix.h | 16 +++++- .../module_dm/density_matrix_io.cpp | 34 ++++++------ .../module_dm/test/test_cal_dmk_psi.cpp | 5 +- .../module_dm/test/test_dm_constructor.cpp | 40 ++++++++++++-- .../module_current/td_current_io.cpp | 35 +++++++------ .../module_lr/dm_trans/dmr_complex.cpp | 10 ++-- 10 files changed, 143 insertions(+), 84 deletions(-) diff --git a/source/source_estate/cal_dm.h b/source/source_estate/cal_dm.h index aede5980e0a..ad38e369d02 100644 --- a/source/source_estate/cal_dm.h +++ b/source/source_estate/cal_dm.h @@ -46,8 +46,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) { continue; -} + if (ib_global >= wg.nc) { continue; } const double wg_local = wg(ik, ib_global); double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); @@ -105,8 +104,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!"); } } - if (ib_global >= wg.nc) { continue; -} + if (ib_global >= wg.nc) { continue; } const double wg_local = wg(ik, ib_global); std::complex* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0)); BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1); diff --git a/source/source_estate/module_dm/cal_dm_psi.cpp b/source/source_estate/module_dm/cal_dm_psi.cpp index 0445a10a16b..467feddf6ee 100644 --- a/source/source_estate/module_dm/cal_dm_psi.cpp +++ b/source/source_estate/module_dm/cal_dm_psi.cpp @@ -61,9 +61,10 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, // C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2) #ifdef __MPI + assert(!DM.is_DMK_row_major()); psiMulPsiMpi(wg_wfc, wfc, dmk_pointer, ParaV->desc_wfc, ParaV->desc); #else - psiMulPsi(wg_wfc, wfc, dmk_pointer); + psiMulPsi(wg_wfc, wfc, dmk_pointer, DM.is_DMK_row_major()); #endif } ModuleBase::timer::end("elecstate", "cal_dm_psi"); @@ -135,14 +136,15 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV, if (PARAM.inp.ks_solver == "cg_in_lcao") { - psiMulPsi(wg_wfc, wfc, dmk_pointer); + psiMulPsi(wg_wfc, wfc, dmk_pointer, DM.is_DMK_row_major()); } else { + assert(!DM.is_DMK_row_major()); psiMulPsiMpi(wg_wfc, wfc, dmk_pointer, ParaV->desc_wfc, ParaV->desc); } #else - psiMulPsi(wg_wfc, wfc, dmk_pointer); + psiMulPsi(wg_wfc, wfc, dmk_pointer, DM.is_DMK_row_major()); #endif } @@ -222,7 +224,7 @@ void psiMulPsiMpi(const psi::Psi>& psi1, #endif -void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, double* dm_out) +void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, double* dm_out, const bool is_DMK_row_major) { const double one_float = 1.0, zero_float = 0.0; const int one_int = 1; @@ -235,9 +237,9 @@ void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, doubl nlocal, nbands, one_float, - psi1.get_pointer(), + is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(), nlocal, - psi2.get_pointer(), + is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(), nlocal, zero_float, dm_out, @@ -246,7 +248,8 @@ void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, doubl void psiMulPsi(const psi::Psi>& psi1, const psi::Psi>& psi2, - std::complex* dm_out) + std::complex* dm_out, + const bool is_DMK_row_major) { const int one_int = 1; const char N_char = 'N', T_char = 'T'; @@ -260,9 +263,9 @@ void psiMulPsi(const psi::Psi>& psi1, nlocal, nbands, one_complex, - psi1.get_pointer(), + is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(), nlocal, - psi2.get_pointer(), + is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(), nlocal, zero_complex, dm_out, diff --git a/source/source_estate/module_dm/cal_dm_psi.h b/source/source_estate/module_dm/cal_dm_psi.h index 09ab3d3974a..56fc2a0afbc 100644 --- a/source/source_estate/module_dm/cal_dm_psi.h +++ b/source/source_estate/module_dm/cal_dm_psi.h @@ -28,11 +28,15 @@ namespace elecstate const int* desc_dm); // for Gamma-Only case without MPI - void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, double* dm_out); + void psiMulPsi(const psi::Psi& psi1, + const psi::Psi& psi2, + double* dm_out, + const bool is_DMK_row_major); // for multi-k case without MPI void psiMulPsi(const psi::Psi>& psi1, const psi::Psi>& psi2, - std::complex* dm_out); + std::complex* dm_out, + const bool is_DMK_row_major); }; #endif diff --git a/source/source_estate/module_dm/density_matrix.cpp b/source/source_estate/module_dm/density_matrix.cpp index b0734e6d065..02f69a39482 100644 --- a/source/source_estate/module_dm/density_matrix.cpp +++ b/source/source_estate/module_dm/density_matrix.cpp @@ -28,8 +28,17 @@ DensityMatrix::~DensityMatrix() } template -DensityMatrix::DensityMatrix(const Parallel_Orbitals* paraV_in, const int nspin, const std::vector>& kvec_d, const int nk) - : _paraV(paraV_in), _nspin(nspin), _kvec_d(kvec_d), _nk((nk > 0 && nk <= _kvec_d.size()) ? nk : _kvec_d.size()) +DensityMatrix::DensityMatrix( + const Parallel_Orbitals* paraV_in, + const int nspin, + const std::vector>& kvec_d, + const int nk, + const bool is_DMK_row_major) + : _paraV(paraV_in), + _nspin(nspin), + _kvec_d(kvec_d), + _nk((nk > 0 && nk <= _kvec_d.size()) ? nk : _kvec_d.size()), + _is_DMK_row_major(is_DMK_row_major) { ModuleBase::TITLE("DensityMatrix", "resize_DMK"); const int nks = _nk * _nspin; @@ -42,7 +51,15 @@ DensityMatrix::DensityMatrix(const Parallel_Orbitals* paraV_in, const in } template -DensityMatrix::DensityMatrix(const Parallel_Orbitals* paraV_in, const int nspin) :_paraV(paraV_in), _nspin(nspin), _kvec_d({ ModuleBase::Vector3(0,0,0) }), _nk(1) +DensityMatrix::DensityMatrix( + const Parallel_Orbitals* paraV_in, + const int nspin, + const bool is_DMK_row_major) + : _paraV(paraV_in), + _nspin(nspin), + _kvec_d({ ModuleBase::Vector3(0,0,0) }), + _nk(1), + _is_DMK_row_major(is_DMK_row_major) { ModuleBase::TITLE("DensityMatrix", "resize_gamma"); this->_DMK.resize(_nspin); @@ -68,7 +85,6 @@ void DensityMatrix_Tools::cal_DMR( assert(dmR_out.size()==dm._nspin && "DMR has not been initialized!"); ModuleBase::timer::start("DensityMatrix", "cal_DMR"); - const int ld_hk = dm._paraV->nrow; for (int is = 1; is <= dm._nspin; ++is) { const int ik_begin = dm._nk * (is - 1); // jump dm._nk for spin_down if nspin==2 @@ -125,13 +141,13 @@ void DensityMatrix_Tools::cal_DMR( for(int ik = 0; ik < dm._nk; ++ik) { if(ik_in >= 0 && ik_in != ik) { continue; } - // copy column-major DMK to row-major DMK_mat_trans (for the purpose of computational efficiency) + // copy DMK to row-major DMK_mat_trans (for the purpose of computational efficiency) const TK*const DMK_mat_ptr = dm._DMK[ik + ik_begin].data() - + col_ap * dm._paraV->nrow + row_ap; + + dm.dmk_index(row_ap, col_ap); for(int icol = 0; icol < col_size; ++icol) { for(int irow = 0; irow < row_size; ++irow) { - DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow]; + DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[dm.dmk_index(irow, icol)]; }} // if nspin != 4, fill DMR @@ -223,7 +239,6 @@ void DensityMatrix_Tools::cal_DMR_td( assert(dmR_out.size()==dm._nspin && "DMR has not been initialized!"); ModuleBase::timer::start("DensityMatrix", "cal_DMR_td"); - const int ld_hk = dm._paraV->nrow; for (int is = 1; is <= dm._nspin; ++is) { const int ik_begin = dm._nk * (is - 1); // jump dm._nk for spin_down if nspin==2 @@ -283,13 +298,13 @@ void DensityMatrix_Tools::cal_DMR_td( for(int ik = 0; ik < dm._nk; ++ik) { if(ik_in >= 0 && ik_in != ik) { continue; } - // copy column-major DMK to row-major DMK_mat_trans (for the purpose of computational efficiency) + // copy DMK to row-major DMK_mat_trans (for the purpose of computational efficiency) const TK*const DMK_mat_ptr = dm._DMK[ik + ik_begin].data() - + col_ap * dm._paraV->nrow + row_ap; + + dm.dmk_index(row_ap, col_ap); for(int icol = 0; icol < col_size; ++icol) { for(int irow = 0; irow < row_size; ++irow) { - DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow]; + DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[dm.dmk_index(irow, icol)]; }} // if nspin != 4, fill DMR @@ -381,7 +396,6 @@ void DensityMatrix_Tools::cal_DMR_full( ModuleBase::TITLE("DensityMatrix", "cal_DMR_full"); ModuleBase::timer::start("DensityMatrix", "cal_DMR_full"); - const int ld_hk = dm._paraV->nrow; hamilt::HContainer* target_DMR = dmR_out; // set zero since this function is called in every scf step target_DMR->set_zero(); @@ -434,13 +448,13 @@ void DensityMatrix_Tools::cal_DMR_full( for(int ik = 0; ik < dm._nk; ++ik) { if(ik_in >= 0 && ik_in != ik) { continue; } - // copy column-major DMK to row-major DMK_mat_trans (for the purpose of computational efficiency) + // copy DMK to row-major DMK_mat_trans (for the purpose of computational efficiency) const TK*const DMK_mat_ptr = dm._DMK[ik].data() - + col_ap * dm._paraV->nrow + row_ap; + + dm.dmk_index(row_ap, col_ap); for(int icol = 0; icol < col_size; ++icol) { for(int irow = 0; irow < row_size; ++irow) { - DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow]; + DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[dm.dmk_index(irow, icol)]; }} for(int iR = 0; iR < R_size; ++iR) @@ -487,7 +501,6 @@ void DensityMatrix::cal_DMR(const int ik_in) assert(this->_DMR.size()==this->_nspin && "DMR has not been initialized!"); ModuleBase::timer::start("DensityMatrix", "cal_DMR"); - const int ld_hk = this->_paraV->nrow; for (int is = 1; is <= this->_nspin; ++is) { const int ik_begin = this->_nk * (is - 1); // jump this->_nk for spin_down if nspin==2 @@ -522,10 +535,9 @@ void DensityMatrix::cal_DMR(const int ik_in) #endif // k index constexpr TK kphase = 1; - // transpose DMK col=>row const TK* DMK_mat_ptr = this->_DMK[0 + ik_begin].data() - + col_ap * this->_paraV->nrow + row_ap; + + this->dmk_index(row_ap, col_ap); // set DMR element TR* target_DMR_ptr = target_mat->get_pointer(); for (int mu = 0; mu < row_size; ++mu) @@ -533,10 +545,10 @@ void DensityMatrix::cal_DMR(const int ik_in) BlasConnector::axpy(col_size, kphase, DMK_mat_ptr, - ld_hk, + this->_is_DMK_row_major ? 1 : this->_paraV->nrow, target_DMR_ptr, 1); - DMK_mat_ptr += 1; + DMK_mat_ptr += this->_is_DMK_row_major ? this->_paraV->ncol : 1; target_DMR_ptr += col_size; } } diff --git a/source/source_estate/module_dm/density_matrix.h b/source/source_estate/module_dm/density_matrix.h index a8b0e1c4ecb..7e517856a61 100644 --- a/source/source_estate/module_dm/density_matrix.h +++ b/source/source_estate/module_dm/density_matrix.h @@ -88,7 +88,8 @@ class DensityMatrix DensityMatrix(const Parallel_Orbitals* _paraV, const int nspin, const std::vector>& kvec_d, - const int nk); + const int nk, + const bool is_DMK_row_major = false); /** * @brief Constructor of class DensityMatrix for gamma-only calculation, where kvector is not required @@ -96,7 +97,7 @@ class DensityMatrix * @param nspin number of spin of the density matrix, set by user according to global nspin * (usually {nspin_global -> nspin_dm} = {1->1, 2->2, 4->1}, but sometimes 2->1 like in LR-TDDFT) */ - DensityMatrix(const Parallel_Orbitals* _paraV, const int nspin); + DensityMatrix(const Parallel_Orbitals* _paraV, const int nspin, const bool is_DMK_row_major = false); /** * @brief initialize density matrix DMR from UnitCell @@ -211,6 +212,15 @@ class DensityMatrix const std::vector>& get_kvec_d() const { return this->_kvec_d; } + bool is_DMK_row_major() const { return this->_is_DMK_row_major; } + + int dmk_index(const int irow, const int icol) const + { + return this->_is_DMK_row_major + ? irow * this->_paraV->ncol + icol + : icol * this->_paraV->nrow + irow; + } + /** * @brief calculate density matrix DMR from dm(k) using blas::axpy * @param ik_in @@ -298,6 +308,8 @@ class DensityMatrix // std::vector _DMK; std::vector> _DMK; + bool _is_DMK_row_major = false; + /** * @brief K_Vectors object, which is used to get k-point information */ diff --git a/source/source_estate/module_dm/density_matrix_io.cpp b/source/source_estate/module_dm/density_matrix_io.cpp index d3f53aa241b..742c5e47e6d 100644 --- a/source/source_estate/module_dm/density_matrix_io.cpp +++ b/source/source_estate/module_dm/density_matrix_io.cpp @@ -213,14 +213,13 @@ void DensityMatrix::set_DMK_pointer(const int ik, TK* DMK_in) // set _DMK element template -void DensityMatrix::set_DMK(const int ispin, const int ik, const int i, const int j, const TK value) +void DensityMatrix::set_DMK(const int ispin, const int ik, const int irow, const int icol, const TK value) { #ifdef __DEBUG assert(ispin > 0 && ispin <= this->_nspin); assert(ik >= 0 && ik < this->_nk); #endif - // consider transpose col=>row - this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->nrow + j] = value; + this->_DMK[ik + this->_nk * (ispin - 1)][this->dmk_index(irow, icol)] = value; } // set _DMK element @@ -236,13 +235,12 @@ void DensityMatrix::set_DMK_zero() // get a matrix element of density matrix dm(k) template -TK DensityMatrix::get_DMK(const int ispin, const int ik, const int i, const int j) const +TK DensityMatrix::get_DMK(const int ispin, const int ik, const int irow, const int icol) const { #ifdef __DEBUG assert(ispin > 0 && ispin <= this->_nspin); #endif - // consider transpose col=>row - return this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->nrow + j]; + return this->_DMK[ik + this->_nk * (ispin - 1)][this->dmk_index(irow, icol)]; } // get _DMK nks, nrow, ncol @@ -353,11 +351,11 @@ void DensityMatrix::read_DMK(const std::string directory, const int ispi } // If file exist, read in data. // Finish reading the first part of density matrix. - for (int i = 0; i < this->_paraV->nrow; ++i) + for (int irow = 0; irow < this->_paraV->nrow; ++irow) { - for (int j = 0; j < this->_paraV->ncol; ++j) + for (int icol = 0; icol < this->_paraV->ncol; ++icol) { - ifs >> this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->ncol + j]; + ifs >> this->_DMK[ik + this->_nk * (ispin - 1)][this->dmk_index(irow, icol)]; } } ifs.close(); @@ -386,15 +384,15 @@ void DensityMatrix::write_DMK(const std::string directory, const ofs << std::setprecision(3); ofs << std::scientific; - for (int i = 0; i < this->_paraV->nrow; ++i) + for (int irow = 0; irow < this->_paraV->nrow; ++irow) { - for (int j = 0; j < this->_paraV->ncol; ++j) + for (int icol = 0; icol < this->_paraV->ncol; ++icol) { - if (j % 8 == 0) + if (icol % 8 == 0) { ofs << "\n"; } - ofs << " " << this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->ncol + j]; + ofs << " " << this->_DMK[ik + this->_nk * (ispin - 1)][this->dmk_index(irow, icol)]; } } @@ -423,15 +421,15 @@ void DensityMatrix, double>::write_DMK(const std::string di ofs << std::setprecision(3); ofs << std::scientific; - for (int i = 0; i < this->_paraV->nrow; ++i) + for (int irow = 0; irow < this->_paraV->nrow; ++irow) { - for (int j = 0; j < this->_paraV->ncol; ++j) + for (int icol = 0; icol < this->_paraV->ncol; ++icol) { - if (j % 8 == 0) + if (icol % 8 == 0) { ofs << "\n"; } - ofs << " " << this->_DMK[ik + this->_nk * (ispin - 1)][i * this->_paraV->ncol + j].real(); + ofs << " " << this->_DMK[ik + this->_nk * (ispin - 1)][this->dmk_index(irow, icol)].real(); } } @@ -443,4 +441,4 @@ template class DensityMatrix; // Gamma-Only case template class DensityMatrix, double>; // Multi-k case template class DensityMatrix, std::complex>; // For EXX in future -} // namespace elecstate \ No newline at end of file +} // namespace elecstate diff --git a/source/source_estate/module_dm/test/test_cal_dmk_psi.cpp b/source/source_estate/module_dm/test/test_cal_dmk_psi.cpp index 689f4f37920..a82b53c26e0 100644 --- a/source/source_estate/module_dm/test/test_cal_dmk_psi.cpp +++ b/source/source_estate/module_dm/test/test_cal_dmk_psi.cpp @@ -107,7 +107,7 @@ TEST_F(DMTest, cal_dmk_psi_nspin1) std::cout << "dim0: " << paraV->dim0 << " dim1:" << paraV->dim1 << std::endl; std::cout << "nrow: " << paraV->nrow << " ncol:" << paraV->ncol << std::endl; int nspin = 1; - elecstate::DensityMatrix DM(kv, paraV, nspin); + elecstate::DensityMatrix DM(paraV, nspin, kv->kvec_d, kv->get_nks()); // compare EXPECT_EQ(DM.get_DMK_nks(), kv->get_nks()); EXPECT_EQ(DM.get_DMK_nrow(), paraV->nrow); @@ -152,8 +152,7 @@ TEST_F(DMTest, cal_dmk_psi_nspin1) { for (int j = 0; j < paraV->ncol; j++) { - // std::cout << ptr[i*paraV->ncol+j] << " "; - EXPECT_EQ(ptr[i * paraV->ncol + j], is + ik * i + j); + EXPECT_EQ(ptr[DM.dmk_index(i, j)], is + ik * i + j); } } } diff --git a/source/source_estate/module_dm/test/test_dm_constructor.cpp b/source/source_estate/module_dm/test/test_dm_constructor.cpp index f82abf8676b..1f4fbec0e76 100644 --- a/source/source_estate/module_dm/test/test_dm_constructor.cpp +++ b/source/source_estate/module_dm/test/test_dm_constructor.cpp @@ -160,8 +160,7 @@ TEST_F(DMTest, DMConstructor_nspin1) { for (int j = 0; j < paraV->ncol; j++) { - // std::cout << ptr[i*paraV->ncol+j] << " "; - EXPECT_EQ(ptr[i * paraV->ncol + j], is + ik * i + j); + EXPECT_EQ(ptr[DM.dmk_index(i, j)], is + ik * i + j); } } } @@ -170,6 +169,40 @@ TEST_F(DMTest, DMConstructor_nspin1) delete kv; } +TEST_F(DMTest, DMKStorageOrder) +{ + int nspin = 1; + elecstate::DensityMatrix DM_col(paraV, nspin); + elecstate::DensityMatrix DM_row(paraV, nspin, true); + + EXPECT_FALSE(DM_col.is_DMK_row_major()); + EXPECT_TRUE(DM_row.is_DMK_row_major()); + + for (int i = 0; i < paraV->nrow; i++) + { + for (int j = 0; j < paraV->ncol; j++) + { + const double value = 100.0 * i + j; + DM_col.set_DMK(1, 0, i, j, value); + DM_row.set_DMK(1, 0, i, j, value); + } + } + + double* col_ptr = DM_col.get_DMK_pointer(0); + double* row_ptr = DM_row.get_DMK_pointer(0); + for (int i = 0; i < paraV->nrow; i++) + { + for (int j = 0; j < paraV->ncol; j++) + { + const double value = 100.0 * i + j; + EXPECT_EQ(DM_col.get_DMK(1, 0, i, j), value); + EXPECT_EQ(DM_row.get_DMK(1, 0, i, j), value); + EXPECT_EQ(col_ptr[j * paraV->nrow + i], value); + EXPECT_EQ(row_ptr[i * paraV->ncol + j], value); + } + } +} + TEST_F(DMTest, DMConstructor_nspin2) { // initalize a kvectors @@ -227,8 +260,7 @@ TEST_F(DMTest, DMConstructor_nspin2) { for (int j = 0; j < paraV->ncol; j++) { - // std::cout << ptr[i*paraV->ncol+j] << " "; - EXPECT_EQ(ptr[i * paraV->ncol + j], ik * i + j); + EXPECT_EQ(ptr[DM.dmk_index(i, j)], ik * i + j); } } } diff --git a/source/source_io/module_current/td_current_io.cpp b/source/source_io/module_current/td_current_io.cpp index b5a59cbe270..0565fb61d28 100644 --- a/source/source_io/module_current/td_current_io.cpp +++ b/source/source_io/module_current/td_current_io.cpp @@ -203,10 +203,10 @@ void ModuleIO::cal_tmp_DM_k(const UnitCell& ucell, { ModuleBase::TITLE("ModuleIO", "cal_tmp_DM_k"); ModuleBase::timer::start("ModuleIO", "cal_tmp_DM_k"); - int ld_hk = DM_real.get_paraV_pointer()->nrow; - int ld_hk2 = 2 * ld_hk; + const int ld_hk2 = 2 * DM_real.get_paraV_pointer()->nrow; // tmp for is - int ik_begin = DM_real.get_DMK_nks() / nspin * (is - 1); // jump nk for spin_down if nspin==2 + const int ik_begin = DM_real.get_DMK_nks() / nspin * (is - 1); // jump nk for spin_down if nspin==2 + const bool is_DMK_row_major = DM_real.is_DMK_row_major(); //sum spin up and down into up hamilt::HContainer* tmp_DMR_real = DM_real.get_DMR_vector()[0]; hamilt::HContainer* tmp_DMR_imag = DM_imag.get_DMR_vector()[0]; @@ -271,9 +271,10 @@ void ModuleIO::cal_tmp_DM_k(const UnitCell& ucell, std::complex* tmp_DMK_pointer = DM_real.get_DMK_pointer(ik + ik_begin); double* DMK_real_pointer = nullptr; double* DMK_imag_pointer = nullptr; - // jump DMK to fill DMR - // DMR is row-major, DMK is column-major - tmp_DMK_pointer += col_ap * DM_real.get_paraV_pointer()->nrow + row_ap; + // jump DMK to fill DMR; DMR is row-major. + tmp_DMK_pointer += is_DMK_row_major + ? row_ap * DM_real.get_paraV_pointer()->ncol + col_ap + : col_ap * DM_real.get_paraV_pointer()->nrow + row_ap; for (int mu = 0; mu < DM_real.get_paraV_pointer()->get_row_size(iat1); ++mu) { DMK_real_pointer = (double*)tmp_DMK_pointer; @@ -282,29 +283,29 @@ void ModuleIO::cal_tmp_DM_k(const UnitCell& ucell, BlasConnector::axpy(DM_real.get_paraV_pointer()->get_col_size(iat2), -kphase.imag(), DMK_imag_pointer, - ld_hk2, + is_DMK_row_major ? 2 : ld_hk2, tmp_DMR_real_pointer, 1); BlasConnector::axpy(DM_real.get_paraV_pointer()->get_col_size(iat2), kphase.real(), DMK_real_pointer, - ld_hk2, + is_DMK_row_major ? 2 : ld_hk2, tmp_DMR_real_pointer, 1); // calculate imag part BlasConnector::axpy(DM_imag.get_paraV_pointer()->get_col_size(iat2), kphase.imag(), DMK_real_pointer, - ld_hk2, + is_DMK_row_major ? 2 : ld_hk2, tmp_DMR_imag_pointer, 1); BlasConnector::axpy(DM_imag.get_paraV_pointer()->get_col_size(iat2), kphase.real(), DMK_imag_pointer, - ld_hk2, + is_DMK_row_major ? 2 : ld_hk2, tmp_DMR_imag_pointer, 1); - tmp_DMK_pointer += 1; + tmp_DMK_pointer += is_DMK_row_major ? DM_real.get_paraV_pointer()->ncol : 1; tmp_DMR_real_pointer += DM_real.get_paraV_pointer()->get_col_size(iat2); tmp_DMR_imag_pointer += DM_imag.get_paraV_pointer()->get_col_size(iat2); } @@ -337,18 +338,19 @@ void ModuleIO::cal_tmp_DM_k(const UnitCell& ucell, std::complex* tmp_DMK_pointer = DM_real.get_DMK_pointer(ik + ik_begin);; double* DMK_real_pointer = nullptr; double* DMK_imag_pointer = nullptr; - // jump DMK to fill DMR - // DMR is row-major, DMK is column-major - tmp_DMK_pointer += col_ap * DM_real.get_paraV_pointer()->nrow + row_ap; + // jump DMK to fill DMR; DMR is row-major. + tmp_DMK_pointer += is_DMK_row_major + ? row_ap * DM_real.get_paraV_pointer()->ncol + col_ap + : col_ap * DM_real.get_paraV_pointer()->nrow + row_ap; for (int mu = 0; mu < tmp_ap_real.get_row_size(); ++mu) { BlasConnector::axpy(tmp_ap_real.get_col_size(), kphase, tmp_DMK_pointer, - ld_hk, + is_DMK_row_major ? 1 : DM_real.get_paraV_pointer()->nrow, tmp_DMR_pointer, 1); - tmp_DMK_pointer += 1; + tmp_DMK_pointer += is_DMK_row_major ? DM_real.get_paraV_pointer()->ncol : 1; tmp_DMR_pointer += tmp_ap_real.get_col_size(); } } @@ -636,4 +638,3 @@ void ModuleIO::write_current>(const UnitCell& ucell, const Velocity_op>* cal_current, Record_adj& ra); #endif //__LCAO - diff --git a/source/source_lcao/module_lr/dm_trans/dmr_complex.cpp b/source/source_lcao/module_lr/dm_trans/dmr_complex.cpp index a973155c32c..291d8b1aac7 100644 --- a/source/source_lcao/module_lr/dm_trans/dmr_complex.cpp +++ b/source/source_lcao/module_lr/dm_trans/dmr_complex.cpp @@ -54,18 +54,18 @@ namespace elecstate std::complex* tmp_DMR_pointer = tmp_matrix->get_pointer(); const std::complex* tmp_DMK_pointer = this->_DMK[ik + ik_begin].data() - + col_ap * this->_paraV->nrow + row_ap; + + this->dmk_index(row_ap, col_ap); // jump DMK to fill DMR - // DMR is row-major, DMK is column-major + // DMR is row-major. for (int mu = 0; mu < this->_paraV->get_row_size(iat1); ++mu) { BlasConnector::axpy(this->_paraV->get_col_size(iat2), kphase, tmp_DMK_pointer, - this->_paraV->get_row_size(), + this->_is_DMK_row_major ? 1 : this->_paraV->nrow, tmp_DMR_pointer, 1); - tmp_DMK_pointer += 1; + tmp_DMK_pointer += this->_is_DMK_row_major ? this->_paraV->ncol : 1; tmp_DMR_pointer += this->_paraV->get_col_size(iat2); } } @@ -79,4 +79,4 @@ namespace elecstate ModuleBase::timer::end("DensityMatrix", "cal_DMR"); } // template class DensityMatrix, std::complex>; -} \ No newline at end of file +} From 2d70d5ad628483d964c57a69c0a34d6d1772bf33 Mon Sep 17 00:00:00 2001 From: linpz Date: Sun, 7 Jun 2026 06:57:12 +0800 Subject: [PATCH 2/2] fix psiMulPsi() --- source/source_estate/cal_dm.h | 4 ++-- source/source_estate/math_tools.h | 13 +++++++------ source/source_lcao/module_rdmft/rdmft_pot.cpp | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/source/source_estate/cal_dm.h b/source/source_estate/cal_dm.h index ad38e369d02..cd831915399 100644 --- a/source/source_estate/cal_dm.h +++ b/source/source_estate/cal_dm.h @@ -56,7 +56,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, #ifdef __MPI psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc); #else - psiMulPsi(wg_wfc, wfc, dm[ik]); + psiMulPsi(wg_wfc, wfc, dm[ik], false); #endif } ModuleBase::timer::end("elecstate","cal_dm"); @@ -114,7 +114,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, #ifdef __MPI psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc); #else - psiMulPsi(wg_wfc, wfc, dm[ik]); + psiMulPsi(wg_wfc, wfc, dm[ik], false); #endif } diff --git a/source/source_estate/math_tools.h b/source/source_estate/math_tools.h index 049f9a4e8e9..0e8a72a3081 100644 --- a/source/source_estate/math_tools.h +++ b/source/source_estate/math_tools.h @@ -75,7 +75,7 @@ inline void psiMulPsiMpi(const psi::Psi>& psi1, } #else -inline void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, ModuleBase::matrix& dm_out) +inline void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2, ModuleBase::matrix& dm_out, const bool is_DMK_row_major) { const double one_float = 1.0, zero_float = 0.0; const int one_int = 1; @@ -88,9 +88,9 @@ inline void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2 &nlocal, &nbands, &one_float, - psi1.get_pointer(), + is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(), &nlocal, - psi2.get_pointer(), + is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(), &nlocal, &zero_float, dm_out.c, @@ -99,7 +99,8 @@ inline void psiMulPsi(const psi::Psi& psi1, const psi::Psi& psi2 inline void psiMulPsi(const psi::Psi>& psi1, const psi::Psi>& psi2, - ModuleBase::ComplexMatrix& dm_out) + ModuleBase::ComplexMatrix& dm_out, + const bool is_DMK_row_major) { const int one_int = 1; const char N_char = 'N', T_char = 'T'; @@ -112,9 +113,9 @@ inline void psiMulPsi(const psi::Psi>& psi1, &nlocal, &nbands, &one_complex, - psi1.get_pointer(), + is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(), &nlocal, - psi2.get_pointer(), + is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(), &nlocal, &zero_complex, dm_out.c, diff --git a/source/source_lcao/module_rdmft/rdmft_pot.cpp b/source/source_lcao/module_rdmft/rdmft_pot.cpp index 9c3d708fa16..f82738e2680 100644 --- a/source/source_lcao/module_rdmft/rdmft_pot.cpp +++ b/source/source_lcao/module_rdmft/rdmft_pot.cpp @@ -40,7 +40,7 @@ void RDMFT::get_DM_XC(std::vector< std::vector >& DM_XC) #ifdef __MPI elecstate::psiMulPsiMpi(wk_funEta_wfc, wfc, DM_Kpointer, ParaV->desc_wfc, ParaV->desc); #else - elecstate::psiMulPsi(wk_funEta_wfc, wfc, DM_Kpointer); + elecstate::psiMulPsi(wk_funEta_wfc, wfc, DM_Kpointer, false); #endif } }