From 439327678cfad1e34bcd8bc406c5a3940db33d7e Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Mon, 4 May 2026 00:10:27 +0800 Subject: [PATCH 1/4] Fix: Fix ELF for kpar > 1 --- source/source_estate/elecstate_pw_cal_tau.cpp | 2 +- source/source_estate/module_charge/charge.h | 6 ++++++ .../module_charge/charge_mpi.cpp | 21 +++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/source/source_estate/elecstate_pw_cal_tau.cpp b/source/source_estate/elecstate_pw_cal_tau.cpp index 628dd25aeff..b7816528f8c 100644 --- a/source/source_estate/elecstate_pw_cal_tau.cpp +++ b/source/source_estate/elecstate_pw_cal_tau.cpp @@ -54,7 +54,7 @@ void ElecStatePW::cal_tau(const psi::Psi& psi) castmem_var_d2h_op()(this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx); } } - this->parallelK(); + this->charge->kin_r_mpi(); ModuleBase::TITLE("ElecStatePW", "cal_tau"); } diff --git a/source/source_estate/module_charge/charge.h b/source/source_estate/module_charge/charge.h index 79a5f5ca9e5..c49e529fb04 100644 --- a/source/source_estate/module_charge/charge.h +++ b/source/source_estate/module_charge/charge.h @@ -136,6 +136,12 @@ class Charge */ void rho_mpi(); + /** + * @brief Sum kin_r at different pools (k-point/band parallelism). + * Only used when GlobalV::KPAR * bndpar > 1 + */ + void kin_r_mpi(); + /** * @brief Reduce among different pools * If NPROC_IN_POOLs are all the same, use GlobalV::KP_WORLD diff --git a/source/source_estate/module_charge/charge_mpi.cpp b/source/source_estate/module_charge/charge_mpi.cpp index e9c229897b8..442317ecb26 100644 --- a/source/source_estate/module_charge/charge_mpi.cpp +++ b/source/source_estate/module_charge/charge_mpi.cpp @@ -137,4 +137,25 @@ void Charge::rho_mpi() ModuleBase::timer::end("Charge", "rho_mpi"); return; } + +void Charge::kin_r_mpi() +{ + ModuleBase::TITLE("Charge", "kin_r_mpi"); + if (GlobalV::KPAR * PARAM.inp.bndpar <= 1) + { + return; + } + ModuleBase::timer::start("Charge", "kin_r_mpi"); + + if (XC_Functional::get_ked_flag() || PARAM.inp.out_elf[0] > 0) + { + for (int is = 0; is < PARAM.inp.nspin; ++is) + { + reduce_diff_pools(this->kin_r[is]); + } + } + + ModuleBase::timer::end("Charge", "kin_r_mpi"); + return; +} #endif From 5e8bb84d2516c9ad46df72234276ac017b6587ee Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Mon, 4 May 2026 10:53:44 +0800 Subject: [PATCH 2/4] Fix: add #ifdef __MPN --- source/source_estate/elecstate_pw_cal_tau.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/source_estate/elecstate_pw_cal_tau.cpp b/source/source_estate/elecstate_pw_cal_tau.cpp index b7816528f8c..a59990600a3 100644 --- a/source/source_estate/elecstate_pw_cal_tau.cpp +++ b/source/source_estate/elecstate_pw_cal_tau.cpp @@ -54,7 +54,9 @@ void ElecStatePW::cal_tau(const psi::Psi& psi) castmem_var_d2h_op()(this->charge->kin_r[ii], this->kin_r[ii], this->charge->nrxx); } } +#ifdef __MPI this->charge->kin_r_mpi(); +#endif ModuleBase::TITLE("ElecStatePW", "cal_tau"); } From 0aae093426aa44e94debb2ea559d62690fc22774 Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Mon, 4 May 2026 10:54:02 +0800 Subject: [PATCH 3/4] Test: Update source/source_estate/test/CMakeLists.txt --- source/source_estate/test/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/source/source_estate/test/CMakeLists.txt b/source/source_estate/test/CMakeLists.txt index e27a241b9c4..aa69f4e29d8 100644 --- a/source/source_estate/test/CMakeLists.txt +++ b/source/source_estate/test/CMakeLists.txt @@ -56,6 +56,7 @@ AddTest( ../elecstate_pw_cal_tau.cpp ../elecstate.cpp ../occupy.cpp + ../module_charge/charge_mpi.cpp ../../source_psi/psi.cpp # ../../source_psi/kernels/psi_memory_op.cpp ../../source_base/module_device/memory_op.cpp From 40e11e4b06eabd8230201f0271bc33883ca40fb0 Mon Sep 17 00:00:00 2001 From: sunliang98 <1700011430@pku.edu.cn> Date: Mon, 4 May 2026 10:54:22 +0800 Subject: [PATCH 4/4] Test: Add a unit test for kin_r_mpi --- .../test_mpi/charge_mpi_test.cpp | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/source/source_estate/test_mpi/charge_mpi_test.cpp b/source/source_estate/test_mpi/charge_mpi_test.cpp index 0ddf8346908..e3214fbf655 100644 --- a/source/source_estate/test_mpi/charge_mpi_test.cpp +++ b/source/source_estate/test_mpi/charge_mpi_test.cpp @@ -201,6 +201,63 @@ TEST_F(ChargeMpiTest, rho_mpi) charge->rho_mpi(); } +TEST_F(ChargeMpiTest, kin_r_mpi) +{ + if (GlobalV::NPROC >= 2 && GlobalV::NPROC % 2 == 0) + { + const bool ked_flag_old = XC_Functional::ked_flag; + XC_Functional::ked_flag = true; + PARAM.input.nspin = 1; + PARAM.input.bndpar = 1; + GlobalV::KPAR = 2; + + Parallel_Global::divide_pools(GlobalV::NPROC, + GlobalV::MY_RANK, + PARAM.input.bndpar, + GlobalV::KPAR, + GlobalV::NPROC_IN_BNDGROUP, + GlobalV::RANK_IN_BPGROUP, + GlobalV::MY_BNDGROUP, + GlobalV::NPROC_IN_POOL, + GlobalV::RANK_IN_POOL, + GlobalV::MY_POOL); + ModulePW::PW_Basis* rhopw = new ModulePW::PW_Basis(); + rhopw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD); + rhopw->initgrids(lat0, latvec, 40); + rhopw->initparameters(false, 10); + rhopw->setuptransform(); + charge->rhopw = rhopw; + + const int nz = rhopw->nz; + const int nrxx = rhopw->nrxx; + const int nxy = rhopw->nxy; + const int nplane = rhopw->nplane; + charge->nrxx = nrxx; + charge->kin_r = new double*[1]; + charge->kin_r[0] = new double[nrxx]; + + for (int ir = 0; ir < nxy; ++ir) + { + for (int iz = 0; iz < nplane; ++iz) + { + charge->kin_r[0][nplane * ir + iz] + = (rhopw->startz_current + iz + ir * nz) / double(nxy * nz); + } + } + const double refsum = sum_array(charge->kin_r[0], nrxx); + + charge->init_chgmpi(); + charge->kin_r_mpi(); + const double sum = sum_array(charge->kin_r[0], nrxx); + EXPECT_EQ(sum, refsum * GlobalV::KPAR); + + delete[] charge->kin_r[0]; + delete[] charge->kin_r; + delete rhopw; + XC_Functional::ked_flag = ked_flag_old; + } +} + int main(int argc, char** argv) { MPI_Init(&argc, &argv);