From 29f4fe2f72c103df600d230355089ffb8dbcf3e7 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 4 Feb 2026 06:12:45 +0100 Subject: [PATCH 01/19] chore(mtagro): start by documenting a bit more --- .../metatomic/metatomic_forceprovider.cpp | 156 ++++++++++++++---- 1 file changed, 120 insertions(+), 36 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 3cbeb9e964..e677341749 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -44,6 +44,9 @@ #include +#include +#include + #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" #include "gromacs/mdrunutility/mdmodulesnotifiers.h" @@ -71,6 +74,11 @@ namespace gmx { +/*! \brief Normalizes the variant string for Metatomic output selection. + * + * \param[in] variant_string The raw variant string from options. + * \return A torch::optional containing the string if valid, or nullopt if empty/"no". + */ static torch::optional normalize_variant(std::string variant_string) { if (variant_string == "no" || variant_string.empty()) @@ -83,7 +91,14 @@ static torch::optional normalize_variant(std::string variant_string } } - +/*! \brief Finds the index of a value in a vector. + * + * Performs a linear search to locate a specific value within a vector. + * + * \param[in] vec The vector to search. + * \param[in] val The value to find. + * \return The index of the value if found, otherwise std::nullopt. + */ static std::optional indexOf(ArrayRef vec, const int32_t val) { auto it = std::find(vec.begin(), vec.end(), val); @@ -94,6 +109,12 @@ static std::optional indexOf(ArrayRef vec, const int32 return std::distance(vec.begin(), it); } +/*! \brief Converts GROMACS PbcType to a boolean tensor for Metatomic. + * + * \param[in] pbcType The GROMACS periodic boundary condition type. + * \param[in] device The torch device where the tensor should reside. + * \return A boolean tensor of shape {3} indicating periodicity in X, Y, Z. + */ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) { auto options = torch::TensorOptions().dtype(torch::kBool).device(device); @@ -108,11 +129,25 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) } else if (*pbcType != PbcType::Xyz) { - GMX_THROW(InconsistentInputError("PBC type not supported.")); + GMX_THROW(InconsistentInputError("PBC type not supported by Metatomic interface.")); } return torch::tensor({ true, true, true }, options); } +/*! \brief Constructs a Metatensor TensorBlock representing the neighbor list. + * + * This function takes the filtered pairlist (atoms participating in the model interaction) + * and constructs the corresponding neighbor list in the format required by Metatensor/Torch. + * It computes the interatomic vectors, applying periodic boundary shifts where necessary. + * + * \param[in] pairlist Flat array of atom pairs (indices into the model's atom list). + * \param[in] shiftVectors Geometric shift vectors (RVec) for each pair. + * \param[in] cellShifts Integer cell shift indices for each pair (for metadata). + * \param[in] positions Positions of the atoms (ordered by model index). + * \param[in] device The torch device for the output tensors. + * \param[in] dtype The torch scalar type (float32/float64). + * \return A TensorBlockHolder containing the neighbor list data. + */ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, ArrayRef shiftVectors, ArrayRef cellShifts, @@ -122,22 +157,24 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(pairlist.size() / 2); + // Prepare CPU tensors first to facilitate efficient element access auto cpu_int_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); auto cpu_float_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); + // Samples: [first_atom, second_atom, cell_shift_a, cell_shift_b, cell_shift_c] auto pair_samples_values = torch::zeros({ n_pairs, 5 }, cpu_int_options); auto pair_samples_ptr = pair_samples_values.accessor(); - // Full interatomic vectors (rj - ri + shift) + // Values: Full interatomic vectors (rj - ri + shift) auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, cpu_float_options); auto vectors_accessor = vectors_cpu.accessor(); for (int64_t i = 0; i < n_pairs; i++) { - int32_t atom_i = pairlist[2 * i]; - int32_t atom_j = pairlist[2 * i + 1]; + const int32_t atom_i = pairlist[2 * i]; + const int32_t atom_j = pairlist[2 * i + 1]; - // Access IVec elements (integers) + // Fill sample metadata pair_samples_ptr[i][0] = static_cast(atom_i); pair_samples_ptr[i][1] = static_cast(atom_j); pair_samples_ptr[i][2] = cellShifts[i][0]; @@ -145,11 +182,11 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]); - double r_ij_y = + const double r_ij_y = static_cast(positions[atom_j][1] - positions[atom_i][1] + shiftVectors[i][1]); - double r_ij_z = + const double r_ij_z = static_cast(positions[atom_j][2] - positions[atom_i][2] + shiftVectors[i][2]); vectors_accessor[i][0][0] = r_ij_x; @@ -157,6 +194,7 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRefcapabilities = data_->model.run_method("capabilities").toCustomClass(); - // Determine device using capabilities and optional environment variable + // Determine computation device torch::optional desiredDevice = torch::nullopt; if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) { @@ -250,6 +294,7 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, .asParagraph() .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); + // Process neighbor list requests from the model auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); for (const auto& request_ivalue : requests_ivalue.toList()) { @@ -259,6 +304,7 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, data_->model.to(data_->device); + // Configure precision if (data_->capabilities->dtype() == "float64") { data_->dtype = torch::kFloat64; @@ -269,23 +315,26 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, } else { - GMX_THROW(APIError("Unsupported dtype: " + data_->capabilities->dtype())); + GMX_THROW(APIError("Unsupported dtype from model capabilities: " + + data_->capabilities->dtype())); } data_->evaluations_options = torch::make_intrusive(); data_->evaluations_options->set_length_unit("nm"); + // Validate energy output existence auto outputs = data_->capabilities->outputs(); auto v_energy = normalize_variant(options_.params_.variant); auto energy_key = pick_output("energy", outputs, v_energy); if (!outputs.contains(energy_key)) { - GMX_THROW(APIError("the model at '" + options_.params_.modelPath_ - + "' does not provide " - "an '" - + energy_key + "' output, we can not use the metatomic interface.")); + GMX_THROW( + APIError(formatString("The model at '%s' does not provide an '%s' output. " + "Metatomic interface cannot proceed.", + options_.params_.modelPath_.c_str(), + energy_key.c_str()))); } auto requested_output = torch::make_intrusive(); @@ -298,6 +347,7 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, data_->check_consistency = options_.params_.checkConsistency; } + // Initialize vectors for atom mapping const auto& mtaIndices = options_.params_.mtaIndices_; const int32_t n_atoms = static_cast(mtaIndices.size()); @@ -334,10 +384,12 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist const auto& mtaIndices = options_.params_.mtaIndices_; const int32_t numInput = static_cast(mtaIndices.size()); + // Reset mappings inputToLocalIndex_.assign(numInput, -1); inputToGlobalIndex_.assign(numInput, -1); atomNumbers_.assign(numInput, 0); + // GROMACS domain decomposition logic if (mpiComm_.isParallel()) { GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), @@ -350,6 +402,7 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist int32_t globalIdx = globalAtomIndices[i]; for (int32_t j = 0; j < numInput; j++) { + // Match current local atom to one of the requested Metatomic input atoms if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) { if (i < numLocal) @@ -362,10 +415,12 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist } } } + // Reduce atomic numbers across ranks to ensure the main rank has the full set mpiComm_.sumReduce(numInput, atomNumbers_.data()); } else { + // Thread-MPI or Serial execution const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); for (int32_t i = 0; i < numInput; i++) { @@ -414,10 +469,21 @@ void MetatomicForceProvider::gatherAtomPositions(ArrayRef pos) void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& signal) { + // Capture the pairlist signal. Processing is deferred to + // calculateForces/preparePairlistInput to keep this callback fast. fullPairlist_.assign(signal.excludedPairlist_.begin(), signal.excludedPairlist_.end()); doPairlist_ = true; } +/*! \brief Converts the GROMACS neighbor list to a model-compatible list. + * + * This function iterates over the full GROMACS excluded pairlist (which contains pairs in + * GROMACS local atom indices). It filters this list to retain only pairs where *both* atoms + * are part of the Metatomic model's input set. + * + * It populates `pairlistForModel_` (using model-relative indices), `shiftVectors_`, + * and `cellShifts_`. + */ void MetatomicForceProvider::preparePairlistInput() { if (!doPairlist_) @@ -425,7 +491,10 @@ void MetatomicForceProvider::preparePairlistInput() return; } - GMX_ASSERT(!fullPairlist_.empty(), "Pairlist empty!"); + // Although the assert catches empty pairlists, in a real simulation with a very large cutoff, + // this might happen legitimately if only 1 atom exists. However, for standard MD, it indicates + // an issue. Assert here to catch initialization ordering bugs. + GMX_ASSERT(!fullPairlist_.empty(), "Pairlist for Metatomic is empty!"); const int32_t numPairs = gmx::ssize(fullPairlist_); pairlistForModel_.clear(); @@ -439,19 +508,29 @@ void MetatomicForceProvider::preparePairlistInput() { const auto [atomPair, shiftIndex] = fullPairlist_[i]; - auto inputIdxA = indexOf(inputToGlobalIndex_, atomPair.first); - auto inputIdxB = indexOf(inputToGlobalIndex_, atomPair.second); + // GROMACS pairlists use local atom indices. + // Map these local indices back to the model's input indices [0, N_model_atoms). + // `inputToLocalIndex_` maps ModelIdx -> LocalIdx. + // indexOf reverses the map: Find ModelIdx k such that inputToLocalIndex_[k] == LocalIdx. + auto inputIdxA = indexOf(inputToLocalIndex_, atomPair.first); - if (inputIdxA.has_value() && inputIdxB.has_value()) + if (inputIdxA.has_value()) { - RVec shift; - const IVec unitShift = shiftIndexToXYZ(shiftIndex); - mvmul_ur0(box_, unitShift.toRVec(), shift); - - pairlistForModel_.push_back(static_cast(inputIdxA.value())); - pairlistForModel_.push_back(static_cast(inputIdxB.value())); - shiftVectors_.push_back(shift); - cellShifts_.push_back(unitShift); + auto inputIdxB = indexOf(inputToLocalIndex_, atomPair.second); + + if (inputIdxB.has_value()) + { + // Both atoms belong to the Metatomic subsystem. + // Calculate the shift vector due to PBC. + RVec shift; + const IVec unitShift = shiftIndexToXYZ(shiftIndex); + mvmul_ur0(box_, unitShift.toRVec(), shift); + + pairlistForModel_.push_back(static_cast(inputIdxA.value())); + pairlistForModel_.push_back(static_cast(inputIdxB.value())); + shiftVectors_.push_back(shift); + cellShifts_.push_back(unitShift); + } } } @@ -464,11 +543,12 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F { const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); + // Update positions and box for the current step gatherAtomPositions(inputs.x_); copy_mat(inputs.box_, box_); preparePairlistInput(); - // Force tensor - main rank fills, others have zeros + // Force tensor - main rank fills this, others hold zeros until reduction torch::Tensor forceTensor = torch::zeros( { n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64).device(data_->device)); @@ -477,6 +557,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F if (mpiComm_.isMainRank()) { + // Select appropriate precision for GROMACS data conversion auto gromacs_scalar_type = torch::kFloat32; if (std::is_same_v) { @@ -492,7 +573,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto torch_cell = torch::from_blob(&box_, { 3, 3 }, cpu_blob_options).to(data_->dtype).to(data_->device); - // Create strain tensor for virial computation (like LAMMPS does) + // Create strain tensor (identity matrix) for virial computation via autodiff auto strain = torch::eye( 3, torch::TensorOptions().dtype(data_->dtype).device(data_->device).requires_grad(true)); @@ -523,6 +604,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F std::vector systems; systems.push_back(system); + // Forward pass auto ivalue_output = data_->model.forward( { systems, data_->evaluations_options, data_->check_consistency }); auto dict_output = ivalue_output.toGenericDict(); @@ -533,6 +615,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F GMX_THROW(APIError("[Metatomic] Model evaluation failed: " + std::string(e.what()))); } + // Extract Energy auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); auto energy_tensor = energy_block->values(); @@ -543,7 +626,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F torch_positions.mutable_grad() = torch::Tensor(); strain.mutable_grad() = torch::Tensor(); - // Compute forces and virial via backward propagation + // Backward pass: Compute forces (-dE/dr) and virial (-dE/dStrain) energy_tensor.backward(-torch::ones_like(energy_tensor)); auto grad = torch_positions.grad(); @@ -553,17 +636,18 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); } - // Distribute forces (sumReduce acts as broadcast since non-main ranks have zeros) + // Distribute results to all ranks if necessary (sumReduce broadcasts if ranks > 1) if (mpiComm_.isParallel()) { mpiComm_.sumReduce(n_atoms * 3, static_cast(forceTensor.data_ptr())); mpiComm_.sumReduce(9, static_cast(virialTensor.data_ptr())); } - // Apply forces to local atoms only + // Accumulate forces into the GROMACS force output auto forceAccessor = forceTensor.accessor(); for (int32_t i = 0; i < n_atoms; ++i) { + // Only apply force if this atom is local to this rank if (inputToLocalIndex_[i] != -1) { outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][0] += forceAccessor[i][0]; From 2f12d375278b6dad92e234b611667389be38373e Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 4 Feb 2026 10:13:47 +0100 Subject: [PATCH 02/19] chore(mtagro): load models on each rank --- .../metatomic/metatomic_forceprovider.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index e677341749..6e2a73ef5d 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -251,15 +251,16 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, // Matches NNPot's limitation if (mpiComm_.isParallel()) { - GMX_THROW( - NotImplementedError("Metatomic does not yet support domain decomposition (MPI). " + GMX_LOG(logger_.warning) + + .asParagraph() + + .appendText( + "Metatomic support domain decomposition is EXPERIMENTAL (MPI). " "Please use thread-MPI (gmx mdrun -ntmpi X) instead of MPI " - "(mpirun -np X gmx_mpi mdrun).")); + "(mpirun -np X gmx_mpi mdrun)."); } - // Only the main rank loads the model to avoid file contention and redundant loading on the same node. - if (mpiComm_.isMainRank()) - { try { torch::optional extensions_directory = torch::nullopt; @@ -345,7 +346,7 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, data_->evaluations_options->outputs.insert(energy_key, requested_output); data_->check_consistency = options_.params_.checkConsistency; - } + // Initialize vectors for atom mapping const auto& mtaIndices = options_.params_.mtaIndices_; From 82c8be591c0f1f06449868a0d4a1ce04653701ce Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 4 Feb 2026 10:24:27 +0100 Subject: [PATCH 03/19] chore(mtagro): start checking dom-dec --- .../metatomic/metatomic_forceprovider.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 6e2a73ef5d..12495d754d 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -418,6 +418,22 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist } // Reduce atomic numbers across ranks to ensure the main rank has the full set mpiComm_.sumReduce(numInput, atomNumbers_.data()); + + // Debug logging for domain decomposition distribution + int32_t localCount = 0; + for (const int32_t idx : inputToLocalIndex_) + { + if (idx != -1) + { + localCount++; + } + } + GMX_LOG(logger_.info) + .asParagraph() + .appendTextFormatted("Rank %d: Mapped %d / %d Metatomic atoms (Home+Halo).", + mpiComm_.rank(), + localCount, + numInput); } else { From 05c7e9d2493cc26f0a016aa12b09fb3df9642887 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 4 Feb 2026 11:13:45 +0100 Subject: [PATCH 04/19] chore(mtagro): even more debugging --- .../applied_forces/metatomic/metatomic_forceprovider.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 12495d754d..b9bebc7a53 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -406,6 +406,15 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist // Match current local atom to one of the requested Metatomic input atoms if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) { + // [DEBUG] Print distribution info + GMX_LOG(logger_.info) + .asParagraph().appendTextFormatted( + "Rank %d: Found ModelAtom %d (Global %d) at Local %d (%s)\n", + mpiComm_.rank(), + j, + globalIdx, + i, + (i < numLocal) ? "HOME" : "HALO"); if (i < numLocal) { inputToLocalIndex_[j] = i; From 6ed12a742b13e003f54eac33d4c2af9b333d59d0 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 4 Feb 2026 17:30:50 +0100 Subject: [PATCH 05/19] chore(mtagro): switch to fprintf since GMX_LOG is rank 0 only, and cerr is not allowed as per GROMACS style guides --- .../metatomic/metatomic_forceprovider.cpp | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index b9bebc7a53..1a808b433b 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -406,15 +406,13 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist // Match current local atom to one of the requested Metatomic input atoms if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) { - // [DEBUG] Print distribution info - GMX_LOG(logger_.info) - .asParagraph().appendTextFormatted( - "Rank %d: Found ModelAtom %d (Global %d) at Local %d (%s)\n", - mpiComm_.rank(), - j, - globalIdx, - i, - (i < numLocal) ? "HOME" : "HALO"); + std::fprintf(stderr, + "Rank %d: Found ModelAtom %d (Global %d) at Local %d (%s)\n", + mpiComm_.rank(), + j, + globalIdx, + i, + (i < numLocal) ? "HOME" : "HALO"); if (i < numLocal) { inputToLocalIndex_[j] = i; @@ -437,12 +435,8 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist localCount++; } } - GMX_LOG(logger_.info) - .asParagraph() - .appendTextFormatted("Rank %d: Mapped %d / %d Metatomic atoms (Home+Halo).", - mpiComm_.rank(), - localCount, - numInput); + fprintf(stderr, "Rank %d: Mapped %d / %d Metatomic atoms (Home+Halo).\n", + mpiComm_.rank(), localCount, numInput); } else { From 040e89fa86b896c4f0d95717164b2b0a270af82e Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 4 Feb 2026 17:40:40 +0100 Subject: [PATCH 06/19] chore(mtagro): start checking signal pairs --- .../applied_forces/metatomic/metatomic_forceprovider.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 1a808b433b..6b9ddf3ba4 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -548,6 +548,9 @@ void MetatomicForceProvider::preparePairlistInput() pairlistForModel_.push_back(static_cast(inputIdxA.value())); pairlistForModel_.push_back(static_cast(inputIdxB.value())); + std::fprintf(stderr, "Rank %d: Signal pair (Local %d, %d) -> Model (%ld, %ld)\n", + mpiComm_.rank(), atomPair.first, atomPair.second, + inputIdxA.value(), inputIdxB.value()); shiftVectors_.push_back(shift); cellShifts_.push_back(unitShift); } From a0488fd31cd3fb53faef8768b1ed1484ff6306bf Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Wed, 4 Feb 2026 17:40:51 +0100 Subject: [PATCH 07/19] feat(mtagro): try augmenting pairs --- .../metatomic/metatomic_forceprovider.cpp | 120 ++++++++++++++---- .../metatomic/metatomic_forceprovider.h | 3 +- 2 files changed, 96 insertions(+), 27 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 6b9ddf3ba4..35f41325ed 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -53,6 +53,8 @@ #include "gromacs/mdtypes/enerdata.h" #include "gromacs/mdtypes/forceoutput.h" #include "gromacs/pbcutil/ishift.h" +#include "gromacs/pbcutil/pbc.h" +#include "gromacs/selection/nbsearch.h" #include "gromacs/utility/arrayref.h" #include "gromacs/utility/exceptions.h" #include "gromacs/utility/logger.h" @@ -91,24 +93,6 @@ static torch::optional normalize_variant(std::string variant_string } } -/*! \brief Finds the index of a value in a vector. - * - * Performs a linear search to locate a specific value within a vector. - * - * \param[in] vec The vector to search. - * \param[in] val The value to find. - * \return The index of the value if found, otherwise std::nullopt. - */ -static std::optional indexOf(ArrayRef vec, const int32_t val) -{ - auto it = std::find(vec.begin(), vec.end(), val); - if (it == vec.end()) - { - return std::nullopt; - } - return std::distance(vec.begin(), it); -} - /*! \brief Converts GROMACS PbcType to a boolean tensor for Metatomic. * * \param[in] pbcType The GROMACS periodic boundary condition type. @@ -389,6 +373,7 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist inputToLocalIndex_.assign(numInput, -1); inputToGlobalIndex_.assign(numInput, -1); atomNumbers_.assign(numInput, 0); + localToModelIndex_.assign(signal.x_.size(), -1); // GROMACS domain decomposition logic if (mpiComm_.isParallel()) @@ -406,6 +391,7 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist // Match current local atom to one of the requested Metatomic input atoms if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) { + localToModelIndex_[i] = j; std::fprintf(stderr, "Rank %d: Found ModelAtom %d (Global %d) at Local %d (%s)\n", mpiComm_.rank(), @@ -532,13 +518,13 @@ void MetatomicForceProvider::preparePairlistInput() // Map these local indices back to the model's input indices [0, N_model_atoms). // `inputToLocalIndex_` maps ModelIdx -> LocalIdx. // indexOf reverses the map: Find ModelIdx k such that inputToLocalIndex_[k] == LocalIdx. - auto inputIdxA = indexOf(inputToLocalIndex_, atomPair.first); + const int32_t inputIdxA = localToModelIndex_[atomPair.first]; - if (inputIdxA.has_value()) + if (inputIdxA != -1) { - auto inputIdxB = indexOf(inputToLocalIndex_, atomPair.second); + const int32_t inputIdxB = localToModelIndex_[atomPair.second]; - if (inputIdxB.has_value()) + if (inputIdxB != -1) { // Both atoms belong to the Metatomic subsystem. // Calculate the shift vector due to PBC. @@ -546,11 +532,11 @@ void MetatomicForceProvider::preparePairlistInput() const IVec unitShift = shiftIndexToXYZ(shiftIndex); mvmul_ur0(box_, unitShift.toRVec(), shift); - pairlistForModel_.push_back(static_cast(inputIdxA.value())); - pairlistForModel_.push_back(static_cast(inputIdxB.value())); - std::fprintf(stderr, "Rank %d: Signal pair (Local %d, %d) -> Model (%ld, %ld)\n", + pairlistForModel_.push_back(inputIdxA); + pairlistForModel_.push_back(inputIdxB); + std::fprintf(stderr, "Rank %d: Signal pair (Local %d, %d) -> Model (%d, %d)\n", mpiComm_.rank(), atomPair.first, atomPair.second, - inputIdxA.value(), inputIdxB.value()); + inputIdxA, inputIdxB); shiftVectors_.push_back(shift); cellShifts_.push_back(unitShift); } @@ -562,6 +548,87 @@ void MetatomicForceProvider::preparePairlistInput() doPairlist_ = false; } +void MetatomicForceProvider::augmentGhostPairs(const ArrayRef x, const matrix box) +{ + const int32_t nHome = options_.params_.mtaAtoms_->localIndex().size(); + const int32_t nTotal = x.size(); + + if (nTotal <= nHome) + { + return; + } + + t_pbc pbc; + set_pbc(&pbc, *options_.params_.pbcType_, box); + + const auto ghostCoords = x.subArray(nHome, nTotal - nHome); + + gmx::AnalysisNeighborhood nb; + nb.setCutoff(data_->nl_requests[0]->cutoff()); + + gmx::AnalysisNeighborhoodPositions ghostPositions(as_rvec_array(ghostCoords.data()), + ghostCoords.size()); + + gmx::AnalysisNeighborhoodSearch search = nb.initSearch(&pbc, ghostPositions); + gmx::AnalysisNeighborhoodPairSearch ghostSearch = search.startSelfPairSearch(); + gmx::AnalysisNeighborhoodPair pair; + + while (ghostSearch.findNextPair(&pair)) + { + const int32_t localIdxA = pair.refIndex() + nHome; + const int32_t localIdxB = pair.testIndex() + nHome; + + const int32_t inputIdxA = localToModelIndex_[localIdxA]; + const int32_t inputIdxB = localToModelIndex_[localIdxB]; + + if (inputIdxA != -1 && inputIdxB != -1) + { + rvec rij_raw, shift; + rvec_sub(x[localIdxA], x[localIdxB], rij_raw); + + // PBC shift calculation: S = r_ij_corrected - (x_j - x_i) + // XXX: there's got to be a better way........ + rvec_sub(pair.dx(), rij_raw, shift); + + // Explicit 3x3 inversion for box matrix to find integer shifts + double det = box[0][0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) - + box[0][1] * (box[1][0] * box[2][2] - box[1][2] * box[2][0]) + + box[0][2] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]); + + double invDet = 1.0 / det; + rvec unitShiftRvec; + unitShiftRvec[0] = invDet * (shift[0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) + + shift[1] * (box[0][2] * box[2][1] - box[0][1] * box[2][2]) + + shift[2] * (box[0][1] * box[1][2] - box[0][2] * box[1][1])); + unitShiftRvec[1] = invDet * (shift[0] * (box[1][2] * box[2][0] - box[1][0] * box[2][2]) + + shift[1] * (box[0][0] * box[2][2] - box[0][2] * box[2][0]) + + shift[2] * (box[0][2] * box[1][0] - box[0][0] * box[1][2])); + unitShiftRvec[2] = invDet * (shift[0] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]) + + shift[1] * (box[0][1] * box[2][0] - box[0][0] * box[2][1]) + + shift[2] * (box[0][0] * box[1][1] - box[0][1] * box[1][0])); + + IVec unitShift; + unitShift[0] = static_cast(std::round(unitShiftRvec[0])); + unitShift[1] = static_cast(std::round(unitShiftRvec[1])); + unitShift[2] = static_cast(std::round(unitShiftRvec[2])); + + pairlistForModel_.push_back(inputIdxA); + pairlistForModel_.push_back(inputIdxB); + shiftVectors_.push_back(RVec(shift)); + cellShifts_.push_back(unitShift); + + std::fprintf(stderr, + "[Augmented] Rank %d: Halo pair (Local %d, %d) -> Model (%d, %d)\n", + mpiComm_.rank(), + localIdxA, + localIdxB, + inputIdxA, + inputIdxB); + } + } +} + + void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) { const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); @@ -570,6 +637,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F gatherAtomPositions(inputs.x_); copy_mat(inputs.box_, box_); preparePairlistInput(); + augmentGhostPairs(inputs.x_, inputs.box_); // Force tensor - main rank fills this, others hold zeros until reduction torch::Tensor forceTensor = torch::zeros( diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index c26ce7b187..58dd985528 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -85,6 +85,7 @@ class MetatomicForceProvider final : public IForceProvider //! Set pairlist from notification and filter to MTA atom pairs. void setPairlist(const MDModulesPairlistConstructedSignal& signal); + void augmentGhostPairs(const ArrayRef x, const matrix box); private: //! Gather atom positions for MTA input. @@ -105,7 +106,7 @@ class MetatomicForceProvider final : public IForceProvider //! lookup table to map model input indices [0...numInput) to local atom indices std::vector inputToLocalIndex_; - + std::vector localToModelIndex_; //! lookup table to map model input indices to global atom indices std::vector inputToGlobalIndex_; From 8c774a1dabe0468e2eae7a49cdbae190d22e5833 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sat, 14 Feb 2026 23:54:01 +0100 Subject: [PATCH 08/19] chore(mtatimer): initialize --- .../metatomic/metatomic_timer.h | 141 ++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 src/gromacs/applied_forces/metatomic/metatomic_timer.h diff --git a/src/gromacs/applied_forces/metatomic/metatomic_timer.h b/src/gromacs/applied_forces/metatomic/metatomic_timer.h new file mode 100644 index 0000000000..370efcfeb6 --- /dev/null +++ b/src/gromacs/applied_forces/metatomic/metatomic_timer.h @@ -0,0 +1,141 @@ +/* + * This file is part of the GROMACS molecular simulation package. + * + * Copyright 2024- The GROMACS Authors + * and the project initiators Erik Lindahl, Berk Hess and David van der Spoel. + * Consult the AUTHORS/COPYING files and https://www.gromacs.org for details. + * + * GROMACS is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public License + * as published by the Free Software Foundation; either version 2.1 + * of the License, or (at your option) any later version. + * + * GROMACS is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with GROMACS; if not, see + * https://www.gnu.org/licenses, or write to the Free Software Foundation, + * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + * + * If you want to redistribute modifications to GROMACS, please + * consider that scientific software is very special. Version + * control is crucial - bugs must be traceable. We will be happy to + * consider code for inclusion in the official distribution, but + * derived work must not be called official GROMACS. Details are found + * in the README & COPYING files - if they are missing, get the + * official version at https://www.gromacs.org. + * + * To help us fund GROMACS development, we humbly ask that you cite + * the research papers on the package. Check out https://www.gromacs.org. + */ +/*! \internal \file + * \brief + * Scoped timer for Metatomic force provider profiling. + * + * RAII timer that prints nested timing information to stderr with MPI rank. + * Enable with MetatomicTimer::enable(true) before use. + * + * \author Metatensor developers + * \ingroup module_applied_forces + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "gromacs/utility/mpicomm.h" + +namespace gmx +{ + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static std::mutex METATOMIC_TIMER_MUTEX = {}; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static int64_t METATOMIC_TIMER_DEPTH = -1; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static uint64_t METATOMIC_TIMER_COUNTER = 0; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +static bool METATOMIC_TIMER_ENABLED = false; + +/*! \internal \brief RAII scoped timer for Metatomic profiling. + * + * Prints hierarchical timing info to stderr. Timers nest automatically. + * Thread-safe via a global mutex. + */ +class MetatomicTimer +{ +public: + //! Enable or disable all timers globally. + static void enable(bool toggle) + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + METATOMIC_TIMER_ENABLED = toggle; + } + + //! Construct a timer with the given label. Starts timing if enabled. + MetatomicTimer(std::string name, const MpiComm& mpiComm) : + enabled_(false), name_(std::move(name)), mpiComm_(mpiComm) + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + if (METATOMIC_TIMER_ENABLED) + { + METATOMIC_TIMER_DEPTH += 1; + METATOMIC_TIMER_COUNTER += 1; + + this->enabled_ = true; + this->starting_counter_ = METATOMIC_TIMER_COUNTER; + this->start_ = std::chrono::high_resolution_clock::now(); + auto indent = std::string(METATOMIC_TIMER_DEPTH * 3, ' '); + + if (METATOMIC_TIMER_DEPTH == 0) + { + std::cerr << "\n"; + } + std::cerr << "\n" << indent << this->name_ << " ..."; + } + } + + ~MetatomicTimer() + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + + if (METATOMIC_TIMER_ENABLED && this->enabled_) + { + auto stop = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast(stop - start_).count(); + + if (METATOMIC_TIMER_COUNTER != starting_counter_) + { + auto indent = std::string(METATOMIC_TIMER_DEPTH * 3, ' '); + std::cerr << "\n" << indent << this->name_; + } + + std::cerr << " took " << elapsed / 1e6 << "ms (rank " << mpiComm_.rank() << ")" + << std::flush; + METATOMIC_TIMER_DEPTH -= 1; + } + } + + // Non-copyable, non-movable + MetatomicTimer(const MetatomicTimer&) = delete; + MetatomicTimer& operator=(const MetatomicTimer&) = delete; + MetatomicTimer(MetatomicTimer&&) = delete; + MetatomicTimer& operator=(MetatomicTimer&&) = delete; + +private: + bool enabled_; + std::string name_; + const MpiComm& mpiComm_; + uint64_t starting_counter_ = 0; + std::chrono::high_resolution_clock::time_point start_; +}; + +} // namespace gmx From 120bf2aa7d92e4a94a66cec674960bfe025cef9d Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sat, 14 Feb 2026 23:54:11 +0100 Subject: [PATCH 09/19] chore(localtopo): ensure pairlist assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. The bonded interaction building (make_bondeds_zone) runs for all zones but won't find any cross-zone bondeds when !hasInterAtomicInteractions() — it's a no-op for the extra zones 2. The exclusion building (make_exclusions_zone) correctly builds exclusion entries for all i-zone atoms, satisfying the pairlist assertion Added nzone_bondeds = std::max(nzone_bondeds, numIZonesForExclusions) to ensure the exclusion building loop covers all i-zones when intermolecularExclusionGroup is present. Without this, 3D DD (e.g., 2x2x2 with 8 ranks) has numIZones=4 but nzone_bondeds=1, so exclusion lists are only built for zone 0 atoms while the nbnxm assertion expects them for zones 0-3. --- src/gromacs/domdec/localtopology.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/gromacs/domdec/localtopology.cpp b/src/gromacs/domdec/localtopology.cpp index 85d5ee547d..f1ca846fd7 100644 --- a/src/gromacs/domdec/localtopology.cpp +++ b/src/gromacs/domdec/localtopology.cpp @@ -846,6 +846,14 @@ static int make_local_bondeds_excls(const gmx_domdec_t& dd, /* We only use exclusions from i-zones to i- and j-zones */ const int numIZonesForExclusions = (dd.haveExclusions ? zones.numIZones() : 0); + /* When intermolecular exclusions are present (e.g. from embedded/ML potentials) + * but there are no inter-atomic bonded interactions spanning zones, the outer loop + * must still cover all i-zones so that exclusion lists are built for all i-zone atoms. + * Without this, the exclusion list would only cover zone 0 atoms while the nbnxm + * pairlist construction expects exclusions for all i-zone atoms. + */ + nzone_bondeds = std::max(nzone_bondeds, numIZonesForExclusions); + const gmx_reverse_top_t& rt = *dd.reverse_top; const real cutoffSquared = gmx::square(cutoff); From d0d49de7cd0f1b71132c9778b3e1039abaa92dfa Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sat, 14 Feb 2026 23:55:25 +0100 Subject: [PATCH 10/19] feat(mta): rework for mpi with timer.. 1. localToModelIndex_ sized to numLocalPlusHalo instead of signal.x_.size() (was OOB write) 2. augmentGhostPairs rewritten to correctly identify halo MTA atoms by iterating localToModelIndex_ from numLocalAtoms_ onward, instead of incorrectly slicing the full coordinate array 3. Shift vector computed as pair.dx() - (positions_[B] - positions_[A]) in model-space, then rounded to integer cell shifts and recomputed from box vectors for consistency 4. Deduplication of pairs using std::set to handle overlap between signal pairs and augmented halo-halo pairs 5. Timer instrumentation via MetatomicTimer RAII class around key phases --- .../metatomic/metatomic_forceprovider.cpp | 404 ++++++++++++------ .../metatomic/metatomic_forceprovider.h | 4 + 2 files changed, 282 insertions(+), 126 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 35f41325ed..f3549eda02 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -42,10 +42,13 @@ #include "metatomic_forceprovider.h" +#include #include #include #include +#include +#include #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" @@ -61,6 +64,8 @@ #include "gromacs/utility/mpicomm.h" #include "gromacs/utility/stringutil.h" +#include "metatomic_timer.h" + #ifdef DIM # undef DIM #endif @@ -231,6 +236,12 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, { GMX_LOG(logger_.info).asParagraph().appendText("Initializing MetatomicForceProvider..."); + // Enable profiling via environment variable GMX_METATOMIC_TIMER=1 + if (const char* timerEnv = std::getenv("GMX_METATOMIC_TIMER")) + { + MetatomicTimer::enable(std::string(timerEnv) == "1"); + } + // Pairlist-based neighbor lists don't work with domain decomposition yet (indices are local) // Matches NNPot's limitation if (mpiComm_.isParallel()) @@ -241,96 +252,94 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, .appendText( "Metatomic support domain decomposition is EXPERIMENTAL (MPI). " - "Please use thread-MPI (gmx mdrun -ntmpi X) instead of MPI " - "(mpirun -np X gmx_mpi mdrun)."); + "Please use thread-MPI (gmx mdrun -ntmpi X) instead of MPI " + "(mpirun -np X gmx_mpi mdrun)."); } - try - { - torch::optional extensions_directory = torch::nullopt; - if (!options_.params_.extensionsDirectory.empty()) - { - extensions_directory = options_.params_.extensionsDirectory; - } - - data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, - extensions_directory); - } - catch (const std::exception& e) + try + { + torch::optional extensions_directory = torch::nullopt; + if (!options_.params_.extensionsDirectory.empty()) { - GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); + extensions_directory = options_.params_.extensionsDirectory; } - data_->capabilities = - data_->model.run_method("capabilities").toCustomClass(); + data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, + extensions_directory); + } + catch (const std::exception& e) + { + GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); + } - // Determine computation device - torch::optional desiredDevice = torch::nullopt; - if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) - { - desiredDevice = std::string(env); - } + data_->capabilities = + data_->model.run_method("capabilities").toCustomClass(); - const auto deviceType = - metatomic_torch::pick_device(data_->capabilities->supported_devices, desiredDevice); - data_->device = torch::Device(deviceType); + // Determine computation device + torch::optional desiredDevice = torch::nullopt; + if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) + { + desiredDevice = std::string(env); + } - GMX_LOG(logger_.info) - .asParagraph() - .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); + const auto deviceType = + metatomic_torch::pick_device(data_->capabilities->supported_devices, desiredDevice); + data_->device = torch::Device(deviceType); - // Process neighbor list requests from the model - auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); - for (const auto& request_ivalue : requests_ivalue.toList()) - { - data_->nl_requests.push_back( - request_ivalue.get().toCustomClass()); - } + GMX_LOG(logger_.info) + .asParagraph() + .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); - data_->model.to(data_->device); + // Process neighbor list requests from the model + auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); + for (const auto& request_ivalue : requests_ivalue.toList()) + { + data_->nl_requests.push_back( + request_ivalue.get().toCustomClass()); + } - // Configure precision - if (data_->capabilities->dtype() == "float64") - { - data_->dtype = torch::kFloat64; - } - else if (data_->capabilities->dtype() == "float32") - { - data_->dtype = torch::kFloat32; - } - else - { - GMX_THROW(APIError("Unsupported dtype from model capabilities: " - + data_->capabilities->dtype())); - } + data_->model.to(data_->device); + + // Configure precision + if (data_->capabilities->dtype() == "float64") + { + data_->dtype = torch::kFloat64; + } + else if (data_->capabilities->dtype() == "float32") + { + data_->dtype = torch::kFloat32; + } + else + { + GMX_THROW(APIError("Unsupported dtype from model capabilities: " + data_->capabilities->dtype())); + } - data_->evaluations_options = - torch::make_intrusive(); - data_->evaluations_options->set_length_unit("nm"); + data_->evaluations_options = torch::make_intrusive(); + data_->evaluations_options->set_length_unit("nm"); - // Validate energy output existence - auto outputs = data_->capabilities->outputs(); - auto v_energy = normalize_variant(options_.params_.variant); - auto energy_key = pick_output("energy", outputs, v_energy); + // Validate energy output existence + auto outputs = data_->capabilities->outputs(); + auto v_energy = normalize_variant(options_.params_.variant); + auto energy_key = pick_output("energy", outputs, v_energy); - if (!outputs.contains(energy_key)) - { - GMX_THROW( - APIError(formatString("The model at '%s' does not provide an '%s' output. " - "Metatomic interface cannot proceed.", - options_.params_.modelPath_.c_str(), - energy_key.c_str()))); - } + if (!outputs.contains(energy_key)) + { + GMX_THROW( + APIError(formatString("The model at '%s' does not provide an '%s' output. " + "Metatomic interface cannot proceed.", + options_.params_.modelPath_.c_str(), + energy_key.c_str()))); + } + + auto requested_output = torch::make_intrusive(); + // TODO: take from the user + requested_output->per_atom = false; + requested_output->explicit_gradients = {}; + requested_output->set_unit("kJ/mol"); - auto requested_output = torch::make_intrusive(); - // TODO: take from the user - requested_output->per_atom = false; - requested_output->explicit_gradients = {}; - requested_output->set_unit("kJ/mol"); + data_->evaluations_options->outputs.insert(energy_key, requested_output); + data_->check_consistency = options_.params_.checkConsistency; - data_->evaluations_options->outputs.insert(energy_key, requested_output); - data_->check_consistency = options_.params_.checkConsistency; - // Initialize vectors for atom mapping const auto& mtaIndices = options_.params_.mtaIndices_; @@ -373,7 +382,6 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist inputToLocalIndex_.assign(numInput, -1); inputToGlobalIndex_.assign(numInput, -1); atomNumbers_.assign(numInput, 0); - localToModelIndex_.assign(signal.x_.size(), -1); // GROMACS domain decomposition logic if (mpiComm_.isParallel()) @@ -382,8 +390,13 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist "Global atom indices required for domain decomposition."); auto globalAtomIndices = signal.globalAtomIndices_.value(); const int32_t numLocal = signal.x_.size(); + const int32_t numLocalPlusHalo = globalAtomIndices.size(); + + // Size to include both home and halo atoms + localToModelIndex_.assign(numLocalPlusHalo, -1); + numLocalAtoms_ = numLocal; - for (int32_t i = 0; i < static_cast(globalAtomIndices.size()); i++) + for (int32_t i = 0; i < numLocalPlusHalo; i++) { int32_t globalIdx = globalAtomIndices[i]; for (int32_t j = 0; j < numInput; j++) @@ -413,21 +426,36 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist mpiComm_.sumReduce(numInput, atomNumbers_.data()); // Debug logging for domain decomposition distribution - int32_t localCount = 0; - for (const int32_t idx : inputToLocalIndex_) + int32_t homeCount = 0; + int32_t haloCount = 0; + for (int32_t i = 0; i < numLocalPlusHalo; i++) { - if (idx != -1) + if (localToModelIndex_[i] != -1) { - localCount++; + if (i < numLocal) + { + homeCount++; + } + else + { + haloCount++; + } } } - fprintf(stderr, "Rank %d: Mapped %d / %d Metatomic atoms (Home+Halo).\n", - mpiComm_.rank(), localCount, numInput); + std::fprintf(stderr, + "Rank %d: Mapped %d HOME + %d HALO = %d / %d Metatomic atoms.\n", + mpiComm_.rank(), + homeCount, + haloCount, + homeCount + haloCount, + numInput); } else { // Thread-MPI or Serial execution const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); + localToModelIndex_.clear(); + numLocalAtoms_ = 0; for (int32_t i = 0; i < numInput; i++) { int32_t localIndex = mtaAtoms->localIndex()[i]; @@ -534,9 +562,13 @@ void MetatomicForceProvider::preparePairlistInput() pairlistForModel_.push_back(inputIdxA); pairlistForModel_.push_back(inputIdxB); - std::fprintf(stderr, "Rank %d: Signal pair (Local %d, %d) -> Model (%d, %d)\n", - mpiComm_.rank(), atomPair.first, atomPair.second, - inputIdxA, inputIdxB); + std::fprintf(stderr, + "Rank %d: Signal pair (Local %d, %d) -> Model (%d, %d)\n", + mpiComm_.rank(), + atomPair.first, + atomPair.second, + inputIdxA, + inputIdxB); shiftVectors_.push_back(shift); cellShifts_.push_back(unitShift); } @@ -550,10 +582,32 @@ void MetatomicForceProvider::preparePairlistInput() void MetatomicForceProvider::augmentGhostPairs(const ArrayRef x, const matrix box) { - const int32_t nHome = options_.params_.mtaAtoms_->localIndex().size(); - const int32_t nTotal = x.size(); + if (!mpiComm_.isParallel()) + { + return; + } + + // Identify halo MTA atoms: atoms in localToModelIndex_ that have a valid model index + // but are NOT home atoms on this rank (i.e. inputToLocalIndex_[modelIdx] == -1). + // These are atoms in the halo zone (local index >= numLocalAtoms_). + std::vector haloLocalIndices; + std::vector haloCoords; + + for (int32_t i = numLocalAtoms_; i < static_cast(localToModelIndex_.size()); i++) + { + if (localToModelIndex_[i] != -1) + { + haloLocalIndices.push_back(i); + haloCoords.push_back(x[i]); + } + } + + std::fprintf(stderr, + "Rank %d: augmentGhostPairs found %zu halo MTA atoms\n", + mpiComm_.rank(), + haloCoords.size()); - if (nTotal <= nHome) + if (haloCoords.size() < 2) { return; } @@ -561,83 +615,178 @@ void MetatomicForceProvider::augmentGhostPairs(const ArrayRef x, con t_pbc pbc; set_pbc(&pbc, *options_.params_.pbcType_, box); - const auto ghostCoords = x.subArray(nHome, nTotal - nHome); - gmx::AnalysisNeighborhood nb; nb.setCutoff(data_->nl_requests[0]->cutoff()); - gmx::AnalysisNeighborhoodPositions ghostPositions(as_rvec_array(ghostCoords.data()), - ghostCoords.size()); + gmx::AnalysisNeighborhoodPositions ghostPositions(as_rvec_array(haloCoords.data()), haloCoords.size()); gmx::AnalysisNeighborhoodSearch search = nb.initSearch(&pbc, ghostPositions); gmx::AnalysisNeighborhoodPairSearch ghostSearch = search.startSelfPairSearch(); gmx::AnalysisNeighborhoodPair pair; + int32_t augmentedCount = 0; while (ghostSearch.findNextPair(&pair)) { - const int32_t localIdxA = pair.refIndex() + nHome; - const int32_t localIdxB = pair.testIndex() + nHome; + const int32_t localIdxA = haloLocalIndices[pair.refIndex()]; + const int32_t localIdxB = haloLocalIndices[pair.testIndex()]; const int32_t inputIdxA = localToModelIndex_[localIdxA]; const int32_t inputIdxB = localToModelIndex_[localIdxB]; + // Both should be valid since we pre-filtered, but guard anyway if (inputIdxA != -1 && inputIdxB != -1) { - rvec rij_raw, shift; - rvec_sub(x[localIdxA], x[localIdxB], rij_raw); - - // PBC shift calculation: S = r_ij_corrected - (x_j - x_i) - // XXX: there's got to be a better way........ - rvec_sub(pair.dx(), rij_raw, shift); - - // Explicit 3x3 inversion for box matrix to find integer shifts - double det = box[0][0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) - - box[0][1] * (box[1][0] * box[2][2] - box[1][2] * box[2][0]) + - box[0][2] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]); - - double invDet = 1.0 / det; - rvec unitShiftRvec; - unitShiftRvec[0] = invDet * (shift[0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) + - shift[1] * (box[0][2] * box[2][1] - box[0][1] * box[2][2]) + - shift[2] * (box[0][1] * box[1][2] - box[0][2] * box[1][1])); - unitShiftRvec[1] = invDet * (shift[0] * (box[1][2] * box[2][0] - box[1][0] * box[2][2]) + - shift[1] * (box[0][0] * box[2][2] - box[0][2] * box[2][0]) + - shift[2] * (box[0][2] * box[1][0] - box[0][0] * box[1][2])); - unitShiftRvec[2] = invDet * (shift[0] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]) + - shift[1] * (box[0][1] * box[2][0] - box[0][0] * box[2][1]) + - shift[2] * (box[0][0] * box[1][1] - box[0][1] * box[1][0])); + // pair.dx() returns the PBC-correct vector from ref to test. + // buildNeighborListFromPairlist computes: r_ij = positions_[B] - positions_[A] + shift + // We need: shift = pair.dx() - (positions_[B] - positions_[A]) + rvec modelDiff; + rvec_sub(positions_[inputIdxB].as_vec(), positions_[inputIdxA].as_vec(), modelDiff); + + rvec shift; + rvec_sub(pair.dx(), modelDiff, shift); + + // Compute integer cell shifts via box matrix inversion + double det = box[0][0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) + - box[0][1] * (box[1][0] * box[2][2] - box[1][2] * box[2][0]) + + box[0][2] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]); IVec unitShift; - unitShift[0] = static_cast(std::round(unitShiftRvec[0])); - unitShift[1] = static_cast(std::round(unitShiftRvec[1])); - unitShift[2] = static_cast(std::round(unitShiftRvec[2])); + if (std::abs(det) > 1e-10) + { + double invDet = 1.0 / det; + rvec unitShiftRvec; + unitShiftRvec[0] = invDet + * (shift[0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) + + shift[1] * (box[0][2] * box[2][1] - box[0][1] * box[2][2]) + + shift[2] * (box[0][1] * box[1][2] - box[0][2] * box[1][1])); + unitShiftRvec[1] = invDet + * (shift[0] * (box[1][2] * box[2][0] - box[1][0] * box[2][2]) + + shift[1] * (box[0][0] * box[2][2] - box[0][2] * box[2][0]) + + shift[2] * (box[0][2] * box[1][0] - box[0][0] * box[1][2])); + unitShiftRvec[2] = invDet + * (shift[0] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]) + + shift[1] * (box[0][1] * box[2][0] - box[0][0] * box[2][1]) + + shift[2] * (box[0][0] * box[1][1] - box[0][1] * box[1][0])); + + unitShift[0] = static_cast(std::round(unitShiftRvec[0])); + unitShift[1] = static_cast(std::round(unitShiftRvec[1])); + unitShift[2] = static_cast(std::round(unitShiftRvec[2])); + } + else + { + unitShift = { 0, 0, 0 }; + } + + // Recompute shift from integer cell shifts for consistency with preparePairlistInput + RVec finalShift; + mvmul_ur0(box, unitShift.toRVec(), finalShift); pairlistForModel_.push_back(inputIdxA); pairlistForModel_.push_back(inputIdxB); - shiftVectors_.push_back(RVec(shift)); + shiftVectors_.push_back(finalShift); cellShifts_.push_back(unitShift); + augmentedCount++; std::fprintf(stderr, - "[Augmented] Rank %d: Halo pair (Local %d, %d) -> Model (%d, %d)\n", + "[Augmented] Rank %d: Halo pair (Local %d, %d) -> Model (%d, %d) " + "shift=(%d,%d,%d)\n", mpiComm_.rank(), localIdxA, localIdxB, inputIdxA, - inputIdxB); + inputIdxB, + unitShift[0], + unitShift[1], + unitShift[2]); } } + + std::fprintf(stderr, "Rank %d: augmentGhostPairs added %d halo-halo pairs\n", mpiComm_.rank(), augmentedCount); } void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) { + MetatomicTimer totalTimer("calculateForces", mpiComm_); + const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); // Update positions and box for the current step - gatherAtomPositions(inputs.x_); + { + MetatomicTimer timer("gatherAtomPositions", mpiComm_); + gatherAtomPositions(inputs.x_); + } copy_mat(inputs.box_, box_); - preparePairlistInput(); - augmentGhostPairs(inputs.x_, inputs.box_); + + { + MetatomicTimer timer("preparePairlistInput", mpiComm_); + preparePairlistInput(); + } + + const int32_t signalPairs = static_cast(pairlistForModel_.size() / 2); + + { + MetatomicTimer timer("augmentGhostPairs", mpiComm_); + augmentGhostPairs(inputs.x_, inputs.box_); + } + + const int32_t totalPairsBeforeDedup = static_cast(pairlistForModel_.size() / 2); + + // Deduplicate pairs: the signal may already include some halo-halo pairs + // that augmentGhostPairs also finds. Metatensor requires unique labels. + { + MetatomicTimer timer("deduplicatePairs", mpiComm_); + + using PairKey = std::tuple; + std::set seen; + std::vector dedupPairlist; + std::vector dedupShifts; + std::vector dedupCellShifts; + + const int32_t nPairs = static_cast(pairlistForModel_.size() / 2); + dedupPairlist.reserve(pairlistForModel_.size()); + dedupShifts.reserve(nPairs); + dedupCellShifts.reserve(nPairs); + + for (int32_t i = 0; i < nPairs; i++) + { + int32_t a = pairlistForModel_[2 * i]; + int32_t b = pairlistForModel_[2 * i + 1]; + PairKey key(a, b, cellShifts_[i][0], cellShifts_[i][1], cellShifts_[i][2]); + + if (seen.insert(key).second) + { + dedupPairlist.push_back(a); + dedupPairlist.push_back(b); + dedupShifts.push_back(shiftVectors_[i]); + dedupCellShifts.push_back(cellShifts_[i]); + } + } + + const int32_t removed = nPairs - static_cast(dedupShifts.size()); + if (removed > 0) + { + std::fprintf(stderr, "Rank %d: Removed %d duplicate pairs\n", mpiComm_.rank(), removed); + } + + pairlistForModel_ = std::move(dedupPairlist); + shiftVectors_ = std::move(dedupShifts); + cellShifts_ = std::move(dedupCellShifts); + } + + const int32_t totalPairs = static_cast(pairlistForModel_.size() / 2); + std::fprintf(stderr, + "Rank %d Step %ld: %d signal + %d augmented - %d dupes = %d unique pairs, " + "%d model atoms, homenr=%d, x.size=%zu\n", + mpiComm_.rank(), + inputs.step_, + signalPairs, + totalPairsBeforeDedup - signalPairs, + totalPairsBeforeDedup - totalPairs, + totalPairs, + n_atoms, + inputs.homenr_, + inputs.x_.size()); // Force tensor - main rank fills this, others hold zeros until reduction torch::Tensor forceTensor = torch::zeros( @@ -648,6 +797,8 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F if (mpiComm_.isMainRank()) { + MetatomicTimer modelTimer("model inference (main rank)", mpiComm_); + // Select appropriate precision for GROMACS data conversion auto gromacs_scalar_type = torch::kFloat32; if (std::is_same_v) @@ -730,6 +881,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // Distribute results to all ranks if necessary (sumReduce broadcasts if ranks > 1) if (mpiComm_.isParallel()) { + MetatomicTimer mpiTimer("MPI force/virial reduction", mpiComm_); mpiComm_.sumReduce(n_atoms * 3, static_cast(forceTensor.data_ptr())); mpiComm_.sumReduce(9, static_cast(virialTensor.data_ptr())); } diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index 58dd985528..580bda87c7 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -106,10 +106,14 @@ class MetatomicForceProvider final : public IForceProvider //! lookup table to map model input indices [0...numInput) to local atom indices std::vector inputToLocalIndex_; + //! reverse map: local atom index -> model input index (sized to home+halo) std::vector localToModelIndex_; //! lookup table to map model input indices to global atom indices std::vector inputToGlobalIndex_; + //! Number of home atoms on this rank (from last DD redistribution) + int32_t numLocalAtoms_ = 0; + //! Full pairlist from MDModules notification std::vector fullPairlist_; From 444cefe4e411f113757955d56771905feac615a5 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 00:19:51 +0100 Subject: [PATCH 11/19] chore(mta): stupidest approach, scatter gather --- .../metatomic/metatomic_forceprovider.cpp | 198 ++++++++++-------- .../metatomic/metatomic_forceprovider.h | 3 + 2 files changed, 110 insertions(+), 91 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index f3549eda02..8f633f1260 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -47,8 +47,6 @@ #include #include -#include -#include #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" @@ -704,100 +702,128 @@ void MetatomicForceProvider::augmentGhostPairs(const ArrayRef x, con std::fprintf(stderr, "Rank %d: augmentGhostPairs added %d halo-halo pairs\n", mpiComm_.rank(), augmentedCount); } - -void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) +void MetatomicForceProvider::buildFullPairlist(const matrix box) { - MetatomicTimer totalTimer("calculateForces", mpiComm_); - - const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); + pairlistForModel_.clear(); + shiftVectors_.clear(); + cellShifts_.clear(); - // Update positions and box for the current step + const int32_t n_atoms = static_cast(positions_.size()); + if (n_atoms < 2) { - MetatomicTimer timer("gatherAtomPositions", mpiComm_); - gatherAtomPositions(inputs.x_); + return; } - copy_mat(inputs.box_, box_); - { - MetatomicTimer timer("preparePairlistInput", mpiComm_); - preparePairlistInput(); - } + t_pbc pbc; + set_pbc(&pbc, *options_.params_.pbcType_, box); - const int32_t signalPairs = static_cast(pairlistForModel_.size() / 2); + gmx::AnalysisNeighborhood nb; + nb.setCutoff(data_->nl_requests[0]->cutoff()); - { - MetatomicTimer timer("augmentGhostPairs", mpiComm_); - augmentGhostPairs(inputs.x_, inputs.box_); - } + gmx::AnalysisNeighborhoodPositions allPositions(as_rvec_array(positions_.data()), n_atoms); - const int32_t totalPairsBeforeDedup = static_cast(pairlistForModel_.size() / 2); + gmx::AnalysisNeighborhoodSearch search = nb.initSearch(&pbc, allPositions); + gmx::AnalysisNeighborhoodPairSearch pairSearch = search.startSelfPairSearch(); + gmx::AnalysisNeighborhoodPair pair; - // Deduplicate pairs: the signal may already include some halo-halo pairs - // that augmentGhostPairs also finds. Metatensor requires unique labels. + while (pairSearch.findNextPair(&pair)) { - MetatomicTimer timer("deduplicatePairs", mpiComm_); + const int32_t atomA = pair.refIndex(); + const int32_t atomB = pair.testIndex(); - using PairKey = std::tuple; - std::set seen; - std::vector dedupPairlist; - std::vector dedupShifts; - std::vector dedupCellShifts; + // pair.dx() = PBC-correct vector from ref to test + // buildNeighborListFromPairlist computes: r_ij = positions_[B] - positions_[A] + shift + // So: shift = pair.dx() - (positions_[B] - positions_[A]) + rvec modelDiff; + rvec_sub(positions_[atomB].as_vec(), positions_[atomA].as_vec(), modelDiff); - const int32_t nPairs = static_cast(pairlistForModel_.size() / 2); - dedupPairlist.reserve(pairlistForModel_.size()); - dedupShifts.reserve(nPairs); - dedupCellShifts.reserve(nPairs); + rvec shift; + rvec_sub(pair.dx(), modelDiff, shift); - for (int32_t i = 0; i < nPairs; i++) - { - int32_t a = pairlistForModel_[2 * i]; - int32_t b = pairlistForModel_[2 * i + 1]; - PairKey key(a, b, cellShifts_[i][0], cellShifts_[i][1], cellShifts_[i][2]); + // Compute integer cell shifts via box matrix inversion + double det = box[0][0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) + - box[0][1] * (box[1][0] * box[2][2] - box[1][2] * box[2][0]) + + box[0][2] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]); - if (seen.insert(key).second) - { - dedupPairlist.push_back(a); - dedupPairlist.push_back(b); - dedupShifts.push_back(shiftVectors_[i]); - dedupCellShifts.push_back(cellShifts_[i]); - } + IVec unitShift; + if (std::abs(det) > 1e-10) + { + double invDet = 1.0 / det; + rvec unitShiftRvec; + unitShiftRvec[0] = invDet + * (shift[0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) + + shift[1] * (box[0][2] * box[2][1] - box[0][1] * box[2][2]) + + shift[2] * (box[0][1] * box[1][2] - box[0][2] * box[1][1])); + unitShiftRvec[1] = invDet + * (shift[0] * (box[1][2] * box[2][0] - box[1][0] * box[2][2]) + + shift[1] * (box[0][0] * box[2][2] - box[0][2] * box[2][0]) + + shift[2] * (box[0][2] * box[1][0] - box[0][0] * box[1][2])); + unitShiftRvec[2] = invDet + * (shift[0] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]) + + shift[1] * (box[0][1] * box[2][0] - box[0][0] * box[2][1]) + + shift[2] * (box[0][0] * box[1][1] - box[0][1] * box[1][0])); + + unitShift[0] = static_cast(std::round(unitShiftRvec[0])); + unitShift[1] = static_cast(std::round(unitShiftRvec[1])); + unitShift[2] = static_cast(std::round(unitShiftRvec[2])); } - - const int32_t removed = nPairs - static_cast(dedupShifts.size()); - if (removed > 0) + else { - std::fprintf(stderr, "Rank %d: Removed %d duplicate pairs\n", mpiComm_.rank(), removed); + unitShift = { 0, 0, 0 }; } - pairlistForModel_ = std::move(dedupPairlist); - shiftVectors_ = std::move(dedupShifts); - cellShifts_ = std::move(dedupCellShifts); + // Recompute shift from integer cell shifts for exact consistency + RVec finalShift; + mvmul_ur0(box, unitShift.toRVec(), finalShift); + + pairlistForModel_.push_back(atomA); + pairlistForModel_.push_back(atomB); + shiftVectors_.push_back(finalShift); + cellShifts_.push_back(unitShift); + } +} + + +void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) +{ + MetatomicTimer totalTimer("calculateForces", mpiComm_); + + const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); + + // Update positions and box for the current step + { + MetatomicTimer timer("gatherAtomPositions", mpiComm_); + gatherAtomPositions(inputs.x_); + } + copy_mat(inputs.box_, box_); + + // Build the full neighbor list on the main rank from gathered positions. + // Every rank has all positions after gatherAtomPositions()/sumReduce, so rank 0 + // can find ALL pairs via AnalysisNeighborhoodSearch. This avoids the fundamental + // problem of per-rank signal pairs only covering that rank's local view. + { + MetatomicTimer timer("buildFullPairlist", mpiComm_); + buildFullPairlist(inputs.box_); } const int32_t totalPairs = static_cast(pairlistForModel_.size() / 2); std::fprintf(stderr, - "Rank %d Step %ld: %d signal + %d augmented - %d dupes = %d unique pairs, " - "%d model atoms, homenr=%d, x.size=%zu\n", + "Rank %d Step %ld: %d pairs, %d model atoms, homenr=%d, x.size=%zu\n", mpiComm_.rank(), inputs.step_, - signalPairs, - totalPairsBeforeDedup - signalPairs, - totalPairsBeforeDedup - totalPairs, totalPairs, n_atoms, inputs.homenr_, inputs.x_.size()); - // Force tensor - main rank fills this, others hold zeros until reduction - torch::Tensor forceTensor = torch::zeros( - { n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64).device(data_->device)); + // Every rank runs the model independently with the same gathered positions. + // This avoids MPI force/virial reduction entirely. + torch::Tensor forceTensor; + torch::Tensor virialTensor; + double energy = 0.0; - // Virial tensor for pressure/stress calculations - torch::Tensor virialTensor = torch::zeros({ 3, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); - - if (mpiComm_.isMainRank()) { - MetatomicTimer modelTimer("model inference (main rank)", mpiComm_); + MetatomicTimer modelTimer("model inference", mpiComm_); // Select appropriate precision for GROMACS data conversion auto gromacs_scalar_type = torch::kFloat32; @@ -861,8 +887,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); auto energy_tensor = energy_block->values(); - outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = - static_cast(energy_tensor.sum().item()); + energy = energy_tensor.sum().item(); // Reset gradients before backward torch_positions.mutable_grad() = torch::Tensor(); @@ -871,26 +896,15 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // Backward pass: Compute forces (-dE/dr) and virial (-dE/dStrain) energy_tensor.backward(-torch::ones_like(energy_tensor)); - auto grad = torch_positions.grad(); - forceTensor = grad.to(torch::kCPU).to(torch::kFloat64); - - // Get virial from strain gradient + forceTensor = torch_positions.grad().to(torch::kCPU).to(torch::kFloat64); virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); } - // Distribute results to all ranks if necessary (sumReduce broadcasts if ranks > 1) - if (mpiComm_.isParallel()) - { - MetatomicTimer mpiTimer("MPI force/virial reduction", mpiComm_); - mpiComm_.sumReduce(n_atoms * 3, static_cast(forceTensor.data_ptr())); - mpiComm_.sumReduce(9, static_cast(virialTensor.data_ptr())); - } - - // Accumulate forces into the GROMACS force output + // Accumulate forces into the GROMACS force output. + // Each rank applies forces only to its home atoms (no double-counting). auto forceAccessor = forceTensor.accessor(); for (int32_t i = 0; i < n_atoms; ++i) { - // Only apply force if this atom is local to this rank if (inputToLocalIndex_[i] != -1) { outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][0] += forceAccessor[i][0]; @@ -899,20 +913,22 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } } - // Apply virial contribution - // GROMACS uses a 3x3 virial tensor in forceWithVirial_ - // Copy the tensor data into a GROMACS matrix and use the public API - matrix virialMatrix; - auto virialAccessor = virialTensor.accessor(); - // TODO: technically this is DIM, not 3... - for (int32_t i = 0; i < 3; ++i) + // Energy and virial: only main rank contributes since GROMACS sums across ranks. + if (mpiComm_.isMainRank() || !mpiComm_.isParallel()) { - for (int32_t j = 0; j < 3; ++j) + outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = static_cast(energy); + + matrix virialMatrix; + auto virialAccessor = virialTensor.accessor(); + for (int32_t i = 0; i < 3; ++i) { - virialMatrix[i][j] = virialAccessor[i][j]; + for (int32_t j = 0; j < 3; ++j) + { + virialMatrix[i][j] = virialAccessor[i][j]; + } } + outputs->forceWithVirial_.addVirialContribution(virialMatrix); } - outputs->forceWithVirial_.addVirialContribution(virialMatrix); } } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index 580bda87c7..0dd9b430cc 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -94,6 +94,9 @@ class MetatomicForceProvider final : public IForceProvider //! Prepare pairlist input for model void preparePairlistInput(); + //! Build full neighbor list on main rank from gathered positions + void buildFullPairlist(const matrix box); + const MetatomicOptions& options_; const MDLogger& logger_; const MpiComm& mpiComm_; From b8289f2d05bf2fb0f09722dec8d2ca40985c5c13 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 08:17:05 +0100 Subject: [PATCH 12/19] feat(mta): fixup for correct parallel impl --- .../metatomic/metatomic_forceprovider.cpp | 802 +++++++----------- .../metatomic/metatomic_forceprovider.h | 66 +- .../metatomic_forceprovider_stub.cpp | 14 +- .../metatomic/metatomic_mdmodule.cpp | 11 +- 4 files changed, 363 insertions(+), 530 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 8f633f1260..9887953394 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -33,7 +33,14 @@ */ /*! \internal \file * \brief - * Implements the Metatomic Force Provider class with proper domain decomposition support. + * Implements the Metatomic Force Provider class with per-rank model evaluation. + * + * Each rank evaluates the model on its local (home + halo) MTA atoms, + * using the GROMACS pairlist (excludedPairlist) as the neighbor list. + * Each pair is assigned to exactly one rank by GROMACS, so summing all + * per-atom energies (home + halo) gives the correct pair contribution + * without double counting. Forces from backward() are combined via MPI + * all-reduce on a global force buffer. * * \author Metatensor developers * \ingroup module_applied_forces @@ -42,11 +49,12 @@ #include "metatomic_forceprovider.h" -#include #include +#include #include -#include +#include +#include #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" @@ -55,7 +63,6 @@ #include "gromacs/mdtypes/forceoutput.h" #include "gromacs/pbcutil/ishift.h" #include "gromacs/pbcutil/pbc.h" -#include "gromacs/selection/nbsearch.h" #include "gromacs/utility/arrayref.h" #include "gromacs/utility/exceptions.h" #include "gromacs/utility/logger.h" @@ -79,11 +86,7 @@ namespace gmx { -/*! \brief Normalizes the variant string for Metatomic output selection. - * - * \param[in] variant_string The raw variant string from options. - * \return A torch::optional containing the string if valid, or nullopt if empty/"no". - */ +/*! \brief Normalizes the variant string for Metatomic output selection. */ static torch::optional normalize_variant(std::string variant_string) { if (variant_string == "no" || variant_string.empty()) @@ -96,12 +99,7 @@ static torch::optional normalize_variant(std::string variant_string } } -/*! \brief Converts GROMACS PbcType to a boolean tensor for Metatomic. - * - * \param[in] pbcType The GROMACS periodic boundary condition type. - * \param[in] device The torch device where the tensor should reside. - * \return A boolean tensor of shape {3} indicating periodicity in X, Y, Z. - */ +/*! \brief Converts GROMACS PbcType to a boolean tensor for Metatomic. */ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) { auto options = torch::TensorOptions().dtype(torch::kBool).device(device); @@ -121,20 +119,7 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) return torch::tensor({ true, true, true }, options); } -/*! \brief Constructs a Metatensor TensorBlock representing the neighbor list. - * - * This function takes the filtered pairlist (atoms participating in the model interaction) - * and constructs the corresponding neighbor list in the format required by Metatensor/Torch. - * It computes the interatomic vectors, applying periodic boundary shifts where necessary. - * - * \param[in] pairlist Flat array of atom pairs (indices into the model's atom list). - * \param[in] shiftVectors Geometric shift vectors (RVec) for each pair. - * \param[in] cellShifts Integer cell shift indices for each pair (for metadata). - * \param[in] positions Positions of the atoms (ordered by model index). - * \param[in] device The torch device for the output tensors. - * \param[in] dtype The torch scalar type (float32/float64). - * \return A TensorBlockHolder containing the neighbor list data. - */ +/*! \brief Constructs a Metatensor TensorBlock representing the neighbor list. */ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, ArrayRef shiftVectors, ArrayRef cellShifts, @@ -144,15 +129,12 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(pairlist.size() / 2); - // Prepare CPU tensors first to facilitate efficient element access auto cpu_int_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); auto cpu_float_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); - // Samples: [first_atom, second_atom, cell_shift_a, cell_shift_b, cell_shift_c] auto pair_samples_values = torch::zeros({ n_pairs, 5 }, cpu_int_options); auto pair_samples_ptr = pair_samples_values.accessor(); - // Values: Full interatomic vectors (rj - ri + shift) auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, cpu_float_options); auto vectors_accessor = vectors_cpu.accessor(); @@ -161,14 +143,12 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(atom_i); - pair_samples_ptr[i][1] = static_cast(atom_j); + pair_samples_ptr[i][0] = atom_i; + pair_samples_ptr[i][1] = atom_j; pair_samples_ptr[i][2] = cellShifts[i][0]; pair_samples_ptr[i][3] = cellShifts[i][1]; pair_samples_ptr[i][4] = cellShifts[i][2]; - // Calculate r_ij = r_j - r_i + shift const double r_ij_x = static_cast(positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]); const double r_ij_y = @@ -181,7 +161,6 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef extensions_directory = torch::nullopt; @@ -273,7 +233,6 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, data_->capabilities = data_->model.run_method("capabilities").toCustomClass(); - // Determine computation device torch::optional desiredDevice = torch::nullopt; if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) { @@ -288,17 +247,43 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, .asParagraph() .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); - // Process neighbor list requests from the model + { + double interactionRange = data_->capabilities->engine_interaction_range("nm"); + std::string fname = "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "w"); + if (fp) + { + std::fprintf(fp, + "=== Metatomic init (rank %d) ===\n" + "interaction_range(nm)=%.6f\n", + mpiComm_.rank(), + interactionRange); + std::fclose(fp); + } + } + auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); for (const auto& request_ivalue : requests_ivalue.toList()) { - data_->nl_requests.push_back( - request_ivalue.get().toCustomClass()); + auto nl_opt = request_ivalue.get().toCustomClass(); + { + std::string fname = "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) + { + std::fprintf(fp, + "NL request: cutoff()=%.6f, engine_cutoff(nm)=%.6f, full_list=%s\n", + nl_opt->cutoff(), + nl_opt->engine_cutoff("nm"), + nl_opt->full_list() ? "true" : "false"); + std::fclose(fp); + } + } + data_->nl_requests.push_back(nl_opt); } data_->model.to(data_->device); - // Configure precision if (data_->capabilities->dtype() == "float64") { data_->dtype = torch::kFloat64; @@ -315,7 +300,6 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, data_->evaluations_options = torch::make_intrusive(); data_->evaluations_options->set_length_unit("nm"); - // Validate energy output existence auto outputs = data_->capabilities->outputs(); auto v_energy = normalize_variant(options_.params_.variant); auto energy_key = pick_output("energy", outputs, v_energy); @@ -330,23 +314,19 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, } auto requested_output = torch::make_intrusive(); - // TODO: take from the user - requested_output->per_atom = false; + // per_atom=true so the model returns per-atom energies (needed for + // correct energy decomposition when using the GROMACS pairlist in DD) + requested_output->per_atom = true; requested_output->explicit_gradients = {}; requested_output->set_unit("kJ/mol"); data_->evaluations_options->outputs.insert(energy_key, requested_output); data_->check_consistency = options_.params_.checkConsistency; - - // Initialize vectors for atom mapping + // Allocate global force buffer sized to total MTA atoms const auto& mtaIndices = options_.params_.mtaIndices_; - const int32_t n_atoms = static_cast(mtaIndices.size()); - - positions_.resize(n_atoms); - atomNumbers_.resize(n_atoms, 0); - inputToLocalIndex_.resize(n_atoms, -1); - inputToGlobalIndex_.resize(n_atoms, -1); + const int32_t n_total = static_cast(mtaIndices.size()); + globalForceBuffer_.resize(n_total, RVec({ 0.0, 0.0, 0.0 })); GMX_LOG(logger_.info) .asParagraph() @@ -355,33 +335,16 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, MetatomicForceProvider::~MetatomicForceProvider() = default; -/*! \brief Updates the mapping between GROMACS local/global atom indices and the Metatomic model's input atoms. - * - * This function is subscribed to the `MDModulesAtomsRedistributedSignal`. It is called whenever - * atoms are redistributed across MPI ranks (Domain Decomposition) or reordered in memory (sorting). - * - * Its primary responsibilities are: - * 1. **Locate Input Atoms**: It iterates through the local atoms on the current rank to find - * which atoms correspond to the "input atoms" defined for the Metatomic model (via `mtaIndices`). - * 2. **Update Maps**: It populates `inputToLocalIndex_` (mapping model input index -> GROMACS local index) - * and `inputToGlobalIndex_` (mapping model input index -> GROMACS global tag). - * 3. **Gather Atomic Numbers**: It ensures `atomNumbers_` (Z numbers) are correctly associated with - * the current local atoms, performing an MPI reduction if necessary to gather data from - * distributed ranks. - * - * \param[in] signal Contains the new mapping of global atom indices to local buffer indices after redistribution. - */ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal) { - const auto& mtaIndices = options_.params_.mtaIndices_; - const int32_t numInput = static_cast(mtaIndices.size()); + const auto& mtaIndices = options_.params_.mtaIndices_; + const int32_t numTotalMta = static_cast(mtaIndices.size()); - // Reset mappings - inputToLocalIndex_.assign(numInput, -1); - inputToGlobalIndex_.assign(numInput, -1); - atomNumbers_.assign(numInput, 0); + mtaToGmxLocal_.clear(); + mtaToGlobalMta_.clear(); + atomNumbers_.clear(); + gmxLocalToMtaIdx_.clear(); - // GROMACS domain decomposition logic if (mpiComm_.isParallel()) { GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), @@ -390,396 +353,184 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist const int32_t numLocal = signal.x_.size(); const int32_t numLocalPlusHalo = globalAtomIndices.size(); - // Size to include both home and halo atoms - localToModelIndex_.assign(numLocalPlusHalo, -1); - numLocalAtoms_ = numLocal; + // Build a map from global atom index to MTA index for fast lookup + std::unordered_map globalToMtaIdx; + for (int32_t j = 0; j < numTotalMta; j++) + { + globalToMtaIdx[static_cast(mtaIndices[j])] = j; + } + + // Separate home and halo MTA atoms, deduplicating periodic ghosts. + // Each unique MTA atom gets one model index. Periodic ghost images + // are NOT added as separate model atoms, but their GROMACS local + // buffer indices ARE recorded in gmxLocalToMtaIdx_ so that + // setPairlist can resolve pairlist entries referencing any image. + std::vector homeGmxLocal; + std::vector homeGlobalMta; + std::vector haloGmxLocal; + std::vector haloGlobalMta; + + // First pass: assign model indices to unique MTA atoms + std::unordered_map mtaIdxToModelIdx; + int32_t numDuplicatesSkipped = 0; for (int32_t i = 0; i < numLocalPlusHalo; i++) { int32_t globalIdx = globalAtomIndices[i]; - for (int32_t j = 0; j < numInput; j++) + auto it = globalToMtaIdx.find(globalIdx); + if (it != globalToMtaIdx.end()) { - // Match current local atom to one of the requested Metatomic input atoms - if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) + int32_t mtaIdx = it->second; + + if (mtaIdxToModelIdx.count(mtaIdx)) + { + // Periodic ghost: record mapping but don't create new model atom + numDuplicatesSkipped++; + } + else { - localToModelIndex_[i] = j; - std::fprintf(stderr, - "Rank %d: Found ModelAtom %d (Global %d) at Local %d (%s)\n", - mpiComm_.rank(), - j, - globalIdx, - i, - (i < numLocal) ? "HOME" : "HALO"); if (i < numLocal) { - inputToLocalIndex_[j] = i; - inputToGlobalIndex_[j] = globalIdx; - atomNumbers_[j] = options_.params_.atoms_.atom[globalIdx].atomnumber; + // Will be assigned model index = homeGmxLocal.size() (filled later) + homeGmxLocal.push_back(i); + homeGlobalMta.push_back(mtaIdx); + } + else + { + haloGmxLocal.push_back(i); + haloGlobalMta.push_back(mtaIdx); } - break; + // Placeholder: model index will be set after we know numHomeMta_ + mtaIdxToModelIdx[mtaIdx] = -1; } } } - // Reduce atomic numbers across ranks to ensure the main rank has the full set - mpiComm_.sumReduce(numInput, atomNumbers_.data()); - // Debug logging for domain decomposition distribution - int32_t homeCount = 0; - int32_t haloCount = 0; - for (int32_t i = 0; i < numLocalPlusHalo; i++) + // Assign final model indices: home [0, numHome), halo [numHome, numLocal) + int32_t modelIdx = 0; + for (int32_t k = 0; k < static_cast(homeGmxLocal.size()); k++) { - if (localToModelIndex_[i] != -1) - { - if (i < numLocal) - { - homeCount++; - } - else - { - haloCount++; - } - } + mtaIdxToModelIdx[homeGlobalMta[k]] = modelIdx++; } - std::fprintf(stderr, - "Rank %d: Mapped %d HOME + %d HALO = %d / %d Metatomic atoms.\n", - mpiComm_.rank(), - homeCount, - haloCount, - homeCount + haloCount, - numInput); - } - else - { - // Thread-MPI or Serial execution - const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); - localToModelIndex_.clear(); - numLocalAtoms_ = 0; - for (int32_t i = 0; i < numInput; i++) + for (int32_t k = 0; k < static_cast(haloGmxLocal.size()); k++) { - int32_t localIndex = mtaAtoms->localIndex()[i]; - int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; - inputToLocalIndex_[i] = localIndex; - inputToGlobalIndex_[i] = globalIdx; - atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; + mtaIdxToModelIdx[haloGlobalMta[k]] = modelIdx++; } - } - - GMX_RELEASE_ASSERT(std::count(atomNumbers_.begin(), atomNumbers_.end(), 0) == 0, - "Some atom numbers not set."); -} - -/*! \brief Updates the internal position buffer with the coordinates of the Metatomic atoms. - * - * This function extracts the coordinates of the atoms relevant to the Metatomic model - * from the full GROMACS local atom array (`pos`). - * - * 1. **Filtering:** `pos` contains all local atoms (solute, solvent, ions). - * This function copies only the atoms defined in `mtaIndices` to `positions_`. - * 2. **Ordering/Packing:** GROMACS reorders atoms dynamically for Domain Decomposition. - * This function uses `inputToLocalIndex_` to collect atoms and pack them into a contiguous, ordered buffer - * - * \param[in] pos The array of all local atom coordinates for the current step. - */ -void MetatomicForceProvider::gatherAtomPositions(ArrayRef pos) -{ - const size_t numInput = inputToLocalIndex_.size(); - positions_.assign(numInput, RVec({ 0.0, 0.0, 0.0 })); - for (size_t i = 0; i < numInput; i++) - { - if (inputToLocalIndex_[i] != -1) + // Second pass: build complete gmxLocal → modelIdx mapping for ALL images + for (int32_t i = 0; i < numLocalPlusHalo; i++) { - positions_[i] = pos[inputToLocalIndex_[i]]; + int32_t globalIdx = globalAtomIndices[i]; + auto it = globalToMtaIdx.find(globalIdx); + if (it != globalToMtaIdx.end()) + { + gmxLocalToMtaIdx_[i] = mtaIdxToModelIdx[it->second]; + } } - } - - if (mpiComm_.isParallel()) - { - mpiComm_.sumReduce(3 * numInput, positions_.data()->as_vec()); - } -} - -void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& signal) -{ - // Capture the pairlist signal. Processing is deferred to - // calculateForces/preparePairlistInput to keep this callback fast. - fullPairlist_.assign(signal.excludedPairlist_.begin(), signal.excludedPairlist_.end()); - doPairlist_ = true; -} - -/*! \brief Converts the GROMACS neighbor list to a model-compatible list. - * - * This function iterates over the full GROMACS excluded pairlist (which contains pairs in - * GROMACS local atom indices). It filters this list to retain only pairs where *both* atoms - * are part of the Metatomic model's input set. - * - * It populates `pairlistForModel_` (using model-relative indices), `shiftVectors_`, - * and `cellShifts_`. - */ -void MetatomicForceProvider::preparePairlistInput() -{ - if (!doPairlist_) - { - return; - } - // Although the assert catches empty pairlists, in a real simulation with a very large cutoff, - // this might happen legitimately if only 1 atom exists. However, for standard MD, it indicates - // an issue. Assert here to catch initialization ordering bugs. - GMX_ASSERT(!fullPairlist_.empty(), "Pairlist for Metatomic is empty!"); - - const int32_t numPairs = gmx::ssize(fullPairlist_); - pairlistForModel_.clear(); - pairlistForModel_.reserve(2 * numPairs); - shiftVectors_.clear(); - shiftVectors_.reserve(numPairs); - cellShifts_.clear(); - cellShifts_.reserve(numPairs); - - for (int32_t i = 0; i < numPairs; i++) - { - const auto [atomPair, shiftIndex] = fullPairlist_[i]; - - // GROMACS pairlists use local atom indices. - // Map these local indices back to the model's input indices [0, N_model_atoms). - // `inputToLocalIndex_` maps ModelIdx -> LocalIdx. - // indexOf reverses the map: Find ModelIdx k such that inputToLocalIndex_[k] == LocalIdx. - const int32_t inputIdxA = localToModelIndex_[atomPair.first]; - - if (inputIdxA != -1) { - const int32_t inputIdxB = localToModelIndex_[atomPair.second]; - - if (inputIdxB != -1) + std::string fname = "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) { - // Both atoms belong to the Metatomic subsystem. - // Calculate the shift vector due to PBC. - RVec shift; - const IVec unitShift = shiftIndexToXYZ(shiftIndex); - mvmul_ur0(box_, unitShift.toRVec(), shift); - - pairlistForModel_.push_back(inputIdxA); - pairlistForModel_.push_back(inputIdxB); - std::fprintf(stderr, - "Rank %d: Signal pair (Local %d, %d) -> Model (%d, %d)\n", - mpiComm_.rank(), - atomPair.first, - atomPair.second, - inputIdxA, - inputIdxB); - shiftVectors_.push_back(shift); - cellShifts_.push_back(unitShift); + std::fprintf(fp, + "gatherAtoms: home=%zu halo=%zu duplicatesSkipped=%d gmxLocalEntries=%zu\n", + homeGmxLocal.size(), + haloGmxLocal.size(), + numDuplicatesSkipped, + gmxLocalToMtaIdx_.size()); + std::fclose(fp); } } - } - - GMX_RELEASE_ASSERT(pairlistForModel_.size() == shiftVectors_.size() * 2, - "Pairlist/shift size mismatch."); - doPairlist_ = false; -} -void MetatomicForceProvider::augmentGhostPairs(const ArrayRef x, const matrix box) -{ - if (!mpiComm_.isParallel()) - { - return; - } + numHomeMta_ = static_cast(homeGmxLocal.size()); + numLocalMta_ = numHomeMta_ + static_cast(haloGmxLocal.size()); - // Identify halo MTA atoms: atoms in localToModelIndex_ that have a valid model index - // but are NOT home atoms on this rank (i.e. inputToLocalIndex_[modelIdx] == -1). - // These are atoms in the halo zone (local index >= numLocalAtoms_). - std::vector haloLocalIndices; - std::vector haloCoords; + // Assign local model indices: home -> [0, numHomeMta_), halo -> [numHomeMta_, numLocalMta_) + mtaToGmxLocal_.resize(numLocalMta_); + mtaToGlobalMta_.resize(numLocalMta_); + atomNumbers_.resize(numLocalMta_); - for (int32_t i = numLocalAtoms_; i < static_cast(localToModelIndex_.size()); i++) - { - if (localToModelIndex_[i] != -1) + for (int32_t k = 0; k < numHomeMta_; k++) { - haloLocalIndices.push_back(i); - haloCoords.push_back(x[i]); + int32_t gmxLocal = homeGmxLocal[k]; + int32_t globalIdx = globalAtomIndices[gmxLocal]; + + mtaToGmxLocal_[k] = gmxLocal; + mtaToGlobalMta_[k] = homeGlobalMta[k]; + atomNumbers_[k] = options_.params_.atoms_.atom[globalIdx].atomnumber; } - } - std::fprintf(stderr, - "Rank %d: augmentGhostPairs found %zu halo MTA atoms\n", - mpiComm_.rank(), - haloCoords.size()); + for (int32_t k = 0; k < static_cast(haloGmxLocal.size()); k++) + { + int32_t modelIdx = numHomeMta_ + k; + int32_t gmxLocal = haloGmxLocal[k]; + int32_t globalIdx = globalAtomIndices[gmxLocal]; - if (haloCoords.size() < 2) - { - return; + mtaToGmxLocal_[modelIdx] = gmxLocal; + mtaToGlobalMta_[modelIdx] = haloGlobalMta[k]; + atomNumbers_[modelIdx] = options_.params_.atoms_.atom[globalIdx].atomnumber; + } } - - t_pbc pbc; - set_pbc(&pbc, *options_.params_.pbcType_, box); - - gmx::AnalysisNeighborhood nb; - nb.setCutoff(data_->nl_requests[0]->cutoff()); - - gmx::AnalysisNeighborhoodPositions ghostPositions(as_rvec_array(haloCoords.data()), haloCoords.size()); - - gmx::AnalysisNeighborhoodSearch search = nb.initSearch(&pbc, ghostPositions); - gmx::AnalysisNeighborhoodPairSearch ghostSearch = search.startSelfPairSearch(); - gmx::AnalysisNeighborhoodPair pair; - - int32_t augmentedCount = 0; - while (ghostSearch.findNextPair(&pair)) + else { - const int32_t localIdxA = haloLocalIndices[pair.refIndex()]; - const int32_t localIdxB = haloLocalIndices[pair.testIndex()]; + // Serial / thread-MPI: all MTA atoms are home, no halos + const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); + numHomeMta_ = numTotalMta; + numLocalMta_ = numTotalMta; - const int32_t inputIdxA = localToModelIndex_[localIdxA]; - const int32_t inputIdxB = localToModelIndex_[localIdxB]; + mtaToGmxLocal_.resize(numTotalMta); + mtaToGlobalMta_.resize(numTotalMta); + atomNumbers_.resize(numTotalMta); - // Both should be valid since we pre-filtered, but guard anyway - if (inputIdxA != -1 && inputIdxB != -1) + for (int32_t i = 0; i < numTotalMta; i++) { - // pair.dx() returns the PBC-correct vector from ref to test. - // buildNeighborListFromPairlist computes: r_ij = positions_[B] - positions_[A] + shift - // We need: shift = pair.dx() - (positions_[B] - positions_[A]) - rvec modelDiff; - rvec_sub(positions_[inputIdxB].as_vec(), positions_[inputIdxA].as_vec(), modelDiff); - - rvec shift; - rvec_sub(pair.dx(), modelDiff, shift); - - // Compute integer cell shifts via box matrix inversion - double det = box[0][0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) - - box[0][1] * (box[1][0] * box[2][2] - box[1][2] * box[2][0]) - + box[0][2] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]); - - IVec unitShift; - if (std::abs(det) > 1e-10) - { - double invDet = 1.0 / det; - rvec unitShiftRvec; - unitShiftRvec[0] = invDet - * (shift[0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) - + shift[1] * (box[0][2] * box[2][1] - box[0][1] * box[2][2]) - + shift[2] * (box[0][1] * box[1][2] - box[0][2] * box[1][1])); - unitShiftRvec[1] = invDet - * (shift[0] * (box[1][2] * box[2][0] - box[1][0] * box[2][2]) - + shift[1] * (box[0][0] * box[2][2] - box[0][2] * box[2][0]) - + shift[2] * (box[0][2] * box[1][0] - box[0][0] * box[1][2])); - unitShiftRvec[2] = invDet - * (shift[0] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]) - + shift[1] * (box[0][1] * box[2][0] - box[0][0] * box[2][1]) - + shift[2] * (box[0][0] * box[1][1] - box[0][1] * box[1][0])); - - unitShift[0] = static_cast(std::round(unitShiftRvec[0])); - unitShift[1] = static_cast(std::round(unitShiftRvec[1])); - unitShift[2] = static_cast(std::round(unitShiftRvec[2])); - } - else - { - unitShift = { 0, 0, 0 }; - } - - // Recompute shift from integer cell shifts for consistency with preparePairlistInput - RVec finalShift; - mvmul_ur0(box, unitShift.toRVec(), finalShift); - - pairlistForModel_.push_back(inputIdxA); - pairlistForModel_.push_back(inputIdxB); - shiftVectors_.push_back(finalShift); - cellShifts_.push_back(unitShift); - augmentedCount++; + int32_t localIndex = mtaAtoms->localIndex()[i]; + int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; - std::fprintf(stderr, - "[Augmented] Rank %d: Halo pair (Local %d, %d) -> Model (%d, %d) " - "shift=(%d,%d,%d)\n", - mpiComm_.rank(), - localIdxA, - localIdxB, - inputIdxA, - inputIdxB, - unitShift[0], - unitShift[1], - unitShift[2]); + mtaToGmxLocal_[i] = localIndex; + mtaToGlobalMta_[i] = i; + atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; + gmxLocalToMtaIdx_[localIndex] = i; } } - std::fprintf(stderr, "Rank %d: augmentGhostPairs added %d halo-halo pairs\n", mpiComm_.rank(), augmentedCount); + GMX_RELEASE_ASSERT(std::count(atomNumbers_.begin(), atomNumbers_.end(), 0) == 0, + "Some atom numbers not set."); } -void MetatomicForceProvider::buildFullPairlist(const matrix box) +void MetatomicForceProvider::gatherAtomPositions(ArrayRef pos) { - pairlistForModel_.clear(); - shiftVectors_.clear(); - cellShifts_.clear(); - - const int32_t n_atoms = static_cast(positions_.size()); - if (n_atoms < 2) + positions_.resize(numLocalMta_); + for (int32_t i = 0; i < numLocalMta_; i++) { - return; + positions_[i] = pos[mtaToGmxLocal_[i]]; } +} - t_pbc pbc; - set_pbc(&pbc, *options_.params_.pbcType_, box); - - gmx::AnalysisNeighborhood nb; - nb.setCutoff(data_->nl_requests[0]->cutoff()); - - gmx::AnalysisNeighborhoodPositions allPositions(as_rvec_array(positions_.data()), n_atoms); - - gmx::AnalysisNeighborhoodSearch search = nb.initSearch(&pbc, allPositions); - gmx::AnalysisNeighborhoodPairSearch pairSearch = search.startSelfPairSearch(); - gmx::AnalysisNeighborhoodPair pair; - - while (pairSearch.findNextPair(&pair)) - { - const int32_t atomA = pair.refIndex(); - const int32_t atomB = pair.testIndex(); - - // pair.dx() = PBC-correct vector from ref to test - // buildNeighborListFromPairlist computes: r_ij = positions_[B] - positions_[A] + shift - // So: shift = pair.dx() - (positions_[B] - positions_[A]) - rvec modelDiff; - rvec_sub(positions_[atomB].as_vec(), positions_[atomA].as_vec(), modelDiff); - - rvec shift; - rvec_sub(pair.dx(), modelDiff, shift); - - // Compute integer cell shifts via box matrix inversion - double det = box[0][0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) - - box[0][1] * (box[1][0] * box[2][2] - box[1][2] * box[2][0]) - + box[0][2] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]); - - IVec unitShift; - if (std::abs(det) > 1e-10) - { - double invDet = 1.0 / det; - rvec unitShiftRvec; - unitShiftRvec[0] = invDet - * (shift[0] * (box[1][1] * box[2][2] - box[1][2] * box[2][1]) - + shift[1] * (box[0][2] * box[2][1] - box[0][1] * box[2][2]) - + shift[2] * (box[0][1] * box[1][2] - box[0][2] * box[1][1])); - unitShiftRvec[1] = invDet - * (shift[0] * (box[1][2] * box[2][0] - box[1][0] * box[2][2]) - + shift[1] * (box[0][0] * box[2][2] - box[0][2] * box[2][0]) - + shift[2] * (box[0][2] * box[1][0] - box[0][0] * box[1][2])); - unitShiftRvec[2] = invDet - * (shift[0] * (box[1][0] * box[2][1] - box[1][1] * box[2][0]) - + shift[1] * (box[0][1] * box[2][0] - box[0][0] * box[2][1]) - + shift[2] * (box[0][0] * box[1][1] - box[0][1] * box[1][0])); - - unitShift[0] = static_cast(std::round(unitShiftRvec[0])); - unitShift[1] = static_cast(std::round(unitShiftRvec[1])); - unitShift[2] = static_cast(std::round(unitShiftRvec[2])); - } - else +void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& signal) +{ + pairlistMta_.clear(); + cellShiftsMta_.clear(); + + // Use gmxLocalToMtaIdx_ which maps ALL GROMACS local buffer indices + // (including periodic ghost images) to their MTA model index. + // + // Sign convention: GROMACS shifts atom I (first): d = x[I]+shift - x[J]. + // Metatensor shifts atom J (second): r_ij = x[J]+cell·box - x[I]. + // So metatensor cell shift = -GROMACS cell shift. + for (const auto& entry : signal.excludedPairlist_) + { + const auto& [atomPair, shiftIndex] = entry; + auto itA = gmxLocalToMtaIdx_.find(atomPair.first); + auto itB = gmxLocalToMtaIdx_.find(atomPair.second); + if (itA != gmxLocalToMtaIdx_.end() && itB != gmxLocalToMtaIdx_.end()) { - unitShift = { 0, 0, 0 }; + pairlistMta_.push_back(itA->second); + pairlistMta_.push_back(itB->second); + const IVec gmxShift = shiftIndexToXYZ(shiftIndex); + cellShiftsMta_.push_back(IVec(-gmxShift[XX], -gmxShift[YY], -gmxShift[ZZ])); } - - // Recompute shift from integer cell shifts for exact consistency - RVec finalShift; - mvmul_ur0(box, unitShift.toRVec(), finalShift); - - pairlistForModel_.push_back(atomA); - pairlistForModel_.push_back(atomB); - shiftVectors_.push_back(finalShift); - cellShifts_.push_back(unitShift); } } @@ -788,36 +539,29 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F { MetatomicTimer totalTimer("calculateForces", mpiComm_); - const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); + const int32_t numTotalMta = static_cast(options_.params_.mtaIndices_.size()); - // Update positions and box for the current step + // Fill local positions (no MPI communication) { MetatomicTimer timer("gatherAtomPositions", mpiComm_); gatherAtomPositions(inputs.x_); } copy_mat(inputs.box_, box_); - // Build the full neighbor list on the main rank from gathered positions. - // Every rank has all positions after gatherAtomPositions()/sumReduce, so rank 0 - // can find ALL pairs via AnalysisNeighborhoodSearch. This avoids the fundamental - // problem of per-rank signal pairs only covering that rank's local view. + // Compute shift vectors from stored cell shifts and current box + std::vector shiftVectors; { - MetatomicTimer timer("buildFullPairlist", mpiComm_); - buildFullPairlist(inputs.box_); + MetatomicTimer timer("prepareNL", mpiComm_); + shiftVectors.reserve(cellShiftsMta_.size()); + for (const auto& cs : cellShiftsMta_) + { + RVec shift; + mvmul_ur0(inputs.box_, cs.toRVec(), shift); + shiftVectors.push_back(shift); + } } - const int32_t totalPairs = static_cast(pairlistForModel_.size() / 2); - std::fprintf(stderr, - "Rank %d Step %ld: %d pairs, %d model atoms, homenr=%d, x.size=%zu\n", - mpiComm_.rank(), - inputs.step_, - totalPairs, - n_atoms, - inputs.homenr_, - inputs.x_.size()); - - // Every rank runs the model independently with the same gathered positions. - // This avoids MPI force/virial reduction entirely. + // Model inference torch::Tensor forceTensor; torch::Tensor virialTensor; double energy = 0.0; @@ -825,7 +569,6 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F { MetatomicTimer modelTimer("model inference", mpiComm_); - // Select appropriate precision for GROMACS data conversion auto gromacs_scalar_type = torch::kFloat32; if (std::is_same_v) { @@ -841,13 +584,10 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto torch_cell = torch::from_blob(&box_, { 3, 3 }, cpu_blob_options).to(data_->dtype).to(data_->device); - // Create strain tensor (identity matrix) for virial computation via autodiff auto strain = torch::eye( 3, torch::TensorOptions().dtype(data_->dtype).device(data_->device).requires_grad(true)); - // Apply strain to cell: strained_cell = cell @ strain - auto strained_cell = torch::matmul(torch_cell, strain); - // Apply strain to positions: r' = r @ strain + auto strained_cell = torch::matmul(torch_cell, strain); auto strained_positions = torch::matmul(torch_positions, strain); auto torch_pbc = preparePbcType(options_.params_.pbcType_.get(), data_->device); @@ -857,22 +597,64 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto system = torch::make_intrusive( torch_types, strained_positions, strained_cell, torch_pbc); - // Build neighbor list from GROMACS pairlist + // Build neighbor list from GROMACS pairlist (each pair on exactly one rank) for (const auto& request : data_->nl_requests) { + std::vector finalPairlist; + std::vector finalShiftVectors; + std::vector finalCellShifts; + + if (request->full_list()) + { + // Full list: add both (i,j) and (j,i) for each pair + const int64_t nHalf = static_cast(pairlistMta_.size() / 2); + finalPairlist.reserve(4 * nHalf); + finalShiftVectors.reserve(2 * nHalf); + finalCellShifts.reserve(2 * nHalf); + + for (int64_t k = 0; k < nHalf; k++) + { + int32_t atomA = pairlistMta_[2 * k]; + int32_t atomB = pairlistMta_[2 * k + 1]; + + finalPairlist.push_back(atomA); + finalPairlist.push_back(atomB); + finalShiftVectors.push_back(shiftVectors[k]); + finalCellShifts.push_back(cellShiftsMta_[k]); + + finalPairlist.push_back(atomB); + finalPairlist.push_back(atomA); + finalShiftVectors.push_back( + RVec(-shiftVectors[k][XX], -shiftVectors[k][YY], -shiftVectors[k][ZZ])); + finalCellShifts.push_back( + IVec(-cellShiftsMta_[k][XX], -cellShiftsMta_[k][YY], -cellShiftsMta_[k][ZZ])); + } + } + else + { + // Half list: use pairlist as-is + finalPairlist = pairlistMta_; + finalShiftVectors = shiftVectors; + finalCellShifts = cellShiftsMta_; + } + auto neighbors = buildNeighborListFromPairlist( - pairlistForModel_, shiftVectors_, cellShifts_, positions_, data_->device, data_->dtype); + finalPairlist, finalShiftVectors, finalCellShifts, positions_, data_->device, data_->dtype); metatomic_torch::register_autograd_neighbors(system, neighbors, data_->check_consistency); system->add_neighbor_list(request, neighbors); } + // No selected_atoms: each pair is on exactly one rank, so summing + // all per-atom energies (home + halo) gives the correct pair energy. + // GROMACS sums across ranks via global_stat. + data_->evaluations_options->set_selected_atoms(torch::nullopt); + metatensor_torch::TensorMap output_map; try { std::vector systems; systems.push_back(system); - // Forward pass auto ivalue_output = data_->model.forward( { systems, data_->evaluations_options, data_->check_consistency }); auto dict_output = ivalue_output.toGenericDict(); @@ -883,52 +665,98 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F GMX_THROW(APIError("[Metatomic] Model evaluation failed: " + std::string(e.what()))); } - // Extract Energy auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); auto energy_tensor = energy_block->values(); energy = energy_tensor.sum().item(); - // Reset gradients before backward + // Diagnostic: log pairlist size, per-rank energy and MPI sum + { + double mpiSumEnergy = energy; + if (mpiComm_.isParallel()) + { + mpiComm_.sumReduce(1, &mpiSumEnergy); + } + std::string fname = + "metatomic_debug_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) + { + std::fprintf(fp, + "pairlistPairs=%zu, numLocalMta=%d, numHomeMta=%d, " + "energy: perRank=%.6f, mpiSum=%.6f\n", + pairlistMta_.size() / 2, + numLocalMta_, + numHomeMta_, + energy, + mpiSumEnergy); + std::fclose(fp); + } + } + torch_positions.mutable_grad() = torch::Tensor(); strain.mutable_grad() = torch::Tensor(); - // Backward pass: Compute forces (-dE/dr) and virial (-dE/dStrain) energy_tensor.backward(-torch::ones_like(energy_tensor)); forceTensor = torch_positions.grad().to(torch::kCPU).to(torch::kFloat64); virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); } - // Accumulate forces into the GROMACS force output. - // Each rank applies forces only to its home atoms (no double-counting). + // Force distribution via all-reduce auto forceAccessor = forceTensor.accessor(); - for (int32_t i = 0; i < n_atoms; ++i) + + if (mpiComm_.isParallel()) { - if (inputToLocalIndex_[i] != -1) + // Scatter local forces into global buffer, all-reduce, then apply + globalForceBuffer_.assign(numTotalMta, RVec({ 0.0, 0.0, 0.0 })); + + for (int32_t i = 0; i < numLocalMta_; i++) { - outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][0] += forceAccessor[i][0]; - outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][1] += forceAccessor[i][1]; - outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][2] += forceAccessor[i][2]; + int32_t globalMtaIdx = mtaToGlobalMta_[i]; + globalForceBuffer_[globalMtaIdx][0] = static_cast(forceAccessor[i][0]); + globalForceBuffer_[globalMtaIdx][1] = static_cast(forceAccessor[i][1]); + globalForceBuffer_[globalMtaIdx][2] = static_cast(forceAccessor[i][2]); } - } - // Energy and virial: only main rank contributes since GROMACS sums across ranks. - if (mpiComm_.isMainRank() || !mpiComm_.isParallel()) + mpiComm_.sumReduce(3 * numTotalMta, globalForceBuffer_.data()->as_vec()); + + // Apply forces only to home MTA atoms from the reduced buffer + for (int32_t i = 0; i < numHomeMta_; i++) + { + int32_t gmxIdx = mtaToGmxLocal_[i]; + int32_t globalMtaIdx = mtaToGlobalMta_[i]; + outputs->forceWithVirial_.force_[gmxIdx][0] += globalForceBuffer_[globalMtaIdx][0]; + outputs->forceWithVirial_.force_[gmxIdx][1] += globalForceBuffer_[globalMtaIdx][1]; + outputs->forceWithVirial_.force_[gmxIdx][2] += globalForceBuffer_[globalMtaIdx][2]; + } + } + else { - outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = static_cast(energy); + // Serial: apply forces directly + for (int32_t i = 0; i < numLocalMta_; i++) + { + int32_t gmxIdx = mtaToGmxLocal_[i]; + outputs->forceWithVirial_.force_[gmxIdx][0] += static_cast(forceAccessor[i][0]); + outputs->forceWithVirial_.force_[gmxIdx][1] += static_cast(forceAccessor[i][1]); + outputs->forceWithVirial_.force_[gmxIdx][2] += static_cast(forceAccessor[i][2]); + } + } + + // Energy: every rank contributes its portion, GROMACS sums globally via global_stat + outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = static_cast(energy); - matrix virialMatrix; - auto virialAccessor = virialTensor.accessor(); - for (int32_t i = 0; i < 3; ++i) + // Virial: every rank contributes its portion, GROMACS sums globally + matrix virialMatrix; + auto virialAccessor = virialTensor.accessor(); + for (int32_t i = 0; i < 3; ++i) + { + for (int32_t j = 0; j < 3; ++j) { - for (int32_t j = 0; j < 3; ++j) - { - virialMatrix[i][j] = virialAccessor[i][j]; - } + virialMatrix[i][j] = virialAccessor[i][j]; } - outputs->forceWithVirial_.addVirialContribution(virialMatrix); } + outputs->forceWithVirial_.addVirialContribution(virialMatrix); } } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index 0dd9b430cc..8d8b98ca74 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -41,6 +41,8 @@ #pragma once +#include + #include "gromacs/mdtypes/iforceprovider.h" #include "metatomic_options.h" @@ -56,15 +58,14 @@ struct MDModulesPairlistConstructedSignal; class MDLogger; class MpiComm; -/*! For compatibility with pairlist data structure in MDModulesPairlistConstructedSignal. - * Contains pairs like ((atom1, atom2), shiftIndex). - */ -using PairlistEntry = std::pair, int32_t>; - /*! \brief \internal * MetatomicForceProvider class * * Implements the IForceProvider interface for the Metatomic force provider. + * Each rank evaluates the model on its local (home + halo) MTA atoms. + * The neighbor list comes from the GROMACS pairlist (excludedPairlist), + * which assigns each pair to exactly one rank — no double counting. + * Forces are combined via MPI all-reduce on a global force buffer. */ class MetatomicForceProvider final : public IForceProvider { @@ -83,60 +84,49 @@ class MetatomicForceProvider final : public IForceProvider //! Gather atom numbers and indices. Triggered on AtomsRedistributed signal. void gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal); - //! Set pairlist from notification and filter to MTA atom pairs. + //! Store GROMACS pairlist and convert to MTA model indices. void setPairlist(const MDModulesPairlistConstructedSignal& signal); - void augmentGhostPairs(const ArrayRef x, const matrix box); private: - //! Gather atom positions for MTA input. - void gatherAtomPositions(ArrayRef globalPositions); - - //! Prepare pairlist input for model - void preparePairlistInput(); - - //! Build full neighbor list on main rank from gathered positions - void buildFullPairlist(const matrix box); + //! Gather atom positions for MTA input (local only, no MPI). + void gatherAtomPositions(ArrayRef positions); const MetatomicOptions& options_; const MDLogger& logger_; const MpiComm& mpiComm_; - //! vector storing all MTA atom positions + //! vector storing local MTA atom positions (home + halo) std::vector positions_; - //! vector storing all atomic numbers + //! vector storing local MTA atomic numbers (home + halo) std::vector atomNumbers_; - //! lookup table to map model input indices [0...numInput) to local atom indices - std::vector inputToLocalIndex_; - //! reverse map: local atom index -> model input index (sized to home+halo) - std::vector localToModelIndex_; - //! lookup table to map model input indices to global atom indices - std::vector inputToGlobalIndex_; - - //! Number of home atoms on this rank (from last DD redistribution) - int32_t numLocalAtoms_ = 0; + //! Number of home MTA atoms on this rank + int32_t numHomeMta_ = 0; + //! Number of home + halo MTA atoms on this rank + int32_t numLocalMta_ = 0; - //! Full pairlist from MDModules notification - std::vector fullPairlist_; + //! Maps local model index [0, numLocalMta_) -> GROMACS local buffer index + std::vector mtaToGmxLocal_; + //! Maps local model index [0, numLocalMta_) -> global MTA index [0, N_total_mta) + std::vector mtaToGlobalMta_; + //! Maps ANY GROMACS local buffer index (including periodic ghosts) -> MTA model index + //! Used by setPairlist to resolve pairlist entries that reference ghost images. + std::unordered_map gmxLocalToMtaIdx_; - //! Interacting pairs of MTA atoms within cutoff, for model input - std::vector pairlistForModel_; + //! Global force buffer sized [N_total_mta] for MPI all-reduce + std::vector globalForceBuffer_; - //! Shift vectors for each atom pair in pairlistForModel_ - std::vector shiftVectors_; - - //! Cell shifts - std::vector cellShifts_; + //! Pairlist from GROMACS (MTA model indices), flat [A0,B0,A1,B1,...] + std::vector pairlistMta_; + //! Integer cell shifts for each pair in pairlistMta_ + std::vector cellShiftsMta_; //! local copy of simulation box matrix box_; //! Data required for metatomic calculations std::unique_ptr data_; - - //! flag to check if pairlist data should be prepared - bool doPairlist_ = false; }; } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp index 302dc2559d..22fa70e6cc 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider_stub.cpp @@ -80,11 +80,17 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& /*inputs* { } -void MetatomicForceProvider::updateLocalAtoms() {} -void MetatomicForceProvider::gatherAtomPositions(ArrayRef globalPositions) { - (void)globalPositions; +void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& /*signal*/) +{ +} + +void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& /*signal*/) +{ +} + +void MetatomicForceProvider::gatherAtomPositions(ArrayRef /*positions*/) +{ } -void MetatomicForceProvider::gatherAtomNumbersIndices() {} CLANG_DIAGNOSTIC_RESET diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index 6f3270f0ed..41c3b6dde2 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -239,7 +239,16 @@ class MetatomicMDModule final : public IMDModule GMX_THROW(InconsistentInputError("Metatomic model cutoff is 0.0 or invalid.")); } } - // Register the requirement with GROMACS + // TODO: For multi-layer GNN models (MACE, NequIP, etc.) the + // interaction_range should be n_layers * cutoff so that DD halos + // are deep enough for message-passing. Many models currently + // report interaction_range == cutoff, which makes DD give wrong + // energies because halo atoms lack complete neighborhoods. + // Unlike LAMMPS (which adds a ~2 Å neighbor skin on top of the + // cutoff), GROMACS caps the DD range at rlist, so we cannot add + // extra range here. The model must report the correct + // interaction_range, or the user must increase rcoulomb/rvdw in + // the .mdp so that rlist >= interaction_range. ranges->addRange(max_cutoff); }; // Register the callback From 95ee966aec8cf00183e8a585b11b7add0a5065ac Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 08:19:53 +0100 Subject: [PATCH 13/19] chore(doc): add some more details --- .../metatomic/metatomic_forceprovider.cpp | 59 ++++++++++++++---- .../metatomic/metatomic_forceprovider.h | 62 +++++++++++++++---- 2 files changed, 95 insertions(+), 26 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 9887953394..aaaea357ee 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -35,12 +35,22 @@ * \brief * Implements the Metatomic Force Provider class with per-rank model evaluation. * - * Each rank evaluates the model on its local (home + halo) MTA atoms, - * using the GROMACS pairlist (excludedPairlist) as the neighbor list. - * Each pair is assigned to exactly one rank by GROMACS, so summing all - * per-atom energies (home + halo) gives the correct pair contribution - * without double counting. Forces from backward() are combined via MPI - * all-reduce on a global force buffer. + * Uses the GROMACS plain pairlist (excludedPairlist) as the neighbor list + * source. MTA-MTA nonbonded pairs are excluded from the classical force + * calculation via intermolecularExclusionGroup, causing them to appear in + * excludedPairlist_ instead of pairlist_. Each pair is assigned to exactly + * one rank by the GROMACS nbnxm pairlist builder. + * + * Key design points: + * - Neighbor list: built from excludedPairlist_, not AnalysisNeighborhood. + * Cell shifts are negated (GROMACS shifts first atom, metatensor shifts + * second atom). Models requesting full_list get both (i,j) and (j,i). + * - Energy: selected_atoms = nullopt (sum all per-atom energies, home + + * halo). No double counting because each pair is on one rank. + * - Forces: all-reduce on a global buffer because ForceWithVirial is not + * communicated by dd_move_f. Only home atom forces are applied. + * - Ghost deduplication: periodic ghost images share the same model index + * but all GROMACS local indices are mapped via gmxLocalToMtaIdx_. * * \author Metatensor developers * \ingroup module_applied_forces @@ -119,7 +129,21 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) return torch::tensor({ true, true, true }, options); } -/*! \brief Constructs a Metatensor TensorBlock representing the neighbor list. */ +/*! \brief Constructs a Metatensor TensorBlock representing the neighbor list. + * + * Builds the metatensor neighbor list from flat pairlist arrays. The + * displacement vector for pair k is: + * r_ij = positions[j] - positions[i] + shiftVectors[k] + * which matches the metatensor convention when cellShifts follow the + * metatensor sign convention (shift applied to second atom j). + * + * \param[in] pairlist Flat [i0,j0, i1,j1, ...] model-index pairs + * \param[in] shiftVectors Real-space shift vectors (box * cellShifts) + * \param[in] cellShifts Integer cell shifts (metatensor convention) + * \param[in] positions Atom positions indexed by model index + * \param[in] device Torch device for output tensors + * \param[in] dtype Torch scalar type for output tensors + */ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, ArrayRef shiftVectors, ArrayRef cellShifts, @@ -597,7 +621,10 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto system = torch::make_intrusive( torch_types, strained_positions, strained_cell, torch_pbc); - // Build neighbor list from GROMACS pairlist (each pair on exactly one rank) + // Build neighbor list from GROMACS pairlist (each pair on exactly one rank). + // The stored pairlistMta_ is a half list. If the model requests + // full_list, we double it by adding the reverse pair (j,i) with + // negated cell shifts for each (i,j). for (const auto& request : data_->nl_requests) { std::vector finalPairlist; @@ -606,7 +633,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F if (request->full_list()) { - // Full list: add both (i,j) and (j,i) for each pair + // Full list needed (e.g. message-passing GNNs like MACE, NequIP) const int64_t nHalf = static_cast(pairlistMta_.size() / 2); finalPairlist.reserve(4 * nHalf); finalShiftVectors.reserve(2 * nHalf); @@ -703,12 +730,16 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); } - // Force distribution via all-reduce + // Force distribution via all-reduce. + // backward() produces forces on ALL local atoms (home + halo). Since + // ForceWithVirial forces are NOT communicated by dd_move_f (which only + // handles ForceWithShiftForces), we must all-reduce ourselves. Each rank + // scatters its local forces into a global buffer indexed by global MTA + // index. After all-reduce, each rank reads back only its home atoms. auto forceAccessor = forceTensor.accessor(); if (mpiComm_.isParallel()) { - // Scatter local forces into global buffer, all-reduce, then apply globalForceBuffer_.assign(numTotalMta, RVec({ 0.0, 0.0, 0.0 })); for (int32_t i = 0; i < numLocalMta_; i++) @@ -743,10 +774,12 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } } - // Energy: every rank contributes its portion, GROMACS sums globally via global_stat + // Energy: each rank's energy is the sum of per-atom energies for all local + // atoms (home + halo) from the pairs assigned to this rank. GROMACS + // global_stat sums across ranks to get the system total. outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = static_cast(energy); - // Virial: every rank contributes its portion, GROMACS sums globally + // Virial: same decomposition as energy — per-rank portion, summed by GROMACS. matrix virialMatrix; auto virialAccessor = virialTensor.accessor(); for (int32_t i = 0; i < 3; ++i) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index 8d8b98ca74..5b4f4e5d6e 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -62,10 +62,39 @@ class MpiComm; * MetatomicForceProvider class * * Implements the IForceProvider interface for the Metatomic force provider. + * + * ## Domain decomposition strategy + * * Each rank evaluates the model on its local (home + halo) MTA atoms. - * The neighbor list comes from the GROMACS pairlist (excludedPairlist), + * The neighbor list comes from the GROMACS plain pairlist (excludedPairlist), * which assigns each pair to exactly one rank — no double counting. - * Forces are combined via MPI all-reduce on a global force buffer. + * + * **Atoms**: In DD, GROMACS partitions atoms into "home" atoms (owned by this + * rank) and "halo" atoms (copies from neighboring ranks needed for short-range + * interactions). The same global atom may appear as multiple periodic ghost + * images in the halo. We deduplicate these so each atom has one model index, + * but record ALL GROMACS local buffer indices in gmxLocalToMtaIdx_ so that + * the pairlist (which may reference any image) can be resolved. + * + * **Pairs**: MTA-MTA pairs are excluded from classical nonbonded interactions + * via intermolecularExclusionGroup (set by addEmbeddedNBExclusions). The + * GROMACS pairlist builder reports these excluded pairs in excludedPairlist_, + * filtered to the plainPairlistRange (= model cutoff). Each pair appears on + * exactly one rank. + * + * **Energy**: With per_atom=true, the model decomposes energy per atom. + * selected_atoms is always nullopt: we sum ALL per-atom energies (home + halo) + * on each rank. Since each pair is on one rank, the per-pair energy + * (V_ij/2 on atom i + V_ij/2 on atom j) sums to V_ij on that rank. + * GROMACS global_stat sums across ranks for the total. + * + * **Forces**: backward() produces forces on all local atoms (home + halo). + * Since ForceWithVirial is not communicated by dd_move_f, we scatter forces + * into a global buffer and MPI all-reduce, then apply only to home atoms. + * + * **Shift convention**: GROMACS shifts atom I (first): d = x[I]+shift - x[J]. + * Metatensor convention: r_ij = x[J] + cell_shift*box - x[I]. Therefore + * metatensor cell shifts = negated GROMACS cell shifts. */ class MetatomicForceProvider final : public IForceProvider { @@ -95,31 +124,38 @@ class MetatomicForceProvider final : public IForceProvider const MDLogger& logger_; const MpiComm& mpiComm_; - //! vector storing local MTA atom positions (home + halo) + //! Positions of local MTA atoms, indexed by model index [0, numLocalMta_). + //! Home atoms occupy [0, numHomeMta_), halo atoms [numHomeMta_, numLocalMta_). std::vector positions_; - //! vector storing local MTA atomic numbers (home + halo) + //! Atomic numbers of local MTA atoms, same indexing as positions_. std::vector atomNumbers_; - //! Number of home MTA atoms on this rank + //! Number of home (owned by this rank) MTA atoms. int32_t numHomeMta_ = 0; - //! Number of home + halo MTA atoms on this rank + //! Number of unique local MTA atoms (home + halo, after deduplication). int32_t numLocalMta_ = 0; - //! Maps local model index [0, numLocalMta_) -> GROMACS local buffer index + //! Maps model index [0, numLocalMta_) -> GROMACS local buffer index (first occurrence). + //! Used for position gathering and force scattering. std::vector mtaToGmxLocal_; - //! Maps local model index [0, numLocalMta_) -> global MTA index [0, N_total_mta) + //! Maps model index [0, numLocalMta_) -> global MTA index [0, N_total_mta). + //! Used for scatter/gather in the global force buffer during MPI all-reduce. std::vector mtaToGlobalMta_; - //! Maps ANY GROMACS local buffer index (including periodic ghosts) -> MTA model index - //! Used by setPairlist to resolve pairlist entries that reference ghost images. + //! Maps ANY GROMACS local buffer index -> MTA model index. + //! Includes ALL periodic ghost images of each atom (not just the first). + //! Needed because excludedPairlist_ entries can reference any image. std::unordered_map gmxLocalToMtaIdx_; - //! Global force buffer sized [N_total_mta] for MPI all-reduce + //! Global force buffer [N_total_mta] for MPI all-reduce of forces. + //! Each rank scatters its local forces here, all-reduce sums them, + //! then home forces are read back. std::vector globalForceBuffer_; - //! Pairlist from GROMACS (MTA model indices), flat [A0,B0,A1,B1,...] + //! Pairlist in MTA model indices, flat [i0,j0, i1,j1, ...]. + //! Built from GROMACS excludedPairlist_ with negated cell shifts. std::vector pairlistMta_; - //! Integer cell shifts for each pair in pairlistMta_ + //! Cell shifts for each pair (metatensor convention: shift applied to second atom). std::vector cellShiftsMta_; //! local copy of simulation box From 497ee846d4cea8db429f88ce0223104c524afbaa Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 14:59:42 +0100 Subject: [PATCH 14/19] chore(check): disable check consistency --- .../tests/refdata/MetatomicOptionsTest_DefaultParameters.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml index 21555a3181..6ba0b09913 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml @@ -6,5 +6,5 @@ - true + false From 24e97e914d90c80f8d6cbd38e73305b383c081ba Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 15:21:20 +0100 Subject: [PATCH 15/19] feat(mtagro): as fast as it gets --- .../metatomic/metatomic_forceprovider.cpp | 244 ++++++++---------- .../metatomic/metatomic_forceprovider.h | 6 + .../metatomic/metatomic_timer.h | 81 +++--- 3 files changed, 160 insertions(+), 171 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index aaaea357ee..1318e8072a 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -129,87 +129,6 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) return torch::tensor({ true, true, true }, options); } -/*! \brief Constructs a Metatensor TensorBlock representing the neighbor list. - * - * Builds the metatensor neighbor list from flat pairlist arrays. The - * displacement vector for pair k is: - * r_ij = positions[j] - positions[i] + shiftVectors[k] - * which matches the metatensor convention when cellShifts follow the - * metatensor sign convention (shift applied to second atom j). - * - * \param[in] pairlist Flat [i0,j0, i1,j1, ...] model-index pairs - * \param[in] shiftVectors Real-space shift vectors (box * cellShifts) - * \param[in] cellShifts Integer cell shifts (metatensor convention) - * \param[in] positions Atom positions indexed by model index - * \param[in] device Torch device for output tensors - * \param[in] dtype Torch scalar type for output tensors - */ -static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, - ArrayRef shiftVectors, - ArrayRef cellShifts, - ArrayRef positions, - torch::Device device, - torch::ScalarType dtype) -{ - const int64_t n_pairs = static_cast(pairlist.size() / 2); - - auto cpu_int_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - auto cpu_float_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); - - auto pair_samples_values = torch::zeros({ n_pairs, 5 }, cpu_int_options); - auto pair_samples_ptr = pair_samples_values.accessor(); - - auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, cpu_float_options); - auto vectors_accessor = vectors_cpu.accessor(); - - for (int64_t i = 0; i < n_pairs; i++) - { - const int32_t atom_i = pairlist[2 * i]; - const int32_t atom_j = pairlist[2 * i + 1]; - - pair_samples_ptr[i][0] = atom_i; - pair_samples_ptr[i][1] = atom_j; - pair_samples_ptr[i][2] = cellShifts[i][0]; - pair_samples_ptr[i][3] = cellShifts[i][1]; - pair_samples_ptr[i][4] = cellShifts[i][2]; - - const double r_ij_x = - static_cast(positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]); - const double r_ij_y = - static_cast(positions[atom_j][1] - positions[atom_i][1] + shiftVectors[i][1]); - const double r_ij_z = - static_cast(positions[atom_j][2] - positions[atom_i][2] + shiftVectors[i][2]); - - vectors_accessor[i][0][0] = r_ij_x; - vectors_accessor[i][1][0] = r_ij_y; - vectors_accessor[i][2][0] = r_ij_z; - } - - auto final_samples_values = pair_samples_values.to(device); - auto final_vectors = vectors_cpu.to(dtype).to(device); - - auto neighbor_samples = torch::make_intrusive( - std::vector{ - "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c" }, - final_samples_values); - - auto neighbor_component = torch::make_intrusive( - std::vector{ "xyz" }, - torch::tensor({ 0, 1, 2 }, torch::TensorOptions().dtype(torch::kInt32).device(device)) - .reshape({ 3, 1 })); - - auto neighbor_properties = torch::make_intrusive( - std::vector{ "distance" }, - torch::zeros({ 1, 1 }, torch::TensorOptions().dtype(torch::kInt32).device(device))); - - return torch::make_intrusive( - final_vectors, - neighbor_samples, - std::vector{ neighbor_component }, - neighbor_properties); -} - - /*! \brief Internal data structure for Metatomic runtime states. */ struct MetatomicData { @@ -220,6 +139,14 @@ struct MetatomicData torch::ScalarType dtype = torch::kFloat32; bool check_consistency = false; torch::Device device = torch::kCPU; + + //! Cached NL Labels that are identical every step (created once in constructor). + metatensor_torch::Labels cachedNLComponent; + metatensor_torch::Labels cachedNLProperties; + //! Cached sample column names (avoids heap-allocating string vector every step). + std::vector nlSampleNames = { + "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c" + }; }; MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, @@ -267,6 +194,16 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, metatomic_torch::pick_device(data_->capabilities->supported_devices, desiredDevice); data_->device = torch::Device(deviceType); + // Cache NL Labels that are constant across steps (avoids per-step + // string vector + tensor allocation for component and properties). + auto devIntOpts = torch::TensorOptions().dtype(torch::kInt32).device(data_->device); + data_->cachedNLComponent = torch::make_intrusive( + std::vector{ "xyz" }, + torch::tensor({ 0, 1, 2 }, devIntOpts).reshape({ 3, 1 })); + data_->cachedNLProperties = torch::make_intrusive( + std::vector{ "distance" }, + torch::zeros({ 1, 1 }, devIntOpts)); + GMX_LOG(logger_.info) .asParagraph() .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); @@ -572,19 +509,6 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } copy_mat(inputs.box_, box_); - // Compute shift vectors from stored cell shifts and current box - std::vector shiftVectors; - { - MetatomicTimer timer("prepareNL", mpiComm_); - shiftVectors.reserve(cellShiftsMta_.size()); - for (const auto& cs : cellShiftsMta_) - { - RVec shift; - mvmul_ur0(inputs.box_, cs.toRVec(), shift); - shiftVectors.push_back(shift); - } - } - // Model inference torch::Tensor forceTensor; torch::Tensor virialTensor; @@ -593,6 +517,8 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F { MetatomicTimer modelTimer("model inference", mpiComm_); + MetatomicTimer tensorPrepTimer("tensorPrep", mpiComm_); + auto gromacs_scalar_type = torch::kFloat32; if (std::is_same_v) { @@ -621,61 +547,101 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto system = torch::make_intrusive( torch_types, strained_positions, strained_cell, torch_pbc); - // Build neighbor list from GROMACS pairlist (each pair on exactly one rank). - // The stored pairlistMta_ is a half list. If the model requests - // full_list, we double it by adding the reverse pair (j,i) with - // negated cell shifts for each (i,j). + tensorPrepTimer.stop(); + + // Build NL directly into raw buffers, then wrap with from_blob. + // Shift vectors are computed inline (no separate prepareNL pass). + // Component and properties Labels are cached (identical every step). + MetatomicTimer buildNLTimer("buildNL", mpiComm_); + for (const auto& request : data_->nl_requests) { - std::vector finalPairlist; - std::vector finalShiftVectors; - std::vector finalCellShifts; + const int64_t nHalf = static_cast(pairlistMta_.size() / 2); + const bool full = request->full_list(); + const int64_t nPairs = full ? 2 * nHalf : nHalf; - if (request->full_list()) - { - // Full list needed (e.g. message-passing GNNs like MACE, NequIP) - const int64_t nHalf = static_cast(pairlistMta_.size() / 2); - finalPairlist.reserve(4 * nHalf); - finalShiftVectors.reserve(2 * nHalf); - finalCellShifts.reserve(2 * nHalf); + nlSamplesBuffer_.resize(nPairs * 5); + nlVectorsBuffer_.resize(nPairs * 3); - for (int64_t k = 0; k < nHalf; k++) + for (int64_t k = 0; k < nHalf; k++) + { + const int32_t ai = pairlistMta_[2 * k]; + const int32_t aj = pairlistMta_[2 * k + 1]; + + // Compute shift vector from cell shift and current box + RVec shift; + mvmul_ur0(inputs.box_, cellShiftsMta_[k].toRVec(), shift); + + // Displacement: r_ij = pos[j] - pos[i] + shift (metatensor convention) + const double dx = static_cast(positions_[aj][0] - positions_[ai][0] + shift[0]); + const double dy = static_cast(positions_[aj][1] - positions_[ai][1] + shift[1]); + const double dz = static_cast(positions_[aj][2] - positions_[ai][2] + shift[2]); + + const int64_t fwd = full ? 2 * k : k; + nlSamplesBuffer_[5 * fwd + 0] = ai; + nlSamplesBuffer_[5 * fwd + 1] = aj; + nlSamplesBuffer_[5 * fwd + 2] = cellShiftsMta_[k][0]; + nlSamplesBuffer_[5 * fwd + 3] = cellShiftsMta_[k][1]; + nlSamplesBuffer_[5 * fwd + 4] = cellShiftsMta_[k][2]; + nlVectorsBuffer_[3 * fwd + 0] = dx; + nlVectorsBuffer_[3 * fwd + 1] = dy; + nlVectorsBuffer_[3 * fwd + 2] = dz; + + if (full) { - int32_t atomA = pairlistMta_[2 * k]; - int32_t atomB = pairlistMta_[2 * k + 1]; - - finalPairlist.push_back(atomA); - finalPairlist.push_back(atomB); - finalShiftVectors.push_back(shiftVectors[k]); - finalCellShifts.push_back(cellShiftsMta_[k]); - - finalPairlist.push_back(atomB); - finalPairlist.push_back(atomA); - finalShiftVectors.push_back( - RVec(-shiftVectors[k][XX], -shiftVectors[k][YY], -shiftVectors[k][ZZ])); - finalCellShifts.push_back( - IVec(-cellShiftsMta_[k][XX], -cellShiftsMta_[k][YY], -cellShiftsMta_[k][ZZ])); + // Reverse pair (j,i) with negated shifts and displacement + const int64_t rev = 2 * k + 1; + nlSamplesBuffer_[5 * rev + 0] = aj; + nlSamplesBuffer_[5 * rev + 1] = ai; + nlSamplesBuffer_[5 * rev + 2] = -cellShiftsMta_[k][0]; + nlSamplesBuffer_[5 * rev + 3] = -cellShiftsMta_[k][1]; + nlSamplesBuffer_[5 * rev + 4] = -cellShiftsMta_[k][2]; + nlVectorsBuffer_[3 * rev + 0] = -dx; + nlVectorsBuffer_[3 * rev + 1] = -dy; + nlVectorsBuffer_[3 * rev + 2] = -dz; } } - else - { - // Half list: use pairlist as-is - finalPairlist = pairlistMta_; - finalShiftVectors = shiftVectors; - finalCellShifts = cellShiftsMta_; - } - auto neighbors = buildNeighborListFromPairlist( - finalPairlist, finalShiftVectors, finalCellShifts, positions_, data_->device, data_->dtype); + MetatomicTimer fromBlobTimer("fromBlob", mpiComm_); + auto samples_tensor = torch::from_blob( + nlSamplesBuffer_.data(), { nPairs, 5 }, + torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); + auto vectors_tensor = torch::from_blob( + nlVectorsBuffer_.data(), { nPairs, 3, 1 }, + torch::TensorOptions().dtype(torch::kFloat64)).to(data_->dtype).to(data_->device); + fromBlobTimer.stop(); + + MetatomicTimer labelsTimer("makeSampleLabels", mpiComm_); + auto neighbor_samples = torch::make_intrusive( + data_->nlSampleNames, samples_tensor); + labelsTimer.stop(); + + MetatomicTimer blockTimer("makeTensorBlock", mpiComm_); + auto neighbors = torch::make_intrusive( + vectors_tensor, + neighbor_samples, + std::vector{ data_->cachedNLComponent }, + data_->cachedNLProperties); + blockTimer.stop(); + + MetatomicTimer autogradTimer("registerAutograd", mpiComm_); metatomic_torch::register_autograd_neighbors(system, neighbors, data_->check_consistency); + autogradTimer.stop(); + + MetatomicTimer addNLTimer("addNeighborList", mpiComm_); system->add_neighbor_list(request, neighbors); + addNLTimer.stop(); } + buildNLTimer.stop(); + // No selected_atoms: each pair is on exactly one rank, so summing // all per-atom energies (home + halo) gives the correct pair energy. // GROMACS sums across ranks via global_stat. data_->evaluations_options->set_selected_atoms(torch::nullopt); + MetatomicTimer forwardTimer("forward", mpiComm_); + metatensor_torch::TensorMap output_map; try { @@ -692,6 +658,8 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F GMX_THROW(APIError("[Metatomic] Model evaluation failed: " + std::string(e.what()))); } + forwardTimer.stop(); + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); auto energy_tensor = energy_block->values(); @@ -721,13 +689,21 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } } + MetatomicTimer backwardTimer("backward", mpiComm_); + torch_positions.mutable_grad() = torch::Tensor(); strain.mutable_grad() = torch::Tensor(); energy_tensor.backward(-torch::ones_like(energy_tensor)); + backwardTimer.stop(); + + MetatomicTimer toCPUTimer("toCPU", mpiComm_); + forceTensor = torch_positions.grad().to(torch::kCPU).to(torch::kFloat64); virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); + + toCPUTimer.stop(); } // Force distribution via all-reduce. @@ -736,6 +712,8 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // handles ForceWithShiftForces), we must all-reduce ourselves. Each rank // scatters its local forces into a global buffer indexed by global MTA // index. After all-reduce, each rank reads back only its home atoms. + MetatomicTimer forceScatterTimer("forceScatter", mpiComm_); + auto forceAccessor = forceTensor.accessor(); if (mpiComm_.isParallel()) @@ -774,6 +752,8 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } } + forceScatterTimer.stop(); + // Energy: each rank's energy is the sum of per-atom energies for all local // atoms (home + halo) from the pairs assigned to this rank. GROMACS // global_stat sums across ranks to get the system total. diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index 5b4f4e5d6e..0448c0fed1 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -158,6 +158,12 @@ class MetatomicForceProvider final : public IForceProvider //! Cell shifts for each pair (metatensor convention: shift applied to second atom). std::vector cellShiftsMta_; + //! Pre-allocated raw buffers for NL construction. + //! Avoids per-step torch::zeros allocations and accessor overhead. + //! Filled directly, then wrapped with torch::from_blob (zero-cost). + std::vector nlSamplesBuffer_; //!< flat [n_pairs * 5]: i, j, cs_a, cs_b, cs_c + std::vector nlVectorsBuffer_; //!< flat [n_pairs * 3]: dx, dy, dz + //! local copy of simulation box matrix box_; diff --git a/src/gromacs/applied_forces/metatomic/metatomic_timer.h b/src/gromacs/applied_forces/metatomic/metatomic_timer.h index 370efcfeb6..867efb3c95 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_timer.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_timer.h @@ -35,8 +35,8 @@ * \brief * Scoped timer for Metatomic force provider profiling. * - * RAII timer that prints nested timing information to stderr with MPI rank. - * Enable with MetatomicTimer::enable(true) before use. + * RAII timer that writes nested timing information to per-rank files + * (metatomic_timer_rank_N.log). Enable with GMX_METATOMIC_TIMER=1. * * \author Metatensor developers * \ingroup module_applied_forces @@ -46,7 +46,7 @@ #include #include -#include +#include #include #include @@ -60,14 +60,15 @@ static std::mutex METATOMIC_TIMER_MUTEX = {}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static int64_t METATOMIC_TIMER_DEPTH = -1; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -static uint64_t METATOMIC_TIMER_COUNTER = 0; -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static bool METATOMIC_TIMER_ENABLED = false; /*! \internal \brief RAII scoped timer for Metatomic profiling. * - * Prints hierarchical timing info to stderr. Timers nest automatically. - * Thread-safe via a global mutex. + * Writes hierarchical timing info to a per-rank file + * (metatomic_timer_rank_N.log). Timers nest automatically via a + * global depth counter. Thread-safe via a global mutex. + * + * Enable with GMX_METATOMIC_TIMER=1 environment variable. */ class MetatomicTimer { @@ -87,55 +88,57 @@ class MetatomicTimer if (METATOMIC_TIMER_ENABLED) { METATOMIC_TIMER_DEPTH += 1; - METATOMIC_TIMER_COUNTER += 1; - - this->enabled_ = true; - this->starting_counter_ = METATOMIC_TIMER_COUNTER; - this->start_ = std::chrono::high_resolution_clock::now(); - auto indent = std::string(METATOMIC_TIMER_DEPTH * 3, ' '); - - if (METATOMIC_TIMER_DEPTH == 0) - { - std::cerr << "\n"; - } - std::cerr << "\n" << indent << this->name_ << " ..."; + this->enabled_ = true; + this->start_ = std::chrono::high_resolution_clock::now(); } } + //! Stop the timer early (before scope exit). Safe to call multiple times. + void stop() + { + auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + recordAndDisable_(); + } + ~MetatomicTimer() { auto guard_ = std::lock_guard(METATOMIC_TIMER_MUTEX); + recordAndDisable_(); + } + + // Non-copyable, non-movable + MetatomicTimer(const MetatomicTimer&) = delete; + MetatomicTimer& operator=(const MetatomicTimer&) = delete; + MetatomicTimer(MetatomicTimer&&) = delete; + MetatomicTimer& operator=(MetatomicTimer&&) = delete; +private: + //! Record elapsed time and mark as done. Must be called under lock. + void recordAndDisable_() + { if (METATOMIC_TIMER_ENABLED && this->enabled_) { - auto stop = std::chrono::high_resolution_clock::now(); - auto elapsed = - std::chrono::duration_cast(stop - start_).count(); + auto stop = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(stop - start_).count(); + auto indent = std::string(METATOMIC_TIMER_DEPTH * 2, ' '); - if (METATOMIC_TIMER_COUNTER != starting_counter_) + std::string fname = "metatomic_timer_rank_" + std::to_string(mpiComm_.rank()) + ".log"; + FILE* fp = std::fopen(fname.c_str(), "a"); + if (fp) { - auto indent = std::string(METATOMIC_TIMER_DEPTH * 3, ' '); - std::cerr << "\n" << indent << this->name_; + std::fprintf(fp, "%s%s: %.3f ms\n", indent.c_str(), name_.c_str(), elapsed / 1e3); + std::fclose(fp); } - std::cerr << " took " << elapsed / 1e6 << "ms (rank " << mpiComm_.rank() << ")" - << std::flush; + this->enabled_ = false; METATOMIC_TIMER_DEPTH -= 1; } } - // Non-copyable, non-movable - MetatomicTimer(const MetatomicTimer&) = delete; - MetatomicTimer& operator=(const MetatomicTimer&) = delete; - MetatomicTimer(MetatomicTimer&&) = delete; - MetatomicTimer& operator=(MetatomicTimer&&) = delete; - -private: - bool enabled_; - std::string name_; - const MpiComm& mpiComm_; - uint64_t starting_counter_ = 0; - std::chrono::high_resolution_clock::time_point start_; + bool enabled_; + std::string name_; + const MpiComm& mpiComm_; + std::chrono::high_resolution_clock::time_point start_; }; } // namespace gmx From 9d979121ed0c9a8b83d7da36c87243ac79d64432 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 15:56:12 +0100 Subject: [PATCH 16/19] chore(mtagro): fix rebase error --- .../applied_forces/metatomic/metatomic_forceprovider.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 1318e8072a..9980b5c840 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -526,7 +526,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } auto cpu_blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(torch::kCPU); - auto torch_positions = torch::from_blob(positions_.data()->as_vec(), { n_atoms, 3 }, cpu_blob_options) + auto torch_positions = torch::from_blob(positions_.data()->as_vec(), { static_cast(numLocalMta_), 3 }, cpu_blob_options) .to(data_->dtype) .to(data_->device) .set_requires_grad(true); From 279374095a608c43481299ca73812ee3eece3522 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 15:56:26 +0100 Subject: [PATCH 17/19] chore(runpath): rework to use rpath correctly --- cmake/gmxManageMetatomic.cmake | 19 +++++++++++++++++++ cmake/gmxManageNNPot.cmake | 16 ++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/cmake/gmxManageMetatomic.cmake b/cmake/gmxManageMetatomic.cmake index e4a60c663f..03d3746a42 100644 --- a/cmake/gmxManageMetatomic.cmake +++ b/cmake/gmxManageMetatomic.cmake @@ -90,6 +90,25 @@ if(NOT GMX_METATOMIC STREQUAL "OFF") endif() set(GMX_TORCH ON) + + # Ensure the torch library directory is in RPATH so that libtorch.so, + # libc10.so, etc. can be found at runtime. CMAKE_INSTALL_RPATH_USE_LINK_PATH + # doesn't always extract paths from imported targets, so we add it explicitly. + # Guard prevents duplicate additions if gmxManageNNPot already added it. + if(TORCH_INSTALL_PREFIX AND NOT _torch_rpath_added) + list(APPEND CMAKE_INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib") + set(_torch_rpath_added TRUE) + + # Use RPATH instead of RUNPATH. Modern linkers default to RUNPATH + # (--enable-new-dtags), but RUNPATH doesn't propagate to transitive + # dependencies: gmx -> libgromacs.so -> libtorch.so -> libc10.so. + include(CheckLinkerFlag) + check_linker_flag(CXX "-Wl,--disable-new-dtags" _linker_supports_disable_new_dtags) + if(_linker_supports_disable_new_dtags) + add_link_options("-Wl,--disable-new-dtags") + endif() + endif() + elseif(GMX_METATOMIC STREQUAL "TORCH") message(FATAL_ERROR "Torch not found. Please install libtorch and add its installation prefix" " to CMAKE_PREFIX_PATH or set Torch_DIR to a directory containing " diff --git a/cmake/gmxManageNNPot.cmake b/cmake/gmxManageNNPot.cmake index ecaef22d6d..560d67d5bd 100644 --- a/cmake/gmxManageNNPot.cmake +++ b/cmake/gmxManageNNPot.cmake @@ -90,6 +90,22 @@ if(NOT GMX_NNPOT STREQUAL "OFF") endif() set(GMX_TORCH ON) + + # Ensure the torch library directory is in RPATH so that libtorch.so, + # libc10.so, etc. can be found at runtime. + if(TORCH_INSTALL_PREFIX AND NOT _torch_rpath_added) + list(APPEND CMAKE_INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib") + set(_torch_rpath_added TRUE) + + # Use RPATH instead of RUNPATH. RUNPATH doesn't propagate to + # transitive dependencies (gmx -> libgromacs -> libtorch -> libc10). + include(CheckLinkerFlag) + check_linker_flag(CXX "-Wl,--disable-new-dtags" _linker_supports_disable_new_dtags) + if(_linker_supports_disable_new_dtags) + add_link_options("-Wl,--disable-new-dtags") + endif() + endif() + elseif(GMX_NNPOT STREQUAL "TORCH") message(FATAL_ERROR "Torch not found. Please install libtorch and add its installation prefix" " to CMAKE_PREFIX_PATH or set Torch_DIR to a directory containing " From 2ecd385d398b60ba7357fc8044f36b3d94cc1fdf Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 16:33:06 +0100 Subject: [PATCH 18/19] feat(tmpi): bring back some semblance of perf --- .../applied_forces/metatomic/metatomic_forceprovider.cpp | 9 +++++++++ src/gromacs/applied_forces/metatomic/metatomic_timer.h | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 9980b5c840..f7aefacead 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -165,6 +165,15 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, MetatomicTimer::enable(std::string(timerEnv) == "1"); } + // With thread-MPI, each rank is a thread sharing the same process. + // PyTorch's internal OpenMP would spawn N threads per rank, causing + // massive oversubscription (e.g. 12 ranks × 12 OMP threads = 144 + // threads on 12 cores). Force single-threaded torch operations. + if (GMX_THREAD_MPI && mpiComm_.isParallel()) + { + at::set_num_threads(1); + } + try { torch::optional extensions_directory = torch::nullopt; diff --git a/src/gromacs/applied_forces/metatomic/metatomic_timer.h b/src/gromacs/applied_forces/metatomic/metatomic_timer.h index 867efb3c95..e4f056a53f 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_timer.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_timer.h @@ -58,7 +58,7 @@ namespace gmx // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static std::mutex METATOMIC_TIMER_MUTEX = {}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -static int64_t METATOMIC_TIMER_DEPTH = -1; +static thread_local int64_t METATOMIC_TIMER_DEPTH = -1; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static bool METATOMIC_TIMER_ENABLED = false; From 5c7ea8652a6670912aa50a104c268c9407d3b7a7 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 15 Feb 2026 17:45:17 +0100 Subject: [PATCH 19/19] chore(ci): start testing a CI build --- .github/workflows/build_cmake.yml | 79 ------------------------------ .github/workflows/metatomic-ci.yml | 79 ++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 79 deletions(-) delete mode 100644 .github/workflows/build_cmake.yml create mode 100644 .github/workflows/metatomic-ci.yml diff --git a/.github/workflows/build_cmake.yml b/.github/workflows/build_cmake.yml deleted file mode 100644 index 5e74952eab..0000000000 --- a/.github/workflows/build_cmake.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: CMake Build Matrix - -on: [push, pull_request] - -env: - CMAKE_VERSION: 3.28.0 # Oldest supported - NINJA_VERSION: 1.12.1 # Latest - BUILD_TYPE: Release - CCACHE_VERSION: 4.10.2 # Latest - NINJA_STATUS: "[%f/%t %o/sec] " - -jobs: - build: - name: ${{ matrix.config.name }} - runs-on: ${{ matrix.config.os }} - strategy: - fail-fast: false - matrix: - config: - - { - name: "Windows MSVC 2022", artifact: "Windows-MSVC-2022.7z", - os: windows-2022, - cc: "cl", cxx: "cl", - environment_script: "C:/Program Files/Microsoft Visual Studio/2022/Enterprise/VC/Auxiliary/Build/vcvars64.bat", - gpu_var: "Off", - openmp_var: "On" - } - - { - name: "macOS Latest Clang", artifact: "macOS.7z", - # In a release branch, we should fix this for the lifetime - # of the branch. - os: macos-latest, - cc: "clang", cxx: "clang++", - gpu_var: "Off", - openmp_var: "Off" - } - - { - name: "macOS Latest GCC 14 with OpenCL", artifact: "macOS-gcc-OpenCL.7z", - # In a release branch, we should fix this for the lifetime - # of the branch. - os: macos-latest, - cc: "gcc-14", cxx: "g++-14", - gpu_var: "OpenCL", - openmp_var: "On" - } - - env: - CC: ${{ matrix.config.cc }} - CXX: ${{ matrix.config.cxx }} - CI_JOB_ID: ${{ github.sha }} # Tell CMake it's running in CI - OPENMP_VAR: ${{ matrix.config.openmp_var }} - GPU_VAR: ${{ matrix.config.gpu_var }} - ENVIRONMENT_SCRIPT: ${{ matrix.config.environment_script }} - - steps: - - uses: actions/checkout@v4 - with: - show-progress: false - - - name: Download Ninja, CMake, and CCache - run: cmake -P .github/scripts/download-ninja-cmake.cmake - - - name: ccache cache files - uses: actions/cache@v4 - with: - path: .ccache - key: ${{ matrix.config.name }}-ccache-${{ github.sha }} - restore-keys: | - ${{ matrix.config.name }}-ccache- - - - name: Configure - run: cmake -P .github/scripts/configure.cmake - - - name: Build - run: cmake -P .github/scripts/build.cmake - - - name: Run tests - run: cmake -P .github/scripts/test.cmake - diff --git a/.github/workflows/metatomic-ci.yml b/.github/workflows/metatomic-ci.yml new file mode 100644 index 0000000000..2cc7ad5f0f --- /dev/null +++ b/.github/workflows/metatomic-ci.yml @@ -0,0 +1,79 @@ +name: Metatomic Integration Tests + +on: + workflow_dispatch: + inputs: + pixi_envs_branch: + description: "pixi_envs branch for test definitions" + required: false + default: "main" + push: + branches: [metatomic] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - mpi: thread-mpi + build_task: gromk-tmpi + gmx_bin: gmx + - mpi: real-mpi + build_task: gromk-mpi + gmx_bin: gmx_mpi + + name: metatomic (${{ matrix.mpi }}) + + steps: + - name: Checkout GROMACS + uses: actions/checkout@v4 + with: + path: gromacs + + - name: Checkout pixi_envs + uses: actions/checkout@v4 + with: + repository: HaoZeke/pixi_envs + ref: ${{ github.event.inputs.pixi_envs_branch || 'main' }} + path: pixi_envs + + - name: Place GROMACS inside pixi_envs layout + run: ln -s ${{ github.workspace }}/gromacs pixi_envs/orgs/metatensor/gromacs/gromacs + + - name: Install pixi + uses: prefix-dev/setup-pixi@v0.8.1 + with: + manifest-path: pixi_envs/orgs/metatensor/gromacs/pixi.toml + environments: metatomic-cpu + + - name: Build GROMACS (${{ matrix.mpi }}) + working-directory: pixi_envs/orgs/metatensor/gromacs + run: pixi run -e metatomic-cpu ${{ matrix.build_task }} Release + + - name: Show cmake error log on failure + if: failure() + working-directory: pixi_envs/orgs/metatensor/gromacs + run: | + for f in gromacs/build-*/CMakeFiles/CMakeError.log; do + echo "=== $f ===" + cat "$f" 2>/dev/null || true + done + for f in gromacs/build-*/CMakeFiles/CMakeOutput.log; do + echo "=== $f (last 100 lines) ===" + tail -100 "$f" 2>/dev/null || true + done + + - name: Generate test model + working-directory: pixi_envs/orgs/metatensor/gromacs + run: cd mta_test && pixi run -e metatomic-cpu python create_model.py + + - name: Run tests (skip dd8/dd12) + working-directory: pixi_envs/orgs/metatensor/gromacs + env: + GMX_BIN: ${{ matrix.gmx_bin }} + run: | + pixi run -e metatomic-cpu pytest mta_test/ -v \ + -m "not dd8 and not dd12"