From 2f235dad421262b08948ebeb900e1d7ccba0755b Mon Sep 17 00:00:00 2001 From: Thijs Vogels Date: Thu, 25 Jun 2026 08:16:32 +0000 Subject: [PATCH] Fix device OneDFT gradient The OneDFT GPU gradients gave incorrect results due to a mismatch in task ordering. GauXC splits grid points in batches (tasks). The OneDFT code path sorts these by atom, because the neural network treats atomic grids as units that belong together. When computing nuclear gradients, we use pytorch to compute gradients w.r.t. model input features (density, kinetic energy density, ...) and combine them with gradients of those input features w.r.t. the coordinates. When those are combined, the ordering still needs to match up. So this PR restores correct ordering + adds a test comparing host gradients to device gradients. This test failed before this PR. --- ...replicated_xc_device_integrator_onedft.hpp | 10 ++- tests/onedft_test.cxx | 75 +++++++++++++++++++ 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp index 387f0bd4..8b7c1277 100644 --- a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_onedft.hpp @@ -1013,11 +1013,13 @@ eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps, } } - // Now sort by workload - std::sort( tasks.begin(), tasks.end(), task_comparator ); + // Normally task_comparator is used for sorting tasks, but for OneDFT + // we want to keep the tasks in their ordering by iParent (atom) to make sure model derivatives + // match up well with feature derivatives. + (void)task_comparator; // Unused - // Build concatenated eps_w array (eps_on_grid * weights) in task order after sorting - // This will be used for the weight derivative + // Build concatenated eps_w array (per-gridpoint exc energy density, eps_on_grid * weights) in + // task order. This will be used for the weight derivative std::vector eps_w_all; eps_w_all.reserve(total_npts); for (const auto& task : tasks) { diff --git a/tests/onedft_test.cxx b/tests/onedft_test.cxx index f211297f..df18ff4e 100644 --- a/tests/onedft_test.cxx +++ b/tests/onedft_test.cxx @@ -151,6 +151,81 @@ TEST_CASE( "OneDFT", "[onedft]" ) { } } +#if defined(GAUXC_HAS_HOST) && defined(GAUXC_HAS_DEVICE) +void test_onedft_grad_host_device( std::string reference_file, + std::string onedft_model_path ) { + + using matrix_type = Eigen::MatrixXd; + Molecule mol; + BasisSet basis; + read_hdf5_record( mol, reference_file, "/MOLECULE" ); + read_hdf5_record( basis, reference_file, "/BASIS" ); + + HighFive::File file( reference_file, HighFive::File::ReadOnly ); + auto dset = file.getDataSet( "/DENSITY" ); + auto dims = dset.getDimensions(); + matrix_type P( dims[0], dims[1] ); + dset.read( P.data() ); + + // We can only call OneDFT with UKS (two channels) for now, so we create a dummy Pz matrix to + // satisfy the interface. + matrix_type Ps = P; + matrix_type Pz = matrix_type::Zero( dims[0], dims[1] ); + + auto mg = MolGridFactory::create_default_molgrid( mol, PruningScheme::Unpruned, + BatchSize(512), RadialQuad::MuraKnowles, AtomicGridSizeDefault::UltraFineGrid ); + + functional_type func = functional_type( ExchCXX::Backend::builtin, + ExchCXX::Functional::PBE0, ExchCXX::Spin::Unpolarized ); + + OneDFTSettings onedft_settings; + onedft_settings.model = onedft_model_path; + +#ifdef GAUXC_HAS_DEVICE + auto rt = DeviceRuntimeEnvironment( GAUXC_MPI_CODE(MPI_COMM_WORLD,) 0.9 ); +#else + auto rt = RuntimeEnvironment( GAUXC_MPI_CODE(MPI_COMM_WORLD) ); +#endif + + auto eval_grad = [&]( ExecutionSpace ex ) { + LoadBalancerFactory lb_factory( ex, "Default" ); + auto lb = lb_factory.get_instance( rt, mol, mg, basis ); + MolecularWeightsFactory mw_factory( ex, "Default", MolecularWeightsSettings{} ); + auto mw = mw_factory.get_instance(); + mw.modify_weights( lb ); + XCIntegratorFactory integrator_factory( ex, "Replicated", + "Default", "Default", "Default" ); + auto integrator = integrator_factory.get_instance( func, lb ); + return integrator.eval_exc_grad_onedft( Ps, Pz, onedft_settings ); + }; + + auto host_grad = eval_grad( ExecutionSpace::Host ); + auto dev_grad = eval_grad( ExecutionSpace::Device ); + + const size_t n = 3 * mol.size(); + REQUIRE( host_grad.size() == n ); + REQUIRE( dev_grad.size() == n ); + + // We want a non-zero gradient to make sure we have something to compare against. + double host_squared_norm = 0.0; + for( size_t i = 0; i < n; ++i ) host_squared_norm += host_grad[i] * host_grad[i]; + CHECK( std::sqrt( host_squared_norm ) > 1e-3 ); + + // The device gradient must match the host gradient component-wise. + for( size_t i = 0; i < n; ++i ) { + CHECK( dev_grad[i] == Approx( host_grad[i] ).margin( 1e-6 ) ); + } +} + +TEST_CASE( "OneDFT EXC Gradient", "[onedft][grad]" ) { + SECTION( " H2O2 / def2-tzvp / tpss.fun" ) { + test_onedft_grad_host_device( + GAUXC_REF_DATA_PATH "/h2o2_def2-tzvp.hdf5", + GAUXC_ONEDFT_MODEL_PATH "/tpss.fun" ); + } +} +#endif + #include // Include the OneDFT utility header for reorder helper functions #include "../../src/xc_integrator/integrator_util/onedft_util.hpp"