Batched PW EXX optimization#7469
Draft
zhubonan wants to merge 17 commits into
Draft
Conversation
# Conflicts: # source/source_basis/module_pw/pw_transform.cpp
Add the exx_full_q_cache input keyword and default-on explicit full-q reciprocal cache for symmetry-reduced PW EXX. Route q-state loads through cache-aware helpers, keep KPAR owner-local cache construction with broadcasts, and make cache rebuilds lazy across temporary psi swaps. Record MgO64 GPU cache-on/off timing validation for PR notes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Reminder
Linked Issue
Fix #...
Unit Tests and/or Case Tests for my changes
All ABACUS runtime validations below used
OMP_NUM_THREADS=1unless otherwise noted.Build/test checks:
cmake --build /tmp/abacus-batch-exx-pr-cpu-mpi --target abacus_pw_para -j2 cmake --build /tmp/abacus-batch-exx-pr-p0p1-gpu-build --target abacus_pw_gpu -j2 ctest --test-dir build-codex-batch-exx-pr-test -R 'MODULE_PW_basis_pw_k_serial$'CUDA CI-style smoke after the latest FFT metadata fix:
OMP_NUM_THREADS=12 bash ../integrate/Autotest.sh \ -n 2 \ -a /tmp/abacus-ci-full-cuda-noelpa-run-build/abacus_basic_gpu \ -f CASES_GPU.txt \ -r '^scf_bpcg$'Result: passed.
Additional local GPU suites run with
mpirun -np 2equivalent Autotest settings:tests/11_PW_GPUtests/12_NAO_Gamma_GPUtests/13_NAO_multik_GPUtests/15_rtTDDFT_GPUtests/16_SDFT_GPUValidation Inputs
Si8 CPU stress reduced-EXX-grid test:
Q-tile variant:
K-point grid:
Si8 structure: conventional 8-atom Si cell,
LATTICE_CONSTANT 1.889766, lattice vectors5.43090 0 0,0 5.43090 0,0 0 5.43090, with the 8 diamond-cubic fractional positions.GPU timing tests used RTX 5090,
precision=single,K_POINTS Gamma 2 2 2, and PW HSE. The patched develop baseline isofficial/developplus only the private EXX FFT precision fix needed to avoid the known CUDAmemory_op.cufailure in the legacy GPU EXX path.Si8 GPU inputs:
MgO64 GPU input summary:
Q-tile GPU variants additionally used:
Key Timing and Correctness Results
CPU stress, Si8,
mpirun -np 2:ecutexx=4024 24 24,npw=2312-1718.7220983469073872-40.3665181109-240.63656430.23 s3.85 secutexx=4024 24 24,npw=2312-1718.7220983469057956-40.3665181110-240.63656422.36 s1.52 secutexx=10036 36 36,npw=9139-1720.9038698769220446-40.3594664072-239.15867899.54 s11.93 secutexx=10036 36 36,npw=9139-1720.9038698769215898-40.3594664072-239.15867883.04 s5.21 sGPU timing summary:
2x2x2,ecutwfc=50,nbands=320,symmetry=1,exxace=12298 sconstruct_ace 2029.78 s-61302.76778492656 eV,E_exx=-1963.6735019729 eVexx_batch_fft_size=8,exx_full_q_cache=1343.94 sconstruct_ace 307.12 s-61302.76734772023 eV,E_exx=-1963.6731245711 eV2x2x2,ecutwfc=50,nbands=32,symmetry=-1,exxace=118.92 sconstruct_ace 15.10 s-857.2818154074363 eV,E_exx=-45.1454174940 eVexx_batch_fft_size=88.13 sconstruct_ace 4.26 s-857.2818340449073 eV,E_exx=-45.1454128819 eV2x2x2,ecutwfc=50,nbands=32,symmetry=-1,exxace=0729.15 sact_op 592.99 s-857.2821719114121 eV,E_exx=-45.1453093445 eVexx_batch_fft_size=835.84 sact_op_batch 23.61 s;cal_exx_energy_batch 7.98 s-857.2819593448013 eV,E_exx=-45.1455110661 eVGPU q-tile checks on the current branch:
9.60 sconstruct_ace 4.59 s;act_op_batch 4.41 s-857.2818340449073 eV,E_exx=-45.1454128819 eV, gap1.2563273446 eVq=4,band=87.89 sconstruct_ace 3.54 s;act_op_qtile 3.45 s+2.34e-5 eV,E_exxdelta-6.59e-7 eV36.05 sact_op_batch 23.42 s;cal_exx_energy_batch 7.86 s-857.2819593448013 eV,E_exx=-45.1455110661 eV, gap1.2562612512 eVq=4,band=827.66 sact_op_qtile 17.32 s;cal_exx_energy_qtile 6.58 s-2.87e-6 eV,E_exxdelta+4.07e-7 eV345.63 sconstruct_ace 313.24 s;build_full_q_cache 4.36 s-61302.76701807489 eV,E_exx=-1963.6730981641 eV, gap6.3900535284 eVq=2,band=8263.47 sconstruct_ace 231.30 s;q_tile_pair 21.33 s;build_full_q_cache 4.16 s-3.65e-4 eV,E_exxdelta+7.11e-5 eVq=2,band=8262.15 sconstruct_ace 229.93 s;q_tile_pair 21.40 s-4.78e-4 eV,E_exxdelta-3.74e-4 eVKey observations:
6.68xfaster in total wall time and6.61xfaster in ACE construction than patched develop.2.33xfaster in total wall time and3.54xfaster in ACE construction than patched develop.20.35xfaster in total wall time and25.12xfaster in Hamiltonian EXX apply than patched develop.24%versus the corresponding no-q-tile branch path.3499.14 MBfor MgO64 cache-on and0 MBfor cache-off.What's changed?
This PR ports and completes the batched/q-tile PW EXX implementation for the
developbranch.Main behavior changes:
exx_batch_fft_size=8.exx_separate_loop=1; no-ACE KPAR is blocked for now.ecutexx > 0with separate EXX reciprocal and real-space FFT grids.ecutexx=0keeps the previous effectiveecutrhobehavior.poolnproc > 1redistribution between wavefunction and EXX grids.poolnproc > 1, because GPU PW FFT does not support intra-pool MPI distribution and would be very slow/unsupported.(k,q)potentials.exx_full_q_cache, default1, for symmetry-reduced PW EXX. This materializes explicit full-q reciprocal wavefunctions to avoid repeated symmetry remaps. Users can setexx_full_q_cache 0for the lower-memory reduced-q remap-on-demand path.Q-tile layout
Compared with
develop, which applies EXX one target(k,n)and one source(q,m)state at a time, this PR groups real-space EXX work into tiles:[n_local][ir], controlled byexx_band_tile_size;[q_local][m_local][ir], controlled byexx_q_tile_size * exx_band_tile_size;[q_local][m_local];Parallel_Common::bcast_dev.For symmetry-reduced runs,
exx_full_q_cache 1keeps the reduced k-point SCF problem but stores explicit full-q EXX wavefunctions. This trades memory for less repeated symmetry rotation/remap work.exx_full_q_cache 0keeps the memory-saving reduced-q path.Batched FFT implementation
This PR adds batch-aware FFT setup to
PW_BasisandPW_Basis_Kthroughsetuptransform(batch_fft_size), while keeping the defaultsetuptransform()behavior equivalent to the existing single-transform path.The new transform helpers accept contiguous batches of reciprocal or real-space states:
PW_Basis::recip_to_real_batch/PW_Basis::real_to_recip_batchfor charge and density-like grids;PW_Basis_K::recip_to_real_batch/PW_Basis_K::real_to_recip_batchfor k-dependent wavefunction grids.PW EXX uses these helpers to transform source wavefunction tiles, density products, and EXX energy batches in one batched FFT call instead of launching one FFT per band pair. Batch FFT setup is currently scoped to EXX-owned bases:
rhopw_dev,wfcpw_exx, andwfcpw_exx_fullq. Standard PW DFT bases still use the default setup path in this PR. The infrastructure may be reused later for standard PW DFT calculations, but that should be a separate benchmarked follow-up.Any changes of core modules? (ignore if not applicable)
Core PW modules are changed.
OperatorEXXPWand related PW EXX kernels now support batched source/state transforms, q-tile apply/energy, fullecutexxgrid honor, full-q cache loads, and ACE KPAR q-tile communication.PW_BasisandPW_Basis_Kgain batch transform setup and batch reciprocal/real transform helpers. The default non-batch call path remains available.exx_full_q_cache,exx_batch_fft_size,exx_use_q_tile,exx_band_tile_size,exx_q_tile_size, andecutexxdocumentation/echo.