Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion source/source_estate/elecstate_pw_cal_tau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
castmem_var_d2h_op()(this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx);
}
}
this->parallelK();
#ifdef __MPI
this->charge->kin_r_mpi();
#endif
ModuleBase::TITLE("ElecStatePW", "cal_tau");
Comment thread
sunliang98 marked this conversation as resolved.
}

Expand Down
6 changes: 6 additions & 0 deletions source/source_estate/module_charge/charge.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class Charge
*/
void rho_mpi();

/**
* @brief Sum kin_r at different pools (k-point/band parallelism).
* Only used when GlobalV::KPAR * bndpar > 1
*/
void kin_r_mpi();

/**
* @brief Reduce among different pools
* If NPROC_IN_POOLs are all the same, use GlobalV::KP_WORLD
Expand Down
21 changes: 21 additions & 0 deletions source/source_estate/module_charge/charge_mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,25 @@ void Charge::rho_mpi()
ModuleBase::timer::end("Charge", "rho_mpi");
return;
}

void Charge::kin_r_mpi()
{
ModuleBase::TITLE("Charge", "kin_r_mpi");
if (GlobalV::KPAR * PARAM.inp.bndpar <= 1)
{
return;
}
ModuleBase::timer::start("Charge", "kin_r_mpi");

if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0)
{
for (int is = 0; is < PARAM.inp.nspin; ++is)
{
reduce_diff_pools(this->kin_r[is]);
}
}
Comment thread
sunliang98 marked this conversation as resolved.

ModuleBase::timer::end("Charge", "kin_r_mpi");
return;
}
#endif
1 change: 1 addition & 0 deletions source/source_estate/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ AddTest(
../elecstate_pw_cal_tau.cpp
../elecstate.cpp
../occupy.cpp
../module_charge/charge_mpi.cpp
../../source_psi/psi.cpp
# ../../source_psi/kernels/psi_memory_op.cpp
../../source_base/module_device/memory_op.cpp
Expand Down
57 changes: 57 additions & 0 deletions source/source_estate/test_mpi/charge_mpi_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,63 @@ TEST_F(ChargeMpiTest, rho_mpi)
charge->rho_mpi();
}

TEST_F(ChargeMpiTest, kin_r_mpi)
{
if (GlobalV::NPROC >= 2 && GlobalV::NPROC % 2 == 0)
{
const bool ked_flag_old = XC_Functional::ked_flag;
XC_Functional::ked_flag = true;
PARAM.input.nspin = 1;
PARAM.input.bndpar = 1;
GlobalV::KPAR = 2;

Parallel_Global::divide_pools(GlobalV::NPROC,
GlobalV::MY_RANK,
PARAM.input.bndpar,
GlobalV::KPAR,
GlobalV::NPROC_IN_BNDGROUP,
GlobalV::RANK_IN_BPGROUP,
GlobalV::MY_BNDGROUP,
GlobalV::NPROC_IN_POOL,
GlobalV::RANK_IN_POOL,
GlobalV::MY_POOL);
ModulePW::PW_Basis* rhopw = new ModulePW::PW_Basis();
rhopw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD);
rhopw->initgrids(lat0, latvec, 40);
rhopw->initparameters(false, 10);
rhopw->setuptransform();
charge->rhopw = rhopw;

const int nz = rhopw->nz;
const int nrxx = rhopw->nrxx;
const int nxy = rhopw->nxy;
const int nplane = rhopw->nplane;
charge->nrxx = nrxx;
charge->kin_r = new double*[1];
charge->kin_r[0] = new double[nrxx];

for (int ir = 0; ir < nxy; ++ir)
{
for (int iz = 0; iz < nplane; ++iz)
{
charge->kin_r[0][nplane * ir + iz]
= (rhopw->startz_current + iz + ir * nz) / double(nxy * nz);
}
}
const double refsum = sum_array(charge->kin_r[0], nrxx);

charge->init_chgmpi();
charge->kin_r_mpi();
const double sum = sum_array(charge->kin_r[0], nrxx);
EXPECT_EQ(sum, refsum * GlobalV::KPAR);

delete[] charge->kin_r[0];
delete[] charge->kin_r;
delete rhopw;
XC_Functional::ked_flag = ked_flag_old;
}
}

int main(int argc, char** argv)
{
MPI_Init(&argc, &argv);
Expand Down
Loading