diff --git a/source/source_estate/elecstate_pw_cal_tau.cpp b/source/source_estate/elecstate_pw_cal_tau.cpp index 628dd25aeff..a59990600a3 100644 --- a/source/source_estate/elecstate_pw_cal_tau.cpp +++ b/source/source_estate/elecstate_pw_cal_tau.cpp @@ -54,7 +54,9 @@ void ElecStatePW::cal_tau(const psi::Psi& psi) castmem_var_d2h_op()(this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx); } } - this->parallelK(); +#ifdef __MPI + this->charge->kin_r_mpi(); +#endif ModuleBase::TITLE("ElecStatePW", "cal_tau"); } diff --git a/source/source_estate/module_charge/charge.h b/source/source_estate/module_charge/charge.h index 79a5f5ca9e5..c49e529fb04 100644 --- a/source/source_estate/module_charge/charge.h +++ b/source/source_estate/module_charge/charge.h @@ -136,6 +136,12 @@ class Charge */ void rho_mpi(); + /** + * @brief Sum kin_r at different pools (k-point/band parallelism). + * Only used when GlobalV::KPAR * bndpar > 1 + */ + void kin_r_mpi(); + /** * @brief Reduce among different pools * If NPROC_IN_POOLs are all the same, use GlobalV::KP_WORLD diff --git a/source/source_estate/module_charge/charge_mpi.cpp b/source/source_estate/module_charge/charge_mpi.cpp index e9c229897b8..442317ecb26 100644 --- a/source/source_estate/module_charge/charge_mpi.cpp +++ b/source/source_estate/module_charge/charge_mpi.cpp @@ -137,4 +137,25 @@ void Charge::rho_mpi() ModuleBase::timer::end("Charge", "rho_mpi"); return; } + +void Charge::kin_r_mpi() +{ + ModuleBase::TITLE("Charge", "kin_r_mpi"); + if (GlobalV::KPAR * PARAM.inp.bndpar <= 1) + { + return; + } + ModuleBase::timer::start("Charge", "kin_r_mpi"); + + if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0) + { + for (int is = 0; is < PARAM.inp.nspin; ++is) + { + reduce_diff_pools(this->kin_r[is]); + } + } + + ModuleBase::timer::end("Charge", "kin_r_mpi"); + return; +} #endif diff --git a/source/source_estate/test/CMakeLists.txt b/source/source_estate/test/CMakeLists.txt index e27a241b9c4..aa69f4e29d8 100644 --- a/source/source_estate/test/CMakeLists.txt +++ b/source/source_estate/test/CMakeLists.txt @@ -56,6 +56,7 @@ AddTest( ../elecstate_pw_cal_tau.cpp ../elecstate.cpp ../occupy.cpp + ../module_charge/charge_mpi.cpp ../../source_psi/psi.cpp # ../../source_psi/kernels/psi_memory_op.cpp ../../source_base/module_device/memory_op.cpp diff --git a/source/source_estate/test_mpi/charge_mpi_test.cpp b/source/source_estate/test_mpi/charge_mpi_test.cpp index 0ddf8346908..e3214fbf655 100644 --- a/source/source_estate/test_mpi/charge_mpi_test.cpp +++ b/source/source_estate/test_mpi/charge_mpi_test.cpp @@ -201,6 +201,63 @@ TEST_F(ChargeMpiTest, rho_mpi) charge->rho_mpi(); } +TEST_F(ChargeMpiTest, kin_r_mpi) +{ + if (GlobalV::NPROC >= 2 && GlobalV::NPROC % 2 == 0) + { + const bool ked_flag_old = XC_Functional::ked_flag; + XC_Functional::ked_flag = true; + PARAM.input.nspin = 1; + PARAM.input.bndpar = 1; + GlobalV::KPAR = 2; + + Parallel_Global::divide_pools(GlobalV::NPROC, + GlobalV::MY_RANK, + PARAM.input.bndpar, + GlobalV::KPAR, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, + GlobalV::NPROC_IN_POOL, + GlobalV::RANK_IN_POOL, + GlobalV::MY_POOL); + ModulePW::PW_Basis* rhopw = new ModulePW::PW_Basis(); + rhopw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD); + rhopw->initgrids(lat0, latvec, 40); + rhopw->initparameters(false, 10); + rhopw->setuptransform(); + charge->rhopw = rhopw; + + const int nz = rhopw->nz; + const int nrxx = rhopw->nrxx; + const int nxy = rhopw->nxy; + const int nplane = rhopw->nplane; + charge->nrxx = nrxx; + charge->kin_r = new double*[1]; + charge->kin_r[0] = new double[nrxx]; + + for (int ir = 0; ir < nxy; ++ir) + { + for (int iz = 0; iz < nplane; ++iz) + { + charge->kin_r[0][nplane * ir + iz] + = (rhopw->startz_current + iz + ir * nz) / double(nxy * nz); + } + } + const double refsum = sum_array(charge->kin_r[0], nrxx); + + charge->init_chgmpi(); + charge->kin_r_mpi(); + const double sum = sum_array(charge->kin_r[0], nrxx); + EXPECT_EQ(sum, refsum * GlobalV::KPAR); + + delete[] charge->kin_r[0]; + delete[] charge->kin_r; + delete rhopw; + XC_Functional::ked_flag = ked_flag_old; + } +} + int main(int argc, char** argv) { MPI_Init(&argc, &argv);