diff --git a/cmake/gmxManageMetatomic.cmake b/cmake/gmxManageMetatomic.cmake index 49e1bb2a2b..e4a60c663f 100644 --- a/cmake/gmxManageMetatomic.cmake +++ b/cmake/gmxManageMetatomic.cmake @@ -111,9 +111,6 @@ if(NOT GMX_METATOMIC STREQUAL "OFF") set(METATOMIC_TORCH_VERSION "0.1.7") set(METATOMIC_TORCH_SHA256 "726f5711b70c4b8cc80d9bc6c3ce6f3449f31d20acc644ab68dab083aa4ea572") - set(VESIN_VERSION "0.4.1") - set(VESIN_GIT_TAG "87dcad999fec47b29ab21be9662ef283edc7530b") - set(DOWNLOAD_METATENSOR_DEFAULT ON) find_package(metatensor_torch ${METATENSOR_TORCH_VERSION} QUIET) if (metatensor_torch_FOUND) @@ -168,23 +165,7 @@ if(NOT GMX_METATOMIC STREQUAL "OFF") find_package(metatomic_torch REQUIRED ${METATOMIC_TORCH_VERSION}) endif() - # always fetch vesin - FetchContent_Declare( - vesin - GIT_REPOSITORY https://github.com/Luthaf/vesin.git - GIT_TAG ${VESIN_GIT_TAG} - ) - - FetchContent_MakeAvailable(vesin) - install(TARGETS vesin EXPORT libgromacs - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} - ) - list(APPEND GMX_COMMON_LIBRARIES - vesin metatensor metatomic_torch metatensor_torch diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index af5b8af9a1..5fcc5a909d 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -34,7 +34,7 @@ /*! \internal \file * \brief * Implements the Metatomic Force Provider class with proper domain decomposition support. - * + * * \author Metatensor developers * \ingroup module_applied_forces */ @@ -43,130 +43,114 @@ #include "metatomic_forceprovider.h" #include -#include #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" +#include "gromacs/mdrunutility/mdmodulesnotifiers.h" #include "gromacs/mdtypes/enerdata.h" #include "gromacs/mdtypes/forceoutput.h" +#include "gromacs/pbcutil/ishift.h" #include "gromacs/utility/arrayref.h" #include "gromacs/utility/exceptions.h" #include "gromacs/utility/logger.h" #include "gromacs/utility/mpicomm.h" +#include "gromacs/utility/stringutil.h" #ifdef DIM # undef DIM #endif -#include #include #include +#if GMX_GPU_CUDA || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_CUDA_TARGET) +# include +#endif + -static torch::Tensor preparePbcType(PbcType* pbcType) +namespace gmx { - torch::Tensor pbcTensor = - torch::tensor({ true, true, true }, torch::TensorOptions().dtype(torch::kBool)); - if (*pbcType == PbcType::XY) - { - pbcTensor[2] = false; - } - else if (*pbcType != PbcType::Xyz) + +static std::optional indexOf(ArrayRef vec, const int32_t val) +{ + auto it = std::find(vec.begin(), vec.end(), val); + if (it == vec.end()) { - GMX_THROW(gmx::InconsistentInputError( - "Option use_pbc was set to true, but PBC type is not supported.")); + return std::nullopt; } - return pbcTensor; + return std::distance(vec.begin(), it); } -static metatensor_torch::TensorBlock computeNeighbors(metatomic_torch::NeighborListOptions request, - long n_atoms, - const float* positions, - const matrix box, - bool periodic, - torch::Device device, - torch::ScalarType dtype) +static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) { - auto cutoff = request->engine_cutoff("nm"); + auto options = torch::TensorOptions().dtype(torch::kBool).device(device); - VesinOptions options; - options.cutoff = cutoff; - options.full = request->full_list(); - options.return_shifts = true; - options.return_distances = false; - options.return_vectors = true; - - VesinNeighborList* vesin_neighbor_list = new VesinNeighborList(); - - double double_box[3][3]; - for (int i = 0; i < 3; i++) + if (*pbcType == PbcType::XY) { - for (int j = 0; j < 3; j++) - { - double_box[i][j] = static_cast(box[i][j]); - } + return torch::tensor({ true, true, false }, options); } - - const size_t total_elements = static_cast(n_atoms) * 3; - std::vector double_positions(total_elements); - - for (size_t i = 0; i < total_elements; i++) + else if (*pbcType == PbcType::No) { - double_positions[i] = static_cast(positions[i]); + return torch::tensor({ false, false, false }, options); } - const double* positions_ptr = double_positions.data(); - - VesinDevice cpu{ VesinCPU, 0 }; - const char* error_message = nullptr; - int status = vesin_neighbors(reinterpret_cast(positions_ptr), - static_cast(n_atoms), - double_box, - &periodic, - cpu, - options, - vesin_neighbor_list, - &error_message); - - if (status != EXIT_SUCCESS) + else if (*pbcType != PbcType::Xyz) { - std::string err_str = "vesin_neighbors failed: "; - if (error_message) - { - err_str += error_message; - } - delete vesin_neighbor_list; - GMX_THROW(gmx::APIError(err_str)); + GMX_THROW(InconsistentInputError("PBC type not supported.")); } + return torch::tensor({ true, true, true }, options); +} + +static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, + ArrayRef shiftVectors, + ArrayRef cellShifts, + ArrayRef positions, + torch::Device device, + torch::ScalarType dtype) +{ + const int64_t n_pairs = static_cast(pairlist.size() / 2); + + auto cpu_int_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + auto cpu_float_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); - auto n_pairs = static_cast(vesin_neighbor_list->length); + auto pair_samples_values = torch::zeros({ n_pairs, 5 }, cpu_int_options); + auto pair_samples_ptr = pair_samples_values.accessor(); + + // Full interatomic vectors (rj - ri + shift) + auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, cpu_float_options); + auto vectors_accessor = vectors_cpu.accessor(); - auto pair_samples_values = torch::empty({ n_pairs, 5 }, torch::TensorOptions().dtype(torch::kInt32)); - auto pair_samples_ptr = pair_samples_values.accessor(); for (int64_t i = 0; i < n_pairs; i++) { - pair_samples_ptr[i][0] = static_cast(vesin_neighbor_list->pairs[i][0]); - pair_samples_ptr[i][1] = static_cast(vesin_neighbor_list->pairs[i][1]); - pair_samples_ptr[i][2] = vesin_neighbor_list->shifts[i][0]; - pair_samples_ptr[i][3] = vesin_neighbor_list->shifts[i][1]; - pair_samples_ptr[i][4] = vesin_neighbor_list->shifts[i][2]; + int32_t atom_i = pairlist[2 * i]; + int32_t atom_j = pairlist[2 * i + 1]; + + // Access IVec elements (integers) + pair_samples_ptr[i][0] = static_cast(atom_i); + pair_samples_ptr[i][1] = static_cast(atom_j); + pair_samples_ptr[i][2] = cellShifts[i][0]; + pair_samples_ptr[i][3] = cellShifts[i][1]; + pair_samples_ptr[i][4] = cellShifts[i][2]; + + // Calculate r_ij = r_j - r_i + shift + double r_ij_x = + static_cast(positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]); + double r_ij_y = + static_cast(positions[atom_j][1] - positions[atom_i][1] + shiftVectors[i][1]); + double r_ij_z = + static_cast(positions[atom_j][2] - positions[atom_i][2] + shiftVectors[i][2]); + + vectors_accessor[i][0][0] = r_ij_x; + vectors_accessor[i][1][0] = r_ij_y; + vectors_accessor[i][2][0] = r_ij_z; } - auto deleter = [=](void*) - { - vesin_free(vesin_neighbor_list); - delete vesin_neighbor_list; - }; - - auto pair_vectors = torch::from_blob(vesin_neighbor_list->vectors, - { n_pairs, 3, 1 }, - deleter, - torch::TensorOptions().dtype(torch::kFloat64)); - pair_vectors.to(dtype); + auto final_samples_values = pair_samples_values.to(device); + auto final_vectors = vectors_cpu.to(dtype).to(device); auto neighbor_samples = torch::make_intrusive( std::vector{ "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c" }, - pair_samples_values.to(device)); + final_samples_values); auto neighbor_component = torch::make_intrusive( std::vector{ "xyz" }, @@ -178,14 +162,12 @@ static metatensor_torch::TensorBlock computeNeighbors(metatomic_torch::NeighborL torch::zeros({ 1, 1 }, torch::TensorOptions().dtype(torch::kInt32).device(device))); return torch::make_intrusive( - pair_vectors.to(dtype).to(device), + final_vectors, neighbor_samples, std::vector{ neighbor_component }, neighbor_properties); } -namespace gmx -{ struct MetatomicData { @@ -193,9 +175,9 @@ struct MetatomicData metatomic_torch::ModelCapabilities capabilities; std::vector nl_requests; metatomic_torch::ModelEvaluationOptions evaluations_options; - torch::ScalarType dtype; - bool check_consistency; - torch::Device device = torch::kCPU; + torch::ScalarType dtype = torch::kFloat32; + bool check_consistency = true; + torch::Device device = torch::kCPU; }; MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, @@ -207,314 +189,383 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, box_{ { 0.0, 0.0, 0.0 }, { 0.0, 0.0, 0.0 }, { 0.0, 0.0, 0.0 } }, data_(std::make_unique()) { - // ALL ranks load the model, not just the main rank - // This enables each rank to compute forces for its local atoms independently - - GMX_LOG(logger_.info).asParagraph().appendText("Initializing MetatomicForceProvider on all ranks..."); + GMX_LOG(logger_.info).asParagraph().appendText("Initializing MetatomicForceProvider..."); - // Load the model on EVERY rank - try + // Pairlist-based neighbor lists don't work with domain decomposition yet (indices are local) + // Matches NNPot's limitation + if (mpiComm_.isParallel()) { - torch::optional extensions_directory = torch::nullopt; - if (!options_.params_.extensionsDirectory.empty()) + GMX_THROW(NotImplementedError( + "Metatomic does not yet support domain decomposition. " + "Use thread-MPI (gmx mdrun) instead of MPI (mpirun gmx_mpi mdrun).")); + } + + // Only main rank loads model + if (mpiComm_.isMainRank()) + { + try + { + torch::optional extensions_directory = torch::nullopt; + if (!options_.params_.extensionsDirectory.empty()) + { + extensions_directory = options_.params_.extensionsDirectory; + } + + data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, + extensions_directory); + } + catch (const std::exception& e) { - extensions_directory = options_.params_.extensionsDirectory; + GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); } - this->data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, - extensions_directory); - } - catch (const std::exception& e) - { - GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); - } + data_->capabilities = + data_->model.run_method("capabilities").toCustomClass(); - // Query model capabilities on all ranks - data_->capabilities = - data_->model.run_method("capabilities").toCustomClass(); - auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); - for (const auto& request_ivalue : requests_ivalue.toList()) - { - data_->nl_requests.push_back( - request_ivalue.get().toCustomClass()); - } + // Determine device using capabilities and optional environment variable + torch::optional desiredDevice = torch::nullopt; + if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) + { + desiredDevice = std::string(env); + } + + const auto deviceType = + metatomic_torch::pick_device(data_->capabilities->supported_devices, desiredDevice); + data_->device = torch::Device(deviceType); - // Determine device - each rank picks its own device - torch::optional desired; - if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) { GMX_LOG(logger_.info) - .asParagraph() - .appendText("Using device from GMX_METATOMIC_DEVICE environment variable: ") - .appendText(env); - desired = std::string(env); - } else { - desired = options_.params_.device; - } + .asParagraph() + .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); - c10::DeviceType device_type_ = - metatomic_torch::pick_device(data_->capabilities->supported_devices, desired); - data_->device = torch::Device(device_type_); + auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); + for (const auto& request_ivalue : requests_ivalue.toList()) + { + data_->nl_requests.push_back( + request_ivalue.get().toCustomClass()); + } - data_->model.to(data_->device); + data_->model.to(data_->device); - // Set data type on all ranks - if (data_->capabilities->dtype() == "float64") - { - data_->dtype = torch::kFloat64; - } - else if (data_->capabilities->dtype() == "float32") - { - data_->dtype = torch::kFloat32; - } - else - { - GMX_THROW(APIError("Unsupported dtype from model: " + data_->capabilities->dtype())); - } + if (data_->capabilities->dtype() == "float64") + { + data_->dtype = torch::kFloat64; + } + else if (data_->capabilities->dtype() == "float32") + { + data_->dtype = torch::kFloat32; + } + else + { + GMX_THROW(APIError("Unsupported dtype: " + data_->capabilities->dtype())); + } - // Set up evaluation options on all ranks - data_->evaluations_options = - torch::make_intrusive(); - data_->evaluations_options->set_length_unit("nm"); + data_->evaluations_options = + torch::make_intrusive(); + data_->evaluations_options->set_length_unit("nm"); - auto outputs = data_->capabilities->outputs(); - if (!outputs.contains("energy")) - { - GMX_THROW(APIError("Metatomic model must provide an 'energy' output.")); - } + auto outputs = data_->capabilities->outputs(); + if (!outputs.contains("energy")) + { + GMX_THROW(APIError("Metatomic model must provide 'energy' output.")); + } - auto requested_output = torch::make_intrusive(); - requested_output->per_atom = true; // KEY: Request per-atom energies for proper DD - requested_output->explicit_gradients = {}; + auto requested_output = torch::make_intrusive(); + requested_output->per_atom = false; + requested_output->explicit_gradients = {}; - data_->evaluations_options->outputs.insert("energy", requested_output); + data_->evaluations_options->outputs.insert("energy", requested_output); + data_->check_consistency = options_.params_.checkConsistency; + } - data_->check_consistency = options_.params_.checkConsistency; + const auto& mtaIndices = options_.params_.mtaIndices_; + const int32_t n_atoms = static_cast(mtaIndices.size()); - // Initialize lookup tables - will be populated on first DD - const auto& mtaIndices = options_.params_.mtaIndices_; - const int n_atoms = static_cast(mtaIndices.size()); - idxLookup_.resize(n_atoms, -1); + positions_.resize(n_atoms); atomNumbers_.resize(n_atoms, 0); + inputToLocalIndex_.resize(n_atoms, -1); + inputToGlobalIndex_.resize(n_atoms, -1); GMX_LOG(logger_.info) .asParagraph() - .appendText("MetatomicForceProvider initialization complete on all ranks."); + .appendText("MetatomicForceProvider initialization complete."); } -void MetatomicForceProvider::gatherAtomNumbersIndices() +MetatomicForceProvider::~MetatomicForceProvider() = default; + +/*! \brief Updates the mapping between GROMACS local/global atom indices and the Metatomic model's input atoms. + * + * This function is subscribed to the `MDModulesAtomsRedistributedSignal`. It is called whenever + * atoms are redistributed across MPI ranks (Domain Decomposition) or reordered in memory (sorting). + * + * Its primary responsibilities are: + * 1. **Locate Input Atoms**: It iterates through the local atoms on the current rank to find + * which atoms correspond to the "input atoms" defined for the Metatomic model (via `mtaIndices`). + * 2. **Update Maps**: It populates `inputToLocalIndex_` (mapping model input index -> GROMACS local index) + * and `inputToGlobalIndex_` (mapping model input index -> GROMACS global tag). + * 3. **Gather Atomic Numbers**: It ensures `atomNumbers_` (Z numbers) are correctly associated with + * the current local atoms, performing an MPI reduction if necessary to gather data from + * distributed ranks. + * + * \param[in] signal Contains the new mapping of global atom indices to local buffer indices after redistribution. + */ +void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal) { - // This function updates the lookup tables after domain decomposition - // Each rank knows which ML atoms are local to it - - const auto& mtaIndices = options_.params_.mtaIndices_; - const int n_atoms = static_cast(mtaIndices.size()); - - // Reset lookup table - std::fill(idxLookup_.begin(), idxLookup_.end(), -1); - - // Build reverse lookup: global index -> ML group index - std::unordered_map globalToMlIndex; - globalToMlIndex.reserve(n_atoms); - for (int i = 0; i < n_atoms; ++i) - { - globalToMlIndex[mtaIndices[i]] = i; - } + const auto& mtaIndices = options_.params_.mtaIndices_; + const int32_t numInput = static_cast(mtaIndices.size()); - // Populate lookup for this rank's local atoms - const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); - for (size_t i = 0; i < mtaAtoms->numAtomsLocal(); ++i) + inputToLocalIndex_.assign(numInput, -1); + inputToGlobalIndex_.assign(numInput, -1); + atomNumbers_.assign(numInput, 0); + + if (mpiComm_.isParallel()) { - const int lIdx = mtaAtoms->localIndex()[i]; - const int gIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), + "Global atom indices required for domain decomposition."); + auto globalAtomIndices = signal.globalAtomIndices_.value(); + const int32_t numLocal = signal.x_.size(); - if (auto it = globalToMlIndex.find(gIdx); it != globalToMlIndex.end()) + for (int32_t i = 0; i < static_cast(globalAtomIndices.size()); i++) { - const int mlIdx = it->second; - idxLookup_[mlIdx] = lIdx; - atomNumbers_[mlIdx] = options_.params_.atoms_.atom[gIdx].atomnumber; + int32_t globalIdx = globalAtomIndices[i]; + for (int32_t j = 0; j < numInput; j++) + { + if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) + { + if (i < numLocal) + { + inputToLocalIndex_[j] = i; + inputToGlobalIndex_[j] = globalIdx; + atomNumbers_[j] = options_.params_.atoms_.atom[globalIdx].atomnumber; + } + break; + } + } } + mpiComm_.sumReduce(numInput, atomNumbers_.data()); } - - // For parallel runs, we need all ranks to know ALL atom numbers for the system tensor - // But we only compute forces for local atoms - // XXX: seems buggy in parallel, need more checks - if (mpiComm_.isParallel()) + else { - // Each rank has partial atomNumbers_, sum to get complete list - // This is needed because the System object needs all atom types - mpiComm_.sumReduce(gmx::ArrayRef(atomNumbers_.data(), atomNumbers_.data() + n_atoms)); + const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); + for (int32_t i = 0; i < numInput; i++) + { + int32_t localIndex = mtaAtoms->localIndex()[i]; + int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + inputToLocalIndex_[i] = localIndex; + inputToGlobalIndex_[i] = globalIdx; + atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; + } } -} -MetatomicForceProvider::~MetatomicForceProvider() = default; + GMX_RELEASE_ASSERT(std::count(atomNumbers_.begin(), atomNumbers_.end(), 0) == 0, + "Some atom numbers not set."); +} -void MetatomicForceProvider::gatherAtomPositions(ArrayRef globalPositions) +/*! \brief Updates the internal position buffer with the coordinates of the Metatomic atoms. + * + * This function extracts the coordinates of the atoms relevant to the Metatomic model + * from the full GROMACS local atom array (`pos`). + * + * 1. **Filtering:** `pos` contains all local atoms (solute, solvent, ions). + * This function copies only the atoms defined in `mtaIndices` to `positions_`. + * 2. **Ordering/Packing:** GROMACS reorders atoms dynamically for Domain Decomposition. + * This function uses `inputToLocalIndex_` to collect atoms and pack them into a contiguous, ordered buffer + * + * \param[in] pos The array of all local atom coordinates for the current step. + */ +void MetatomicForceProvider::gatherAtomPositions(ArrayRef pos) { - const int n_atoms = static_cast(options_.params_.mtaIndices_.size()); - positions_.assign(n_atoms, RVec{ 0.0, 0.0, 0.0 }); + const size_t numInput = inputToLocalIndex_.size(); + positions_.assign(numInput, RVec({ 0.0, 0.0, 0.0 })); - // Each rank fills its local atoms' positions - for (int i = 0; i < n_atoms; ++i) + for (size_t i = 0; i < numInput; i++) { - if (idxLookup_[i] != -1) + if (inputToLocalIndex_[i] != -1) { - positions_[i] = globalPositions[idxLookup_[i]]; + positions_[i] = pos[inputToLocalIndex_[i]]; } } - // Sum-reduce positions so all ranks have complete position array - // This is needed because neighbor list computation needs all positions if (mpiComm_.isParallel()) { - real* data_ptr = reinterpret_cast(positions_.data()); - const size_t n_reals = static_cast(n_atoms) * 3; - mpiComm_.sumReduce(gmx::ArrayRef(data_ptr, data_ptr + n_reals)); + mpiComm_.sumReduce(3 * numInput, positions_.data()->as_vec()); } } -void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) +void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& signal) { - const int n_atoms = static_cast(options_.params_.mtaIndices_.size()); - - // Gather positions (all ranks need this for neighbor list computation) - this->gatherAtomPositions(inputs.x_); - copy_mat(inputs.box_, box_); + fullPairlist_.assign(signal.excludedPairlist_.begin(), signal.excludedPairlist_.end()); + doPairlist_ = true; +} - auto gromacs_scalar_type = torch::kFloat32; - if (std::is_same_v) +void MetatomicForceProvider::preparePairlistInput() +{ + if (!doPairlist_) { - gromacs_scalar_type = torch::kFloat64; + return; } - auto blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(torch::kCPU); - // ALL ranks run the model, each computing forces for its local atoms - auto coerced_positions = makeArrayRef(positions_); + GMX_ASSERT(!fullPairlist_.empty(), "Pairlist empty!"); - auto torch_positions = - torch::from_blob(coerced_positions.data()->as_vec(), { n_atoms, 3 }, blob_options) - .to(data_->dtype) - .to(data_->device) - .set_requires_grad(true); + const int32_t numPairs = gmx::ssize(fullPairlist_); + pairlistForModel_.clear(); + pairlistForModel_.reserve(2 * numPairs); + shiftVectors_.clear(); + shiftVectors_.reserve(numPairs); + cellShifts_.clear(); + cellShifts_.reserve(numPairs); - auto torch_cell = - torch::from_blob(&box_, { 3, 3 }, blob_options).to(data_->dtype).to(data_->device); + for (int32_t i = 0; i < numPairs; i++) + { + const auto [atomPair, shiftIndex] = fullPairlist_[i]; - auto torch_pbc = preparePbcType(options_.params_.pbcType_.get()).to(data_->device); - auto torch_types = - torch::tensor(atomNumbers_, torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); + auto inputIdxA = indexOf(inputToGlobalIndex_, atomPair.first); + auto inputIdxB = indexOf(inputToGlobalIndex_, atomPair.second); - auto system = torch::make_intrusive( - torch_types, torch_positions, torch_cell, torch_pbc); + if (inputIdxA.has_value() && inputIdxB.has_value()) + { + RVec shift; + const IVec unitShift = shiftIndexToXYZ(shiftIndex); + mvmul_ur0(box_, unitShift.toRVec(), shift); + + pairlistForModel_.push_back(static_cast(inputIdxA.value())); + pairlistForModel_.push_back(static_cast(inputIdxB.value())); + shiftVectors_.push_back(shift); + cellShifts_.push_back(unitShift); + } + } - bool periodic = torch::all(torch_pbc).item(); + GMX_RELEASE_ASSERT(pairlistForModel_.size() == shiftVectors_.size() * 2, + "Pairlist/shift size mismatch."); + doPairlist_ = false; +} - // Compute neighbor lists on each rank - for (const auto& request : data_->nl_requests) - { - auto neighbors = computeNeighbors( - request, n_atoms, coerced_positions.data()->as_vec(), box_, periodic, data_->device, data_->dtype); - metatomic_torch::register_autograd_neighbors(system, neighbors, false); - system->add_neighbor_list(request, neighbors); - } +void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) +{ + const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); - // Run the model on EVERY rank - metatensor_torch::TensorMap output_map; - try - { - std::vector systems; - systems.push_back(system); + gatherAtomPositions(inputs.x_); + copy_mat(inputs.box_, box_); + preparePairlistInput(); - auto ivalue_output = data_->model.forward( - { c10::IValue(systems), data_->evaluations_options, data_->check_consistency }); - auto dict_output = ivalue_output.toGenericDict(); - output_map = dict_output.at("energy").toCustomClass(); - } - catch (const std::exception& e) - { - GMX_THROW(APIError("[MetatomicPotential] Model evaluation failed: " + std::string(e.what()))); - } + // Force tensor - main rank fills, others have zeros + torch::Tensor forceTensor = + torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64).device(data_->device)); - // Extract per-atom energies - auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); - auto energy_tensor = energy_block->values(); + // Virial tensor for pressure/stress calculations + torch::Tensor virialTensor = torch::zeros({ 3, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); - // Handle energy for domain decomposition - // If per_atom=true, we get per-atom energies [n_atoms, 1] - // So sum over local atoms' energies - double local_energy = 0.0; - - if (energy_tensor.dim() == 2 && energy_tensor.size(0) == n_atoms) + if (mpiComm_.isMainRank()) { - // Per-atom energies - sum only local atoms - auto energy_cpu = energy_tensor.to(torch::kCPU).to(torch::kFloat64); - auto energy_accessor = energy_cpu.accessor(); - - for (int i = 0; i < n_atoms; ++i) + auto gromacs_scalar_type = torch::kFloat32; + if (std::is_same_v) { - if (idxLookup_[i] != -1) // Only count local atoms - { - local_energy += energy_accessor[i][0]; - } + gromacs_scalar_type = torch::kFloat64; } - } - else - { - // Scalar energy - only main rank contributes (or divide among ranks) - if (mpiComm_.isMainRank()) + auto blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(data_->device); + + auto torch_positions = torch::from_blob(positions_.data()->as_vec(), { n_atoms, 3 }, blob_options) + .to(data_->dtype) + .to(data_->device) + .set_requires_grad(true); + + auto torch_cell = + torch::from_blob(&box_, { 3, 3 }, blob_options).to(data_->dtype).to(data_->device); + + // Create strain tensor for virial computation (like LAMMPS does) + auto strain = torch::eye( + 3, torch::TensorOptions().dtype(data_->dtype).device(data_->device).requires_grad(true)); + + // Apply strain to cell: strained_cell = cell @ strain + auto strained_cell = torch::matmul(torch_cell, strain); + // Apply strain to positions: r' = r @ strain + auto strained_positions = torch::matmul(torch_positions, strain); + + auto torch_pbc = preparePbcType(options_.params_.pbcType_.get(), data_->device); + auto torch_types = + torch::tensor(atomNumbers_, torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); + + auto system = torch::make_intrusive( + torch_types, strained_positions, strained_cell, torch_pbc); + + // Build neighbor list from GROMACS pairlist + for (const auto& request : data_->nl_requests) { - local_energy = energy_tensor.item(); + auto neighbors = buildNeighborListFromPairlist( + pairlistForModel_, shiftVectors_, cellShifts_, positions_, data_->device, data_->dtype); + // TODO: take from the user / model + metatomic_torch::register_autograd_neighbors(system, neighbors, /*check_consistency*/ true); + system->add_neighbor_list(request, neighbors); } + + metatensor_torch::TensorMap output_map; + try + { + std::vector systems; + systems.push_back(system); + + auto ivalue_output = data_->model.forward( + { c10::IValue(systems), data_->evaluations_options, data_->check_consistency }); + auto dict_output = ivalue_output.toGenericDict(); + output_map = dict_output.at("energy").toCustomClass(); + } + catch (const std::exception& e) + { + GMX_THROW(APIError("[Metatomic] Model evaluation failed: " + std::string(e.what()))); + } + + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); + auto energy_tensor = energy_block->values(); + + outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = + static_cast(energy_tensor.item()); + + // Reset gradients before backward + torch_positions.mutable_grad() = torch::Tensor(); + strain.mutable_grad() = torch::Tensor(); + + // Compute forces and virial via backward propagation + energy_tensor.backward(-torch::ones_like(energy_tensor)); + + auto grad = torch_positions.grad(); + forceTensor = grad.to(torch::kCPU).to(torch::kFloat64); + + // Get virial from strain gradient + virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); } - // Reduce energy across all ranks - double total_energy = local_energy; + // Distribute forces (sumReduce acts as broadcast since non-main ranks have zeros) if (mpiComm_.isParallel()) { - mpiComm_.sumReduce(gmx::ArrayRef(&total_energy, &total_energy + 1)); - } - - // Only main rank sets the energy (GROMACS handles the rest) - if (mpiComm_.isMainRank()) - { - outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = - static_cast(total_energy); + mpiComm_.sumReduce(n_atoms * 3, static_cast(forceTensor.data_ptr())); + mpiComm_.sumReduce(9, static_cast(virialTensor.data_ptr())); } - // Compute gradients - all ranks do this - energy_tensor.sum().backward(); - auto grad = system->positions().grad(); - auto forceTensor = -grad.to(torch::kCPU).to(data_->dtype); - - // Scatter forces to local atoms ONLY - no broadcast needed! - if (data_->dtype == torch::kFloat64) + // Apply forces to local atoms only + auto forceAccessor = forceTensor.accessor(); + for (int32_t i = 0; i < n_atoms; ++i) { - auto accessor = forceTensor.accessor(); - for (int i = 0; i < n_atoms; ++i) + if (inputToLocalIndex_[i] != -1) { - const int localIndex = idxLookup_[i]; - if (localIndex != -1) // Only update local atoms - { - outputs->forceWithVirial_.force_[localIndex][0] += accessor[i][0]; - outputs->forceWithVirial_.force_[localIndex][1] += accessor[i][1]; - outputs->forceWithVirial_.force_[localIndex][2] += accessor[i][2]; - } + outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][0] += forceAccessor[i][0]; + outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][1] += forceAccessor[i][1]; + outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][2] += forceAccessor[i][2]; } } - else if (data_->dtype == torch::kFloat32) + + // Apply virial contribution + // GROMACS uses a 3x3 virial tensor in forceWithVirial_ + // Copy the tensor data into a GROMACS matrix and use the public API + matrix virialMatrix; + auto virialAccessor = virialTensor.accessor(); + // TODO: technically this is DIM, not 3... + for (int32_t i = 0; i < 3; ++i) { - auto accessor = forceTensor.accessor(); - for (int i = 0; i < n_atoms; ++i) + for (int32_t j = 0; j < 3; ++j) { - const int localIndex = idxLookup_[i]; - if (localIndex != -1) // Only update local atoms - { - outputs->forceWithVirial_.force_[localIndex][0] += static_cast(accessor[i][0]); - outputs->forceWithVirial_.force_[localIndex][1] += static_cast(accessor[i][1]); - outputs->forceWithVirial_.force_[localIndex][2] += static_cast(accessor[i][2]); - } + virialMatrix[i][j] = virialAccessor[i][j]; } } - // Note: Virial still needs proper implementation + outputs->forceWithVirial_.addVirialContribution(virialMatrix); } } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index 0976da4d89..c26ce7b187 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -50,10 +50,17 @@ namespace gmx struct MetatomicParameters; struct MetatomicData; +struct MDModulesAtomsRedistributedSignal; +struct MDModulesPairlistConstructedSignal; class MDLogger; class MpiComm; +/*! For compatibility with pairlist data structure in MDModulesPairlistConstructedSignal. + * Contains pairs like ((atom1, atom2), shiftIndex). + */ +using PairlistEntry = std::pair, int32_t>; + /*! \brief \internal * MetatomicForceProvider class * @@ -72,28 +79,56 @@ class MetatomicForceProvider final : public IForceProvider * \param[out] fOutput output for force provider */ void calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) override; - void updateLocalAtoms(); - void gatherAtomPositions(ArrayRef globalPositions); - void gatherAtomNumbersIndices(); + + //! Gather atom numbers and indices. Triggered on AtomsRedistributed signal. + void gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal); + + //! Set pairlist from notification and filter to MTA atom pairs. + void setPairlist(const MDModulesPairlistConstructedSignal& signal); private: + //! Gather atom positions for MTA input. + void gatherAtomPositions(ArrayRef globalPositions); + + //! Prepare pairlist input for model + void preparePairlistInput(); + const MetatomicOptions& options_; const MDLogger& logger_; const MpiComm& mpiComm_; - //! vector storing all atom positions + //! vector storing all MTA atom positions std::vector positions_; //! vector storing all atomic numbers - std::vector atomNumbers_; + std::vector atomNumbers_; - //! global index lookup table to map indices from model input to global atom indices - std::vector idxLookup_; + //! lookup table to map model input indices [0...numInput) to local atom indices + std::vector inputToLocalIndex_; + + //! lookup table to map model input indices to global atom indices + std::vector inputToGlobalIndex_; + + //! Full pairlist from MDModules notification + std::vector fullPairlist_; + + //! Interacting pairs of MTA atoms within cutoff, for model input + std::vector pairlistForModel_; + + //! Shift vectors for each atom pair in pairlistForModel_ + std::vector shiftVectors_; + + //! Cell shifts + std::vector cellShifts_; //! local copy of simulation box matrix box_; + //! Data required for metatomic calculations std::unique_ptr data_; + + //! flag to check if pairlist data should be prepared + bool doPairlist_ = false; }; } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index 1e01e0f09a..6f3270f0ed 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -47,11 +47,21 @@ #include "gromacs/domdec/localatomset.h" #include "gromacs/domdec/localatomsetmanager.h" #include "gromacs/mdrunutility/mdmodulesnotifiers.h" +#include "gromacs/mdrunutility/plainpairlistranges.h" #include "gromacs/mdtypes/imdmodule.h" +#include "gromacs/utility/basenetwork.h" #include "gromacs/utility/keyvaluetreebuilder.h" #include "metatomic_forceprovider.h" #include "metatomic_options.h" +#ifdef DIM +# undef DIM +#endif + +#include + +#include +#include namespace gmx { @@ -167,6 +177,73 @@ class MetatomicMDModule final : public IMDModule [](MDModulesEnergyOutputToMetatomicPotRequestChecker* energyOutputRequest) { energyOutputRequest->energyOutputToMetatomicPot_ = true; }; notifiers->simulationSetupNotifier_.subscribe(requestEnergyOutput); + + const auto setPlainPairlistRangeFunction = [this](PlainPairlistRanges* ranges) + { + // Temporary: Load model just to peek at cutoff. + // TODO: can the whole model be loaded earlier..? + double max_cutoff{ 0.0 }; + try + { + auto model = metatomic_torch::load_atomistic_model(options_.parameters().modelPath_); + + // Check strict interaction range + auto capabilities = + model.run_method("capabilities").toCustomClass(); + double interaction_range = capabilities->engine_interaction_range("nm"); + + if (interaction_range < 0.0) + { + GMX_THROW(InconsistentInputError( + "interaction_range is negative for this model.")); + } + + if (!std::isfinite(interaction_range)) + { + // Infinite range (global) is only supported on a single rank + if (gmx_node_num() > 1) + { + GMX_THROW(NotImplementedError( + "interaction_range is infinite for this model; " + "using multiple MPI domains is not supported.")); + } + // For infinite range, check if specific NLs were requested + // effectively falling through to the loop below. + } + else + { + max_cutoff = interaction_range; + } + + // Check requested neighbor lists + auto requested_nl = model.run_method("requested_neighbor_lists"); + for (const auto& ivalue : requested_nl.toList()) + { + auto options = + ivalue.get().toCustomClass(); + double cutoff = options->engine_cutoff("nm"); + max_cutoff = std::max(max_cutoff, cutoff); + } + } + catch (const std::exception& e) + { + GMX_THROW(InternalError("Failed to read cutoff from model: " + std::string(e.what()))); + } + + if (max_cutoff <= 0.0 || !std::isfinite(max_cutoff)) + { + // If the model is purely global and requested no specific NL, + // there's no need for a pairlist, so max_cutoff remains 0. + if (max_cutoff == 0.0) + { + GMX_THROW(InconsistentInputError("Metatomic model cutoff is 0.0 or invalid.")); + } + } + // Register the requirement with GROMACS + ranges->addRange(max_cutoff); + }; + // Register the callback + notifiers->simulationSetupNotifier_.subscribe(setPlainPairlistRangeFunction); } /*! \brief Requests to be notified during the simulation. @@ -175,6 +252,7 @@ class MetatomicMDModule final : public IMDModule * * The Metatomic module subscribes to the following notifications: * - Atom redistribution due to domain decomposition + * - Changes in the neighborlist * by taking a const MDModulesAtomsRedistributedSignal as a parameter. */ void subscribeToSimulationRunNotifications(MDModulesNotifiers* notifiers) override @@ -185,10 +263,14 @@ class MetatomicMDModule final : public IMDModule } // After domain decomposition, the force provider needs to know which atoms are local. - const auto notifyDDFunction = [this](const MDModulesAtomsRedistributedSignal& /*signal*/) { - force_provider_->gatherAtomNumbersIndices(); - }; + const auto notifyDDFunction = [this](const MDModulesAtomsRedistributedSignal& signal) + { force_provider_->gatherAtomNumbersIndices(signal); }; notifiers->simulationRunNotifier_.subscribe(notifyDDFunction); + + // subscribe to pairlist construction notification + const auto notifyPairlistFunction = [this](const MDModulesPairlistConstructedSignal& signal) + { force_provider_->setPairlist(signal); }; + notifiers->simulationRunNotifier_.subscribe(notifyPairlistFunction); } void initForceProviders(ForceProviders* forceProviders) override @@ -199,8 +281,7 @@ class MetatomicMDModule final : public IMDModule } force_provider_ = std::make_unique( - options_, options_.logger(), options_.mpiComm() - ); + options_, options_.logger(), options_.mpiComm()); forceProviders->addForceProvider(force_provider_.get(), "Metatomic"); } diff --git a/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp b/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp index 3edd815ec9..257d830e19 100644 --- a/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp +++ b/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp @@ -34,9 +34,9 @@ */ /*! \internal \file * \brief - * Tests for functionality of the NNPotOptions + * Tests for functionality of the MetatomicOptions * - * \author Lukas Müllender + * \author Metatensor developers * \ingroup module_applied_forces */ diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml index d9407106e7..21555a3181 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml @@ -2,10 +2,9 @@ false - model.pt - System - - - + System + + + true diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml index cf1c69fc50..9f2f99bd32 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml @@ -3,11 +3,11 @@ ; Machine learning potential using metatomic -metatomic-active = true -metatomic-input-group = System -metatomic-model = -metatomic-extensions = -metatomic-device = -metatomic-check-consistency = true +metatomic-active = true +metatomic-input-group = System +metatomic-model = +metatomic-extensions = +metatomic-device = +metatomic-check-consistency = true diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml index c5c4bbdd90..0f2e8f2a3e 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml @@ -3,6 +3,6 @@ ; Machine learning potential using metatomic -metatomic-active = false +metatomic-active = false