Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,13 @@ eval_exc_grad_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
}
}

// Now sort by workload
std::sort( tasks.begin(), tasks.end(), task_comparator );
// Normally task_comparator is used for sorting tasks, but for OneDFT
// we want to keep the tasks in their ordering by iParent (atom) to make sure model derivatives
// match up well with feature derivatives.
(void)task_comparator; // Unused

// Build concatenated eps_w array (eps_on_grid * weights) in task order after sorting
// This will be used for the weight derivative
// Build concatenated eps_w array (per-gridpoint exc energy density, eps_on_grid * weights) in
// task order. This will be used for the weight derivative
std::vector<double> eps_w_all;
eps_w_all.reserve(total_npts);
for (const auto& task : tasks) {
Expand Down
75 changes: 75 additions & 0 deletions tests/onedft_test.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,81 @@ TEST_CASE( "OneDFT", "[onedft]" ) {
}
}

#if defined(GAUXC_HAS_HOST) && defined(GAUXC_HAS_DEVICE)
void test_onedft_grad_host_device( std::string reference_file,
std::string onedft_model_path ) {

using matrix_type = Eigen::MatrixXd;
Molecule mol;
BasisSet<double> basis;
read_hdf5_record( mol, reference_file, "/MOLECULE" );
read_hdf5_record( basis, reference_file, "/BASIS" );

HighFive::File file( reference_file, HighFive::File::ReadOnly );
auto dset = file.getDataSet( "/DENSITY" );
auto dims = dset.getDimensions();
matrix_type P( dims[0], dims[1] );
dset.read( P.data() );

// We can only call OneDFT with UKS (two channels) for now, so we create a dummy Pz matrix to
// satisfy the interface.
matrix_type Ps = P;
matrix_type Pz = matrix_type::Zero( dims[0], dims[1] );

auto mg = MolGridFactory::create_default_molgrid( mol, PruningScheme::Unpruned,
BatchSize(512), RadialQuad::MuraKnowles, AtomicGridSizeDefault::UltraFineGrid );

functional_type func = functional_type( ExchCXX::Backend::builtin,
ExchCXX::Functional::PBE0, ExchCXX::Spin::Unpolarized );

OneDFTSettings onedft_settings;
onedft_settings.model = onedft_model_path;

#ifdef GAUXC_HAS_DEVICE
auto rt = DeviceRuntimeEnvironment( GAUXC_MPI_CODE(MPI_COMM_WORLD,) 0.9 );
#else
auto rt = RuntimeEnvironment( GAUXC_MPI_CODE(MPI_COMM_WORLD) );
#endif

auto eval_grad = [&]( ExecutionSpace ex ) {
LoadBalancerFactory lb_factory( ex, "Default" );
auto lb = lb_factory.get_instance( rt, mol, mg, basis );
MolecularWeightsFactory mw_factory( ex, "Default", MolecularWeightsSettings{} );
auto mw = mw_factory.get_instance();
mw.modify_weights( lb );
XCIntegratorFactory<matrix_type> integrator_factory( ex, "Replicated",
"Default", "Default", "Default" );
auto integrator = integrator_factory.get_instance( func, lb );
return integrator.eval_exc_grad_onedft( Ps, Pz, onedft_settings );
};

auto host_grad = eval_grad( ExecutionSpace::Host );
auto dev_grad = eval_grad( ExecutionSpace::Device );

const size_t n = 3 * mol.size();
REQUIRE( host_grad.size() == n );
REQUIRE( dev_grad.size() == n );

// We want a non-zero gradient to make sure we have something to compare against.
double host_squared_norm = 0.0;
for( size_t i = 0; i < n; ++i ) host_squared_norm += host_grad[i] * host_grad[i];
CHECK( std::sqrt( host_squared_norm ) > 1e-3 );

// The device gradient must match the host gradient component-wise.
for( size_t i = 0; i < n; ++i ) {
CHECK( dev_grad[i] == Approx( host_grad[i] ).margin( 1e-6 ) );
}
}

TEST_CASE( "OneDFT EXC Gradient", "[onedft][grad]" ) {
SECTION( " H2O2 / def2-tzvp / tpss.fun" ) {
test_onedft_grad_host_device(
GAUXC_REF_DATA_PATH "/h2o2_def2-tzvp.hdf5",
GAUXC_ONEDFT_MODEL_PATH "/tpss.fun" );
}
}
#endif

#include <gauxc/xc_integrator/replicated/impl.hpp>
// Include the OneDFT utility header for reorder helper functions
#include "../../src/xc_integrator/integrator_util/onedft_util.hpp"
Expand Down