Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions source/source_estate/cal_dm.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) { continue;
}
if (ib_global >= wg.nc) { continue; }
const double wg_local = wg(ik, ib_global);
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand All @@ -57,7 +56,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
#ifdef __MPI
psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc);
#else
psiMulPsi(wg_wfc, wfc, dm[ik]);
psiMulPsi(wg_wfc, wfc, dm[ik], false);
#endif
}
ModuleBase::timer::end("elecstate","cal_dm");
Expand Down Expand Up @@ -105,8 +104,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc) { continue;
}
if (ib_global >= wg.nc) { continue; }
const double wg_local = wg(ik, ib_global);
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand All @@ -116,7 +114,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
#ifdef __MPI
psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc);
#else
psiMulPsi(wg_wfc, wfc, dm[ik]);
psiMulPsi(wg_wfc, wfc, dm[ik], false);
#endif
}

Expand Down
13 changes: 7 additions & 6 deletions source/source_estate/math_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ inline void psiMulPsiMpi(const psi::Psi<std::complex<double>>& psi1,
}

#else
inline void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, ModuleBase::matrix& dm_out)
inline void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, ModuleBase::matrix& dm_out, const bool is_DMK_row_major)
{
const double one_float = 1.0, zero_float = 0.0;
const int one_int = 1;
Expand All @@ -88,9 +88,9 @@ inline void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2
&nlocal,
&nbands,
&one_float,
psi1.get_pointer(),
is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(),
&nlocal,
psi2.get_pointer(),
is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(),
&nlocal,
&zero_float,
dm_out.c,
Expand All @@ -99,7 +99,8 @@ inline void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2

inline void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
const psi::Psi<std::complex<double>>& psi2,
ModuleBase::ComplexMatrix& dm_out)
ModuleBase::ComplexMatrix& dm_out,
const bool is_DMK_row_major)
{
const int one_int = 1;
const char N_char = 'N', T_char = 'T';
Expand All @@ -112,9 +113,9 @@ inline void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
&nlocal,
&nbands,
&one_complex,
psi1.get_pointer(),
is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(),
&nlocal,
psi2.get_pointer(),
is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(),
&nlocal,
&zero_complex,
dm_out.c,
Expand Down
21 changes: 12 additions & 9 deletions source/source_estate/module_dm/cal_dm_psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,

// C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
#ifdef __MPI
assert(!DM.is_DMK_row_major());
psiMulPsiMpi(wg_wfc, wfc, dmk_pointer, ParaV->desc_wfc, ParaV->desc);
#else
psiMulPsi(wg_wfc, wfc, dmk_pointer);
psiMulPsi(wg_wfc, wfc, dmk_pointer, DM.is_DMK_row_major());
#endif
}
ModuleBase::timer::end("elecstate", "cal_dm_psi");
Expand Down Expand Up @@ -135,14 +136,15 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,

if (PARAM.inp.ks_solver == "cg_in_lcao")
{
psiMulPsi(wg_wfc, wfc, dmk_pointer);
psiMulPsi(wg_wfc, wfc, dmk_pointer, DM.is_DMK_row_major());
}
else
{
assert(!DM.is_DMK_row_major());
psiMulPsiMpi(wg_wfc, wfc, dmk_pointer, ParaV->desc_wfc, ParaV->desc);
}
#else
psiMulPsi(wg_wfc, wfc, dmk_pointer);
psiMulPsi(wg_wfc, wfc, dmk_pointer, DM.is_DMK_row_major());
#endif
}

Expand Down Expand Up @@ -222,7 +224,7 @@ void psiMulPsiMpi(const psi::Psi<std::complex<double>>& psi1,

#endif

void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, double* dm_out)
void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, double* dm_out, const bool is_DMK_row_major)
{
const double one_float = 1.0, zero_float = 0.0;
const int one_int = 1;
Expand All @@ -235,9 +237,9 @@ void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, doubl
nlocal,
nbands,
one_float,
psi1.get_pointer(),
is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(),
nlocal,
psi2.get_pointer(),
is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(),
nlocal,
zero_float,
dm_out,
Expand All @@ -246,7 +248,8 @@ void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, doubl

void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
const psi::Psi<std::complex<double>>& psi2,
std::complex<double>* dm_out)
std::complex<double>* dm_out,
const bool is_DMK_row_major)
{
const int one_int = 1;
const char N_char = 'N', T_char = 'T';
Expand All @@ -260,9 +263,9 @@ void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
nlocal,
nbands,
one_complex,
psi1.get_pointer(),
is_DMK_row_major ? psi2.get_pointer() : psi1.get_pointer(),
nlocal,
psi2.get_pointer(),
is_DMK_row_major ? psi1.get_pointer() : psi2.get_pointer(),
nlocal,
zero_complex,
dm_out,
Expand Down
8 changes: 6 additions & 2 deletions source/source_estate/module_dm/cal_dm_psi.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ namespace elecstate
const int* desc_dm);

// for Gamma-Only case without MPI
void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, double* dm_out);
void psiMulPsi(const psi::Psi<double>& psi1,
const psi::Psi<double>& psi2,
double* dm_out,
const bool is_DMK_row_major);

// for multi-k case without MPI
void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
const psi::Psi<std::complex<double>>& psi2,
std::complex<double>* dm_out);
std::complex<double>* dm_out,
const bool is_DMK_row_major);
};
#endif
52 changes: 32 additions & 20 deletions source/source_estate/module_dm/density_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,17 @@ DensityMatrix<TK, TR>::~DensityMatrix()
}

template <typename TK, typename TR>
DensityMatrix<TK, TR>::DensityMatrix(const Parallel_Orbitals* paraV_in, const int nspin, const std::vector<ModuleBase::Vector3<double>>& kvec_d, const int nk)
: _paraV(paraV_in), _nspin(nspin), _kvec_d(kvec_d), _nk((nk > 0 && nk <= _kvec_d.size()) ? nk : _kvec_d.size())
DensityMatrix<TK, TR>::DensityMatrix(
const Parallel_Orbitals* paraV_in,
const int nspin,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const int nk,
const bool is_DMK_row_major)
: _paraV(paraV_in),
_nspin(nspin),
_kvec_d(kvec_d),
_nk((nk > 0 && nk <= _kvec_d.size()) ? nk : _kvec_d.size()),
_is_DMK_row_major(is_DMK_row_major)
{
ModuleBase::TITLE("DensityMatrix", "resize_DMK");
const int nks = _nk * _nspin;
Expand All @@ -42,7 +51,15 @@ DensityMatrix<TK, TR>::DensityMatrix(const Parallel_Orbitals* paraV_in, const in
}

template <typename TK, typename TR>
DensityMatrix<TK, TR>::DensityMatrix(const Parallel_Orbitals* paraV_in, const int nspin) :_paraV(paraV_in), _nspin(nspin), _kvec_d({ ModuleBase::Vector3<double>(0,0,0) }), _nk(1)
DensityMatrix<TK, TR>::DensityMatrix(
const Parallel_Orbitals* paraV_in,
const int nspin,
const bool is_DMK_row_major)
: _paraV(paraV_in),
_nspin(nspin),
_kvec_d({ ModuleBase::Vector3<double>(0,0,0) }),
_nk(1),
_is_DMK_row_major(is_DMK_row_major)
{
ModuleBase::TITLE("DensityMatrix", "resize_gamma");
this->_DMK.resize(_nspin);
Expand All @@ -68,7 +85,6 @@ void DensityMatrix_Tools::cal_DMR(
assert(dmR_out.size()==dm._nspin && "DMR has not been initialized!");

ModuleBase::timer::start("DensityMatrix", "cal_DMR");
const int ld_hk = dm._paraV->nrow;
for (int is = 1; is <= dm._nspin; ++is)
{
const int ik_begin = dm._nk * (is - 1); // jump dm._nk for spin_down if nspin==2
Expand Down Expand Up @@ -125,13 +141,13 @@ void DensityMatrix_Tools::cal_DMR(
for(int ik = 0; ik < dm._nk; ++ik)
{
if(ik_in >= 0 && ik_in != ik) { continue; }
// copy column-major DMK to row-major DMK_mat_trans (for the purpose of computational efficiency)
// copy DMK to row-major DMK_mat_trans (for the purpose of computational efficiency)
const TK*const DMK_mat_ptr
= dm._DMK[ik + ik_begin].data()
+ col_ap * dm._paraV->nrow + row_ap;
+ dm.dmk_index(row_ap, col_ap);
for(int icol = 0; icol < col_size; ++icol) {
for(int irow = 0; irow < row_size; ++irow) {
DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow];
DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[dm.dmk_index(irow, icol)];
}}

// if nspin != 4, fill DMR
Expand Down Expand Up @@ -223,7 +239,6 @@ void DensityMatrix_Tools::cal_DMR_td(
assert(dmR_out.size()==dm._nspin && "DMR has not been initialized!");

ModuleBase::timer::start("DensityMatrix", "cal_DMR_td");
const int ld_hk = dm._paraV->nrow;
for (int is = 1; is <= dm._nspin; ++is)
{
const int ik_begin = dm._nk * (is - 1); // jump dm._nk for spin_down if nspin==2
Expand Down Expand Up @@ -283,13 +298,13 @@ void DensityMatrix_Tools::cal_DMR_td(
for(int ik = 0; ik < dm._nk; ++ik)
{
if(ik_in >= 0 && ik_in != ik) { continue; }
// copy column-major DMK to row-major DMK_mat_trans (for the purpose of computational efficiency)
// copy DMK to row-major DMK_mat_trans (for the purpose of computational efficiency)
const TK*const DMK_mat_ptr
= dm._DMK[ik + ik_begin].data()
+ col_ap * dm._paraV->nrow + row_ap;
+ dm.dmk_index(row_ap, col_ap);
for(int icol = 0; icol < col_size; ++icol) {
for(int irow = 0; irow < row_size; ++irow) {
DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow];
DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[dm.dmk_index(irow, icol)];
}}

// if nspin != 4, fill DMR
Expand Down Expand Up @@ -381,7 +396,6 @@ void DensityMatrix_Tools::cal_DMR_full(
ModuleBase::TITLE("DensityMatrix", "cal_DMR_full");

ModuleBase::timer::start("DensityMatrix", "cal_DMR_full");
const int ld_hk = dm._paraV->nrow;
hamilt::HContainer<TR_out>* target_DMR = dmR_out;
// set zero since this function is called in every scf step
target_DMR->set_zero();
Expand Down Expand Up @@ -434,13 +448,13 @@ void DensityMatrix_Tools::cal_DMR_full(
for(int ik = 0; ik < dm._nk; ++ik)
{
if(ik_in >= 0 && ik_in != ik) { continue; }
// copy column-major DMK to row-major DMK_mat_trans (for the purpose of computational efficiency)
// copy DMK to row-major DMK_mat_trans (for the purpose of computational efficiency)
const TK*const DMK_mat_ptr
= dm._DMK[ik].data()
+ col_ap * dm._paraV->nrow + row_ap;
+ dm.dmk_index(row_ap, col_ap);
for(int icol = 0; icol < col_size; ++icol) {
for(int irow = 0; irow < row_size; ++irow) {
DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[icol * ld_hk + irow];
DMK_mat_trans[irow * col_size + icol] = DMK_mat_ptr[dm.dmk_index(irow, icol)];
}}

for(int iR = 0; iR < R_size; ++iR)
Expand Down Expand Up @@ -487,7 +501,6 @@ void DensityMatrix<double, double>::cal_DMR(const int ik_in)
assert(this->_DMR.size()==this->_nspin && "DMR has not been initialized!");

ModuleBase::timer::start("DensityMatrix", "cal_DMR");
const int ld_hk = this->_paraV->nrow;
for (int is = 1; is <= this->_nspin; ++is)
{
const int ik_begin = this->_nk * (is - 1); // jump this->_nk for spin_down if nspin==2
Expand Down Expand Up @@ -522,21 +535,20 @@ void DensityMatrix<double, double>::cal_DMR(const int ik_in)
#endif
// k index
constexpr TK kphase = 1;
// transpose DMK col=>row
const TK* DMK_mat_ptr
= this->_DMK[0 + ik_begin].data()
+ col_ap * this->_paraV->nrow + row_ap;
+ this->dmk_index(row_ap, col_ap);
// set DMR element
TR* target_DMR_ptr = target_mat->get_pointer();
for (int mu = 0; mu < row_size; ++mu)
{
BlasConnector::axpy(col_size,
kphase,
DMK_mat_ptr,
ld_hk,
this->_is_DMK_row_major ? 1 : this->_paraV->nrow,
target_DMR_ptr,
1);
DMK_mat_ptr += 1;
DMK_mat_ptr += this->_is_DMK_row_major ? this->_paraV->ncol : 1;
target_DMR_ptr += col_size;
}
}
Expand Down
16 changes: 14 additions & 2 deletions source/source_estate/module_dm/density_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,16 @@ class DensityMatrix
DensityMatrix(const Parallel_Orbitals* _paraV,
const int nspin,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const int nk);
const int nk,
const bool is_DMK_row_major = false);

/**
* @brief Constructor of class DensityMatrix for gamma-only calculation, where kvector is not required
* @param _paraV pointer of Parallel_Orbitals object
* @param nspin number of spin of the density matrix, set by user according to global nspin
* (usually {nspin_global -> nspin_dm} = {1->1, 2->2, 4->1}, but sometimes 2->1 like in LR-TDDFT)
*/
DensityMatrix(const Parallel_Orbitals* _paraV, const int nspin);
DensityMatrix(const Parallel_Orbitals* _paraV, const int nspin, const bool is_DMK_row_major = false);

/**
* @brief initialize density matrix DMR from UnitCell
Expand Down Expand Up @@ -211,6 +212,15 @@ class DensityMatrix

const std::vector<ModuleBase::Vector3<double>>& get_kvec_d() const { return this->_kvec_d; }

bool is_DMK_row_major() const { return this->_is_DMK_row_major; }

int dmk_index(const int irow, const int icol) const
{
return this->_is_DMK_row_major
? irow * this->_paraV->ncol + icol
: icol * this->_paraV->nrow + irow;
}

/**
* @brief calculate density matrix DMR from dm(k) using blas::axpy
* @param ik_in
Expand Down Expand Up @@ -298,6 +308,8 @@ class DensityMatrix
// std::vector<ModuleBase::ComplexMatrix> _DMK;
std::vector<std::vector<TK>> _DMK;

bool _is_DMK_row_major = false;

/**
* @brief K_Vectors object, which is used to get k-point information
*/
Expand Down
Loading
Loading