From aee078ea9ae0d61996022a4e0aceb8b7580abc47 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Fri, 30 Jan 2026 08:28:39 +0100 Subject: [PATCH 01/17] feat(mtapot): remove vesin, use DD pairlist As per the NNPot implementation --- .../metatomic/metatomic_forceprovider.cpp | 634 +++++++++--------- .../metatomic/metatomic_forceprovider.h | 44 +- .../metatomic/metatomic_mdmodule.cpp | 14 +- 3 files changed, 371 insertions(+), 321 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index af5b8af9a1..2afd3e3875 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -47,122 +47,153 @@ #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" +#include "gromacs/mdrunutility/mdmodulesnotifiers.h" #include "gromacs/mdtypes/enerdata.h" #include "gromacs/mdtypes/forceoutput.h" +#include "gromacs/pbcutil/ishift.h" #include "gromacs/utility/arrayref.h" #include "gromacs/utility/exceptions.h" #include "gromacs/utility/logger.h" #include "gromacs/utility/mpicomm.h" +#include "gromacs/utility/stringutil.h" #ifdef DIM # undef DIM #endif -#include #include #include -static torch::Tensor preparePbcType(PbcType* pbcType) +namespace gmx { - torch::Tensor pbcTensor = - torch::tensor({ true, true, true }, torch::TensorOptions().dtype(torch::kBool)); - if (*pbcType == PbcType::XY) + +static std::optional indexOf(ArrayRef vec, const int val) +{ + auto it = std::find(vec.begin(), vec.end(), val); + if (it == vec.end()) { - pbcTensor[2] = false; + return std::nullopt; } - else if (*pbcType != PbcType::Xyz) + return std::distance(vec.begin(), it); +} + +static std::tuple> getMdrunActiveDevice() +{ +#if GMX_GPU_CUDA || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_CUDA_TARGET) + GMX_RELEASE_ASSERT(torch::hasCUDA(), "Libtorch not compiled with CUDA support."); + int activeDevice; + if (cudaGetDevice(&activeDevice) != cudaSuccess) + { + GMX_THROW(InternalError("cudaGetDevice failed.")); + } + return { "cuda", activeDevice }; +#elif GMX_GPU_HIP || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_HIP_TARGET) +# ifndef USE_ROCM + GMX_THROW(InternalError("Libtorch not compiled with HIP support.")); +# endif + int activeDevice; + if (hipGetDevice(&activeDevice) != hipSuccess) { - GMX_THROW(gmx::InconsistentInputError( - "Option use_pbc was set to true, but PBC type is not supported.")); + GMX_THROW(InternalError("hipGetDevice failed.")); } - return pbcTensor; + return { "hip", activeDevice }; +#else + return { "cpu", std::nullopt }; +#endif } -static metatensor_torch::TensorBlock computeNeighbors(metatomic_torch::NeighborListOptions request, - long n_atoms, - const float* positions, - const matrix box, - bool periodic, - torch::Device device, - torch::ScalarType dtype) +static torch::Device determineDevice(const MDLogger& logger, const MpiComm& mpiComm) { - auto cutoff = request->engine_cutoff("nm"); + torch::Device device(torch::kCPU); - VesinOptions options; - options.cutoff = cutoff; - options.full = request->full_list(); - options.return_shifts = true; - options.return_distances = false; - options.return_vectors = true; + // Non-main ranks don't run model, return CPU + if (!mpiComm.isMainRank()) + { + return device; + } - VesinNeighborList* vesin_neighbor_list = new VesinNeighborList(); + auto [torchDeviceType, activeDevice] = getMdrunActiveDevice(); - double double_box[3][3]; - for (int i = 0; i < 3; i++) + if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) + { + const std::string devLC = toLowerCase(env); + if (devLC == "gpu" || devLC == "cuda") + { + if (!torch::cuda::is_available()) + { + GMX_THROW(InternalError(formatString( + "GMX_METATOMIC_DEVICE='%s' but no device available.", env))); + } + GMX_RELEASE_ASSERT(activeDevice.has_value(), "Could not determine active device."); + device = torch::Device(torch::kCUDA, activeDevice.value()); + } + else if (devLC != "cpu") + { + GMX_THROW(InvalidInputError(formatString( + "GMX_METATOMIC_DEVICE invalid value: '%s'.", env))); + } + GMX_LOG(logger.info).asParagraph() + .appendTextFormatted("Using device from GMX_METATOMIC_DEVICE: '%s'.", env); + } + else { - for (int j = 0; j < 3; j++) + if (torch::cuda::is_available() && activeDevice.has_value()) { - double_box[i][j] = static_cast(box[i][j]); + GMX_LOG(logger.info).asParagraph() + .appendText("Using " + toUpperCase(torchDeviceType) + " for Metatomic."); + device = torch::Device(torch::kCUDA, activeDevice.value()); + } + else + { + GMX_LOG(logger.info).asParagraph().appendText("Using CPU for Metatomic."); } } + return device; +} - const size_t total_elements = static_cast(n_atoms) * 3; - std::vector double_positions(total_elements); - - for (size_t i = 0; i < total_elements; i++) +static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) +{ + torch::Tensor pbcTensor = + torch::tensor({ true, true, true }, torch::TensorOptions().dtype(torch::kBool)); + if (*pbcType == PbcType::XY) { - double_positions[i] = static_cast(positions[i]); + pbcTensor[2] = false; } - const double* positions_ptr = double_positions.data(); - - VesinDevice cpu{ VesinCPU, 0 }; - const char* error_message = nullptr; - int status = vesin_neighbors(reinterpret_cast(positions_ptr), - static_cast(n_atoms), - double_box, - &periodic, - cpu, - options, - vesin_neighbor_list, - &error_message); - - if (status != EXIT_SUCCESS) + else if (*pbcType != PbcType::Xyz) { - std::string err_str = "vesin_neighbors failed: "; - if (error_message) - { - err_str += error_message; - } - delete vesin_neighbor_list; - GMX_THROW(gmx::APIError(err_str)); + GMX_THROW(InconsistentInputError("PBC type not supported.")); } + return pbcTensor.to(device); +} - auto n_pairs = static_cast(vesin_neighbor_list->length); +static metatensor_torch::TensorBlock buildNeighborListFromPairlist( + ArrayRef pairlist, + ArrayRef shiftVectors, + torch::Device device, + torch::ScalarType dtype) +{ + const int64_t n_pairs = static_cast(pairlist.size() / 2); - auto pair_samples_values = torch::empty({ n_pairs, 5 }, torch::TensorOptions().dtype(torch::kInt32)); + auto pair_samples_values = torch::zeros({ n_pairs, 5 }, torch::TensorOptions().dtype(torch::kInt32)); auto pair_samples_ptr = pair_samples_values.accessor(); + + auto pair_vectors = torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)); + auto vectors_accessor = pair_vectors.accessor(); + for (int64_t i = 0; i < n_pairs; i++) { - pair_samples_ptr[i][0] = static_cast(vesin_neighbor_list->pairs[i][0]); - pair_samples_ptr[i][1] = static_cast(vesin_neighbor_list->pairs[i][1]); - pair_samples_ptr[i][2] = vesin_neighbor_list->shifts[i][0]; - pair_samples_ptr[i][3] = vesin_neighbor_list->shifts[i][1]; - pair_samples_ptr[i][4] = vesin_neighbor_list->shifts[i][2]; + pair_samples_ptr[i][0] = static_cast(pairlist[2 * i]); + pair_samples_ptr[i][1] = static_cast(pairlist[2 * i + 1]); + pair_samples_ptr[i][2] = 0; + pair_samples_ptr[i][3] = 0; + pair_samples_ptr[i][4] = 0; + + vectors_accessor[i][0][0] = static_cast(shiftVectors[i][0]); + vectors_accessor[i][1][0] = static_cast(shiftVectors[i][1]); + vectors_accessor[i][2][0] = static_cast(shiftVectors[i][2]); } - auto deleter = [=](void*) - { - vesin_free(vesin_neighbor_list); - delete vesin_neighbor_list; - }; - - auto pair_vectors = torch::from_blob(vesin_neighbor_list->vectors, - { n_pairs, 3, 1 }, - deleter, - torch::TensorOptions().dtype(torch::kFloat64)); - pair_vectors.to(dtype); - auto neighbor_samples = torch::make_intrusive( std::vector{ "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c" }, @@ -184,8 +215,6 @@ static metatensor_torch::TensorBlock computeNeighbors(metatomic_torch::NeighborL neighbor_properties); } -namespace gmx -{ struct MetatomicData { @@ -193,8 +222,8 @@ struct MetatomicData metatomic_torch::ModelCapabilities capabilities; std::vector nl_requests; metatomic_torch::ModelEvaluationOptions evaluations_options; - torch::ScalarType dtype; - bool check_consistency; + torch::ScalarType dtype = torch::kFloat32; + bool check_consistency = true; torch::Device device = torch::kCPU; }; @@ -207,314 +236,299 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, box_{ { 0.0, 0.0, 0.0 }, { 0.0, 0.0, 0.0 }, { 0.0, 0.0, 0.0 } }, data_(std::make_unique()) { - // ALL ranks load the model, not just the main rank - // This enables each rank to compute forces for its local atoms independently - - GMX_LOG(logger_.info).asParagraph().appendText("Initializing MetatomicForceProvider on all ranks..."); - - // Load the model on EVERY rank - try - { - torch::optional extensions_directory = torch::nullopt; - if (!options_.params_.extensionsDirectory.empty()) - { - extensions_directory = options_.params_.extensionsDirectory; - } + GMX_LOG(logger_.info).asParagraph().appendText("Initializing MetatomicForceProvider..."); - this->data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, - extensions_directory); - } - catch (const std::exception& e) + // Pairlist-based neighbor lists don't work with DD yet (indices are local) + // Matches NNPot's limitation + if (mpiComm_.isParallel()) { - GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); + GMX_THROW(NotImplementedError( + "Metatomic does not yet support domain decomposition. " + "Use thread-MPI (gmx mdrun) instead of MPI (mpirun gmx_mpi mdrun).")); } - // Query model capabilities on all ranks - data_->capabilities = - data_->model.run_method("capabilities").toCustomClass(); - auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); - for (const auto& request_ivalue : requests_ivalue.toList()) + data_->device = determineDevice(logger_, mpiComm_); + + // Only main rank loads model + if (mpiComm_.isMainRank()) { - data_->nl_requests.push_back( - request_ivalue.get().toCustomClass()); - } + try + { + torch::optional extensions_directory = torch::nullopt; + if (!options_.params_.extensionsDirectory.empty()) + { + extensions_directory = options_.params_.extensionsDirectory; + } - // Determine device - each rank picks its own device - torch::optional desired; - if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) { - GMX_LOG(logger_.info) - .asParagraph() - .appendText("Using device from GMX_METATOMIC_DEVICE environment variable: ") - .appendText(env); - desired = std::string(env); - } else { - desired = options_.params_.device; - } + data_->model = metatomic_torch::load_atomistic_model(options_.params_.modelPath_, + extensions_directory); + } + catch (const std::exception& e) + { + GMX_THROW(APIError("Failed to load metatomic model: " + std::string(e.what()))); + } - c10::DeviceType device_type_ = - metatomic_torch::pick_device(data_->capabilities->supported_devices, desired); - data_->device = torch::Device(device_type_); + data_->capabilities = + data_->model.run_method("capabilities").toCustomClass(); - data_->model.to(data_->device); + auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); + for (const auto& request_ivalue : requests_ivalue.toList()) + { + data_->nl_requests.push_back( + request_ivalue.get().toCustomClass()); + } - // Set data type on all ranks - if (data_->capabilities->dtype() == "float64") - { - data_->dtype = torch::kFloat64; - } - else if (data_->capabilities->dtype() == "float32") - { - data_->dtype = torch::kFloat32; - } - else - { - GMX_THROW(APIError("Unsupported dtype from model: " + data_->capabilities->dtype())); - } + data_->model.to(data_->device); - // Set up evaluation options on all ranks - data_->evaluations_options = - torch::make_intrusive(); - data_->evaluations_options->set_length_unit("nm"); + if (data_->capabilities->dtype() == "float64") + { + data_->dtype = torch::kFloat64; + } + else if (data_->capabilities->dtype() == "float32") + { + data_->dtype = torch::kFloat32; + } + else + { + GMX_THROW(APIError("Unsupported dtype: " + data_->capabilities->dtype())); + } - auto outputs = data_->capabilities->outputs(); - if (!outputs.contains("energy")) - { - GMX_THROW(APIError("Metatomic model must provide an 'energy' output.")); - } + data_->evaluations_options = + torch::make_intrusive(); + data_->evaluations_options->set_length_unit("nm"); - auto requested_output = torch::make_intrusive(); - requested_output->per_atom = true; // KEY: Request per-atom energies for proper DD - requested_output->explicit_gradients = {}; + auto outputs = data_->capabilities->outputs(); + if (!outputs.contains("energy")) + { + GMX_THROW(APIError("Metatomic model must provide 'energy' output.")); + } - data_->evaluations_options->outputs.insert("energy", requested_output); + auto requested_output = torch::make_intrusive(); + requested_output->per_atom = false; + requested_output->explicit_gradients = {}; - data_->check_consistency = options_.params_.checkConsistency; + data_->evaluations_options->outputs.insert("energy", requested_output); + data_->check_consistency = options_.params_.checkConsistency; + } - // Initialize lookup tables - will be populated on first DD const auto& mtaIndices = options_.params_.mtaIndices_; const int n_atoms = static_cast(mtaIndices.size()); - idxLookup_.resize(n_atoms, -1); + + positions_.resize(n_atoms); atomNumbers_.resize(n_atoms, 0); + inputToLocalIndex_.resize(n_atoms, -1); + inputToGlobalIndex_.resize(n_atoms, -1); - GMX_LOG(logger_.info) - .asParagraph() - .appendText("MetatomicForceProvider initialization complete on all ranks."); + GMX_LOG(logger_.info).asParagraph().appendText("MetatomicForceProvider initialization complete."); } -void MetatomicForceProvider::gatherAtomNumbersIndices() +MetatomicForceProvider::~MetatomicForceProvider() = default; + +void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal) { - // This function updates the lookup tables after domain decomposition - // Each rank knows which ML atoms are local to it - const auto& mtaIndices = options_.params_.mtaIndices_; - const int n_atoms = static_cast(mtaIndices.size()); - - // Reset lookup table - std::fill(idxLookup_.begin(), idxLookup_.end(), -1); + const int numInput = static_cast(mtaIndices.size()); - // Build reverse lookup: global index -> ML group index - std::unordered_map globalToMlIndex; - globalToMlIndex.reserve(n_atoms); - for (int i = 0; i < n_atoms; ++i) - { - globalToMlIndex[mtaIndices[i]] = i; - } + inputToLocalIndex_.assign(numInput, -1); + inputToGlobalIndex_.assign(numInput, -1); + atomNumbers_.assign(numInput, 0); - // Populate lookup for this rank's local atoms - const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); - for (size_t i = 0; i < mtaAtoms->numAtomsLocal(); ++i) + if (mpiComm_.isParallel()) { - const int lIdx = mtaAtoms->localIndex()[i]; - const int gIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), + "Global atom indices required for DD."); + auto globalAtomIndices = signal.globalAtomIndices_.value(); + const int numLocal = signal.x_.size(); - if (auto it = globalToMlIndex.find(gIdx); it != globalToMlIndex.end()) + for (int i = 0; i < static_cast(globalAtomIndices.size()); i++) { - const int mlIdx = it->second; - idxLookup_[mlIdx] = lIdx; - atomNumbers_[mlIdx] = options_.params_.atoms_.atom[gIdx].atomnumber; + int globalIdx = globalAtomIndices[i]; + for (int j = 0; j < numInput; j++) + { + if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) + { + if (i < numLocal) + { + inputToLocalIndex_[j] = i; + inputToGlobalIndex_[j] = globalIdx; + atomNumbers_[j] = options_.params_.atoms_.atom[globalIdx].atomnumber; + } + break; + } + } } + mpiComm_.sumReduce(numInput, atomNumbers_.data()); } - - // For parallel runs, we need all ranks to know ALL atom numbers for the system tensor - // But we only compute forces for local atoms - // XXX: seems buggy in parallel, need more checks - if (mpiComm_.isParallel()) + else { - // Each rank has partial atomNumbers_, sum to get complete list - // This is needed because the System object needs all atom types - mpiComm_.sumReduce(gmx::ArrayRef(atomNumbers_.data(), atomNumbers_.data() + n_atoms)); + const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); + for (int i = 0; i < numInput; i++) + { + int localIndex = mtaAtoms->localIndex()[i]; + int globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + inputToLocalIndex_[i] = localIndex; + inputToGlobalIndex_[i] = globalIdx; + atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; + } } -} -MetatomicForceProvider::~MetatomicForceProvider() = default; + GMX_RELEASE_ASSERT(std::count(atomNumbers_.begin(), atomNumbers_.end(), 0) == 0, + "Some atom numbers not set."); +} -void MetatomicForceProvider::gatherAtomPositions(ArrayRef globalPositions) +void MetatomicForceProvider::gatherAtomPositions(ArrayRef pos) { - const int n_atoms = static_cast(options_.params_.mtaIndices_.size()); - positions_.assign(n_atoms, RVec{ 0.0, 0.0, 0.0 }); + const size_t numInput = inputToLocalIndex_.size(); + positions_.assign(numInput, RVec({ 0.0, 0.0, 0.0 })); - // Each rank fills its local atoms' positions - for (int i = 0; i < n_atoms; ++i) + for (size_t i = 0; i < numInput; i++) { - if (idxLookup_[i] != -1) + if (inputToLocalIndex_[i] != -1) { - positions_[i] = globalPositions[idxLookup_[i]]; + positions_[i] = pos[inputToLocalIndex_[i]]; } } - // Sum-reduce positions so all ranks have complete position array - // This is needed because neighbor list computation needs all positions if (mpiComm_.isParallel()) { - real* data_ptr = reinterpret_cast(positions_.data()); - const size_t n_reals = static_cast(n_atoms) * 3; - mpiComm_.sumReduce(gmx::ArrayRef(data_ptr, data_ptr + n_reals)); + mpiComm_.sumReduce(3 * numInput, positions_.data()->as_vec()); } } -void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) +void MetatomicForceProvider::setPairlist(const MDModulesPairlistConstructedSignal& signal) { - const int n_atoms = static_cast(options_.params_.mtaIndices_.size()); - - // Gather positions (all ranks need this for neighbor list computation) - this->gatherAtomPositions(inputs.x_); - copy_mat(inputs.box_, box_); + fullPairlist_.assign(signal.excludedPairlist_.begin(), signal.excludedPairlist_.end()); + doPairlist_ = true; +} - auto gromacs_scalar_type = torch::kFloat32; - if (std::is_same_v) +void MetatomicForceProvider::preparePairlistInput() +{ + if (!doPairlist_) { - gromacs_scalar_type = torch::kFloat64; + return; } - auto blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(torch::kCPU); - - // ALL ranks run the model, each computing forces for its local atoms - auto coerced_positions = makeArrayRef(positions_); - auto torch_positions = - torch::from_blob(coerced_positions.data()->as_vec(), { n_atoms, 3 }, blob_options) - .to(data_->dtype) - .to(data_->device) - .set_requires_grad(true); + GMX_ASSERT(!fullPairlist_.empty(), "Pairlist empty!"); - auto torch_cell = - torch::from_blob(&box_, { 3, 3 }, blob_options).to(data_->dtype).to(data_->device); + const int numPairs = gmx::ssize(fullPairlist_); + pairlistForModel_.clear(); + pairlistForModel_.reserve(2 * numPairs); + shiftVectors_.clear(); + shiftVectors_.reserve(numPairs); - auto torch_pbc = preparePbcType(options_.params_.pbcType_.get()).to(data_->device); - auto torch_types = - torch::tensor(atomNumbers_, torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); + for (int i = 0; i < numPairs; i++) + { + const auto [atomPair, shiftIndex] = fullPairlist_[i]; - auto system = torch::make_intrusive( - torch_types, torch_positions, torch_cell, torch_pbc); + auto inputIdxA = indexOf(inputToGlobalIndex_, atomPair.first); + auto inputIdxB = indexOf(inputToGlobalIndex_, atomPair.second); - bool periodic = torch::all(torch_pbc).item(); + if (inputIdxA.has_value() && inputIdxB.has_value()) + { + RVec shift; + const IVec unitShift = shiftIndexToXYZ(shiftIndex); + mvmul_ur0(box_, unitShift.toRVec(), shift); - // Compute neighbor lists on each rank - for (const auto& request : data_->nl_requests) - { - auto neighbors = computeNeighbors( - request, n_atoms, coerced_positions.data()->as_vec(), box_, periodic, data_->device, data_->dtype); - metatomic_torch::register_autograd_neighbors(system, neighbors, false); - system->add_neighbor_list(request, neighbors); + pairlistForModel_.push_back(static_cast(inputIdxA.value())); + pairlistForModel_.push_back(static_cast(inputIdxB.value())); + shiftVectors_.push_back(shift); + } } - // Run the model on EVERY rank - metatensor_torch::TensorMap output_map; - try - { - std::vector systems; - systems.push_back(system); + GMX_RELEASE_ASSERT(pairlistForModel_.size() == shiftVectors_.size() * 2, + "Pairlist/shift size mismatch."); + doPairlist_ = false; +} - auto ivalue_output = data_->model.forward( - { c10::IValue(systems), data_->evaluations_options, data_->check_consistency }); - auto dict_output = ivalue_output.toGenericDict(); - output_map = dict_output.at("energy").toCustomClass(); - } - catch (const std::exception& e) - { - GMX_THROW(APIError("[MetatomicPotential] Model evaluation failed: " + std::string(e.what()))); - } +void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) +{ + const int n_atoms = static_cast(options_.params_.mtaIndices_.size()); - // Extract per-atom energies - auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); - auto energy_tensor = energy_block->values(); + gatherAtomPositions(inputs.x_); + copy_mat(inputs.box_, box_); + preparePairlistInput(); - // Handle energy for domain decomposition - // If per_atom=true, we get per-atom energies [n_atoms, 1] - // So sum over local atoms' energies - double local_energy = 0.0; - - if (energy_tensor.dim() == 2 && energy_tensor.size(0) == n_atoms) + // Force tensor - main rank fills, others have zeros + torch::Tensor forceTensor = torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); + + if (mpiComm_.isMainRank()) { - // Per-atom energies - sum only local atoms - auto energy_cpu = energy_tensor.to(torch::kCPU).to(torch::kFloat64); - auto energy_accessor = energy_cpu.accessor(); - - for (int i = 0; i < n_atoms; ++i) + auto gromacs_scalar_type = torch::kFloat32; + if (std::is_same_v) { - if (idxLookup_[i] != -1) // Only count local atoms - { - local_energy += energy_accessor[i][0]; - } + gromacs_scalar_type = torch::kFloat64; } - } - else - { - // Scalar energy - only main rank contributes (or divide among ranks) - if (mpiComm_.isMainRank()) + auto blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(torch::kCPU); + + auto torch_positions = + torch::from_blob(positions_.data()->as_vec(), { n_atoms, 3 }, blob_options) + .to(data_->dtype) + .to(data_->device) + .set_requires_grad(true); + + auto torch_cell = + torch::from_blob(&box_, { 3, 3 }, blob_options).to(data_->dtype).to(data_->device); + + auto torch_pbc = preparePbcType(options_.params_.pbcType_.get(), data_->device); + auto torch_types = + torch::tensor(atomNumbers_, torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); + + auto system = torch::make_intrusive( + torch_types, torch_positions, torch_cell, torch_pbc); + + // Build neighbor list from GROMACS pairlist + for (const auto& request : data_->nl_requests) { - local_energy = energy_tensor.item(); + auto neighbors = buildNeighborListFromPairlist( + pairlistForModel_, shiftVectors_, data_->device, data_->dtype); + metatomic_torch::register_autograd_neighbors(system, neighbors, false); + system->add_neighbor_list(request, neighbors); } - } - // Reduce energy across all ranks - double total_energy = local_energy; - if (mpiComm_.isParallel()) - { - mpiComm_.sumReduce(gmx::ArrayRef(&total_energy, &total_energy + 1)); - } - - // Only main rank sets the energy (GROMACS handles the rest) - if (mpiComm_.isMainRank()) - { + metatensor_torch::TensorMap output_map; + try + { + std::vector systems; + systems.push_back(system); + + auto ivalue_output = data_->model.forward( + { c10::IValue(systems), data_->evaluations_options, data_->check_consistency }); + auto dict_output = ivalue_output.toGenericDict(); + output_map = dict_output.at("energy").toCustomClass(); + } + catch (const std::exception& e) + { + GMX_THROW(APIError("[Metatomic] Model evaluation failed: " + std::string(e.what()))); + } + + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(output_map, 0); + auto energy_tensor = energy_block->values(); + outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = - static_cast(total_energy); - } + static_cast(energy_tensor.item()); - // Compute gradients - all ranks do this - energy_tensor.sum().backward(); - auto grad = system->positions().grad(); - auto forceTensor = -grad.to(torch::kCPU).to(data_->dtype); + energy_tensor.backward(); + auto grad = system->positions().grad(); + forceTensor = -grad.to(torch::kCPU).to(torch::kFloat64); + } - // Scatter forces to local atoms ONLY - no broadcast needed! - if (data_->dtype == torch::kFloat64) + // Distribute forces (sumReduce acts as broadcast since non-main ranks have zeros) + if (mpiComm_.isParallel()) { - auto accessor = forceTensor.accessor(); - for (int i = 0; i < n_atoms; ++i) - { - const int localIndex = idxLookup_[i]; - if (localIndex != -1) // Only update local atoms - { - outputs->forceWithVirial_.force_[localIndex][0] += accessor[i][0]; - outputs->forceWithVirial_.force_[localIndex][1] += accessor[i][1]; - outputs->forceWithVirial_.force_[localIndex][2] += accessor[i][2]; - } - } + mpiComm_.sumReduce(n_atoms * 3, static_cast(forceTensor.data_ptr())); } - else if (data_->dtype == torch::kFloat32) + + // Apply to local atoms only + auto forceAccessor = forceTensor.accessor(); + for (int i = 0; i < n_atoms; ++i) { - auto accessor = forceTensor.accessor(); - for (int i = 0; i < n_atoms; ++i) + if (inputToLocalIndex_[i] != -1) { - const int localIndex = idxLookup_[i]; - if (localIndex != -1) // Only update local atoms - { - outputs->forceWithVirial_.force_[localIndex][0] += static_cast(accessor[i][0]); - outputs->forceWithVirial_.force_[localIndex][1] += static_cast(accessor[i][1]); - outputs->forceWithVirial_.force_[localIndex][2] += static_cast(accessor[i][2]); - } + outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][0] += forceAccessor[i][0]; + outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][1] += forceAccessor[i][1]; + outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][2] += forceAccessor[i][2]; } } - // Note: Virial still needs proper implementation } } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index 0976da4d89..ccb1cd37be 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -50,10 +50,17 @@ namespace gmx struct MetatomicParameters; struct MetatomicData; +struct MDModulesAtomsRedistributedSignal; +struct MDModulesPairlistConstructedSignal; class MDLogger; class MpiComm; +/*! For compatibility with pairlist data structure in MDModulesPairlistConstructedSignal. + * Contains pairs like ((atom1, atom2), shiftIndex). + */ +using PairlistEntry = std::pair, int>; + /*! \brief \internal * MetatomicForceProvider class * @@ -72,28 +79,53 @@ class MetatomicForceProvider final : public IForceProvider * \param[out] fOutput output for force provider */ void calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) override; - void updateLocalAtoms(); - void gatherAtomPositions(ArrayRef globalPositions); - void gatherAtomNumbersIndices(); + + //! Gather atom numbers and indices. Triggered on AtomsRedistributed signal. + void gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal); + + //! Set pairlist from notification and filter to MTA atom pairs. + void setPairlist(const MDModulesPairlistConstructedSignal& signal); private: + //! Gather atom positions for MTA input. + void gatherAtomPositions(ArrayRef globalPositions); + + //! Prepare pairlist input for model + void preparePairlistInput(); + const MetatomicOptions& options_; const MDLogger& logger_; const MpiComm& mpiComm_; - //! vector storing all atom positions + //! vector storing all MTA atom positions std::vector positions_; //! vector storing all atomic numbers std::vector atomNumbers_; - //! global index lookup table to map indices from model input to global atom indices - std::vector idxLookup_; + //! lookup table to map model input indices [0...numInput) to local atom indices + std::vector inputToLocalIndex_; + + //! lookup table to map model input indices to global atom indices + std::vector inputToGlobalIndex_; + + //! Full pairlist from MDModules notification + std::vector fullPairlist_; + + //! Interacting pairs of MTA atoms within cutoff, for model input + std::vector pairlistForModel_; + + //! Shift vectors for each atom pair in pairlistForModel_ + std::vector shiftVectors_; //! local copy of simulation box matrix box_; + //! Data required for metatomic calculations std::unique_ptr data_; + + //! flag to check if pairlist data should be prepared + bool doPairlist_ = false; }; } // namespace gmx diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index 1e01e0f09a..67652e5ff2 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -175,6 +175,7 @@ class MetatomicMDModule final : public IMDModule * * The Metatomic module subscribes to the following notifications: * - Atom redistribution due to domain decomposition + * - Changes in the neighborlist * by taking a const MDModulesAtomsRedistributedSignal as a parameter. */ void subscribeToSimulationRunNotifications(MDModulesNotifiers* notifiers) override @@ -185,10 +186,14 @@ class MetatomicMDModule final : public IMDModule } // After domain decomposition, the force provider needs to know which atoms are local. - const auto notifyDDFunction = [this](const MDModulesAtomsRedistributedSignal& /*signal*/) { - force_provider_->gatherAtomNumbersIndices(); - }; + const auto notifyDDFunction = [this](const MDModulesAtomsRedistributedSignal& signal) + { force_provider_->gatherAtomNumbersIndices(signal); }; notifiers->simulationRunNotifier_.subscribe(notifyDDFunction); + + // subscribe to pairlist construction notification + const auto notifyPairlistFunction = [this](const MDModulesPairlistConstructedSignal& signal) + { force_provider_->setPairlist(signal); }; + notifiers->simulationRunNotifier_.subscribe(notifyPairlistFunction); } void initForceProviders(ForceProviders* forceProviders) override @@ -199,8 +204,7 @@ class MetatomicMDModule final : public IMDModule } force_provider_ = std::make_unique( - options_, options_.logger(), options_.mpiComm() - ); + options_, options_.logger(), options_.mpiComm()); forceProviders->addForceProvider(force_provider_.get(), "Metatomic"); } From 776093e82abc0d673f02406cc2feae5da8888f5c Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Fri, 30 Jan 2026 08:45:05 +0100 Subject: [PATCH 02/17] tst(mtaopt): fixup --- cmake/gmxManageMetatomic.cmake | 19 ------------------- .../metatomic/tests/metatomic_options.cpp | 4 ++-- ...MetatomicOptionsTest_DefaultParameters.xml | 9 ++++----- ...ionsTest_OutputDefaultValuesWhenActive.xml | 12 ++++++------ ...Test_OutputNoDefaultValuesWhenInactive.xml | 2 +- 5 files changed, 13 insertions(+), 33 deletions(-) diff --git a/cmake/gmxManageMetatomic.cmake b/cmake/gmxManageMetatomic.cmake index 49e1bb2a2b..e4a60c663f 100644 --- a/cmake/gmxManageMetatomic.cmake +++ b/cmake/gmxManageMetatomic.cmake @@ -111,9 +111,6 @@ if(NOT GMX_METATOMIC STREQUAL "OFF") set(METATOMIC_TORCH_VERSION "0.1.7") set(METATOMIC_TORCH_SHA256 "726f5711b70c4b8cc80d9bc6c3ce6f3449f31d20acc644ab68dab083aa4ea572") - set(VESIN_VERSION "0.4.1") - set(VESIN_GIT_TAG "87dcad999fec47b29ab21be9662ef283edc7530b") - set(DOWNLOAD_METATENSOR_DEFAULT ON) find_package(metatensor_torch ${METATENSOR_TORCH_VERSION} QUIET) if (metatensor_torch_FOUND) @@ -168,23 +165,7 @@ if(NOT GMX_METATOMIC STREQUAL "OFF") find_package(metatomic_torch REQUIRED ${METATOMIC_TORCH_VERSION}) endif() - # always fetch vesin - FetchContent_Declare( - vesin - GIT_REPOSITORY https://github.com/Luthaf/vesin.git - GIT_TAG ${VESIN_GIT_TAG} - ) - - FetchContent_MakeAvailable(vesin) - install(TARGETS vesin EXPORT libgromacs - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} - ) - list(APPEND GMX_COMMON_LIBRARIES - vesin metatensor metatomic_torch metatensor_torch diff --git a/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp b/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp index 3edd815ec9..257d830e19 100644 --- a/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp +++ b/src/gromacs/applied_forces/metatomic/tests/metatomic_options.cpp @@ -34,9 +34,9 @@ */ /*! \internal \file * \brief - * Tests for functionality of the NNPotOptions + * Tests for functionality of the MetatomicOptions * - * \author Lukas Müllender + * \author Metatensor developers * \ingroup module_applied_forces */ diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml index d9407106e7..21555a3181 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_DefaultParameters.xml @@ -2,10 +2,9 @@ false - model.pt - System - - - + System + + + true diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml index cf1c69fc50..9f2f99bd32 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputDefaultValuesWhenActive.xml @@ -3,11 +3,11 @@ ; Machine learning potential using metatomic -metatomic-active = true -metatomic-input-group = System -metatomic-model = -metatomic-extensions = -metatomic-device = -metatomic-check-consistency = true +metatomic-active = true +metatomic-input-group = System +metatomic-model = +metatomic-extensions = +metatomic-device = +metatomic-check-consistency = true diff --git a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml index c5c4bbdd90..0f2e8f2a3e 100644 --- a/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml +++ b/src/gromacs/applied_forces/metatomic/tests/refdata/MetatomicOptionsTest_OutputNoDefaultValuesWhenInactive.xml @@ -3,6 +3,6 @@ ; Machine learning potential using metatomic -metatomic-active = false +metatomic-active = false From 61f7e88ee994249c67bde6b243576bf6f30bebe5 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Fri, 30 Jan 2026 09:17:36 +0100 Subject: [PATCH 03/17] bug(pairlists): subscribe correctly --- .../applied_forces/metatomic/metatomic_mdmodule.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index 67652e5ff2..36fcf578c4 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -47,6 +47,7 @@ #include "gromacs/domdec/localatomset.h" #include "gromacs/domdec/localatomsetmanager.h" #include "gromacs/mdrunutility/mdmodulesnotifiers.h" +#include "gromacs/mdrunutility/plainpairlistranges.h" #include "gromacs/mdtypes/imdmodule.h" #include "gromacs/utility/keyvaluetreebuilder.h" @@ -167,6 +168,14 @@ class MetatomicMDModule final : public IMDModule [](MDModulesEnergyOutputToMetatomicPotRequestChecker* energyOutputRequest) { energyOutputRequest->energyOutputToMetatomicPot_ = true; }; notifiers->simulationSetupNotifier_.subscribe(requestEnergyOutput); + + const auto setPlainPairlistRangeFunction = [this](PlainPairlistRanges* ranges) + { + // XXX: take the real cutoff.. + float cutoff = 0.576; + ranges->addRange(cutoff); + }; + notifiers->simulationSetupNotifier_.subscribe(setPlainPairlistRangeFunction); } /*! \brief Requests to be notified during the simulation. From 9c302910d102a5eeaa68ae35fce2354ed5b7e80b Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Fri, 30 Jan 2026 11:34:04 +0100 Subject: [PATCH 04/17] feat(mtagro): actually use the DD pairlist --- .../metatomic/metatomic_forceprovider.cpp | 87 +++++++++++-------- .../metatomic/metatomic_mdmodule.cpp | 37 +++++++- 2 files changed, 85 insertions(+), 39 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 2afd3e3875..523dd97169 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -34,7 +34,7 @@ /*! \internal \file * \brief * Implements the Metatomic Force Provider class with proper domain decomposition support. - * + * * \author Metatensor developers * \ingroup module_applied_forces */ @@ -43,6 +43,7 @@ #include "metatomic_forceprovider.h" #include + #include #include "gromacs/domdec/localatomset.h" @@ -122,25 +123,26 @@ static torch::Device determineDevice(const MDLogger& logger, const MpiComm& mpiC { if (!torch::cuda::is_available()) { - GMX_THROW(InternalError(formatString( - "GMX_METATOMIC_DEVICE='%s' but no device available.", env))); + GMX_THROW(InternalError( + formatString("GMX_METATOMIC_DEVICE='%s' but no device available.", env))); } GMX_RELEASE_ASSERT(activeDevice.has_value(), "Could not determine active device."); device = torch::Device(torch::kCUDA, activeDevice.value()); } else if (devLC != "cpu") { - GMX_THROW(InvalidInputError(formatString( - "GMX_METATOMIC_DEVICE invalid value: '%s'.", env))); + GMX_THROW(InvalidInputError(formatString("GMX_METATOMIC_DEVICE invalid value: '%s'.", env))); } - GMX_LOG(logger.info).asParagraph() + GMX_LOG(logger.info) + .asParagraph() .appendTextFormatted("Using device from GMX_METATOMIC_DEVICE: '%s'.", env); } else { if (torch::cuda::is_available() && activeDevice.has_value()) { - GMX_LOG(logger.info).asParagraph() + GMX_LOG(logger.info) + .asParagraph() .appendText("Using " + toUpperCase(torchDeviceType) + " for Metatomic."); device = torch::Device(torch::kCUDA, activeDevice.value()); } @@ -167,31 +169,42 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) return pbcTensor.to(device); } -static metatensor_torch::TensorBlock buildNeighborListFromPairlist( - ArrayRef pairlist, - ArrayRef shiftVectors, - torch::Device device, - torch::ScalarType dtype) +static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, + ArrayRef shiftVectors, + ArrayRef positions, + torch::Device device, + torch::ScalarType dtype) { const int64_t n_pairs = static_cast(pairlist.size() / 2); auto pair_samples_values = torch::zeros({ n_pairs, 5 }, torch::TensorOptions().dtype(torch::kInt32)); auto pair_samples_ptr = pair_samples_values.accessor(); + // Full interatomic vectors (rj - ri + shift), not just shifts. auto pair_vectors = torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)); - auto vectors_accessor = pair_vectors.accessor(); + + auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)); + auto vectors_accessor = vectors_cpu.accessor(); for (int64_t i = 0; i < n_pairs; i++) { - pair_samples_ptr[i][0] = static_cast(pairlist[2 * i]); - pair_samples_ptr[i][1] = static_cast(pairlist[2 * i + 1]); + int atom_i = pairlist[2 * i]; + int atom_j = pairlist[2 * i + 1]; + + pair_samples_ptr[i][0] = static_cast(atom_i); + pair_samples_ptr[i][1] = static_cast(atom_j); pair_samples_ptr[i][2] = 0; pair_samples_ptr[i][3] = 0; pair_samples_ptr[i][4] = 0; - vectors_accessor[i][0][0] = static_cast(shiftVectors[i][0]); - vectors_accessor[i][1][0] = static_cast(shiftVectors[i][1]); - vectors_accessor[i][2][0] = static_cast(shiftVectors[i][2]); + // Calculate r_ij = r_j - r_i + shift + double r_ij_x = positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]; + double r_ij_y = positions[atom_j][1] - positions[atom_i][1] + shiftVectors[i][1]; + double r_ij_z = positions[atom_j][2] - positions[atom_i][2] + shiftVectors[i][2]; + + vectors_accessor[i][0][0] = r_ij_x; + vectors_accessor[i][1][0] = r_ij_y; + vectors_accessor[i][2][0] = r_ij_z; } auto neighbor_samples = torch::make_intrusive( @@ -209,7 +222,7 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist( torch::zeros({ 1, 1 }, torch::TensorOptions().dtype(torch::kInt32).device(device))); return torch::make_intrusive( - pair_vectors.to(dtype).to(device), + vectors_cpu.to(dtype).to(device), neighbor_samples, std::vector{ neighbor_component }, neighbor_properties); @@ -222,9 +235,9 @@ struct MetatomicData metatomic_torch::ModelCapabilities capabilities; std::vector nl_requests; metatomic_torch::ModelEvaluationOptions evaluations_options; - torch::ScalarType dtype = torch::kFloat32; + torch::ScalarType dtype = torch::kFloat32; bool check_consistency = true; - torch::Device device = torch::kCPU; + torch::Device device = torch::kCPU; }; MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, @@ -303,7 +316,7 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, GMX_THROW(APIError("Metatomic model must provide 'energy' output.")); } - auto requested_output = torch::make_intrusive(); + auto requested_output = torch::make_intrusive(); requested_output->per_atom = false; requested_output->explicit_gradients = {}; @@ -319,7 +332,9 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, inputToLocalIndex_.resize(n_atoms, -1); inputToGlobalIndex_.resize(n_atoms, -1); - GMX_LOG(logger_.info).asParagraph().appendText("MetatomicForceProvider initialization complete."); + GMX_LOG(logger_.info) + .asParagraph() + .appendText("MetatomicForceProvider initialization complete."); } MetatomicForceProvider::~MetatomicForceProvider() = default; @@ -337,8 +352,8 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist { GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), "Global atom indices required for DD."); - auto globalAtomIndices = signal.globalAtomIndices_.value(); - const int numLocal = signal.x_.size(); + auto globalAtomIndices = signal.globalAtomIndices_.value(); + const int numLocal = signal.x_.size(); for (int i = 0; i < static_cast(globalAtomIndices.size()); i++) { @@ -364,8 +379,8 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); for (int i = 0; i < numInput; i++) { - int localIndex = mtaAtoms->localIndex()[i]; - int globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + int localIndex = mtaAtoms->localIndex()[i]; + int globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; inputToLocalIndex_[i] = localIndex; inputToGlobalIndex_[i] = globalIdx; atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; @@ -425,7 +440,7 @@ void MetatomicForceProvider::preparePairlistInput() if (inputIdxA.has_value() && inputIdxB.has_value()) { - RVec shift; + RVec shift; const IVec unitShift = shiftIndexToXYZ(shiftIndex); mvmul_ur0(box_, unitShift.toRVec(), shift); @@ -449,7 +464,8 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F preparePairlistInput(); // Force tensor - main rank fills, others have zeros - torch::Tensor forceTensor = torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); + torch::Tensor forceTensor = + torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); if (mpiComm_.isMainRank()) { @@ -460,11 +476,10 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F } auto blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(torch::kCPU); - auto torch_positions = - torch::from_blob(positions_.data()->as_vec(), { n_atoms, 3 }, blob_options) - .to(data_->dtype) - .to(data_->device) - .set_requires_grad(true); + auto torch_positions = torch::from_blob(positions_.data()->as_vec(), { n_atoms, 3 }, blob_options) + .to(data_->dtype) + .to(data_->device) + .set_requires_grad(true); auto torch_cell = torch::from_blob(&box_, { 3, 3 }, blob_options).to(data_->dtype).to(data_->device); @@ -480,7 +495,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F for (const auto& request : data_->nl_requests) { auto neighbors = buildNeighborListFromPairlist( - pairlistForModel_, shiftVectors_, data_->device, data_->dtype); + pairlistForModel_, shiftVectors_, positions_, data_->device, data_->dtype); metatomic_torch::register_autograd_neighbors(system, neighbors, false); system->add_neighbor_list(request, neighbors); } @@ -508,7 +523,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F static_cast(energy_tensor.item()); energy_tensor.backward(); - auto grad = system->positions().grad(); + auto grad = system->positions().grad(); forceTensor = -grad.to(torch::kCPU).to(torch::kFloat64); } diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index 36fcf578c4..4627248423 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -53,6 +53,12 @@ #include "metatomic_forceprovider.h" #include "metatomic_options.h" +#ifdef DIM +# undef DIM +#endif + +#include +#include namespace gmx { @@ -171,9 +177,34 @@ class MetatomicMDModule final : public IMDModule const auto setPlainPairlistRangeFunction = [this](PlainPairlistRanges* ranges) { - // XXX: take the real cutoff.. - float cutoff = 0.576; - ranges->addRange(cutoff); + // Temporary: Load model just to peek at cutoff. + // Optimization: move this to MetatomicOptions::checkNNPotModel equivalent later. + double req_cutoff = 0.0; + try + { + auto model = metatomic_torch::load_atomistic_model(options_.parameters().modelPath_); + auto requests = model.run_method("requested_neighbor_lists"); + for (const auto& req : requests.toList()) + { + auto opts = req.get().toCustomClass(); + double c = opts->engine_cutoff("nm"); + if (c > req_cutoff) + req_cutoff = c; + } + } + catch (const std::exception& e) + { + GMX_THROW(InternalError("Failed to read cutoff from model: " + std::string(e.what()))); + } + + if (req_cutoff <= 0.0) + { + // Fallback or error + GMX_THROW(InconsistentInputError( + "Metatomic model requested 0.0 or negative cutoff.")); + } + + ranges->addRange(req_cutoff); }; notifiers->simulationSetupNotifier_.subscribe(setPlainPairlistRangeFunction); } From 28a581c225069e03805027bb512a8c3ddc78b443 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 1 Feb 2026 03:29:50 +0100 Subject: [PATCH 05/17] chore(mtagro): try to prevent CUDA errors --- .../applied_forces/metatomic/metatomic_forceprovider.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 523dd97169..85a2e6c255 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -65,6 +65,10 @@ #include #include +#if GMX_GPU_CUDA || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_CUDA_TARGET) +#include +#endif + namespace gmx { From 6727ce34bc99fea47d463246254fc18673610d13 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Sun, 1 Feb 2026 21:11:17 +0100 Subject: [PATCH 06/17] feat(mtagro): actually compute the virial --- .../metatomic/metatomic_forceprovider.cpp | 43 ++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 85a2e6c255..1c04634f3f 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -59,7 +59,7 @@ #include "gromacs/utility/stringutil.h" #ifdef DIM -# undef DIM +#undef DIM #endif #include @@ -471,6 +471,9 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F torch::Tensor forceTensor = torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); + // Virial tensor for pressure/stress calculations + torch::Tensor virialTensor = torch::zeros({ 3, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); + if (mpiComm_.isMainRank()) { auto gromacs_scalar_type = torch::kFloat32; @@ -488,12 +491,19 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F auto torch_cell = torch::from_blob(&box_, { 3, 3 }, blob_options).to(data_->dtype).to(data_->device); + // Create strain tensor for virial computation (like LAMMPS does) + auto strain = torch::eye( + 3, torch::TensorOptions().dtype(data_->dtype).device(data_->device).requires_grad(true)); + + // Apply strain to cell: strained_cell = cell @ strain + auto strained_cell = torch::matmul(torch_cell, strain); + auto torch_pbc = preparePbcType(options_.params_.pbcType_.get(), data_->device); auto torch_types = torch::tensor(atomNumbers_, torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); auto system = torch::make_intrusive( - torch_types, torch_positions, torch_cell, torch_pbc); + torch_types, torch_positions, strained_cell, torch_pbc); // Build neighbor list from GROMACS pairlist for (const auto& request : data_->nl_requests) @@ -526,18 +536,28 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F outputs->enerd_.term[InteractionFunction::MetatomicPotentialEnergy] = static_cast(energy_tensor.item()); - energy_tensor.backward(); - auto grad = system->positions().grad(); - forceTensor = -grad.to(torch::kCPU).to(torch::kFloat64); + // Reset gradients before backward + torch_positions.mutable_grad() = torch::Tensor(); + strain.mutable_grad() = torch::Tensor(); + + // Compute forces and virial via backward propagation + energy_tensor.backward(-torch::ones_like(energy_tensor)); + + auto grad = torch_positions.grad(); + forceTensor = grad.to(torch::kCPU).to(torch::kFloat64); + + // Get virial from strain gradient + virialTensor = strain.grad().to(torch::kCPU).to(torch::kFloat64); } // Distribute forces (sumReduce acts as broadcast since non-main ranks have zeros) if (mpiComm_.isParallel()) { mpiComm_.sumReduce(n_atoms * 3, static_cast(forceTensor.data_ptr())); + mpiComm_.sumReduce(9, static_cast(virialTensor.data_ptr())); } - // Apply to local atoms only + // Apply forces to local atoms only auto forceAccessor = forceTensor.accessor(); for (int i = 0; i < n_atoms; ++i) { @@ -548,6 +568,17 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F outputs->forceWithVirial_.force_[inputToLocalIndex_[i]][2] += forceAccessor[i][2]; } } + + // Apply virial contribution + // GROMACS uses a 3x3 virial tensor in forceWithVirial_ + auto virialAccessor = virialTensor.accessor(); + for (int i = 0; i < 3; ++i) + { + for (int j = 0; j < 3; ++j) + { + outputs->forceWithVirial_.virial_[i][j] += virialAccessor[i][j]; + } + } } } // namespace gmx From 54583a93c5c2771c1fc4929798ce0ebb092f5ba7 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 13:33:29 +0100 Subject: [PATCH 07/17] bug(mtagro): use the public API for virial --- .../metatomic/metatomic_forceprovider.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 1c04634f3f..a009ef20f3 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -59,14 +59,14 @@ #include "gromacs/utility/stringutil.h" #ifdef DIM -#undef DIM +# undef DIM #endif #include #include #if GMX_GPU_CUDA || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_CUDA_TARGET) -#include +# include #endif @@ -571,14 +571,18 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // Apply virial contribution // GROMACS uses a 3x3 virial tensor in forceWithVirial_ - auto virialAccessor = virialTensor.accessor(); + // Copy the tensor data into a GROMACS matrix and use the public API + matrix virialMatrix; + auto virialAccessor = virialTensor.accessor(); + // TODO: technically this is DIM, not 3... for (int i = 0; i < 3; ++i) { for (int j = 0; j < 3; ++j) { - outputs->forceWithVirial_.virial_[i][j] += virialAccessor[i][j]; + virialMatrix[i][j] = virialAccessor[i][j]; } } + outputs->forceWithVirial_.addVirialContribution(virialMatrix); } } // namespace gmx From 9c0e5ead036f803f573191f2776ef467e870b53e Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 13:59:08 +0100 Subject: [PATCH 08/17] chore(gromta): import LAMMPS interaction logic --- .../metatomic/metatomic_mdmodule.cpp | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index 4627248423..e5201be178 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -49,6 +49,7 @@ #include "gromacs/mdrunutility/mdmodulesnotifiers.h" #include "gromacs/mdrunutility/plainpairlistranges.h" #include "gromacs/mdtypes/imdmodule.h" +#include "gromacs/utility/basenetwork.h" #include "gromacs/utility/keyvaluetreebuilder.h" #include "metatomic_forceprovider.h" @@ -57,6 +58,8 @@ # undef DIM #endif +#include + #include #include @@ -178,18 +181,41 @@ class MetatomicMDModule final : public IMDModule const auto setPlainPairlistRangeFunction = [this](PlainPairlistRanges* ranges) { // Temporary: Load model just to peek at cutoff. - // Optimization: move this to MetatomicOptions::checkNNPotModel equivalent later. - double req_cutoff = 0.0; + // TODO: can the whole model be loaded earlier..? + double max_cutoff{ 0.0 }; try { auto model = metatomic_torch::load_atomistic_model(options_.parameters().modelPath_); - auto requests = model.run_method("requested_neighbor_lists"); - for (const auto& req : requests.toList()) + auto capabilities = + model.run_method("capabilities").toCustomClass(); + double interaction_range = capabilities->engine_interaction_range("nm"); + if (interaction_range < 0.0) + { + GMX_THROW(InconsistentInputError( + "interaction_range is negative for this model.")); + } + + if (!std::isfinite(interaction_range)) + { + if (gmx_node_num() > 1) + { + GMX_THROW(NotImplementedError( + "interaction_range is infinite for this model; " + "using multiple MPI domains is not supported.")); + } + + auto requested_nl = model.run_method("requested_neighbor_lists"); + for (const auto& ivalue : requested_nl.toList()) + { + auto options = + ivalue.get().toCustomClass(); + double cutoff = options->engine_cutoff("nm"); + max_cutoff = std::max(max_cutoff, cutoff); + } + } + else { - auto opts = req.get().toCustomClass(); - double c = opts->engine_cutoff("nm"); - if (c > req_cutoff) - req_cutoff = c; + max_cutoff = interaction_range; } } catch (const std::exception& e) @@ -197,14 +223,13 @@ class MetatomicMDModule final : public IMDModule GMX_THROW(InternalError("Failed to read cutoff from model: " + std::string(e.what()))); } - if (req_cutoff <= 0.0) + if (max_cutoff <= 0.0 || !std::isfinite(max_cutoff)) { - // Fallback or error GMX_THROW(InconsistentInputError( "Metatomic model requested 0.0 or negative cutoff.")); } - ranges->addRange(req_cutoff); + ranges->addRange(max_cutoff); }; notifiers->simulationSetupNotifier_.subscribe(setPlainPairlistRangeFunction); } From 8f82210a1a6ad21bc29e30a596de00231ca17887 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 14:13:04 +0100 Subject: [PATCH 09/17] bug(mtagro): apply strain on atoms --- .../applied_forces/metatomic/metatomic_forceprovider.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index a009ef20f3..6e3356e953 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -497,13 +497,15 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // Apply strain to cell: strained_cell = cell @ strain auto strained_cell = torch::matmul(torch_cell, strain); + // Apply strain to positions: r' = r @ strain + auto strained_positions = torch::matmul(torch_positions, strain); auto torch_pbc = preparePbcType(options_.params_.pbcType_.get(), data_->device); auto torch_types = torch::tensor(atomNumbers_, torch::TensorOptions().dtype(torch::kInt32)).to(data_->device); auto system = torch::make_intrusive( - torch_types, torch_positions, strained_cell, torch_pbc); + torch_types, strained_positions, strained_cell, torch_pbc); // Build neighbor list from GROMACS pairlist for (const auto& request : data_->nl_requests) From eddad9db1e119a4a47d1abfba506c6ffb30022a1 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 14:26:43 +0100 Subject: [PATCH 10/17] chore(gromta): restructure with comments --- .../metatomic/metatomic_mdmodule.cpp | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp index e5201be178..6f3270f0ed 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_mdmodule.cpp @@ -181,14 +181,17 @@ class MetatomicMDModule final : public IMDModule const auto setPlainPairlistRangeFunction = [this](PlainPairlistRanges* ranges) { // Temporary: Load model just to peek at cutoff. - // TODO: can the whole model be loaded earlier..? + // TODO: can the whole model be loaded earlier..? double max_cutoff{ 0.0 }; try { auto model = metatomic_torch::load_atomistic_model(options_.parameters().modelPath_); + + // Check strict interaction range auto capabilities = model.run_method("capabilities").toCustomClass(); double interaction_range = capabilities->engine_interaction_range("nm"); + if (interaction_range < 0.0) { GMX_THROW(InconsistentInputError( @@ -197,26 +200,30 @@ class MetatomicMDModule final : public IMDModule if (!std::isfinite(interaction_range)) { + // Infinite range (global) is only supported on a single rank if (gmx_node_num() > 1) { GMX_THROW(NotImplementedError( "interaction_range is infinite for this model; " "using multiple MPI domains is not supported.")); } - - auto requested_nl = model.run_method("requested_neighbor_lists"); - for (const auto& ivalue : requested_nl.toList()) - { - auto options = - ivalue.get().toCustomClass(); - double cutoff = options->engine_cutoff("nm"); - max_cutoff = std::max(max_cutoff, cutoff); - } + // For infinite range, check if specific NLs were requested + // effectively falling through to the loop below. } else { max_cutoff = interaction_range; } + + // Check requested neighbor lists + auto requested_nl = model.run_method("requested_neighbor_lists"); + for (const auto& ivalue : requested_nl.toList()) + { + auto options = + ivalue.get().toCustomClass(); + double cutoff = options->engine_cutoff("nm"); + max_cutoff = std::max(max_cutoff, cutoff); + } } catch (const std::exception& e) { @@ -225,12 +232,17 @@ class MetatomicMDModule final : public IMDModule if (max_cutoff <= 0.0 || !std::isfinite(max_cutoff)) { - GMX_THROW(InconsistentInputError( - "Metatomic model requested 0.0 or negative cutoff.")); + // If the model is purely global and requested no specific NL, + // there's no need for a pairlist, so max_cutoff remains 0. + if (max_cutoff == 0.0) + { + GMX_THROW(InconsistentInputError("Metatomic model cutoff is 0.0 or invalid.")); + } } - + // Register the requirement with GROMACS ranges->addRange(max_cutoff); }; + // Register the callback notifiers->simulationSetupNotifier_.subscribe(setPlainPairlistRangeFunction); } From 644a4b74acd04e0f88f7c560eb3157e51ca58858 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 15:30:16 +0100 Subject: [PATCH 11/17] chore(mtagro): use fixed width integers --- .../metatomic/metatomic_forceprovider.cpp | 48 +++++++++---------- .../metatomic/metatomic_forceprovider.h | 10 ++-- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 6e3356e953..0c69efc8dd 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -73,7 +73,7 @@ namespace gmx { -static std::optional indexOf(ArrayRef vec, const int val) +static std::optional indexOf(ArrayRef vec, const int32_t val) { auto it = std::find(vec.begin(), vec.end(), val); if (it == vec.end()) @@ -83,11 +83,11 @@ static std::optional indexOf(ArrayRef vec, const int val) return std::distance(vec.begin(), it); } -static std::tuple> getMdrunActiveDevice() +static std::tuple> getMdrunActiveDevice() { #if GMX_GPU_CUDA || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_CUDA_TARGET) GMX_RELEASE_ASSERT(torch::hasCUDA(), "Libtorch not compiled with CUDA support."); - int activeDevice; + int32_t activeDevice; if (cudaGetDevice(&activeDevice) != cudaSuccess) { GMX_THROW(InternalError("cudaGetDevice failed.")); @@ -97,7 +97,7 @@ static std::tuple> getMdrunActiveDevice() # ifndef USE_ROCM GMX_THROW(InternalError("Libtorch not compiled with HIP support.")); # endif - int activeDevice; + int32_t activeDevice; if (hipGetDevice(&activeDevice) != hipSuccess) { GMX_THROW(InternalError("hipGetDevice failed.")); @@ -173,7 +173,7 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) return pbcTensor.to(device); } -static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, +static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, ArrayRef shiftVectors, ArrayRef positions, torch::Device device, @@ -192,8 +192,8 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(atom_i); pair_samples_ptr[i][1] = static_cast(atom_j); @@ -329,7 +329,7 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, } const auto& mtaIndices = options_.params_.mtaIndices_; - const int n_atoms = static_cast(mtaIndices.size()); + const int32_t n_atoms = static_cast(mtaIndices.size()); positions_.resize(n_atoms); atomNumbers_.resize(n_atoms, 0); @@ -346,7 +346,7 @@ MetatomicForceProvider::~MetatomicForceProvider() = default; void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal) { const auto& mtaIndices = options_.params_.mtaIndices_; - const int numInput = static_cast(mtaIndices.size()); + const int32_t numInput = static_cast(mtaIndices.size()); inputToLocalIndex_.assign(numInput, -1); inputToGlobalIndex_.assign(numInput, -1); @@ -357,12 +357,12 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), "Global atom indices required for DD."); auto globalAtomIndices = signal.globalAtomIndices_.value(); - const int numLocal = signal.x_.size(); + const int32_t numLocal = signal.x_.size(); - for (int i = 0; i < static_cast(globalAtomIndices.size()); i++) + for (int32_t i = 0; i < static_cast(globalAtomIndices.size()); i++) { - int globalIdx = globalAtomIndices[i]; - for (int j = 0; j < numInput; j++) + int32_t globalIdx = globalAtomIndices[i]; + for (int32_t j = 0; j < numInput; j++) { if (options_.params_.mtaAtoms_->globalIndex()[j] == globalIdx) { @@ -381,10 +381,10 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist else { const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); - for (int i = 0; i < numInput; i++) + for (int32_t i = 0; i < numInput; i++) { - int localIndex = mtaAtoms->localIndex()[i]; - int globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + int32_t localIndex = mtaAtoms->localIndex()[i]; + int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; inputToLocalIndex_[i] = localIndex; inputToGlobalIndex_[i] = globalIdx; atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; @@ -429,13 +429,13 @@ void MetatomicForceProvider::preparePairlistInput() GMX_ASSERT(!fullPairlist_.empty(), "Pairlist empty!"); - const int numPairs = gmx::ssize(fullPairlist_); + const int32_t numPairs = gmx::ssize(fullPairlist_); pairlistForModel_.clear(); pairlistForModel_.reserve(2 * numPairs); shiftVectors_.clear(); shiftVectors_.reserve(numPairs); - for (int i = 0; i < numPairs; i++) + for (int32_t i = 0; i < numPairs; i++) { const auto [atomPair, shiftIndex] = fullPairlist_[i]; @@ -448,8 +448,8 @@ void MetatomicForceProvider::preparePairlistInput() const IVec unitShift = shiftIndexToXYZ(shiftIndex); mvmul_ur0(box_, unitShift.toRVec(), shift); - pairlistForModel_.push_back(static_cast(inputIdxA.value())); - pairlistForModel_.push_back(static_cast(inputIdxB.value())); + pairlistForModel_.push_back(static_cast(inputIdxA.value())); + pairlistForModel_.push_back(static_cast(inputIdxB.value())); shiftVectors_.push_back(shift); } } @@ -461,7 +461,7 @@ void MetatomicForceProvider::preparePairlistInput() void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, ForceProviderOutput* outputs) { - const int n_atoms = static_cast(options_.params_.mtaIndices_.size()); + const int32_t n_atoms = static_cast(options_.params_.mtaIndices_.size()); gatherAtomPositions(inputs.x_); copy_mat(inputs.box_, box_); @@ -561,7 +561,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // Apply forces to local atoms only auto forceAccessor = forceTensor.accessor(); - for (int i = 0; i < n_atoms; ++i) + for (int32_t i = 0; i < n_atoms; ++i) { if (inputToLocalIndex_[i] != -1) { @@ -577,9 +577,9 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F matrix virialMatrix; auto virialAccessor = virialTensor.accessor(); // TODO: technically this is DIM, not 3... - for (int i = 0; i < 3; ++i) + for (int32_t i = 0; i < 3; ++i) { - for (int j = 0; j < 3; ++j) + for (int32_t j = 0; j < 3; ++j) { virialMatrix[i][j] = virialAccessor[i][j]; } diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index ccb1cd37be..ee378e42a4 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -59,7 +59,7 @@ class MpiComm; /*! For compatibility with pairlist data structure in MDModulesPairlistConstructedSignal. * Contains pairs like ((atom1, atom2), shiftIndex). */ -using PairlistEntry = std::pair, int>; +using PairlistEntry = std::pair, int32_t>; /*! \brief \internal * MetatomicForceProvider class @@ -101,19 +101,19 @@ class MetatomicForceProvider final : public IForceProvider std::vector positions_; //! vector storing all atomic numbers - std::vector atomNumbers_; + std::vector atomNumbers_; //! lookup table to map model input indices [0...numInput) to local atom indices - std::vector inputToLocalIndex_; + std::vector inputToLocalIndex_; //! lookup table to map model input indices to global atom indices - std::vector inputToGlobalIndex_; + std::vector inputToGlobalIndex_; //! Full pairlist from MDModules notification std::vector fullPairlist_; //! Interacting pairs of MTA atoms within cutoff, for model input - std::vector pairlistForModel_; + std::vector pairlistForModel_; //! Shift vectors for each atom pair in pairlistForModel_ std::vector shiftVectors_; From 376e568d41581e3b4e4edda4a4187e38897f76cc Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 15:36:37 +0100 Subject: [PATCH 12/17] chore(mtagro): cleanup with pick device --- .../metatomic/metatomic_forceprovider.cpp | 113 ++++-------------- 1 file changed, 25 insertions(+), 88 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 0c69efc8dd..0c3a888cf7 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -44,8 +44,6 @@ #include -#include - #include "gromacs/domdec/localatomset.h" #include "gromacs/mdlib/broadcaststructs.h" #include "gromacs/mdrunutility/mdmodulesnotifiers.h" @@ -83,81 +81,6 @@ static std::optional indexOf(ArrayRef vec, const int32 return std::distance(vec.begin(), it); } -static std::tuple> getMdrunActiveDevice() -{ -#if GMX_GPU_CUDA || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_CUDA_TARGET) - GMX_RELEASE_ASSERT(torch::hasCUDA(), "Libtorch not compiled with CUDA support."); - int32_t activeDevice; - if (cudaGetDevice(&activeDevice) != cudaSuccess) - { - GMX_THROW(InternalError("cudaGetDevice failed.")); - } - return { "cuda", activeDevice }; -#elif GMX_GPU_HIP || (GMX_SYCL_ACPP && GMX_ACPP_HAVE_HIP_TARGET) -# ifndef USE_ROCM - GMX_THROW(InternalError("Libtorch not compiled with HIP support.")); -# endif - int32_t activeDevice; - if (hipGetDevice(&activeDevice) != hipSuccess) - { - GMX_THROW(InternalError("hipGetDevice failed.")); - } - return { "hip", activeDevice }; -#else - return { "cpu", std::nullopt }; -#endif -} - -static torch::Device determineDevice(const MDLogger& logger, const MpiComm& mpiComm) -{ - torch::Device device(torch::kCPU); - - // Non-main ranks don't run model, return CPU - if (!mpiComm.isMainRank()) - { - return device; - } - - auto [torchDeviceType, activeDevice] = getMdrunActiveDevice(); - - if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) - { - const std::string devLC = toLowerCase(env); - if (devLC == "gpu" || devLC == "cuda") - { - if (!torch::cuda::is_available()) - { - GMX_THROW(InternalError( - formatString("GMX_METATOMIC_DEVICE='%s' but no device available.", env))); - } - GMX_RELEASE_ASSERT(activeDevice.has_value(), "Could not determine active device."); - device = torch::Device(torch::kCUDA, activeDevice.value()); - } - else if (devLC != "cpu") - { - GMX_THROW(InvalidInputError(formatString("GMX_METATOMIC_DEVICE invalid value: '%s'.", env))); - } - GMX_LOG(logger.info) - .asParagraph() - .appendTextFormatted("Using device from GMX_METATOMIC_DEVICE: '%s'.", env); - } - else - { - if (torch::cuda::is_available() && activeDevice.has_value()) - { - GMX_LOG(logger.info) - .asParagraph() - .appendText("Using " + toUpperCase(torchDeviceType) + " for Metatomic."); - device = torch::Device(torch::kCUDA, activeDevice.value()); - } - else - { - GMX_LOG(logger.info).asParagraph().appendText("Using CPU for Metatomic."); - } - } - return device; -} - static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) { torch::Tensor pbcTensor = @@ -255,7 +178,7 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, { GMX_LOG(logger_.info).asParagraph().appendText("Initializing MetatomicForceProvider..."); - // Pairlist-based neighbor lists don't work with DD yet (indices are local) + // Pairlist-based neighbor lists don't work with domain decomposition yet (indices are local) // Matches NNPot's limitation if (mpiComm_.isParallel()) { @@ -264,8 +187,6 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, "Use thread-MPI (gmx mdrun) instead of MPI (mpirun gmx_mpi mdrun).")); } - data_->device = determineDevice(logger_, mpiComm_); - // Only main rank loads model if (mpiComm_.isMainRank()) { @@ -288,6 +209,21 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, data_->capabilities = data_->model.run_method("capabilities").toCustomClass(); + // Determine device using capabilities and optional environment variable + torch::optional desiredDevice = torch::nullopt; + if (const char* env = std::getenv("GMX_METATOMIC_DEVICE")) + { + desiredDevice = std::string(env); + } + + const auto deviceType = + metatomic_torch::pick_device(data_->capabilities->supported_devices, desiredDevice); + data_->device = torch::Device(deviceType); + + GMX_LOG(logger_.info) + .asParagraph() + .appendTextFormatted("Metatomic using device: %s", data_->device.str().c_str()); + auto requests_ivalue = data_->model.run_method("requested_neighbor_lists"); for (const auto& request_ivalue : requests_ivalue.toList()) { @@ -328,8 +264,8 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, data_->check_consistency = options_.params_.checkConsistency; } - const auto& mtaIndices = options_.params_.mtaIndices_; - const int32_t n_atoms = static_cast(mtaIndices.size()); + const auto& mtaIndices = options_.params_.mtaIndices_; + const int32_t n_atoms = static_cast(mtaIndices.size()); positions_.resize(n_atoms); atomNumbers_.resize(n_atoms, 0); @@ -345,8 +281,8 @@ MetatomicForceProvider::~MetatomicForceProvider() = default; void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal) { - const auto& mtaIndices = options_.params_.mtaIndices_; - const int32_t numInput = static_cast(mtaIndices.size()); + const auto& mtaIndices = options_.params_.mtaIndices_; + const int32_t numInput = static_cast(mtaIndices.size()); inputToLocalIndex_.assign(numInput, -1); inputToGlobalIndex_.assign(numInput, -1); @@ -356,7 +292,7 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist { GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), "Global atom indices required for DD."); - auto globalAtomIndices = signal.globalAtomIndices_.value(); + auto globalAtomIndices = signal.globalAtomIndices_.value(); const int32_t numLocal = signal.x_.size(); for (int32_t i = 0; i < static_cast(globalAtomIndices.size()); i++) @@ -383,8 +319,8 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist const auto* mtaAtoms = options_.params_.mtaAtoms_.get(); for (int32_t i = 0; i < numInput; i++) { - int32_t localIndex = mtaAtoms->localIndex()[i]; - int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; + int32_t localIndex = mtaAtoms->localIndex()[i]; + int32_t globalIdx = mtaAtoms->globalIndex()[mtaAtoms->collectiveIndex()[i]]; inputToLocalIndex_[i] = localIndex; inputToGlobalIndex_[i] = globalIdx; atomNumbers_[i] = options_.params_.atoms_.atom[globalIdx].atomnumber; @@ -512,7 +448,8 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F { auto neighbors = buildNeighborListFromPairlist( pairlistForModel_, shiftVectors_, positions_, data_->device, data_->dtype); - metatomic_torch::register_autograd_neighbors(system, neighbors, false); + // TODO: take from the user / model + metatomic_torch::register_autograd_neighbors(system, neighbors, /*check_consistency*/ true); system->add_neighbor_list(request, neighbors); } From 18797cc5561a19dba70ef889c6f70a6e608a70ea Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 15:48:17 +0100 Subject: [PATCH 13/17] chore(mtagro): handle cell shifts --- .../metatomic/metatomic_forceprovider.cpp | 14 +++++++++----- .../metatomic/metatomic_forceprovider.h | 3 +++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 0c3a888cf7..3c4b8560ff 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -98,6 +98,7 @@ static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, ArrayRef shiftVectors, + ArrayRef cellShifts, ArrayRef positions, torch::Device device, torch::ScalarType dtype) @@ -120,9 +121,9 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(atom_i); pair_samples_ptr[i][1] = static_cast(atom_j); - pair_samples_ptr[i][2] = 0; - pair_samples_ptr[i][3] = 0; - pair_samples_ptr[i][4] = 0; + pair_samples_ptr[i][2] = cellShifts[i][0]; + pair_samples_ptr[i][3] = cellShifts[i][1]; + pair_samples_ptr[i][4] = cellShifts[i][2]; // Calculate r_ij = r_j - r_i + shift double r_ij_x = positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]; @@ -291,7 +292,7 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist if (mpiComm_.isParallel()) { GMX_RELEASE_ASSERT(signal.globalAtomIndices_.has_value(), - "Global atom indices required for DD."); + "Global atom indices required for domain decomposition."); auto globalAtomIndices = signal.globalAtomIndices_.value(); const int32_t numLocal = signal.x_.size(); @@ -370,6 +371,8 @@ void MetatomicForceProvider::preparePairlistInput() pairlistForModel_.reserve(2 * numPairs); shiftVectors_.clear(); shiftVectors_.reserve(numPairs); + cellShifts_.clear(); + cellShifts_.reserve(numPairs); for (int32_t i = 0; i < numPairs; i++) { @@ -387,6 +390,7 @@ void MetatomicForceProvider::preparePairlistInput() pairlistForModel_.push_back(static_cast(inputIdxA.value())); pairlistForModel_.push_back(static_cast(inputIdxB.value())); shiftVectors_.push_back(shift); + cellShifts_.push_back(unitShift); } } @@ -447,7 +451,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F for (const auto& request : data_->nl_requests) { auto neighbors = buildNeighborListFromPairlist( - pairlistForModel_, shiftVectors_, positions_, data_->device, data_->dtype); + pairlistForModel_, shiftVectors_, cellShifts_, positions_, data_->device, data_->dtype); // TODO: take from the user / model metatomic_torch::register_autograd_neighbors(system, neighbors, /*check_consistency*/ true); system->add_neighbor_list(request, neighbors); diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index ee378e42a4..e3b897e0d2 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -118,6 +118,9 @@ class MetatomicForceProvider final : public IForceProvider //! Shift vectors for each atom pair in pairlistForModel_ std::vector shiftVectors_; + //! Cell shifts + std::vector cellShifts_; + //! local copy of simulation box matrix box_; From dc779628ac653fd1946b108edaa9a4f1c7c23c46 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 15:52:09 +0100 Subject: [PATCH 14/17] chore(gromta): conventionally pass device --- .../metatomic/metatomic_forceprovider.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 3c4b8560ff..83d8de3115 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -84,7 +84,7 @@ static std::optional indexOf(ArrayRef vec, const int32 static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) { torch::Tensor pbcTensor = - torch::tensor({ true, true, true }, torch::TensorOptions().dtype(torch::kBool)); + torch::tensor({ true, true, true }, torch::TensorOptions().dtype(torch::kBool)).device(device); if (*pbcType == PbcType::XY) { pbcTensor[2] = false; @@ -105,13 +105,16 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(pairlist.size() / 2); - auto pair_samples_values = torch::zeros({ n_pairs, 5 }, torch::TensorOptions().dtype(torch::kInt32)); + auto pair_samples_values = + torch::zeros({ n_pairs, 5 }, torch::TensorOptions().dtype(torch::kInt32)).device(device); auto pair_samples_ptr = pair_samples_values.accessor(); // Full interatomic vectors (rj - ri + shift), not just shifts. - auto pair_vectors = torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)); + auto pair_vectors = + torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)).device(device); - auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)); + auto vectors_cpu = + torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)).device(device); auto vectors_accessor = vectors_cpu.accessor(); for (int64_t i = 0; i < n_pairs; i++) @@ -409,7 +412,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // Force tensor - main rank fills, others have zeros torch::Tensor forceTensor = - torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); + torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64)).device(data_->device); // Virial tensor for pressure/stress calculations torch::Tensor virialTensor = torch::zeros({ 3, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); @@ -421,7 +424,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F { gromacs_scalar_type = torch::kFloat64; } - auto blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(torch::kCPU); + auto blob_options = torch::TensorOptions().dtype(gromacs_scalar_type).device(data_->device); auto torch_positions = torch::from_blob(positions_.data()->as_vec(), { n_atoms, 3 }, blob_options) .to(data_->dtype) From 5b22fbc3dac61ca71f1b29c6a6f4fc044f298ff0 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 15:56:47 +0100 Subject: [PATCH 15/17] chore(mtagro): add a bit of docs.. --- .../metatomic/metatomic_forceprovider.cpp | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index 83d8de3115..b6f49aae3b 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -283,6 +283,22 @@ MetatomicForceProvider::MetatomicForceProvider(const MetatomicOptions& options, MetatomicForceProvider::~MetatomicForceProvider() = default; +/*! \brief Updates the mapping between GROMACS local/global atom indices and the Metatomic model's input atoms. + * + * This function is subscribed to the `MDModulesAtomsRedistributedSignal`. It is called whenever + * atoms are redistributed across MPI ranks (Domain Decomposition) or reordered in memory (sorting). + * + * Its primary responsibilities are: + * 1. **Locate Input Atoms**: It iterates through the local atoms on the current rank to find + * which atoms correspond to the "input atoms" defined for the Metatomic model (via `mtaIndices`). + * 2. **Update Maps**: It populates `inputToLocalIndex_` (mapping model input index -> GROMACS local index) + * and `inputToGlobalIndex_` (mapping model input index -> GROMACS global tag). + * 3. **Gather Atomic Numbers**: It ensures `atomNumbers_` (Z numbers) are correctly associated with + * the current local atoms, performing an MPI reduction if necessary to gather data from + * distributed ranks. + * + * \param[in] signal Contains the new mapping of global atom indices to local buffer indices after redistribution. + */ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedistributedSignal& signal) { const auto& mtaIndices = options_.params_.mtaIndices_; @@ -335,6 +351,18 @@ void MetatomicForceProvider::gatherAtomNumbersIndices(const MDModulesAtomsRedist "Some atom numbers not set."); } +/*! \brief Updates the internal position buffer with the coordinates of the Metatomic atoms. + * + * This function extracts the coordinates of the atoms relevant to the Metatomic model + * from the full GROMACS local atom array (`pos`). + * + * 1. **Filtering:** `pos` contains all local atoms (solute, solvent, ions). + * This function copies only the atoms defined in `mtaIndices` to `positions_`. + * 2. **Ordering/Packing:** GROMACS reorders atoms dynamically for Domain Decomposition. + * This function uses `inputToLocalIndex_` to collect atoms and pack them into a contiguous, ordered buffer + * + * \param[in] pos The array of all local atom coordinates for the current step. + */ void MetatomicForceProvider::gatherAtomPositions(ArrayRef pos) { const size_t numInput = inputToLocalIndex_.size(); From 64317393d3a984d7a00548703e831501f33d9bd3 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 17:15:35 +0100 Subject: [PATCH 16/17] chore(pbc): rework and handle no pbc --- .../metatomic/metatomic_forceprovider.cpp | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index b6f49aae3b..e0330cc097 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -83,38 +83,40 @@ static std::optional indexOf(ArrayRef vec, const int32 static torch::Tensor preparePbcType(PbcType* pbcType, torch::Device device) { - torch::Tensor pbcTensor = - torch::tensor({ true, true, true }, torch::TensorOptions().dtype(torch::kBool)).device(device); + auto options = torch::TensorOptions().dtype(torch::kBool).device(device); + if (*pbcType == PbcType::XY) { - pbcTensor[2] = false; + return torch::tensor({ true, true, false }, options); + } + else if (*pbcType == PbcType::No) + { + return torch::tensor({ false, false, false }, options); } else if (*pbcType != PbcType::Xyz) { GMX_THROW(InconsistentInputError("PBC type not supported.")); } - return pbcTensor.to(device); + return torch::tensor({ true, true, true }, options); } static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef pairlist, ArrayRef shiftVectors, - ArrayRef cellShifts, + ArrayRef cellShifts, ArrayRef positions, torch::Device device, torch::ScalarType dtype) { const int64_t n_pairs = static_cast(pairlist.size() / 2); - auto pair_samples_values = - torch::zeros({ n_pairs, 5 }, torch::TensorOptions().dtype(torch::kInt32)).device(device); - auto pair_samples_ptr = pair_samples_values.accessor(); + auto cpu_int_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + auto cpu_float_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); - // Full interatomic vectors (rj - ri + shift), not just shifts. - auto pair_vectors = - torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)).device(device); + auto pair_samples_values = torch::zeros({ n_pairs, 5 }, cpu_int_options); + auto pair_samples_ptr = pair_samples_values.accessor(); - auto vectors_cpu = - torch::zeros({ n_pairs, 3, 1 }, torch::TensorOptions().dtype(torch::kFloat64)).device(device); + // Full interatomic vectors (rj - ri + shift) + auto vectors_cpu = torch::zeros({ n_pairs, 3, 1 }, cpu_float_options); auto vectors_accessor = vectors_cpu.accessor(); for (int64_t i = 0; i < n_pairs; i++) From b4468ad2302d445b1e49de991a721c34f5f347ad Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 2 Feb 2026 17:17:44 +0100 Subject: [PATCH 17/17] chore(gromta): correctly identify shifts .. as integers --- .../metatomic/metatomic_forceprovider.cpp | 19 +++++++++++++------ .../metatomic/metatomic_forceprovider.h | 2 +- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp index e0330cc097..5fcc5a909d 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.cpp @@ -124,6 +124,7 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(atom_i); pair_samples_ptr[i][1] = static_cast(atom_j); pair_samples_ptr[i][2] = cellShifts[i][0]; @@ -131,19 +132,25 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef(positions[atom_j][0] - positions[atom_i][0] + shiftVectors[i][0]); + double r_ij_y = + static_cast(positions[atom_j][1] - positions[atom_i][1] + shiftVectors[i][1]); + double r_ij_z = + static_cast(positions[atom_j][2] - positions[atom_i][2] + shiftVectors[i][2]); vectors_accessor[i][0][0] = r_ij_x; vectors_accessor[i][1][0] = r_ij_y; vectors_accessor[i][2][0] = r_ij_z; } + auto final_samples_values = pair_samples_values.to(device); + auto final_vectors = vectors_cpu.to(dtype).to(device); + auto neighbor_samples = torch::make_intrusive( std::vector{ "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c" }, - pair_samples_values.to(device)); + final_samples_values); auto neighbor_component = torch::make_intrusive( std::vector{ "xyz" }, @@ -155,7 +162,7 @@ static metatensor_torch::TensorBlock buildNeighborListFromPairlist(ArrayRef( - vectors_cpu.to(dtype).to(device), + final_vectors, neighbor_samples, std::vector{ neighbor_component }, neighbor_properties); @@ -442,7 +449,7 @@ void MetatomicForceProvider::calculateForces(const ForceProviderInput& inputs, F // Force tensor - main rank fills, others have zeros torch::Tensor forceTensor = - torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64)).device(data_->device); + torch::zeros({ n_atoms, 3 }, torch::TensorOptions().dtype(torch::kFloat64).device(data_->device)); // Virial tensor for pressure/stress calculations torch::Tensor virialTensor = torch::zeros({ 3, 3 }, torch::TensorOptions().dtype(torch::kFloat64)); diff --git a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h index e3b897e0d2..c26ce7b187 100644 --- a/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h +++ b/src/gromacs/applied_forces/metatomic/metatomic_forceprovider.h @@ -119,7 +119,7 @@ class MetatomicForceProvider final : public IForceProvider std::vector shiftVectors_; //! Cell shifts - std::vector cellShifts_; + std::vector cellShifts_; //! local copy of simulation box matrix box_;